训练器 (Trainer)#
训练器提供了一个完整的默认训练流程,包括钩子、数据流水线、渲染器、模型、优化器和日志记录器的初始化,以及训练步骤、训练循环和验证过程。
DefaultTrainer#
DefaultTrainer是用于训练和测试模型的类,其框架主要包括以下部分:
初始化:在初始化阶段,DefaultTrainer根据配置文件和实验目录构建训练环境。这包括解析配置文件、构建数据流水线、渲染器、模型、优化器、调度器和钩子。
训练循环:
train_loop
方法定义了整个训练过程。它在每个训练步骤中更新模型参数,并在每个验证间隔时验证模型。train_step
方法定义了每个训练步骤的过程。它首先通过模型获取渲染结果,然后计算损失,最后通过优化器更新模型参数。validation
方法定期调用,并用于评估模型在验证集上的性能。它遍历验证集中的所有数据,对每个数据进行渲染,然后计算评估指标。测试:
test
方法用于评估模型在测试集上的性能。它首先加载模型,然后渲染测试集的数据,并保存渲染结果,如一些新颖视图图像。加载和保存:
save_model
和load_model
方法用于保存和加载模型。模型的状态保存在一个字典中,其中包括全局步数、优化器状态、模型状态和渲染器状态。钩子函数:
call_hook
方法用于调用钩子函数。钩子函数是在特定阶段执行的函数(如训练开始、训练结束、每个训练步骤前后等),可以用于实现一些自定义功能。
Note
您还可以通过定义钩子函数或继承DefaultTrainer类并添加自己的修改来定义自己的训练过程。
数据流程#
在这个框架中,数据的流程主要经历以下步骤:
数据加载:通过数据流水线加载训练数据和验证数据。
训练:
前向传播:将训练数据集输入到模型中,进行前向传播,并获得渲染结果。
损失计算:基于渲染结果和实际数据计算损失字典。
反向传播和参数更新:根据损失进行反向传播,然后生成优化器字典,通过优化器更新模型参数和结构。
验证:在验证阶段(定期调用),也执行数据加载、前向传播和评估指标计算的步骤,但不执行反向传播和参数更新。
保存模型:在训练过程的特定步骤或训练结束时,保存模型的状态。
Note
如果要测试模型的效果,只需调用test()
方法,它将加载模型并渲染图像以评估模型。