溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch怎么加載自己的圖片數據集

發布時間:2022-06-13 10:33:28 來源:億速云 閱讀:278 作者:iii 欄目:開發技術

本文小編為大家詳細介紹“pytorch怎么加載自己的圖片數據集”,內容詳細,步驟清晰,細節處理妥當,希望這篇“pytorch怎么加載自己的圖片數據集”文章能幫助大家解決疑惑,下面跟著小編的思路慢慢深入,一起來學習新知識吧。

1.ImageFolder 適合于分類數據集,并且每一個類別的圖片在同一個文件夾, ImageFolder加載的數據集, 訓練數據為文件件下的圖片, 訓練標簽是對應的文件夾, 每個文件夾為一個類別

導入ImageFolder()包
from torchvision.datasets import ImageFolder

pytorch怎么加載自己的圖片數據集

在Flower_Orig_dataset文件夾下有flower_orig 和 sunflower這兩個文件夾, 這兩個文件夾下放著同一個類別的圖片。 使用 ImageFolder 加載的圖片, 就會返回圖片信息和對應的label信息, 但是label信息是根據文件夾給出的, 如flower_orig就是標簽0, sunflower就是標簽1。

ImageFolder 加載數據集

1. 導入包和設置transform

import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
 
transforms = transforms.Compose([
    transforms.Resize(256),    # 將圖片短邊縮放至256,長寬比保持不變:
    transforms.CenterCrop(224),   #將圖片從中心切剪成3*224*224大小的圖片
    transforms.ToTensor()          #把圖片進行歸一化,并把數據轉換成Tensor類型
])

2.加載數據集: 將分類圖片的父目錄作為路徑傳遞給ImageFolder(), 并傳入transform。這樣就有了要加載的數據集, 之后就可以使用DataLoader加載數據, 并構建網絡訓練。

path = r'D:\數據集\Flower_Orig_dataset'
data_train = datasets.ImageFolder(path, transform=transforms)
data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
for i, data in enumerate(data_loader):
    images, labels = data
    print(images.shape)
    print(labels.shape)
    break

使用pytorch提供的Dataset類創建自己的數據集。

具體步驟:

1.  首先要有一個txt文件, 這個文件格式是: 圖片路徑  標簽.  這樣的格式, 所以使用os庫, 遍歷自己的圖片名, 并把標簽和圖片路徑寫入txt文件。

2. 有了這個txt文件, 我們就可以在類里面構造我們的數據集.

2.1    把圖片路徑和圖片標簽分割開, 有兩個列表, 一個列表是圖片路徑名, 一個列表是標簽號, 有一點就是第 i 個圖片列表和 第 i 個標簽是對應的

3. 重寫__len__方法  和  __getitem__方法

3.1 getitem方法中, 獲得對應的圖片路徑,并用PIL庫讀取文件把圖片transfrom后, 在getitem函數中返回讀取的圖片和標簽即可

4.就可以構建數據集實例和加載數據集.

 定義一個用來生成[ 圖片路徑 標簽] 這樣的txt文件函數

def make_txt(root, file_name, label):
    path = os.path.join(root, file_name)
    data = os.listdir(path)
    f = open(path+'\\'+'f.txt', 'w')
    for line in data:
        f.write(line+' '+str(label)+'\n')
    f.close()
#調用函數生成兩個文件夾下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)

將連個txt文件合并成一個txt文件,表示數據集所有的圖片和標簽

def link_txt(file1, file2):
    txt_list = []
    path = r'D:\數據集\Flower_Orig_dataset\data.txt'
 
    f = open(path, 'a')
 
    f1 = open(file1, 'r')
    data1 = f1.readlines()
    for line in data1:
        txt_list.append(line)
 
    f2 = open(file2, 'r')
    data2 = f2.readlines()
    for line in data2:
        txt_list.append(line)
 
    for line in txt_list:
        f.write(line)
 
    f.close()
    f1.close()
    f2.close()
 
#調用函數, 將兩個文件夾下的txt文件合并
file1 = r'D:\數據集\Flower_Orig_dataset\flower_orig\f.txt'
file2 = r'D:\數據集\Flower_Orig_dataset\sunflower\f.txt'
link_txt(file1=file1, file2=file2)

現在我們已經有了我們制作數據集所需要的txt文件, 接下來要做的即使繼承Dataset類, 來構建自己的數據集 , 別忘了前面說的 構建數據集步驟, 在__getitem__函數中, 需要拿到圖片路徑和標簽, 并且用PIL庫方法讀取圖片,對圖片進行transform轉換后,返回圖片信息和標簽信息

Dataset加載數據集

我們讀取圖片的根目錄, 在根目錄下有所有圖片的txt文件, 拿到txt文件后, 先讀取txt文件, 之后遍歷txt文件中的每一行, 首先去除掉尾部的換行符, 在以空格切分,前半部分是圖片名稱, 后半部分是圖片標簽, 當圖片名稱和根目錄結合,就得到了我們的圖片路徑   
class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = img_path
 
        self.txt_root = self.root + 'data.txt'
        f = open(self.txt_root, 'r')
        data = f.readlines()
 
        imgs = []
        labels = []
        for line in data:
            line = line.rstrip()
            word = line.split()
            imgs.append(os.path.join(self.root, word[1], word[0]))
 
            labels.append(word[1])
        self.img = imgs
        self.label = labels
        self.transform = transform
 
    def __len__(self):
        return len(self.label)
 
    def __getitem__(self, item):
        img = self.img[item]
        label = self.label[item]
 
        img = Image.open(img).convert('RGB')
 
        #此時img是PIL.Image類型   label是str類型
 
        if transforms is not None:
            img = self.transform(img)
 
        label = np.array(label).astype(np.int64)
        label = torch.from_numpy(label)
        
        return img, label

 加載我們的數據集:

path = r'D:\數據集\Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)
 
data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

接下來我們就可以構建我們的網絡架構:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.maxpool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16,5,3)
 
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(55*55*5, 1200)
        self.fc2 = nn.Linear(1200,64)
        self.fc3 = nn.Linear(64,2)
 
    def forward(self,x):
        x = self.maxpool(self.relu(self.conv1(x)))    #113
        x = self.maxpool(self.relu(self.conv2(x)))    #55
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
 
        return num_features

 訓練我們的網絡:

model = Net()
 
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(data_loader):
        images, label = data
 
        out = model(images)
 
        loss = criterion(out, label)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if(i+1)%10 == 0:
            print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))
            running_loss = 0.0
 
print('finished train')

 保存網絡模型(這里不止是保存參數,還保存了網絡結構)

#保存模型
torch.save(net, 'model_name.pth')   #保存的是模型, 不止是w和b權重值
 
# 讀取模型
model = torch.load('model_name.pth')

讀到這里,這篇“pytorch怎么加載自己的圖片數據集”文章已經介紹完畢,想要掌握這篇文章的知識點還需要大家自己動手實踐使用過才能領會,如果想了解更多相關內容的文章,歡迎關注億速云行業資訊頻道。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

亚洲午夜精品一区二区_中文无码日韩欧免_久久香蕉精品视频_欧美主播一区二区三区美女