混合精度训练适配
混合精度训练是指在训练时,对神经网络不同的运算采用不同的数值精度的运算策略。对于conv、matmul等运算占比较大的神经网络,其训练速度通常会有较大的加速比。mindspore.amp模块提供了便捷的自动混合精度接口,用户可以在不同的硬件后端通过简单的接口调用获得训练加速。
目前由于框架机制不同,MindTorch目前暂未支持torch.cuda.amp.autocast
模块下接口功能(计划支持中)。用户需要将torch.cuda.amp.autocast
接口替换成mindspore.amp.auto_mixed_precision
接口,从而使能MindSpore的自动混合精度训练。
迁移前代码:
from torch.cuda.amp import autocast, GradScaler
model = Net().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
scaler = GradScaler()
model.train()
for epoch in epochs:
for inputs, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss = scaler.scale(loss) # 损失缩放
loss.backward()
scaler.unscale_(optimizer) # 反向缩放梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) # 梯度裁剪
scaler.step(optimizer) # 梯度更新
scaler.update() # 更新系数
...
迁移后代码:
import mindtorch.torch as torch
from mindtorch.torch.cuda.amp import GradScaler
from mindspore.amp import auto_mixed_precision
...
model = Net().cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)
scaler = GradScaler()
model.train() # model的方法调用需放在混合精度模型转换前
model = auto_mixed_precision(model, 'O3') # Ascend环境推荐配置'O3',GPU环境推荐配置'O2'
def forward_fn(data, target):
logits = model(data)
logits = torch.cast_to_adapter_tensor(logits) # model为混合精度模型,需要对输出tensor进行类型转换
loss = criterion(logits, target)
loss = scaler.scale(loss) # 损失缩放
return loss
grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters)
def train_step(data, target):
loss, grads = grad_fn(data, target)
return loss, grads
for epoch in epochs:
for inputs, target in data:
loss, grads = train_step(input, target)
scaler.unscale_(optimizer, grads) # 反向缩放梯度
grads = ms.ops.clip_by_global_norm(grads, max_norm) # 梯度裁剪
scaler.step(optimizer, grads) # 梯度更新
scaler.update() # 更新系数
...
Step 1:调用
auto_mixed_precision
自动生成混合精度模型,如果需要调用原始模型的方法请在混合精度模型生成前执行,如model.train()
;Step 2(可选):如果后续有对网络输出Tensor的操作,需调用
cast_to_adapter_tensor
手动将输出Tensor转换为MindTorch Tensor。Step 3:调用
GradScaler
对梯度进行缩放时,由于自动微分机制和接口区别,unscale_
和step
等接口需要把梯度grads作为入参传入。
更多细节请参考自动混合精度使用教程。