在CentOS上評估PyTorch模型的效果,通常涉及以下幾個步驟:
準備數據集:
加載模型:
.pth或.pt文件。設置模型為評估模式:
model.eval()來實現。處理數據:
進行預測:
計算評估指標:
分析結果:
下面是一個簡單的示例代碼,展示了如何在CentOS上使用PyTorch評估一個分類模型的效果:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from your_model import YourModel # 假設你的模型定義在這個文件中
from sklearn.metrics import accuracy_score, classification_report
# 設置設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加載評估數據集
transform = transforms.Compose([
# 定義你的數據預處理操作
])
test_dataset = datasets.YourDataset(root='path/to/your/test/data', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# 加載模型
model = YourModel() # 假設你的模型類定義在這個文件中
model.load_state_dict(torch.load('path/to/your/model.pth'))
model.to(device)
model.eval()
# 進行預測并計算評估指標
predictions = []
true_labels = []
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
predictions.extend(predicted.cpu().numpy())
true_labels.extend(labels.cpu().numpy())
# 計算準確率
accuracy = accuracy_score(true_labels, predictions)
print(f'Accuracy: {accuracy:.2f}')
# 打印分類報告
print(classification_report(true_labels, predictions))
請根據你的具體情況調整上述代碼,例如數據集路徑、模型定義、預處理操作等。