# PyTorch中計算準確率、召回率和F1值的方法
## 1. 引言
在機器學習分類任務中,評估模型性能是至關重要的環節。準確率(Accuracy)、召回率(Recall)和F1值(F1-Score)是最常用的評估指標之一。本文將詳細介紹如何在PyTorch框架下實現這些指標的計算,包括理論基礎、實現方法和實際應用示例。
## 2. 分類任務評估指標基礎
### 2.1 混淆矩陣
混淆矩陣(Confusion Matrix)是理解分類指標的基礎。對于二分類問題,混淆矩陣如下:
| | 預測為正類 | 預測為負類 |
|----------------|------------|------------|
| **實際為正類** | TP (真正例) | FN (假反例) |
| **實際為負類** | FP (假正例) | TN (真反例) |
### 2.2 指標定義
- **準確率(Accuracy)**: 正確預測的樣本比例
$$Accuracy = \frac{TP + TN}{TP + TN + FP + FN}$$
- **召回率(Recall)**: 正類樣本中被正確預測的比例
$$Recall = \frac{TP}{TP + FN}$$
- **精確率(Precision)**: 預測為正類的樣本中實際為正類的比例
$$Precision = \frac{TP}{TP + FP}$$
- **F1值(F1-Score)**: 精確率和召回率的調和平均
$$F1 = 2 \times \frac{Precision \times Recall}{Precision + Recall}$$
## 3. PyTorch實現基礎
在PyTorch中計算這些指標,我們需要處理模型的輸出和真實標簽。通常分類模型的輸出是每個類別的概率(logits),我們需要先將其轉換為預測類別。
### 3.1 獲取預測結果
```python
import torch
# 假設模型輸出為logits (batch_size × num_classes)
logits = torch.randn(4, 3) # 4個樣本,3分類問題
# 獲取預測類別
_, preds = torch.max(logits, dim=1) # 獲取每行最大值的索引
真實標簽通常是類別索引形式:
targets = torch.tensor([0, 2, 1, 1]) # 真實標簽
def accuracy_binary(preds, targets):
correct = (preds == targets).float()
acc = correct.mean()
return acc
def precision_recall_binary(preds, targets, positive_class=1):
true_positives = ((preds == positive_class) & (targets == positive_class)).sum().float()
predicted_positives = (preds == positive_class).sum().float()
actual_positives = (targets == positive_class).sum().float()
precision = true_positives / (predicted_positives + 1e-8) # 避免除以0
recall = true_positives / (actual_positives + 1e-8)
return precision, recall
def f1_score_binary(preds, targets, positive_class=1):
precision, recall = precision_recall_binary(preds, targets, positive_class)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
return f1
對于多分類問題,我們有兩種計算方式: - 宏平均(Macro-average): 對每個類別的指標單獨計算后平均 - 微平均(Micro-average): 將所有類別的TP,FP,FN等先求和再計算
def macro_precision_recall_f1(preds, targets, num_classes):
# 初始化統計量
class_stats = []
for class_idx in range(num_classes):
# 計算當前類別的TP, FP, FN
tp = ((preds == class_idx) & (targets == class_idx)).sum().float()
fp = ((preds == class_idx) & (targets != class_idx)).sum().float()
fn = ((preds != class_idx) & (targets == class_idx)).sum().float()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
class_stats.append((precision, recall, f1))
# 計算宏平均
macro_precision = torch.mean(torch.tensor([s[0] for s in class_stats]))
macro_recall = torch.mean(torch.tensor([s[1] for s in class_stats]))
macro_f1 = torch.mean(torch.tensor([s[2] for s in class_stats]))
return macro_precision, macro_recall, macro_f1
def micro_precision_recall_f1(preds, targets, num_classes):
# 初始化全局統計量
total_tp = 0
total_fp = 0
total_fn = 0
for class_idx in range(num_classes):
tp = ((preds == class_idx) & (targets == class_idx)).sum().float()
fp = ((preds == class_idx) & (targets != class_idx)).sum().float()
fn = ((preds != class_idx) & (targets == class_idx)).sum().float()
total_tp += tp
total_fp += fp
total_fn += fn
micro_precision = total_tp / (total_tp + total_fp + 1e-8)
micro_recall = total_tp / (total_tp + total_fn + 1e-8)
micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-8)
return micro_precision, micro_recall, micro_f1
PyTorch提供了一些內置函數可以簡化計算:
def accuracy_torch(preds, targets):
return (preds == targets).float().mean()
from sklearn.metrics import confusion_matrix
import numpy as np
def get_confusion_matrix(preds, targets, num_classes):
preds_np = preds.cpu().numpy()
targets_np = targets.cpu().numpy()
return confusion_matrix(targets_np, preds_np, labels=list(range(num_classes)))
def train_epoch(model, dataloader, criterion, optimizer, device, num_classes):
model.train()
total_loss = 0
total_acc = 0
total_precision = 0
total_recall = 0
total_f1 = 0
num_batches = len(dataloader)
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
# 前向傳播
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
# 計算損失
loss = criterion(outputs, targets)
# 反向傳播和優化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 計算指標
batch_acc = accuracy_torch(preds, targets)
batch_precision, batch_recall, batch_f1 = macro_precision_recall_f1(preds, targets, num_classes)
# 累計統計
total_loss += loss.item()
total_acc += batch_acc.item()
total_precision += batch_precision.item()
total_recall += batch_recall.item()
total_f1 += batch_f1.item()
# 計算平均指標
avg_loss = total_loss / num_batches
avg_acc = total_acc / num_batches
avg_precision = total_precision / num_batches
avg_recall = total_recall / num_batches
avg_f1 = total_f1 / num_batches
return avg_loss, avg_acc, avg_precision, avg_recall, avg_f1
def evaluate(model, dataloader, criterion, device, num_classes):
model.eval()
total_loss = 0
total_acc = 0
total_precision = 0
total_recall = 0
total_f1 = 0
num_batches = len(dataloader)
with torch.no_grad():
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
# 前向傳播
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
# 計算損失
loss = criterion(outputs, targets)
# 計算指標
batch_acc = accuracy_torch(preds, targets)
batch_precision, batch_recall, batch_f1 = macro_precision_recall_f1(preds, targets, num_classes)
# 累計統計
total_loss += loss.item()
total_acc += batch_acc.item()
total_precision += batch_precision.item()
total_recall += batch_recall.item()
total_f1 += batch_f1.item()
# 計算平均指標
avg_loss = total_loss / num_batches
avg_acc = total_acc / num_batches
avg_precision = total_precision / num_batches
avg_recall = total_recall / num_batches
avg_f1 = total_f1 / num_batches
return avg_loss, avg_acc, avg_precision, avg_recall, avg_f1
對于多標簽分類(一個樣本可以屬于多個類別),指標計算有所不同:
def multilabel_metrics(preds, targets, threshold=0.5):
# 假設preds是sigmoid后的概率
preds_binary = (preds > threshold).float()
# 計算TP, FP, FN
tp = (preds_binary * targets).sum(dim=0)
fp = (preds_binary * (1 - targets)).sum(dim=0)
fn = ((1 - preds_binary) * targets).sum(dim=0)
# 計算各指標
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
# 微平均
micro_precision = tp.sum() / (tp.sum() + fp.sum() + 1e-8)
micro_recall = tp.sum() / (tp.sum() + fn.sum() + 1e-8)
micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-8)
return {
'precision': precision,
'recall': recall,
'f1': f1,
'micro_precision': micro_precision,
'micro_recall': micro_recall,
'micro_f1': micro_f1
}
當數據集中各類別樣本數量差異很大時,準確率可能不是最佳指標。解決方案: - 使用加權指標 - 關注F1值而非準確率 - 使用過采樣/欠采樣技術
對于概率輸出,如何選擇最佳閾值: - 使用ROC曲線尋找最佳平衡點 - 根據業務需求調整(如醫療診斷可能更重視召回率)
訓練過程中指標波動大可能原因: - 學習率設置不當 - 批量大小太小 - 數據預處理不一致
本文詳細介紹了在PyTorch中計算準確率、召回率和F1值的方法,包括: - 二分類和多分類場景 - 宏平均和微平均策略 - 訓練循環中的集成方法 - 多標簽分類的特殊處理 - 性能優化和常見問題解決方案
正確計算和解讀這些指標對于模型開發和評估至關重要。希望本文能為您的PyTorch項目提供有價值的參考。
”`
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。