在Linux系統下,使用PyTorch保存和加載模型的主要方法是使用torch.save()和torch.load()函數。以下是詳細的步驟和示例:
定義模型: 首先,你需要定義一個PyTorch模型。
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = MyModel()
保存模型:
使用torch.save()函數將模型保存到文件中。
torch.save(model, 'model.pth')
這會將整個模型(包括其結構和參數)保存到一個名為model.pth的文件中。
加載模型:
使用torch.load()函數從文件中加載模型。
loaded_model = torch.load('model.pth')
使用加載的模型: 加載的模型可以直接用于推理或進一步訓練。
input_data = torch.randn(1, 10) # 示例輸入數據
output = loaded_model(input_data)
print(output)
設備兼容性:
如果你在GPU上訓練模型,保存的模型會包含GPU相關的信息。在加載模型到CPU時,可以使用map_location參數。
loaded_model = torch.load('model.pth', map_location=torch.device('cpu'))
版本兼容性: 確保保存和加載模型的PyTorch版本一致,否則可能會出現兼容性問題。
自定義對象: 如果模型中使用了自定義層或函數,需要在加載模型時提供這些自定義對象的定義。
def custom_function(x):
return x * 2
torch.save({'model_state_dict': model.state_dict(), 'custom_function': custom_function}, 'model.pth')
# 加載模型時
checkpoint = torch.load('model.pth')
model = MyModel()
model.load_state_dict(checkpoint['model_state_dict'])
model.custom_function = checkpoint['custom_function']
通過以上步驟,你可以在Linux系統下方便地保存和加載PyTorch模型。