在Ubuntu上使用PyTorch進行模型保存非常簡單。以下是保存和加載模型的基本步驟:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
model = MyModel()
torch.save()
函數來保存整個模型。torch.save(model, 'model.pth')
這會將模型的所有參數和結構保存到一個文件中。
torch.load()
函數來加載模型。loaded_model = torch.load('model.pth')
input_data = torch.randn(1, 10) # 示例輸入數據
output = loaded_model(input_data)
print(output)
# 保存模型時指定設備為CPU
torch.save(model.cpu(), 'model.pth')
# 加載模型時指定設備為CPU
loaded_model = torch.load('model.pth', map_location=torch.device('cpu'))
以下是一個完整的示例,展示了如何定義、保存和加載模型:
import torch
import torch.nn as nn
# 定義模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
# 創建模型實例
model = MyModel()
# 保存模型
torch.save(model, 'model.pth')
# 加載模型
loaded_model = torch.load('model.pth', map_location=torch.device('cpu'))
# 使用加載的模型進行推理
input_data = torch.randn(1, 10)
output = loaded_model(input_data)
print(output)
通過以上步驟,你可以在Ubuntu上輕松地保存和加載PyTorch模型。