在CentOS系統上,使用PyTorch保存和加載模型的步驟與其他操作系統相同。以下是保存和加載PyTorch模型的基本方法:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = MyModel()
# 假設模型已經訓練完成
torch.save(model.state_dict(), 'model.pth')
model.state_dict() 是一個包含模型所有參數的字典。torch.save() 函數將這個字典保存到文件 model.pth 中。
# 創建相同結構的模型實例
model = MyModel()
# 加載權重
model.load_state_dict(torch.load('model.pth'))
# 如果模型是在GPU上訓練的,需要將模型移動到CPU并設置為評估模式
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()
map_location 參數用于指定加載模型權重時的設備位置。如果模型是在GPU上訓練的,你需要將其加載到CPU上。model.eval() 將模型設置為評估模式,這在推理時是必要的。
通過以上步驟,你可以在CentOS系統上輕松地保存和加載PyTorch模型。