Adding Supervision for 3DGS#

We use surface normals as an example to illustrate how to add supervision for surface normal priors to the model for point cloud rendering. All code can be found in the example/supervise directory. The data download link for this tutorial is provided below:

We employ the DSINE model to generate normals for the truck scene in the Tanks and Temple dataset.

Modification of Data Section#

Since the Tanks and Temple dataset is in Colmap format, we opt to modify the Colmap Dataset inherited from Pointrix. To read the normal prior outputs of the DSINE model, we first need to modify the configuration:

trainer.datapipeline.dataset.observed_data_dirs_dict={"image": "images", "normal":"normals"},

Where normal is the folder name where Normal is stored, and normal is the variable name for this data.

Pointrix will automatically read the data based on the current data path and folder name according to the suffix. The relevant reading code in Pointrix is shown below:

The relevant code section in Colmap for automatically reading data based on the suffix.#
 1def load_observed_data(self, split):
 2    """
 3    The function for loading the observed_data.
 5    Parameters:
 6    -----------
 7    split: str
 8        The split of the dataset.
10    Returns:
11    --------
12    observed_data: List[Dict[str, Any]]
13        The observed_datafor the dataset.
14    """
15    observed_data = []
16    for k, v in self.observed_data_dirs_dict.items():
17        observed_data_path = self.data_root / Path(v)
18        if not os.path.exists(observed_data_path):
19            Logger.error(f"observed_data path {observed_data_path} does not exist.")
20        observed_data_file_names = sorted(os.listdir(observed_data_path))
21        observed_data_file_names_split = [observed_data_file_names[i] for i in self.train_index] if split == "train" else [observed_data_file_names[i] for i in self.val_index]
22        cached_progress = ProgressLogger(description='Loading cached observed_data', suffix='iters/s')
23        cached_progress.add_task(f'cache_{k}', f'Loading {split} cached {k}', len(observed_data_file_names_split))
24        with cached_progress.progress as progress:
25            for idx, file in enumerate(observed_data_file_names_split):
26                if len(observed_data) <= idx:
27                    observed_data.append({})
28                if file.endswith('.npy'):
29                    observed_data[idx].update({k: np.load(observed_data_path / Path(file))})
30                elif file.endswith('png') or file.endswith('jpg') or file.endswith('JPG'):
31                    observed_data[idx].update({k: / Path(file))})
32                else:
33                    print(f"File format {file} is not supported.")
34                cached_progress.update(f'cache_{k}', step=1)
35    return observed_data

After utilizing Pointrix’s automatic data reading feature, we need to process the read Normal data. We must override the Colmap Dataset and modify the _transform_observed_data function to handle the observed data (surface normals). The specific code is located in examples/gaussian_splatting_supervise/

We highlight the modificated part.#
 1# Registry
 3class ColmapDepthNormalDataset(ColmapDataset):
 4    def _transform_observed_data(self, observed_data, split):
 5        cached_progress = ProgressLogger(description='transforming cached observed_data', suffix='iters/s')
 6        cached_progress.add_task(f'Transforming', f'Transforming {split} cached observed_data', len(observed_data))
 7        with cached_progress.progress as progress:
 8            for i in range(len(observed_data)):
 9                # Transform Image
10                image = observed_data[i]['image']
11                w, h = image.size
12                image = image.resize((int(w * self.scale), int(h * self.scale)))
13                image = np.array(image) / 255.
14                if image.shape[2] == 4:
15                    image = image[:, :, :3] * image[:, :, 3:4] + * (1 - image[:, :, 3:4])
16                observed_data[i]['image'] = torch.from_numpy(np.array(image)).permute(2, 0, 1).float().clamp(0.0, 1.0)
17                cached_progress.update(f'Transforming', step=1)
19                # Transform Normal
20                observed_data[i]['normal'] = \
21                    (torch.from_numpy(np.array(observed_data[i]['normal'])) \
22                    / 255.0).float().permute(2, 0, 1)
23        return observed_data

After storing the processed Normal data into observed_data, the Datapipeline in Pointrix automatically generates corresponding data during the training process once the data section modifications are complete.

Model Section Modifications#

Firstly, we need to import basic models from Pointrix so that we can inherit, register, and modify them.

from pointrix.model.base_model import BaseModel, MODEL_REGISTRY

