在Ubuntu系統中使用PyTorch進行模型的保存與加載是一個常見的操作。以下是詳細的步驟和示例代碼:
torch.save()函數將整個模型或僅保存模型的狀態字典(state_dict)。import torch
import torch.nn as nn
import torch.optim as optim
# 定義一個簡單的神經網絡模型
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(784, 10) # 假設輸入是784維,輸出是10類
def forward(self, x):
x = x.view(x.size(0), -1) # 將輸入展平
x = self.fc(x)
return x
# 創建模型實例
model = SimpleNet()
# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 假設我們有一些訓練數據
inputs = torch.randn(64, 1, 28, 28) # 示例輸入
labels = torch.randint(0, 10, (64,)) # 示例標簽
# 訓練模型(這里省略了訓練循環)
for epoch in range(5):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 保存整個模型
torch.save(model, 'model.pth')
# 或者僅保存模型的狀態字典
torch.save(model.state_dict(), 'model_state_dict.pth')
torch.load()函數加載模型或模型的狀態字典。# 加載整個模型
loaded_model = torch.load('model.pth')
# 或者加載模型的狀態字典
model = SimpleNet() # 創建一個新的模型實例
model.load_state_dict(torch.load('model_state_dict.pth'))
# 確保模型在評估模式
model.eval()
# 使用加載的模型進行預測
with torch.no_grad():
test_inputs = torch.randn(1, 1, 28, 28) # 示例測試輸入
predictions = loaded_model(test_inputs)
print(predictions)
model = model.to('cpu')
通過以上步驟,你可以在Ubuntu系統中輕松地進行PyTorch模型的保存與加載。