model#
- class pointrix.model.base_model.BaseModel(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseModuleBase class for all models.
- Parameters:
cfg (Optional[Union[dict, DictConfig]]) – The configuration dictionary.
datapipeline (BaseDataPipeline) – The data pipeline which is used to initialize the point cloud.
device (str, optional) – The device to use, by default “cuda”.
- class Config(camera_model: dict = <factory>, point_cloud: dict = <factory>, renderer: dict = <factory>, lambda_ssim: float = 0.2)#
Bases:
object
- forward(batch=None, training=True, render=True, iteration=None) dict#
Forward pass of the model.
- Parameters:
batch (dict) – The batch of data.
- Returns:
The render results which will be the input of renderers.
- Return type:
dict
- get_loss_dict(render_results, batch) dict#
Get the loss dictionary.
- Parameters:
render_results (dict) – The render results which is the output of the renderer.
batch (dict) – The batch of data which contains the ground truth images.
- Returns:
The loss dictionary which contain loss for backpropagation.
- Return type:
dict
- get_metric_dict(render_results, batch) dict#
Get the metric dictionary.
- Parameters:
render_results (dict) – The render results which is the output of the renderer.
batch (dict) – The batch of data which contains the ground truth images.
- Returns:
The metric dictionary which contains the metrics for evaluation.
- Return type:
dict
- get_optimizer_dict(loss_dict, render_results, white_bg) dict#
Get the optimizer dictionary which will be the input of the optimizer update model
- Parameters:
loss_dict (dict) – The loss dictionary.
render_results (dict) – The render results which is the output of the renderer.
white_bg (bool) – The white background flag.
- load_ply(path)#
Load the ply model for point cloud.
- Parameters:
path (str) – The path of the ply file.
- load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)#
Copies parameters and buffers from
state_dictinto this module and its descendants. IfstrictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dict.- Parameters:
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When
False, the properties of the tensors in the current module are preserved while whenTrue, the properties of the Tensors in the state dict are preserved. Default:False
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- class pointrix.model.camera.camera_model.CameraModel(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseObjectCamera class used in Pointrix
- class Config(name: str = 'CameraModel', enable_training: bool = False, scene_scale: float = 1.0)#
Bases:
object- Parameters:
enable_training (bool) – Whether the camera is trainable
scene_scale (float) – The scale of the scene
- camera_centers(idx_list) Float[Tensor, 'C 3']#
Get the camera center from the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
_camera_center
- Return type:
Float[Tensor, “C 3”]
Notes
property of the camera class
- extrinsic_matrices(idx_list) Float[Tensor, 'C 4 4']#
Get the extrinsic matrix from the cameras.
- Parameters:
idx (int) – The index of the camera.
- Returns:
_extrinsic_matrix
- Return type:
Float[Tensor, “C 4 4”]
Notes
property of the camera class
- property image_height: int#
Get the image height from the cameras.
- Returns:
height – The image height.
- Return type:
int
Notes
property of the camera class
- property image_width: int#
Get the image width from the cameras.
- Returns:
width – The image width.
- Return type:
int
Notes
property of the camera class
- intrinsic_params(idx_list) Float[Tensor, 'C 4']#
Get the intrinsics matrix of the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
intrinsic_params
- Return type:
Float[Tensor, “4”]
Notes
property of the camera class
- rotation_matrices(idx_list) Float[Tensor, 'C 3 3']#
Get the rotation matrix of the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
rotation_matrix
- Return type:
Float[Tensor, “C 3 3”]
- setup(camerasprior: CamerasPrior, device='cuda') None#
Setup the camera class
- Parameters:
camerasprior (CamerasPrior) – The camera priors
- translation_vectors(idx_list) Float[Tensor, 'C 3']#
Get the translation vector from the cameras.
- Parameters:
idx_list (int) – The index list of the camera.
- Returns:
_translation_vector
- Return type:
Float[Tensor, “C 3”]
- class pointrix.model.point_cloud.gaussian_points.GaussianPointCloud(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
PointCloudA class for Gaussian point cloud.
- Parameters:
PointCloud (PointCloud) – The point cloud for initialisation.
- class Config(point_cloud_type: str = '', initializer: dict = <factory>, trainable: bool = True, unwarp_prefix: str = 'point_cloud', max_sh_degree: int = 3, lambda_dssim: float = 0.2)#
Bases:
Config
- re_init(num_points)#
re-initialize the point cloud.
- setup(point_cloud=None)#
The function for setting up the point cloud.
- Parameters:
point_cloud (PointCloud) – The point cloud for initialisation.
- class pointrix.model.point_cloud.points.PointCloud(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseModule- class Config(point_cloud_type: str = '', initializer: dict = <factory>, trainable: bool = True, unwarp_prefix: str = 'point_cloud')#
Bases:
object
- extand_points(new_atributes: dict, optimizer: Optimizer | None = None) None#
extand atribute of the point cloud with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- extend_optimizer(new_atributes: dict, optimizer: Optimizer) dict#
extend the point cloud in optimizer with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- Returns:
new_tensors
- Return type:
dict
- get_all_atributes() list#
return all atribute of the point cloud.
- Returns:
atributes – The list of all atributes of the point cloud.
- Return type:
list
- list_of_attributes() list#
return the list of all attributes of the point cloud for ply saving.
- load_ply(path: Path)#
load the point cloud from ply file.
- Parameters:
path (Path) – The path of the ply file.
- prune_optimizer(mask: Tensor, optimizer: Optimizer | None) None#
prune the point cloud in optimizer with mask.
- Parameters:
mask (Tensor) – The mask for removing the points.
optimizer (Optimizer) – The optimizer for the point cloud.
- re_init(num_points) None#
re-initialize the point cloud.
- register_atribute(name: str, value: Float[Tensor, '3 1'], trainable=True) None#
register trainable atribute of the point cloud.
- Parameters:
name (str) – The name of the atribute.
value (Tensor) – The value of the atribute.
trainable (bool) – Whether the atribute is trainable.
Examples
>>> point_cloud = PointsCloud(cfg) >>> point_cloud.register_atribute('position', position) >>> point_cloud.register_atribute('rgb', rgb)
- remove_points(mask: Tensor, optimizer: Optimizer | None = None) None#
remove points of the point cloud with mask.
- Parameters:
mask (Tensor) – The mask for removing the points.
- replace(new_atributes: dict, optimizer: Optimizer | None = None) None#
replace atribute of the point cloud with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- replace_optimizer(new_atributes: dict, optimizer: Optimizer) dict#
replace the point cloud in optimizer with new atribute.
- Parameters:
new_atributes (dict) – The dict of new atributes.
optimizer (Optimizer) – The optimizer for the point cloud.
- save_ply(path: Path) None#
save the point cloud to ply file.
- Parameters:
path (Path) – The path of the ply file.
- select_atributes(mask: Tensor) dict#
select atribute of the point cloud by input mask.
- Parameters:
mask (Tensor) – The mask for selecting the atributes.
- Returns:
selected_atributes – The dict of selected atributes.
- Return type:
dict
- set_all_atributes_trainable() None#
set all atributes of the point cloud trainable.
- set_prefix_name(name: str) None#
set the prefix name to distinguish different point cloud.
- Parameters:
name (str) – The prefix name.
- setup(point_cloud: dict | None = None) None#
The function for setting up the point cloud.
- Parameters:
point_cloud (PointCloud) – The point cloud for initialisation.
- unwarp(name) str#
remove the prefix name of the atribute.
- Parameters:
name (str) – The name of the atribute.
- Returns:
name – The name of the atribute without prefix name.
- Return type:
str
- class pointrix.model.renderer.base_splatting.GaussianSplattingRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
MsplatRenderA class for rendering point clouds using Gaussian splatting.
- render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict#
Render the point cloud for one iteration
- Parameters:
height (int) – The height of the image.
width (int) – The width of the image.
extrinsic_matrix (torch.Tensor) – The extrinsic matrix.
intrinsic_params (list) – The intrinsic parameters.
camera_center (torch.Tensor) – The camera center.
position (torch.Tensor) – The position of the point cloud.
opacity (torch.Tensor) – The opacity of the point cloud.
scaling (torch.Tensor) – The scaling of the point cloud.
rotation (torch.Tensor) – The rotation of the point cloud.
shs (torch.Tensor) – The spherical harmonics.
- Returns:
The rendered point cloud.
- Return type:
dict
- setup(white_bg, device, **kwargs)#
Setup the renderer.
- Parameters:
white_bg (bool) – Whether the background is white.
device (str) – The device used in the pipeline.
- class pointrix.model.renderer.msplat.MsplatRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
BaseObjectA class for rendering point clouds using msplat.
- class Config(update_sh_iter: int = 1000, max_sh_degree: int = 3, render_depth: bool = False)#
Bases:
object- Parameters:
update_sh_iter (int, optional) – The iteration to update the spherical harmonics degree, by default 1000.
max_sh_degree (int, optional) – The maximum spherical harmonics degree, by default 3.
render_depth (bool, optional) – Whether to render the depth or not, by default False.
- load_state_dict(state_dict)#
Load the state dictionary of render.
- Parameters:
state_dict (dict) – The state dictionary
- render_batch(render_dict: dict, batch: List[dict]) dict#
Render the batch of point clouds.
- Parameters:
render_dict (dict) – The render dictionary.
batch (List[dict]) – The batch data.
- Returns:
The rendered image, the viewspace points, the visibility filter, the radii, the xyz, the color, the rotation, the scales, and the xy.
- Return type:
dict
- render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict#
Render the point cloud for one iteration
- Parameters:
height (int) – The height of the image.
width (int) – The width of the image.
extrinsic_matrix (torch.Tensor) – The extrinsic matrix.
intrinsic_params (list) – The intrinsic parameters.
camera_center (torch.Tensor) – The camera center.
position (torch.Tensor) – The position of the point cloud.
opacity (torch.Tensor) – The opacity of the point cloud.
scaling (torch.Tensor) – The scaling of the point cloud.
rotation (torch.Tensor) – The rotation of the point cloud.
shs (torch.Tensor) – The spherical harmonics.
- Returns:
The rendered point cloud.
- Return type:
dict
- setup(white_bg, device, **kwargs)#
Setup the renderer.
- Parameters:
white_bg (bool) – Whether the background is white.
device (str) – The device used in the pipeline.
- state_dict()#
Get the state dictionary of render.
- Returns:
The state dictionary
- Return type:
dict
- update_sh_degree(step)#
Update the spherical harmonics degree in render
- Parameters:
step (int) – The current training step.
- class pointrix.model.renderer.gsplat.GsplatRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#
Bases:
MsplatRenderA class for rendering point clouds using gsplat
- render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict#
Render the point cloud for one iteration
- Parameters:
height (int) – The height of the image.
width (int) – The width of the image.
extrinsic_matrix (torch.Tensor) – The extrinsic matrix.
intrinsic_params (torch.Tensor) – The intrinsic parameters.
camera_center (torch.Tensor) – The camera center.
position (torch.Tensor) – The position of the point cloud.
opacity (torch.Tensor) – The opacity of the point cloud.
scaling (torch.Tensor) – The scaling of the point cloud.
rotation (torch.Tensor) – The rotation of the point cloud.
shs (torch.Tensor) – The spherical harmonics.