這篇文章主要介紹Pytorch如何實現計算分類器準確率,文中介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們一定要看完!
分類器平均準確率計算:
correct = torch.zeros(1).squeeze().cuda() total = torch.zeros(1).squeeze().cuda() for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) correct += (prediction == labels).sum().float() total += len(labels) acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())
分類器各個子類準確率計算:
correct = list(0. for i in range(args.class_num)) total = list(0. for i in range(args.class_num)) for i, (images, labels) in enumerate(train_loader): images = Variable(images.cuda()) labels = Variable(labels.cuda()) output = model(images) prediction = torch.argmax(output, 1) res = prediction == labels for label_idx in range(len(labels)): label_single = label[label_idx] correct[label_single] += res[label_idx].item() total[label_single] += 1 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total)) for acc_idx in range(len(train_class_correct)): try: acc = correct[acc_idx]/total[acc_idx] except: acc = 0 finally: acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)
1.PyTorch是相當簡潔且高效快速的框架;2.設計追求最少的封裝;3.設計符合人類思維,它讓用戶盡可能地專注于實現自己的想法;4.與google的Tensorflow類似,FAIR的支持足以確保PyTorch獲得持續的開發更新;5.PyTorch作者親自維護的論壇 供用戶交流和求教問題6.入門簡單
以上是“Pytorch如何實現計算分類器準確率”這篇文章的所有內容,感謝各位的閱讀!希望分享的內容對大家有幫助,更多相關知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。