模型保存与加载

error_log

import mindtorch.torch as torch

# 加载来自PyTorch原生脚本的预训练权重pth
net.load_state_dict(torch.load('pytorch.pth'))

...
# 网络训练脚本
...

# 模型保存
torch.save(net.state_dict(),'mindtorch.pth')

# 加载来自MindTorch迁移模型保存的pth进行finetune
net.load_state_dict(torch.load('mindtorch.pth'))

我们支持PyTorch原生的模型保存与加载语法,允许用户保存网络权重或以字典形式保存其他数据;对于模型加载阶段,当前暂不支持加载网络模型结构。用户同样可以加载来自PyTorch原生的pth文件,当前仅支持加载网络权重,不支持加载网络结构。基于MindTorch保存的pth文件不支持PyTorch原生脚本使用。

如果使用mindtorch 0.3之前的版本,可参考FAQ中的方法来加载PyTorch权重文件。