model#

class pointrix.model.base_model.BaseModel(cfg: dict | DictConfig | None = None, *args, **kwargs)#

Bases: BaseModule

Base class for all models.

Parameters:
  • cfg (Optional[Union[dict, DictConfig]]) – The configuration dictionary.

  • datapipeline (BaseDataPipeline) – The data pipeline which is used to initialize the point cloud.

  • device (str, optional) – The device to use, by default “cuda”.

class Config(camera_model: dict = <factory>, point_cloud: dict = <factory>, renderer: dict = <factory>, lambda_ssim: float = 0.2)#

Bases: object

forward(batch=None, training=True, render=True, iteration=None) dict#

Forward pass of the model.

Parameters:

batch (dict) – The batch of data.

Returns:

The render results which will be the input of renderers.

Return type:

dict

get_loss_dict(render_results, batch) dict#

Get the loss dictionary.

Parameters:
  • render_results (dict) – The render results which is the output of the renderer.

  • batch (dict) – The batch of data which contains the ground truth images.

Returns:

The loss dictionary which contain loss for backpropagation.

Return type:

dict

get_metric_dict(render_results, batch) dict#

Get the metric dictionary.

Parameters:
  • render_results (dict) – The render results which is the output of the renderer.

  • batch (dict) – The batch of data which contains the ground truth images.

Returns:

The metric dictionary which contains the metrics for evaluation.

Return type:

dict

get_optimizer_dict(loss_dict, render_results, white_bg) dict#

Get the optimizer dictionary which will be the input of the optimizer update model

Parameters:
  • loss_dict (dict) – The loss dictionary.

  • render_results (dict) – The render results which is the output of the renderer.

  • white_bg (bool) – The white background flag.

load_ply(path)#

Load the ply model for point cloud.

Parameters:

path (str) – The path of the ply file.

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)#

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

Note

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

class pointrix.model.camera.camera_model.CameraModel(cfg: dict | DictConfig | None = None, *args, **kwargs)#

Bases: BaseObject

Camera class used in Pointrix

class Config(name: str = 'CameraModel', enable_training: bool = False, scene_scale: float = 1.0)#

Bases: object

Parameters:
  • enable_training (bool) – Whether the camera is trainable

  • scene_scale (float) – The scale of the scene

camera_centers(idx_list) Float[Tensor, 'C 3']#

Get the camera center from the cameras.

Parameters:

idx_list (int) – The index list of the camera.

Returns:

_camera_center

Return type:

Float[Tensor, “C 3”]

Notes

property of the camera class

extrinsic_matrices(idx_list) Float[Tensor, 'C 4 4']#

Get the extrinsic matrix from the cameras.

Parameters:

idx (int) – The index of the camera.

Returns:

_extrinsic_matrix

Return type:

Float[Tensor, “C 4 4”]

Notes

property of the camera class

property image_height: int#

Get the image height from the cameras.

Returns:

height – The image height.

Return type:

int

Notes

property of the camera class

property image_width: int#

Get the image width from the cameras.

Returns:

width – The image width.

Return type:

int

Notes

property of the camera class

intrinsic_params(idx_list) Float[Tensor, 'C 4']#

Get the intrinsics matrix of the cameras.

Parameters:

idx_list (int) – The index list of the camera.

Returns:

intrinsic_params

Return type:

Float[Tensor, “4”]

Notes

property of the camera class

rotation_matrices(idx_list) Float[Tensor, 'C 3 3']#

Get the rotation matrix of the cameras.

Parameters:

idx_list (int) – The index list of the camera.

Returns:

rotation_matrix

Return type:

Float[Tensor, “C 3 3”]

setup(camerasprior: CamerasPrior, device='cuda') None#

Setup the camera class

Parameters:

camerasprior (CamerasPrior) – The camera priors

translation_vectors(idx_list) Float[Tensor, 'C 3']#

Get the translation vector from the cameras.

Parameters:

idx_list (int) – The index list of the camera.

Returns:

_translation_vector

Return type:

Float[Tensor, “C 3”]

class pointrix.model.point_cloud.gaussian_points.GaussianPointCloud(cfg: dict | DictConfig | None = None, *args, **kwargs)#

