训练器 (Trainer)

训练器 (Trainer)#

训练器提供了一个完整的默认训练流程,包括钩子、数据流水线、渲染器、模型、优化器和日志记录器的初始化,以及训练步骤、训练循环和验证过程。

DefaultTrainer#

DefaultTrainer是用于训练和测试模型的类,其框架主要包括以下部分:

  • 初始化:在初始化阶段,DefaultTrainer根据配置文件和实验目录构建训练环境。这包括解析配置文件、构建数据流水线、渲染器、模型、优化器、调度器和钩子。

  • 训练循环:train_loop方法定义了整个训练过程。它在每个训练步骤中更新模型参数,并在每个验证间隔时验证模型。train_step方法定义了每个训练步骤的过程。它首先通过模型获取渲染结果,然后计算损失,最后通过优化器更新模型参数。validation方法定期调用,并用于评估模型在验证集上的性能。它遍历验证集中的所有数据,对每个数据进行渲染,然后计算评估指标。

  • 测试:test方法用于评估模型在测试集上的性能。它首先加载模型,然后渲染测试集的数据,并保存渲染结果,如一些新颖视图图像。

  • 加载和保存:save_modelload_model方法用于保存和加载模型。模型的状态保存在一个字典中,其中包括全局步数、优化器状态、模型状态和渲染器状态。

  • 钩子函数:call_hook方法用于调用钩子函数。钩子函数是在特定阶段执行的函数(如训练开始、训练结束、每个训练步骤前后等),可以用于实现一些自定义功能。

Note

您还可以通过定义钩子函数或继承DefaultTrainer类并添加自己的修改来定义自己的训练过程。

数据流程#

在这个框架中,数据的流程主要经历以下步骤:

  1. 数据加载:通过数据流水线加载训练数据和验证数据。

  2. 训练:

    • 前向传播:将训练数据集输入到模型中,进行前向传播,并获得渲染结果

    • 损失计算:基于渲染结果和实际数据计算损失字典

    • 反向传播和参数更新:根据损失进行反向传播,然后生成优化器字典,通过优化器更新模型参数和结构。

    • 验证:在验证阶段(定期调用),也执行数据加载、前向传播和评估指标计算的步骤,但不执行反向传播和参数更新

  3. 保存模型:在训练过程的特定步骤或训练结束时,保存模型的状态。

Note

如果要测试模型的效果,只需调用test()方法,它将加载模型并渲染图像以评估模型。