Adding Supervision for 3DGS#
We use surface normals as an example to illustrate how to add supervision for surface normal priors to the model for point cloud rendering.
All code can be found in the example/supervise
directory.
The data download link for this tutorial is provided below:
https://pan.baidu.com/s/1NlhxylY7q3SmVf9j29we3Q?pwd=f95m.
We employ the DSINE model to generate normals for the truck scene in the Tanks and Temple dataset.
Modification of Data Section#
Since the Tanks and Temple dataset is in Colmap format, we opt to modify the Colmap Dataset inherited from Pointrix. To read the normal prior outputs of the DSINE model, we first need to modify the configuration:
trainer.datapipeline.dataset.observed_data_dirs_dict={"image": "images", "normal":"normals"},
Where normal
is the folder name where Normal is stored, and normal
is the variable name for this data.
Pointrix will automatically read the data based on the current data path and folder name according to the suffix. The relevant reading code in Pointrix is shown below:
1def load_observed_data(self, split):
2 """
3 The function for loading the observed_data.
4
5 Parameters:
6 -----------
7 split: str
8 The split of the dataset.
9
10 Returns:
11 --------
12 observed_data: List[Dict[str, Any]]
13 The observed_datafor the dataset.
14 """
15 observed_data = []
16 for k, v in self.observed_data_dirs_dict.items():
17 observed_data_path = self.data_root / Path(v)
18 if not os.path.exists(observed_data_path):
19 Logger.error(f"observed_data path {observed_data_path} does not exist.")
20 observed_data_file_names = sorted(os.listdir(observed_data_path))
21 observed_data_file_names_split = [observed_data_file_names[i] for i in self.train_index] if split == "train" else [observed_data_file_names[i] for i in self.val_index]
22 cached_progress = ProgressLogger(description='Loading cached observed_data', suffix='iters/s')
23 cached_progress.add_task(f'cache_{k}', f'Loading {split} cached {k}', len(observed_data_file_names_split))
24 with cached_progress.progress as progress:
25 for idx, file in enumerate(observed_data_file_names_split):
26 if len(observed_data) <= idx:
27 observed_data.append({})
28 if file.endswith('.npy'):
29 observed_data[idx].update({k: np.load(observed_data_path / Path(file))})
30 elif file.endswith('png') or file.endswith('jpg') or file.endswith('JPG'):
31 observed_data[idx].update({k: Image.open(observed_data_path / Path(file))})
32 else:
33 print(f"File format {file} is not supported.")
34 cached_progress.update(f'cache_{k}', step=1)
35 return observed_data
After utilizing Pointrix’s automatic data reading feature, we need to process the read Normal data. We must override the Colmap Dataset and modify the _transform_observed_data
function to handle the observed data (surface normals). The specific code is located in examples/gaussian_splatting_supervise/dataset.py
.
1# Registry
2@DATA_SET_REGISTRY.register()
3class ColmapDepthNormalDataset(ColmapDataset):
4 def _transform_observed_data(self, observed_data, split):
5 cached_progress = ProgressLogger(description='transforming cached observed_data', suffix='iters/s')
6 cached_progress.add_task(f'Transforming', f'Transforming {split} cached observed_data', len(observed_data))
7 with cached_progress.progress as progress:
8 for i in range(len(observed_data)):
9 # Transform Image
10 image = observed_data[i]['image']
11 w, h = image.size
12 image = image.resize((int(w * self.scale), int(h * self.scale)))
13 image = np.array(image) / 255.
14 if image.shape[2] == 4:
15 image = image[:, :, :3] * image[:, :, 3:4] + self.bg * (1 - image[:, :, 3:4])
16 observed_data[i]['image'] = torch.from_numpy(np.array(image)).permute(2, 0, 1).float().clamp(0.0, 1.0)
17 cached_progress.update(f'Transforming', step=1)
18
19 # Transform Normal
20 observed_data[i]['normal'] = \
21 (torch.from_numpy(np.array(observed_data[i]['normal'])) \
22 / 255.0).float().permute(2, 0, 1)
23 return observed_data
After storing the processed Normal data into observed_data
, the Datapipeline in Pointrix automatically generates corresponding data during the training process once the data section modifications are complete.
Model Section Modifications#
Firstly, we need to import basic models from Pointrix so that we can inherit, register, and modify them.
from pointrix.model.base_model import BaseModel, MODEL_REGISTRY
The basic models include a Gaussian point cloud model and a camera model. Since we aim to obtain surface normals from the point cloud model, we need to modify the Gaussian point cloud model accordingly. This modification ensures that it outputs surface normals in the forward
function. Additionally, we need to obtain the corresponding normal loss in the get_loss_dict
function to include normal supervision in the backward pass. Furthermore, in the get_metric_dict
function, we obtain rendered surface normal images to prepare for visualizing predicted surface normals.
1@MODEL_REGISTRY.register()
2class NormalModel(BaseModel):
3 def forward(self, batch=None, training=True) -> dict:
4
5 frame_idx_list = [batch[i]["frame_idx"] for i in range(len(batch))]
6 extrinsic_matrix = self.training_camera_model.extrinsic_matrices(frame_idx_list) \
7 if training else self.validation_camera_model.extrinsic_matrices(frame_idx_list)
8 intrinsic_params = self.training_camera_model.intrinsic_params(frame_idx_list) \
9 if training else self.validation_camera_model.intrinsic_params(frame_idx_list)
10 camera_center = self.training_camera_model.camera_centers(frame_idx_list) \
11 if training else self.validation_camera_model.camera_centers(frame_idx_list)
12
13 # 获得表面法向
14 point_normal = self.get_normals
15 projected_normal = self.process_normals(point_normal, camera_center, extrinsic_matrix)
16
17 render_dict = {
18 "extrinsic_matrix": extrinsic_matrix,
19 "intrinsic_params": intrinsic_params,
20 "camera_center": camera_center,
21 "position": self.point_cloud.position,
22 "opacity": self.point_cloud.get_opacity,
23 "scaling": self.point_cloud.get_scaling,
24 "rotation": self.point_cloud.get_rotation,
25 "shs": self.point_cloud.get_shs,
26 "normals": projected_normal
27 }
28 return render_dict
29
30 # get the normal from min scale
31 @property
32 def get_normals(self):
33 scaling = self.point_cloud.scaling.clone()
34 normal_arg_min = torch.argmin(scaling, dim=-1)
35 normal_each = F.one_hot(normal_arg_min, num_classes=3)
36 normal_each = normal_each.float()
37
38 rotatation_matrix = unitquat_to_rotmat(self.point_cloud.get_rotation)
39 normal_each = torch.bmm(
40 rotatation_matrix, normal_each.unsqueeze(-1)).squeeze(-1)
41
42 normal_each = F.normalize(normal_each, dim=-1)
43 return normal_each
44
45 # project the surface normal to the camera coordinates
46 def process_normals(self, normals, camera_center, E):
47 camera_center = camera_center.squeeze(0)
48 E = E.squeeze(0)
49 xyz = self.point_cloud.position
50 direction = (camera_center.repeat(
51 xyz.shape[0], 1).cuda().detach() - xyz.cuda().detach())
52 direction = direction / direction.norm(dim=1, keepdim=True)
53 dot_for_judge = torch.sum(direction*normals, dim=-1)
54 normals[dot_for_judge < 0] = -normals[dot_for_judge < 0]
55 w2c = E[:3, :3].cuda().float()
56 normals_image = normals @ w2c.T
57 return normals_image
58
59 def get_loss_dict(self, render_results, batch) -> dict:
60 loss = 0.0
61 gt_images = torch.stack(
62 [batch[i]["image"] for i in range(len(batch))],
63 dim=0
64 )
65 normal_images = torch.stack(
66 [batch[i]["normal"] for i in range(len(batch))],
67 dim=0
68 )
69 L1_loss = l1_loss(render_results['rgb'], gt_images)
70 ssim_loss = 1.0 - ssim(render_results['rgb'], gt_images)
71 loss += (1.0 - self.cfg.lambda_ssim) * L1_loss
72 loss += self.cfg.lambda_ssim * ssim_loss
73 # normal loss
74 normal_loss = 0.1 * l1_loss(render_results['normal'], normal_images)
75 loss += normal_loss
76 loss_dict = {"loss": loss,
77 "L1_loss": L1_loss,
78 "ssim_loss": ssim_loss,
79 "normal_loss": normal_loss}
80 return loss_dict
81
82 @torch.no_grad()
83 def get_metric_dict(self, render_results, batch) -> dict:
84 gt_images = torch.clamp(torch.stack(
85 [batch[i]["image"].to(self.device) for i in range(len(batch))],
86 dim=0), 0.0, 1.0)
87 rgb = torch.clamp(render_results['rgb'], 0.0, 1.0)
88 L1_loss = l1_loss(rgb, gt_images).mean().double()
89 psnr_test = psnr(rgb.squeeze(), gt_images.squeeze()).mean().double()
90 ssims_test = ssim(rgb, gt_images, size_average=True).mean().item()
91 lpips_vgg_test = self.lpips_func(rgb, gt_images).mean().item()
92 metric_dict = {"L1_loss": L1_loss,
93 "psnr": psnr_test,
94 "ssims": ssims_test,
95 "lpips": lpips_vgg_test,
96 "gt_images": gt_images,
97 "images": rgb,
98 "rgb_file_name": batch[0]["camera"].rgb_file_name}
99
100 if 'depth' in render_results:
101 depth = render_results['depth']
102 metric_dict['depth'] = depth
103
104 if 'normal' in render_results:
105 normal = render_results['normal']
106 metric_dict['normal'] = normal
107
108 if 'normal' in batch[0]:
109 normal = batch[0]['normal']
110 metric_dict['normal_gt'] = normal
111
112 return metric_dict
Rendering Section Modifications#
Thanks to Msplat’s multi-target rendering capabilities, we only need to modify render_iter
by incorporating the Normal features outputted by the point cloud model into the renderer. Similarly, the newly modified renderer needs to be registered using a registry so that we can reference it through configuration. The Normals on line 14 correspond to the output parameters of the model’s forward
function, which Pointrix will automatically interface with.
1@RENDERER_REGISTRY.register()
2class MsplatNormalRender(MsplatRender):
3 """
4 A class for rendering point clouds using DPTR.
5
6 Parameters
7 ----------
8 cfg : dict
9 The configuration dictionary.
10 white_bg : bool
11 Whether the background is white or not.
12 device : str
13 The device to use.
14 update_sh_iter : int, optional
15 The iteration to update the spherical harmonics degree, by default 1000.
16 """
17
18 def render_iter(self,
19 height,
20 width,
21 extrinsic_matrix,
22 intrinsic_params,
23 camera_center,
24 position,
25 opacity,
26 scaling,
27 rotation,
28 shs,
29 normals,
30 **kwargs) -> dict:
31
32 direction = (position -
33 camera_center.repeat(position.shape[0], 1))
34 direction = direction / direction.norm(dim=1, keepdim=True)
35 rgb = msplat.compute_sh(shs.permute(0, 2, 1), direction)
36 extrinsic_matrix = extrinsic_matrix[:3, :]
37
38 (uv, depth) = msplat.project_point(
39 position,
40 intrinsic_params,
41 extrinsic_matrix,
42 width, height)
43
44 visible = depth != 0
45
46 # compute cov3d
47 cov3d = msplat.compute_cov3d(scaling, rotation, visible)
48
49 # ewa project
50 (conic, radius, tiles_touched) = msplat.ewa_project(
51 position,
52 cov3d,
53 intrinsic_params,
54 extrinsic_matrix,
55 uv,
56 width,
57 height,
58 visible
59 )
60
61 # sort
62 (gaussian_ids_sorted, tile_range) = msplat.sort_gaussian(
63 uv, depth, width, height, radius, tiles_touched
64 )
65
66 Render_Features = RenderFeatures(rgb=rgb, depth=depth, normal=normals)
67 render_features = Render_Features.combine()
68
69 ndc = torch.zeros_like(uv, requires_grad=True)
70 try:
71 ndc.retain_grad()
72 except:
73 raise ValueError("ndc does not have grad")
74
75 # alpha blending
76 rendered_features = msplat.alpha_blending(
77 uv, conic, opacity, render_features,
78 gaussian_ids_sorted, tile_range, self.bg_color, width, height, ndc
79 )
80 rendered_features_split = Render_Features.split(rendered_features)
81
82 normals = rendered_features_split["normal"]
83
84 # convert normals from [-1,1] to [0,1]
85 normals_im = normals / normals.norm(dim=0, keepdim=True)
86 normals_im = (normals_im + 1) / 2
87
88 rendered_features_split["normal"] = normals_im
89
90 return {"rendered_features_split": rendered_features_split,
91 "uv_points": ndc,
92 "visibility": radius > 0,
93 "radii": radius
94 }
Adding Relevant Logging Using Hook Functions#
Finally, we aim to visualize predicted surface normal images during each validation process. Therefore, we need to modify the corresponding hook function to achieve the visualization of surface normals after each validation.
1@HOOK_REGISTRY.register()
2class NormalLogHook(LogHook):
3 def after_val_iter(self, trainner) -> None:
4 self.progress_bar.update("validation", step=1)
5 for key, value in trainner.metric_dict.items():
6 if key in self.losses_test:
7 self.losses_test[key] += value
8
9 image_name = os.path.basename(trainner.metric_dict['rgb_file_name'])
10 iteration = trainner.global_step
11 if 'depth' in trainner.metric_dict:
12 visual_depth = visualize_depth(trainner.metric_dict['depth'].squeeze(), tensorboard=True)
13 trainner.writer.write_image(
14 "test" + f"_view_{image_name}/depth",
15 visual_depth, step=iteration)
16 trainner.writer.write_image(
17 "test" + f"_view_{image_name}/render",
18 trainner.metric_dict['images'].squeeze(),
19 step=iteration)
20
21 trainner.writer.write_image(
22 "test" + f"_view_{image_name}/ground_truth",
23 trainner.metric_dict['gt_images'].squeeze(),
24 step=iteration)
25
26 trainner.writer.write_image(
27 "test" + f"_view_{image_name}/normal",
28 trainner.metric_dict['normal'].squeeze(),
29 step=iteration)
30 trainner.writer.write_image(
31 "test" + f"_view_{image_name}/normal_gt",
32 trainner.metric_dict['normal_gt'].squeeze(),
33 step=iteration)
Lastly, we need to modify our configuration to integrate the updated model, renderer, dataset, and hook functions into the Pointrix training pipeline.
Warning
If you have added learnable parameters (such as convolutional networks or MLPs) on top of the Basemodel, please include the corresponding learnable parameters in the optimizer configuration so that the new parameters can be optimized.
1name: "garden"
2
3trainer:
4 output_path: "/home/linzhuo/clz/log/garden"
5 max_steps: 30000
6 val_interval: 5000
7 training: True
8
9 model:
10 name: NormalModel
11 lambda_ssim: 0.2
12 point_cloud:
13 point_cloud_type: "GaussianPointCloud"
14 max_sh_degree: 3
15 trainable: true
16 unwarp_prefix: "point_cloud"
17 initializer:
18 init_type: 'colmap'
19 feat_dim: 3
20 camera_model:
21 enable_training: False
22 renderer:
23 name: "MsplatNormalRender"
24 max_sh_degree: ${trainer.model.point_cloud.max_sh_degree}
25
26 controller:
27 normalize_grad: False
28
29 optimizer:
30 optimizer_1:
31 type: BaseOptimizer
32 name: Adam
33 args:
34 eps: 1e-15
35 extra_cfg:
36 backward: False
37 params:
38 point_cloud.position:
39 lr: 0.00016
40 point_cloud.features:
41 lr: 0.0025
42 point_cloud.features_rest:
43 lr: 0.000125 # features/20
44 point_cloud.scaling:
45 lr: 0.005
46 point_cloud.rotation:
47 lr: 0.001
48 point_cloud.opacity:
49 lr: 0.05
50 # camera_params:
51 # lr: 1e-3
52
53 scheduler:
54 name: "ExponLRScheduler"
55 params:
56 point_cloud.position:
57 init: 0.00016
58 final: 0.0000016
59 max_steps: ${trainer.max_steps}
60 datapipeline:
61 data_set: "ColmapDepthNormalDataset"
62 shuffle: True
63 batch_size: 1
64 num_workers: 0
65 dataset:
66 data_path: "/home/linzhuo/gj/data/garden"
67 cached_observed_data: ${trainer.training}
68 scale: 0.25
69 white_bg: False
70 observed_data_dirs_dict: {"image": "images", "normal": "normals"}
71
72 writer:
73 writer_type: "TensorboardWriter"
74
75 hooks:
76 LogHook:
77 name: NormalLogHook
78 CheckPointHook:
79 name: CheckPointHook
80
81 exporter:
82 exporter_a:
83 type: MetricExporter
84 exporter_b:
85 type: TSDFFusion
86 extra_cfg:
87 voxel_size: 0.02
88 sdf_truc: 0.08
89 total_points: 8_000_000
90 exporter_c:
91 type: VideoExporter
After the above modifications (highlighted portions in all code), we have successfully implemented supervision for Gaussian point cloud surface normals. All code can be found in the example/supervise
directory.
We run the following command:
python launch.py --config colmap.yaml trainer.datapipeline.dataset.data_path=your_data_path trainer.datapipeline.dataset.scale=0.5 trainer.output_path=your_log_path
The results show below: