model#
- class pointrix.model.base_model.BaseModel(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseModule
Base class for all models.
- Parameters:
cfg (Optional[Union[dict, DictConfig]]) – The configuration dictionary.
datapipeline (BaseDataPipeline) – The data pipeline which is used to initialize the point cloud.
device (str, optional) – The device to use, by default “cuda”.
- class Config(camera_model: dict = <factory>, point_cloud: dict = <factory>, renderer: dict = <factory>, lambda_ssim: float = 0.2)#
Bases:
object
- forward(batch=None, training=True, render=True, iteration=None) dict #
Forward pass of the model.
- Parameters:
batch (dict) – The batch of data.
- Returns:
The render results which will be the input of renderers.
- Return type:
dict
- get_loss_dict(render_results, batch) dict #
Get the loss dictionary.
- Parameters:
render_results (dict) – The render results which is the output of the renderer.
batch (dict) – The batch of data which contains the ground truth images.
- Returns:
The loss dictionary which contain loss for backpropagation.
- Return type:
dict
- get_metric_dict(render_results, batch) dict #
Get the metric dictionary.
- Parameters:
render_results (dict) – The render results which is the output of the renderer.
batch (dict) – The batch of data which contains the ground truth images.
- Returns:
The metric dictionary which contains the metrics for evaluation.
- Return type:
dict
- get_optimizer_dict(loss_dict, render_results, white_bg) dict #
Get the optimizer dictionary which will be the input of the optimizer update model
- Parameters:
loss_dict (dict) – The loss dictionary.
render_results (dict) – The render results which is the output of the renderer.
white_bg (bool) – The white background flag.
- load_ply(path)#
Load the ply model for point cloud.
- Parameters:
path (str) – The path of the ply file.
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)#
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Warning
If
assign
isTrue
the optimizer must be created after the call toload_state_dict
.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When
False
, the properties of the tensors in the current module are preserved while whenTrue
, the properties of the Tensors in the state dict are preserved. Default:False
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
Note
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
- class pointrix.model.camera.camera_model.CameraModel(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseObject
Camera class used in Pointrix
- class Config(name: str = 'CameraModel', enable_training: bool = False, scene_scale: float = 1.0)#
Bases:
object
- Parameters:
enable_training (bool) – Whether the camera is trainable
scene_scale (float) – The scale of the scene
- camera_centers(idx_list) Float[Tensor, 'C 3'] #
Get the camera center from the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
_camera_center
- Return type:
Float[Tensor, “C 3”]
Notes
property of the camera class
- extrinsic_matrices(idx_list) Float[Tensor, 'C 4 4'] #
Get the extrinsic matrix from the cameras.
- Parameters:
idx (int) – The index of the camera.
- Returns:
_extrinsic_matrix
- Return type:
Float[Tensor, “C 4 4”]
Notes
property of the camera class
- property image_height: int#
Get the image height from the cameras.
- Returns:
height – The image height.
- Return type:
int
Notes
property of the camera class
- property image_width: int#
Get the image width from the cameras.
- Returns:
width – The image width.
- Return type:
int
Notes
property of the camera class
- intrinsic_params(idx_list) Float[Tensor, 'C 4'] #
Get the intrinsics matrix of the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
intrinsic_params
- Return type:
Float[Tensor, “4”]
Notes
property of the camera class
- rotation_matrices(idx_list) Float[Tensor, 'C 3 3'] #
Get the rotation matrix of the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
rotation_matrix
- Return type:
Float[Tensor, “C 3 3”]
- setup(camerasprior: CamerasPrior, device='cuda') None #
Setup the camera class
- Parameters:
camerasprior (CamerasPrior) – The camera priors
- translation_vectors(idx_list) Float[Tensor, 'C 3'] #
Get the translation vector from the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
_translation_vector
- Return type:
Float[Tensor, “C 3”]
- class pointrix.model.point_cloud.gaussian_points.GaussianPointCloud(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
PointCloud
A class for Gaussian point cloud.
- Parameters:
PointCloud (PointCloud) – The point cloud for initialisation.
- class Config(point_cloud_type: str = '', initializer: dict = <factory>, trainable: bool = True, unwarp_prefix: str = 'point_cloud', max_sh_degree: int = 3, lambda_dssim: float = 0.2)#
Bases:
Config
- re_init(num_points)#
re-initialize the point cloud.
- setup(point_cloud=None)#
The function for setting up the point cloud.
- Parameters:
point_cloud (PointCloud) – The point cloud for initialisation.
- class pointrix.model.point_cloud.points.PointCloud(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseModule
- class Config(point_cloud_type: str = '', initializer: dict = <factory>, trainable: bool = True, unwarp_prefix: str = 'point_cloud')#
Bases:
object
- extand_points(new_atributes: dict, optimizer: Optimizer | None = None) None #
extand atribute of the point cloud with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- extend_optimizer(new_atributes: dict, optimizer: Optimizer) dict #
extend the point cloud in optimizer with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- Returns:
new_tensors
- Return type:
dict
- get_all_atributes() list #
return all atribute of the point cloud.
- Returns:
atributes – The list of all atributes of the point cloud.
- Return type:
list
- list_of_attributes() list #
return the list of all attributes of the point cloud for ply saving.
- load_ply(path: Path)#
load the point cloud from ply file.
- Parameters:
path (Path) – The path of the ply file.
- prune_optimizer(mask: Tensor, optimizer: Optimizer | None) None #
prune the point cloud in optimizer with mask.
- Parameters:
mask (Tensor) – The mask for removing the points.
optimizer (Optimizer) – The optimizer for the point cloud.
- re_init(num_points) None #
re-initialize the point cloud.
- register_atribute(name: str, value: Float[Tensor, '3 1'], trainable=True) None #
register trainable atribute of the point cloud.
- Parameters:
name (str) – The name of the atribute.
value (Tensor) – The value of the atribute.
trainable (bool) – Whether the atribute is trainable.
Examples
>>> point_cloud = PointsCloud(cfg) >>> point_cloud.register_atribute('position', position) >>> point_cloud.register_atribute('rgb', rgb)
- remove_points(mask: Tensor, optimizer: Optimizer | None = None) None #
remove points of the point cloud with mask.
- Parameters:
mask (Tensor) – The mask for removing the points.
- replace(new_atributes: dict, optimizer: Optimizer | None = None) None #
replace atribute of the point cloud with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- replace_optimizer(new_atributes: dict, optimizer: Optimizer) dict #
replace the point cloud in optimizer with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- save_ply(path: Path) None #
save the point cloud to ply file.
- Parameters:
path (Path) – The path of the ply file.
- select_atributes(mask: Tensor) dict #
select atribute of the point cloud by input mask.
- Parameters:
mask (Tensor) – The mask for selecting the atributes.
- Returns:
selected_atributes – The dict of selected atributes.
- Return type:
dict
- set_all_atributes_trainable() None #
set all atributes of the point cloud trainable.
- set_prefix_name(name: str) None #
set the prefix name to distinguish different point cloud.
- Parameters:
name (str) – The prefix name.
- setup(point_cloud: dict | None = None) None #
The function for setting up the point cloud.
- Parameters:
point_cloud (PointCloud) – The point cloud for initialisation.
- unwarp(name) str #
remove the prefix name of the atribute.
- Parameters:
name (str) – The name of the atribute.
- Returns:
name – The name of the atribute without prefix name.
- Return type:
str
- class pointrix.model.renderer.base_splatting.GaussianSplattingRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
MsplatRender
A class for rendering point clouds using Gaussian splatting.
- render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict #
Render the point cloud for one iteration
- Parameters:
height (int) – The height of the image.
width (int) – The width of the image.
extrinsic_matrix (torch.Tensor) – The extrinsic matrix.
intrinsic_params (list) – The intrinsic parameters.
camera_center (torch.Tensor) – The camera center.
position (torch.Tensor) – The position of the point cloud.
opacity (torch.Tensor) – The opacity of the point cloud.
scaling (torch.Tensor) – The scaling of the point cloud.
rotation (torch.Tensor) – The rotation of the point cloud.
shs (torch.Tensor) – The spherical harmonics.
- Returns:
The rendered point cloud.
- Return type:
dict
- setup(white_bg, device, **kwargs)#
Setup the renderer.
- Parameters:
white_bg (bool) – Whether the background is white.
device (str) – The device used in the pipeline.
- class pointrix.model.renderer.msplat.MsplatRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseObject
A class for rendering point clouds using msplat.
- class Config(update_sh_iter: int = 1000, max_sh_degree: int = 3, render_depth: bool = False)#
Bases:
object
- Parameters:
update_sh_iter (int, optional) – The iteration to update the spherical harmonics degree, by default 1000.
max_sh_degree (int, optional) – The maximum spherical harmonics degree, by default 3.
render_depth (bool, optional) – Whether to render the depth or not, by default False.
- load_state_dict(state_dict)#
Load the state dictionary of render.
- Parameters:
state_dict (dict) – The state dictionary
- render_batch(render_dict: dict, batch: List[dict]) dict #
Render the batch of point clouds.
- Parameters:
render_dict (dict) – The render dictionary.
batch (List[dict]) – The batch data.
- Returns:
The rendered image, the viewspace points, the visibility filter, the radii, the xyz, the color, the rotation, the scales, and the xy.
- Return type:
dict
- render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict #
Render the point cloud for one iteration
- Parameters:
height (int) – The height of the image.
width (int) – The width of the image.
extrinsic_matrix (torch.Tensor) – The extrinsic matrix.
intrinsic_params (list) – The intrinsic parameters.
camera_center (torch.Tensor) – The camera center.
position (torch.Tensor) – The position of the point cloud.
opacity (torch.Tensor) – The opacity of the point cloud.
scaling (torch.Tensor) – The scaling of the point cloud.
rotation (torch.Tensor) – The rotation of the point cloud.
shs (torch.Tensor) – The spherical harmonics.
- Returns:
The rendered point cloud.
- Return type:
dict
- setup(white_bg, device, **kwargs)#
Setup the renderer.
- Parameters:
white_bg (bool) – Whether the background is white.
device (str) – The device used in the pipeline.
- state_dict()#
Get the state dictionary of render.
- Returns:
The state dictionary
- Return type:
dict
- update_sh_degree(step)#
Update the spherical harmonics degree in render
- Parameters:
step (int) – The current training step.
- class pointrix.model.renderer.gsplat.GsplatRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
MsplatRender
A class for rendering point clouds using gsplat
- render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict #
Render the point cloud for one iteration
- Parameters:
height (int) – The height of the image.
width (int) – The width of the image.
extrinsic_matrix (torch.Tensor) – The extrinsic matrix.
intrinsic_params (torch.Tensor) – The intrinsic parameters.
camera_center (torch.Tensor) – The camera center.
position (torch.Tensor) – The position of the point cloud.
opacity (torch.Tensor) – The opacity of the point cloud.
scaling (torch.Tensor) – The scaling of the point cloud.
rotation (torch.Tensor) – The rotation of the point cloud.
shs (torch.Tensor) – The spherical harmonics.