混合精度训练适配

混合精度训练是指在训练时,对神经网络不同的运算采用不同的数值精度的运算策略。对于conv、matmul等运算占比较大的神经网络,其训练速度通常会有较大的加速比。mindspore.amp模块提供了便捷的自动混合精度接口,用户可以在不同的硬件后端通过简单的接口调用获得训练加速。

目前由于框架机制不同,MindTorch目前暂未支持torch.cuda.amp.autocast模块下接口功能(计划支持中)。用户需要将torch.cuda.amp.autocast接口替换成mindspore.amp.auto_mixed_precision接口,从而使能MindSpore的自动混合精度训练。

迁移前代码:

from torch.cuda.amp import autocast, GradScaler

model = Net().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

scaler = GradScaler()

model.train()
for epoch in epochs:
    for inputs, target in data:
        optimizer.zero_grad()

        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        loss = scaler.scale(loss)  # 损失缩放
        loss.backward()
        scaler.unscale_(optimizer) # 反向缩放梯度
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)  # 梯度裁剪
        scaler.step(optimizer)     # 梯度更新
        scaler.update()            # 更新系数
...

迁移后代码:

import mindtorch.torch as torch
from mindtorch.torch.cuda.amp import GradScaler
from mindspore.amp import auto_mixed_precision

...
model = Net().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

scaler = GradScaler()

model.train()    # model的方法调用需放在混合精度模型转换前
model = auto_mixed_precision(model, 'O3')    # Ascend环境推荐配置'O3',GPU环境推荐配置'O2'

def forward_fn(data, target):
    logits = model(data)
    logits = torch.cast_to_adapter_tensor(logits)  # model为混合精度模型,需要对输出tensor进行类型转换
    loss = criterion(logits, target)
    loss = scaler.scale(loss)   # 损失缩放
    return loss

grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters)

def train_step(data, target):
    loss, grads = grad_fn(data, target)
    return loss, grads

for epoch in epochs:
    for inputs, target in data:
        loss, grads = train_step(input, target)
        scaler.unscale_(optimizer, grads)     # 反向缩放梯度
        grads = ms.ops.clip_by_global_norm(grads, max_norm)  # 梯度裁剪
        scaler.step(optimizer, grads)         # 梯度更新
        scaler.update()                       # 更新系数
...
  • Step 1:调用auto_mixed_precision自动生成混合精度模型,如果需要调用原始模型的方法请在混合精度模型生成前执行,如model.train()

  • Step 2(可选):如果后续有对网络输出Tensor的操作,需调用cast_to_adapter_tensor手动将输出Tensor转换为MindTorch Tensor。

  • Step 3:调用GradScaler对梯度进行缩放时,由于自动微分机制和接口区别,unscale_step等接口需要把梯度grads作为入参传入。

更多细节请参考自动混合精度使用教程