Bases: PointCloud

A class for Gaussian point cloud.

Parameters:

PointCloud (PointCloud) – The point cloud for initialisation.

class Config(point_cloud_type: str = '', initializer: dict = <factory>, trainable: bool = True, unwarp_prefix: str = 'point_cloud', max_sh_degree: int = 3, lambda_dssim: float = 0.2)#

Bases: Config

re_init(num_points)#

re-initialize the point cloud.

setup(point_cloud=None)#

The function for setting up the point cloud.

Parameters:

point_cloud (PointCloud) – The point cloud for initialisation.

class pointrix.model.point_cloud.points.PointCloud(cfg: dict | DictConfig | None = None, *args, **kwargs)#

Bases: BaseModule

class Config(point_cloud_type: str = '', initializer: dict = <factory>, trainable: bool = True, unwarp_prefix: str = 'point_cloud')#

Bases: object

extand_points(new_atributes: dict, optimizer: Optimizer | None = None) None#

extand atribute of the point cloud with new atribute.

Parameters:
  • new_atributes (dict) – The dict of new atributes.

  • optimizer (Optimizer) – The optimizer for the point cloud.

extend_optimizer(new_atributes: dict, optimizer: Optimizer) dict#

extend the point cloud in optimizer with new atribute.

Parameters:
  • new_atributes (dict) – The dict of new atributes.

  • optimizer (Optimizer) – The optimizer for the point cloud.

Returns:

new_tensors

Return type:

dict

get_all_atributes() list#

return all atribute of the point cloud.

Returns:

atributes – The list of all atributes of the point cloud.

Return type:

list

list_of_attributes() list#

return the list of all attributes of the point cloud for ply saving.

load_ply(path: Path)#

load the point cloud from ply file.

Parameters:

path (Path) – The path of the ply file.

prune_optimizer(mask: Tensor, optimizer: Optimizer | None) None#

prune the point cloud in optimizer with mask.

Parameters:
  • mask (Tensor) – The mask for removing the points.

  • optimizer (Optimizer) – The optimizer for the point cloud.

re_init(num_points) None#

re-initialize the point cloud.

register_atribute(name: str, value: Float[Tensor, '3 1'], trainable=True) None#

register trainable atribute of the point cloud.

Parameters:
  • name (str) – The name of the atribute.

  • value (Tensor) – The value of the atribute.

  • trainable (bool) – Whether the atribute is trainable.

Examples

>>> point_cloud = PointsCloud(cfg)
>>> point_cloud.register_atribute('position', position)
>>> point_cloud.register_atribute('rgb', rgb)
remove_points(mask: Tensor, optimizer: Optimizer | None = None) None#

remove points of the point cloud with mask.

Parameters:

mask (Tensor) – The mask for removing the points.

replace(new_atributes: dict, optimizer: Optimizer | None = None) None#

replace atribute of the point cloud with new atribute.

Parameters:
  • new_atributes (dict) – The dict of new atributes.

  • optimizer (Optimizer) – The optimizer for the point cloud.

replace_optimizer(new_atributes: dict, optimizer: Optimizer) dict#

replace the point cloud in optimizer with new atribute.

Parameters:
  • new_atributes (dict) – The dict of new atributes.

  • optimizer (Optimizer) – The optimizer for the point cloud.

save_ply(path: Path) None#

save the point cloud to ply file.

Parameters:

path (Path) – The path of the ply file.

select_atributes(mask: Tensor) dict#

select atribute of the point cloud by input mask.

Parameters:

mask (Tensor) – The mask for selecting the atributes.

Returns:

selected_atributes – The dict of selected atributes.

Return type:

dict

set_all_atributes_trainable() None#

set all atributes of the point cloud trainable.

set_prefix_name(name: str) None#

set the prefix name to distinguish different point cloud.

Parameters:

name (str) – The prefix name.

setup(point_cloud: dict | None = None) None#

The function for setting up the point cloud.

Parameters:

point_cloud (PointCloud) – The point cloud for initialisation.

unwarp(name) str#

remove the prefix name of the atribute.

Parameters:

name (str) – The name of the atribute.

Returns:

name – The name of the atribute without prefix name.

