模型保存与加载
import mindtorch.torch as torch
# 加载来自PyTorch原生脚本的预训练权重pth
net.load_state_dict(torch.load('pytorch.pth'))
...
# 网络训练脚本
...
# 模型保存
torch.save(net.state_dict(),'mindtorch.pth')
# 加载来自MindTorch迁移模型保存的pth进行finetune
net.load_state_dict(torch.load('mindtorch.pth'))
我们支持PyTorch原生的模型保存与加载语法,允许用户保存网络权重或以字典形式保存其他数据;对于模型加载阶段,当前暂不支持加载网络模型结构。用户同样可以加载来自PyTorch原生的pth文件,当前仅支持加载网络权重,不支持加载网络结构。基于MindTorch保存的pth文件不支持PyTorch原生脚本使用。
如果使用mindtorch 0.3之前的版本,可参考FAQ中的方法来加载PyTorch权重文件。