圖片顯示
pytorch 載入的數據集是元組tuple 形式,里面包括了數據及標簽(train_data,label),其中的train_data數據可以轉換為torch.Tensor形式,方便后面計算使用。
同樣給一些剛入門的同學在使用載入的數據顯示圖片的時候帶來一些難以理解的地方,這里主要是將Tensor與numpy轉換的過程,理解了這些就可以就行轉換了
CIAFA10數據集
首先載入數據集,這里做了一些數據處理,包括圖片尺寸、數據歸一化等
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torchvision.datasets as dset
import torchvision.transforms as transforms
from autoencoder import AutoEncoder
import torch.nn as nn
import torchvision
import numpy as np
dataset = dset.CIFAR10(root='../train/data', download=True,
transform=transforms.Compose([
transforms.Scale(200),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.Gray()
]))
在這里 dataset 是一個CIFAR10對象,(大家可以查看一下他的源代碼)
方式一
dataset[1] = ([torch.FloatTensor of size 1x200x200],9)
載入的第二個數據是個tensor格式,包含一個標簽 9
這里我們做的就是將torch.FloatTensor 轉換為numpy,然后顯示
b = dataset[1][0].numpy() #取數據,不取標簽
因為這里的b仍然是1*200*200的大小,所以要重新reshape一下,適合輸出圖像
plt.imshow(b.reshape(200,200),cmap = 'gray') plt.show()
然后可以顯示圖像了
方式二
利用torch的接口
img = torchvision.utils.make_grid(dataset[1][0]).numpy() plt.imshow(np.transpose(img,(1,2,0))) plt.show()
這用np.transpose 是因為plt.imshow在顯示 時候輸入的是(imgsize,imgsieze,channels),而這里得到的img是(3,200,200)的格式,所以進行了轉換,才能顯示
以上這篇pytorch 數據集圖片顯示方法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。