The basic models include a Gaussian point cloud model and a camera model. Since we aim to obtain surface normals from the point cloud model, we need to modify the Gaussian point cloud model accordingly. This modification ensures that it outputs surface normals in the forward function. Additionally, we need to obtain the corresponding normal loss in the get_loss_dict function to include normal supervision in the backward pass. Furthermore, in the get_metric_dict function, we obtain rendered surface normal images to prepare for visualizing predicted surface normals.

We highlight the modified part.#
  2class NormalModel(BaseModel):
  3    def forward(self, batch=None, training=True) -> dict:
  5        frame_idx_list = [batch[i]["frame_idx"] for i in range(len(batch))]
  6        extrinsic_matrix = self.training_camera_model.extrinsic_matrices(frame_idx_list) \
  7            if training else self.validation_camera_model.extrinsic_matrices(frame_idx_list)
  8        intrinsic_params = self.training_camera_model.intrinsic_params(frame_idx_list) \
  9            if training else self.validation_camera_model.intrinsic_params(frame_idx_list)
 10        camera_center = self.training_camera_model.camera_centers(frame_idx_list) \
 11            if training else self.validation_camera_model.camera_centers(frame_idx_list)
 13        # 获得表面法向
 14        point_normal = self.get_normals
 15        projected_normal = self.process_normals(point_normal, camera_center, extrinsic_matrix)
 17        render_dict = {
 18            "extrinsic_matrix": extrinsic_matrix,
 19            "intrinsic_params": intrinsic_params,
 20            "camera_center": camera_center,
 21            "position": self.point_cloud.position,
 22            "opacity": self.point_cloud.get_opacity,
 23            "scaling": self.point_cloud.get_scaling,
 24            "rotation": self.point_cloud.get_rotation,
 25            "shs": self.point_cloud.get_shs,
 26            "normals": projected_normal
 27        }
 28        return render_dict
 30    # get the normal from min scale
 31    @property
 32    def get_normals(self):
 33        scaling = self.point_cloud.scaling.clone()
 34        normal_arg_min = torch.argmin(scaling, dim=-1)
 35        normal_each = F.one_hot(normal_arg_min, num_classes=3)
 36        normal_each = normal_each.float()
 38        rotatation_matrix = unitquat_to_rotmat(self.point_cloud.get_rotation)
 39        normal_each = torch.bmm(
 40            rotatation_matrix, normal_each.unsqueeze(-1)).squeeze(-1)
 42        normal_each = F.normalize(normal_each, dim=-1)
 43        return normal_each
 45    # project the surface normal to the camera coordinates
 46    def process_normals(self, normals, camera_center, E):
 47        camera_center = camera_center.squeeze(0)
 48        E = E.squeeze(0)
 49        xyz = self.point_cloud.position
 50        direction = (camera_center.repeat(
 51            xyz.shape[0], 1).cuda().detach() - xyz.cuda().detach())
 52        direction = direction / direction.norm(dim=1, keepdim=True)
 53        dot_for_judge = torch.sum(direction*normals, dim=-1)
 54        normals[dot_for_judge < 0] = -normals[dot_for_judge < 0]
 55        w2c = E[:3, :3].cuda().float()
 56        normals_image = normals @ w2c.T
 57        return normals_image
 59    def get_loss_dict(self, render_results, batch) -> dict:
 60        loss = 0.0
 61        gt_images = torch.stack(
 62            [batch[i]["image"] for i in range(len(batch))],
 63            dim=0
 64        )
 65        normal_images = torch.stack(
 66            [batch[i]["normal"] for i in range(len(batch))],
 67            dim=0
 68        )
 69        L1_loss = l1_loss(render_results['rgb'], gt_images)
 70        ssim_loss = 1.0 - ssim(render_results['rgb'], gt_images)
 71        loss += (1.0 - self.cfg.lambda_ssim) * L1_loss
 72        loss += self.cfg.lambda_ssim * ssim_loss
 73        # normal loss
 74        normal_loss = 0.1 * l1_loss(render_results['normal'], normal_images)
 75        loss += normal_loss
 76        loss_dict = {"loss": loss,
 77                     "L1_loss": L1_loss,
 78                     "ssim_loss": ssim_loss,
 79                     "normal_loss": normal_loss}
 80        return loss_dict
 82    @torch.no_grad()
 83    def get_metric_dict(self, render_results, batch) -> dict:
 84        gt_images = torch.clamp(torch.stack(
 85            [batch[i]["image"].to(self.device) for i in range(len(batch))],
 86            dim=0), 0.0, 1.0)
 87        rgb = torch.clamp(render_results['rgb'], 0.0, 1.0)
 88        L1_loss = l1_loss(rgb, gt_images).mean().double()
 89        psnr_test = psnr(rgb.squeeze(), gt_images.squeeze()).mean().double()
 90        ssims_test = ssim(rgb, gt_images, size_average=True).mean().item()
 91        lpips_vgg_test = self.lpips_func(rgb, gt_images).mean().item()
 92        metric_dict = {"L1_loss": L1_loss,
 93                       "psnr": psnr_test,
 94                       "ssims": ssims_test,
 95                       "lpips": lpips_vgg_test,
 96                       "gt_images": gt_images,
 97                       "images": rgb,
 98                       "rgb_file_name": batch[0]["camera"].rgb_file_name}
