在Ubuntu上使用PyTorch進行模型的保存與加載,可以按照以下步驟操作:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
torch.save()
函數來保存整個模型或模型的狀態字典。model = MyModel()
# 假設模型已經訓練完成
torch.save(model, 'model.pth') # 保存整個模型
# 或者
torch.save(model.state_dict(), 'model_state_dict.pth') # 只保存模型的狀態字典
torch.load()
函數來加載模型。如果你之前保存了整個模型,可以直接加載;如果只保存了狀態字典,則需要先實例化模型,然后再加載狀態字典。# 加載整個模型
model = torch.load('model.pth')
# 或者,如果只保存了狀態字典
model = MyModel() # 先實例化模型
model.load_state_dict(torch.load('model_state_dict.pth'))
注意:在加載模型時,確保模型類(在本例中為MyModel
)已經在當前環境中定義。
# 假設我們有一些輸入數據
input_data = torch.randn(1, 10)
# 使用模型進行推理
output = model(input_data)
print(output)