钩子函数#
钩子模块可以作为一个工具类,在某些固定的时间点对训练器进行一些操作。CheckPointHook 和 LogHook 是两个继承自基类的示例钩子,它们分别用于在训练循环中保存检查点和记录训练和验证损失。
LogHook 依赖于一个名为 Logger 的 Console 对象,在终端打印一些日志信息,以及一个 ProgressLogger 对象来可视化训练进度。在指定的时间点,它将调用训练器的 writer 来记录日志信息。
对于与训练器绑定的 writer,它应该继承自基类 Writer
,并完成 write_scalar()
、write_image()
和 write_config()
这三个抽象函数。您可以创建更多类型的 Writer。例如,TensorboardWriter 通过重写这三个接口函数封装了 torch.utils.tensorboard.SummaryWriter
。您还可以在 .yaml
配置文件中方便地指定要使用的 writer 类型。
使用场合#
在训练器中,钩子函数将在特定位置被调用:
1class DefaultTrainer:
2 """
3 The default trainer class for training and testing the model.
4
5 Parameters
6 ----------
7 cfg : dict
8 The configuration dictionary.
9 exp_dir : str
10 The experiment directory.
11 device : str, optional
12 The device to use, by default "cuda".
13 """
14 def __init__(self, cfg: Config, exp_dir: Path, device: str = "cuda") -> None:
15 super().__init__()
16 self.exp_dir = exp_dir
17 self.device = device
18
19 self.start_steps = 1
20 self.global_step = 0
21
22 # build config
23 self.cfg = parse_structured(self.Config, cfg)
24 # build hooks
25 self.hooks = parse_hooks(self.cfg.hooks)
26 self.call_hook("before_run")
27 # build datapipeline
28
29 # some code are ignored
30
31 @torch.no_grad()
32 def validation(self):
33 self.val_dataset_size = len(self.datapipeline.validation_dataset)
34 for i in range(0, self.val_dataset_size):
35 self.call_hook("before_val_iter")
36 batch = self.datapipeline.next_val(i)
37 render_dict = self.model(batch)
38 render_results = self.renderer.render_batch(render_dict, batch)
39 self.metric_dict = self.model.get_metric_dict(render_results, batch)
40 self.call_hook("after_val_iter")
41
42 def train_loop(self) -> None:
43 """
44 The training loop for the model.
45 """
46 loop_range = range(self.start_steps, self.cfg.max_steps+1)
47 self.global_step = self.start_steps
48 self.call_hook("before_train")
49 for iteration in loop_range:
50 self.call_hook("before_train_iter")
51 batch = self.datapipeline.next_train(self.global_step)
52 self.renderer.update_sh_degree(iteration)
53 self.schedulers.step(self.global_step, self.optimizer)
54 self.train_step(batch)
55 self.optimizer.update_model(**self.optimizer_dict)
56 self.call_hook("after_train_iter")
57 self.global_step += 1
58 if iteration % self.cfg.val_interval == 0 or iteration == self.cfg.max_steps:
59 self.call_hook("before_val")
60 self.validation()
61 self.call_hook("after_val")
62 self.call_hook("after_train")
63
64 def call_hook(self, fn_name: str, **kwargs) -> None:
65 """
66 Call the hook method.
67
68 Parameters
69 ----------
70 fn_name : str
71 The hook method name.
72 kwargs : dict
73 The keyword arguments.
74 """
75 for hook in self.hooks:
76 # support adding additional custom hook methods
77 if hasattr(hook, fn_name):
78 try:
79 getattr(hook, fn_name)(self, **kwargs)
80 except TypeError as e:
81 raise TypeError(f'{e} in {hook}') from None
更多类型的钩子#
您可以通过定义钩子函数来修改训练器的进度,例如,如果您想在训练迭代后记录一些内容:
Note
在钩子函数中可以完全访问训练器。 我们默认提供了日志钩子和检查点钩子。
@HOOK_REGISTRY.register()
class LogHook(Hook):
"""
A hook to log the training and validation losses.
"""
def __init__(self):
self.ema_loss_for_log = 0.
self.bar_info = {}
self.losses_test = {"L1_loss": 0., "psnr": 0., "ssims": 0., "lpips": 0.}
def after_train_iter(self, trainner) -> None:
"""
some operations after the training iteration ends.
Parameters
----------
trainner : Trainer
The trainer object.
"""
for param_group in trainner.optimizer.param_groups:
name = param_group['name']
if name == "point_cloud." + "position":
pos_lr = param_group['lr']
break
log_dict = {
"num_pt": len(trainner.model.point_cloud),
"pos_lr": pos_lr
}
log_dict.update(trainner.loss_dict)
for key, value in log_dict.items():
if key == 'loss':
self.ema_loss_for_log = 0.4 * value.item() + 0.6 * self.ema_loss_for_log
self.bar_info.update(
{key: f"{self.ema_loss_for_log:.{7}f}"})
if trainner.logger and key != "optimizer_params":
trainner.logger.write_scalar(key, value, trainner.global_step)
您可以参考教程部分或方法部分,了解更多钩子函数的示例。