在Linux系統中,使用PyTorch保存模型的主要方法是使用torch.save()
函數。以下是一個簡單的示例:
import torch
import torchvision.models as models
# 創建一個預訓練的ResNet18模型
model = models.resnet18(pretrained=True)
# 保存整個模型
torch.save(model, 'resnet18_model.pth')
# 如果只想保存模型的狀態字典(即權重和偏置),可以使用以下方法:
torch.save(model.state_dict(), 'resnet18_state_dict.pth')
在這個例子中,我們首先導入了torch
庫和torchvision.models
模塊。然后,我們創建了一個預訓練的ResNet18模型。接下來,我們使用torch.save()
函數將整個模型保存到一個名為resnet18_model.pth
的文件中。此外,我們還可以選擇僅保存模型的狀態字典(權重和偏置),而不是整個模型。這可以通過調用model.state_dict()
方法并將其傳遞給torch.save()
函數來實現。
要加載保存的模型,可以使用以下代碼:
# 加載整個模型
model_loaded = torch.load('resnet18_model.pth')
# 加載模型的狀態字典
model = models.resnet18(pretrained=False) # 創建一個新的ResNet18模型實例
model.load_state_dict(torch.load('resnet18_state_dict.pth'))
在這個例子中,我們首先使用torch.load()
函數加載保存的模型。然后,我們創建了一個新的ResNet18模型實例,并使用load_state_dict()
方法將之前保存的狀態字典加載到新模型中。注意,在加載狀態字典時,我們需要確保新模型的架構與保存的模型架構相同。