100        if 'depth' in render_results:
101            depth = render_results['depth']
102            metric_dict['depth'] = depth
104        if 'normal' in render_results:
105            normal = render_results['normal']
106            metric_dict['normal'] = normal
108        if 'normal' in batch[0]:
109            normal = batch[0]['normal']
110            metric_dict['normal_gt'] = normal
112        return metric_dict

Rendering Section Modifications#

Thanks to Msplat’s multi-target rendering capabilities, we only need to modify render_iter by incorporating the Normal features outputted by the point cloud model into the renderer. Similarly, the newly modified renderer needs to be registered using a registry so that we can reference it through configuration. The Normals on line 14 correspond to the output parameters of the model’s forward function, which Pointrix will automatically interface with.

We highlight the modified part.#
 2class MsplatNormalRender(MsplatRender):
 3    """
 4    A class for rendering point clouds using DPTR.
 6    Parameters
 7    ----------
 8    cfg : dict
 9        The configuration dictionary.
10    white_bg : bool
11        Whether the background is white or not.
12    device : str
13        The device to use.
14    update_sh_iter : int, optional
15        The iteration to update the spherical harmonics degree, by default 1000.
16    """
18    def render_iter(self,
19                    height,
20                    width,
21                    extrinsic_matrix,
22                    intrinsic_params,
23                    camera_center,
24                    position,
25                    opacity,
26                    scaling,
27                    rotation,
28                    shs,
29                    normals,
30                    **kwargs) -> dict:
32        direction = (position -
33                     camera_center.repeat(position.shape[0], 1))
34        direction = direction / direction.norm(dim=1, keepdim=True)
35        rgb = msplat.compute_sh(shs.permute(0, 2, 1), direction)
36        extrinsic_matrix = extrinsic_matrix[:3, :]
38        (uv, depth) = msplat.project_point(
39            position,
40            intrinsic_params,
41            extrinsic_matrix,
42            width, height)
44        visible = depth != 0
46        # compute cov3d
47        cov3d = msplat.compute_cov3d(scaling, rotation, visible)
49        # ewa project
50        (conic, radius, tiles_touched) = msplat.ewa_project(
51            position,
52            cov3d,
53            intrinsic_params,
54            extrinsic_matrix,
55            uv,
56            width,
57            height,
58            visible
59        )
61        # sort
62        (gaussian_ids_sorted, tile_range) = msplat.sort_gaussian(
63            uv, depth, width, height, radius, tiles_touched
64        )
66        Render_Features = RenderFeatures(rgb=rgb, depth=depth, normal=normals)
67        render_features = Render_Features.combine()
69        ndc = torch.zeros_like(uv, requires_grad=True)
70        try:
71            ndc.retain_grad()
72        except:
73            raise ValueError("ndc does not have grad")
75        # alpha blending
76        rendered_features = msplat.alpha_blending(
77            uv, conic, opacity, render_features,
78            gaussian_ids_sorted, tile_range, self.bg_color, width, height, ndc
79        )
80        rendered_features_split = Render_Features.split(rendered_features)
82        normals = rendered_features_split["normal"]
84        # convert normals from [-1,1] to [0,1]
85        normals_im = normals / normals.norm(dim=0, keepdim=True)
86        normals_im = (normals_im + 1) / 2
88        rendered_features_split["normal"] = normals_im
90        return {"rendered_features_split": rendered_features_split,
91                "uv_points": ndc,
92                "visibility": radius > 0,
93                "radii": radius
94                }

Adding Relevant Logging Using Hook Functions#

Finally, we aim to visualize predicted surface normal images during each validation process. Therefore, we need to modify the corresponding hook function to achieve the visualization of surface normals after each validation.

