溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch 數據集圖片顯示方法

發布時間:2020-09-09 00:54:04 來源:腳本之家 閱讀:264 作者:zzw小凡 欄目:開發技術

圖片顯示

pytorch 載入的數據集是元組tuple 形式,里面包括了數據及標簽(train_data,label),其中的train_data數據可以轉換為torch.Tensor形式,方便后面計算使用。

同樣給一些剛入門的同學在使用載入的數據顯示圖片的時候帶來一些難以理解的地方,這里主要是將Tensor與numpy轉換的過程,理解了這些就可以就行轉換了

CIAFA10數據集

首先載入數據集,這里做了一些數據處理,包括圖片尺寸、數據歸一化等

import torch
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
import torchvision.datasets as dset
import torchvision.transforms as transforms
from autoencoder import AutoEncoder
import torch.nn as nn
import torchvision
import numpy as np
dataset = dset.CIFAR10(root='../train/data', download=True, 
    transform=transforms.Compose([
    transforms.Scale(200),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Gray()
    ]))

在這里 dataset 是一個CIFAR10對象,(大家可以查看一下他的源代碼)

方式一

dataset[1] = ([torch.FloatTensor of size 1x200x200],9)

載入的第二個數據是個tensor格式,包含一個標簽 9

這里我們做的就是將torch.FloatTensor 轉換為numpy,然后顯示

b = dataset[1][0].numpy()
#取數據,不取標簽

因為這里的b仍然是1*200*200的大小,所以要重新reshape一下,適合輸出圖像

plt.imshow(b.reshape(200,200),cmap = 'gray')
plt.show()

然后可以顯示圖像了

方式二

利用torch的接口

img = torchvision.utils.make_grid(dataset[1][0]).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

這用np.transpose 是因為plt.imshow在顯示 時候輸入的是(imgsize,imgsieze,channels),而這里得到的img是(3,200,200)的格式,所以進行了轉換,才能顯示

以上這篇pytorch 數據集圖片顯示方法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女