Return type:

str

class pointrix.model.renderer.base_splatting.GaussianSplattingRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#

Bases: MsplatRender

A class for rendering point clouds using Gaussian splatting.

render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict#

Render the point cloud for one iteration

Parameters:
  • height (int) – The height of the image.

  • width (int) – The width of the image.

  • extrinsic_matrix (torch.Tensor) – The extrinsic matrix.

  • intrinsic_params (list) – The intrinsic parameters.

  • camera_center (torch.Tensor) – The camera center.

  • position (torch.Tensor) – The position of the point cloud.

  • opacity (torch.Tensor) – The opacity of the point cloud.

  • scaling (torch.Tensor) – The scaling of the point cloud.

  • rotation (torch.Tensor) – The rotation of the point cloud.

  • shs (torch.Tensor) – The spherical harmonics.

Returns:

The rendered point cloud.

Return type:

dict

setup(white_bg, device, **kwargs)#

Setup the renderer.

Parameters:
  • white_bg (bool) – Whether the background is white.

  • device (str) – The device used in the pipeline.

class pointrix.model.renderer.msplat.MsplatRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#

Bases: BaseObject

A class for rendering point clouds using msplat.

class Config(update_sh_iter: int = 1000, max_sh_degree: int = 3, render_depth: bool = False)#

Bases: object

Parameters:
  • update_sh_iter (int, optional) – The iteration to update the spherical harmonics degree, by default 1000.

  • max_sh_degree (int, optional) – The maximum spherical harmonics degree, by default 3.

  • render_depth (bool, optional) – Whether to render the depth or not, by default False.

load_state_dict(state_dict)#

Load the state dictionary of render.

Parameters:

state_dict (dict) – The state dictionary

render_batch(render_dict: dict, batch: List[dict]) dict#

Render the batch of point clouds.

Parameters:
  • render_dict (dict) – The render dictionary.

  • batch (List[dict]) – The batch data.

Returns:

The rendered image, the viewspace points, the visibility filter, the radii, the xyz, the color, the rotation, the scales, and the xy.

Return type:

dict

render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict#

Render the point cloud for one iteration

Parameters:
  • height (int) – The height of the image.

  • width (int) – The width of the image.

  • extrinsic_matrix (torch.Tensor) – The extrinsic matrix.

  • intrinsic_params (list) – The intrinsic parameters.

  • camera_center (torch.Tensor) – The camera center.

  • position (torch.Tensor) – The position of the point cloud.

  • opacity (torch.Tensor) – The opacity of the point cloud.

  • scaling (torch.Tensor) – The scaling of the point cloud.

  • rotation (torch.Tensor) – The rotation of the point cloud.

  • shs (torch.Tensor) – The spherical harmonics.

Returns:

The rendered point cloud.

Return type:

dict

setup(white_bg, device, **kwargs)#

Setup the renderer.

Parameters:
  • white_bg (bool) – Whether the background is white.

  • device (str) – The device used in the pipeline.

state_dict()#

Get the state dictionary of render.

Returns:

The state dictionary

Return type:

dict

update_sh_degree(step)#

Update the spherical harmonics degree in render

Parameters:

step (int) – The current training step.

class pointrix.model.renderer.gsplat.GsplatRender(cfg: dict | DictConfig | None = None, *args, **kwargs)#

Bases: MsplatRender

A class for rendering point clouds using gsplat

render_iter(height, width, extrinsic_matrix, intrinsic_params, camera_center, position, opacity, scaling, rotation, shs, **kwargs) dict#

Render the point cloud for one iteration

Parameters:
  • height (int) – The height of the image.

  • width (int) – The width of the image.

  • extrinsic_matrix (torch.Tensor) – The extrinsic matrix.

  • intrinsic_params (torch.Tensor) – The intrinsic parameters.

  • camera_center (torch.Tensor) – The camera center.

  • position (torch.Tensor) – The position of the point cloud.

  • opacity (torch.Tensor) – The opacity of the point cloud.

  • scaling (torch.Tensor) – The scaling of the point cloud.

  • rotation (torch.Tensor) – The rotation of the point cloud.

  • shs (torch.Tensor) – The spherical harmonics.