钩子函数

钩子函数#

钩子模块可以作为一个工具类,在某些固定的时间点对训练器进行一些操作。CheckPointHook 和 LogHook 是两个继承自基类的示例钩子,它们分别用于在训练循环中保存检查点和记录训练和验证损失。

LogHook 依赖于一个名为 Logger 的 Console 对象,在终端打印一些日志信息,以及一个 ProgressLogger 对象来可视化训练进度。在指定的时间点,它将调用训练器的 writer 来记录日志信息。

对于与训练器绑定的 writer,它应该继承自基类 Writer,并完成 write_scalar()write_image()write_config() 这三个抽象函数。您可以创建更多类型的 Writer。例如,TensorboardWriter 通过重写这三个接口函数封装了 torch.utils.tensorboard.SummaryWriter。您还可以在 .yaml 配置文件中方便地指定要使用的 writer 类型。

使用场合#

在训练器中,钩子函数将在特定位置被调用:

We highlight the hook part.#
 1class DefaultTrainer:
 2    """
 3    The default trainer class for training and testing the model.
 4
 5    Parameters
 6    ----------
 7    cfg : dict
 8        The configuration dictionary.
 9    exp_dir : str
10        The experiment directory.
11    device : str, optional
12        The device to use, by default "cuda".
13    """
14    def __init__(self, cfg: Config, exp_dir: Path, device: str = "cuda") -> None:
15        super().__init__()
16        self.exp_dir = exp_dir
17        self.device = device
18
19        self.start_steps = 1
20        self.global_step = 0
21
22        # build config
23        self.cfg = parse_structured(self.Config, cfg)
24        # build hooks
25        self.hooks = parse_hooks(self.cfg.hooks)
26        self.call_hook("before_run")
27        # build datapipeline
28        
29        # some code are ignored
30
31    @torch.no_grad()
32    def validation(self):
33        self.val_dataset_size = len(self.datapipeline.validation_dataset)
34        for i in range(0, self.val_dataset_size):
35            self.call_hook("before_val_iter")
36            batch = self.datapipeline.next_val(i)
37            render_dict = self.model(batch)
38            render_results = self.renderer.render_batch(render_dict, batch)
39            self.metric_dict = self.model.get_metric_dict(render_results, batch)
40            self.call_hook("after_val_iter")
41
42    def train_loop(self) -> None:
43        """
44        The training loop for the model.
45        """
46        loop_range = range(self.start_steps, self.cfg.max_steps+1)
47        self.global_step = self.start_steps
48        self.call_hook("before_train")
49        for iteration in loop_range:
50            self.call_hook("before_train_iter")
51            batch = self.datapipeline.next_train(self.global_step)
52            self.renderer.update_sh_degree(iteration)
53            self.schedulers.step(self.global_step, self.optimizer)
54            self.train_step(batch)
55            self.optimizer.update_model(**self.optimizer_dict)
56            self.call_hook("after_train_iter")
57            self.global_step += 1
58            if iteration % self.cfg.val_interval == 0 or iteration == self.cfg.max_steps:
59                self.call_hook("before_val")
60                self.validation()
61                self.call_hook("after_val")
62        self.call_hook("after_train")
63
64    def call_hook(self, fn_name: str, **kwargs) -> None:
65        """
66        Call the hook method.
67
68        Parameters
69        ----------
70        fn_name : str
71            The hook method name.
72        kwargs : dict
73            The keyword arguments.
74        """
75        for hook in self.hooks:
76            # support adding additional custom hook methods
77            if hasattr(hook, fn_name):
78                try:
79                    getattr(hook, fn_name)(self, **kwargs)
80                except TypeError as e:
81                    raise TypeError(f'{e} in {hook}') from None

更多类型的钩子#

您可以通过定义钩子函数来修改训练器的进度,例如,如果您想在训练迭代后记录一些内容:

Note

在钩子函数中可以完全访问训练器。 我们默认提供了日志钩子和检查点钩子。

@HOOK_REGISTRY.register()
class LogHook(Hook):
    """
    A hook to log the training and validation losses.
    """

    def __init__(self):
        self.ema_loss_for_log = 0.
        self.bar_info = {}
        
        self.losses_test = {"L1_loss": 0., "psnr": 0., "ssims": 0., "lpips": 0.}

    def after_train_iter(self, trainner) -> None:
        """
        some operations after the training iteration ends.

        Parameters
        ----------
        trainner : Trainer
            The trainer object.
        """
        for param_group in trainner.optimizer.param_groups:
            name = param_group['name']
            if name == "point_cloud." + "position":
                pos_lr = param_group['lr']
                break

        log_dict = {
            "num_pt": len(trainner.model.point_cloud),
            "pos_lr": pos_lr
        }
        log_dict.update(trainner.loss_dict)

        for key, value in log_dict.items():
            if key == 'loss':
                self.ema_loss_for_log = 0.4 * value.item() + 0.6 * self.ema_loss_for_log
                self.bar_info.update(
                    {key: f"{self.ema_loss_for_log:.{7}f}"})

            if trainner.logger and key != "optimizer_params":
                trainner.logger.write_scalar(key, value, trainner.global_step)

您可以参考教程部分或方法部分,了解更多钩子函数的示例。