We highlight the modified part.#
 2class NormalLogHook(LogHook):
 3    def after_val_iter(self, trainner) -> None:
 4        self.progress_bar.update("validation", step=1)
 5        for key, value in trainner.metric_dict.items():
 6            if key in self.losses_test:
 7                self.losses_test[key] += value
 9        image_name = os.path.basename(trainner.metric_dict['rgb_file_name'])
10        iteration = trainner.global_step
11        if 'depth' in trainner.metric_dict:
12            visual_depth = visualize_depth(trainner.metric_dict['depth'].squeeze(), tensorboard=True)
13            trainner.writer.write_image(
14            "test" + f"_view_{image_name}/depth",
15            visual_depth, step=iteration)
16        trainner.writer.write_image(
17            "test" + f"_view_{image_name}/render",
18            trainner.metric_dict['images'].squeeze(),
19            step=iteration)
21        trainner.writer.write_image(
22            "test" + f"_view_{image_name}/ground_truth",
23            trainner.metric_dict['gt_images'].squeeze(),
24            step=iteration)
26        trainner.writer.write_image(
27            "test" + f"_view_{image_name}/normal",
28            trainner.metric_dict['normal'].squeeze(),
29            step=iteration)
30        trainner.writer.write_image(
31            "test" + f"_view_{image_name}/normal_gt",
32            trainner.metric_dict['normal_gt'].squeeze(),
33            step=iteration)

Lastly, we need to modify our configuration to integrate the updated model, renderer, dataset, and hook functions into the Pointrix training pipeline.


If you have added learnable parameters (such as convolutional networks or MLPs) on top of the Basemodel, please include the corresponding learnable parameters in the optimizer configuration so that the new parameters can be optimized.

We highlight the modified part.#
 1name: "garden"
 4  output_path: "/home/linzhuo/clz/log/garden"
 5  max_steps: 30000
 6  val_interval: 5000
 7  training: True
 9  model:
10    name: NormalModel
11    lambda_ssim: 0.2
12    point_cloud:
13      point_cloud_type: "GaussianPointCloud"  
14      max_sh_degree: 3
15      trainable: true
16      unwarp_prefix: "point_cloud"
17      initializer:
18        init_type: 'colmap'
19        feat_dim: 3
20    camera_model:
21      enable_training: False
22    renderer:
23      name: "MsplatNormalRender"
24      max_sh_degree: ${trainer.model.point_cloud.max_sh_degree}
26  controller:
27    normalize_grad: False
29  optimizer:
30    optimizer_1:
31      type: BaseOptimizer
32      name: Adam
33      args:
34        eps: 1e-15
35      extra_cfg:
36        backward: False
37      params:
38        point_cloud.position:
39          lr: 0.00016
40        point_cloud.features:
41          lr: 0.0025
42        point_cloud.features_rest:
43          lr: 0.000125 # features/20
44        point_cloud.scaling:
45          lr: 0.005
46        point_cloud.rotation:
47          lr: 0.001
48        point_cloud.opacity:
49          lr: 0.05
50      # camera_params:
51      #   lr: 1e-3
53  scheduler:
54    name: "ExponLRScheduler"
55    params:
56      point_cloud.position:
57        init:  0.00016
58        final: 0.0000016
59        max_steps: ${trainer.max_steps}
60  datapipeline:
61    data_set: "ColmapDepthNormalDataset"
62    shuffle: True
63    batch_size: 1
64    num_workers: 0
65    dataset:
66      data_path: "/home/linzhuo/gj/data/garden"
67      cached_observed_data: ${}
68      scale: 0.25
69      white_bg: False
70      observed_data_dirs_dict: {"image": "images", "normal": "normals"}
72  writer:
73    writer_type: "TensorboardWriter"
75  hooks:
76    LogHook:
77      name: NormalLogHook
78    CheckPointHook:
79      name: CheckPointHook
81  exporter:
82    exporter_a:
83      type: MetricExporter
84    exporter_b:
85      type: TSDFFusion
86      extra_cfg:
87        voxel_size: 0.02
88        sdf_truc: 0.08
89        total_points: 8_000_000
90    exporter_c:
91      type: VideoExporter

After the above modifications (highlighted portions in all code), we have successfully implemented supervision for Gaussian point cloud surface normals. All code can be found in the example/supervise directory.

We run the following command:

python --config colmap.yaml trainer.datapipeline.dataset.data_path=your_data_path trainer.datapipeline.dataset.scale=0.5 trainer.output_path=your_log_path

The results show below: