「PyTorch入門 2. データセットとデータローダー」¶

菅間修正済み 2025/05/15

菅間修正済み2 2025/05/18

【原題】DATASETS & DATALOADERS

【原著】 Suraj Subramanian、Seth Juarez 、Cassie Breviu 、Dmitry Soshnikov、Ari Bornstein

【元URL】https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

【翻訳】電通国際情報サービスISID AIトランスフォーメーションセンター 小川 雄太郎

【日付】2021年03月20日

【チュトーリアル概要】

本チュートリアルでは、PyTorchでサンプルデータを扱う基本要素である、DatasetとDataLoaderについて解説を行います。


Datasets & Dataloaders¶

サンプルデータを処理するコードは複雑であり、メンテナンスも大変です。

データセットに関するコードは可読性とモジュール性を考慮し、モデルの訓練コードから切り離すのが理想的です。

PyTorchにはデータセットを扱う基本要素が2つあります。

torch.utils.data.DataLoaderと、torch.utils.data.Datasetです。

これらを活用することであらかじめ用意されたデータセットや自分で作成したデータを使用することができます。

Datasetにはサンプルとそれに対応するラベルが格納され、DataLoaderにはイテレート処理が可能なデータが格納されます。

DataLoaderは、サンプルを簡単に利用できるように、Datasetをイテレート処理可能なものへとラップします。

PyTorch domain librariesでは、多くのデータセット(FashionMNISTなど)を提供しています。

これらは torch.utils.data.Dataset を継承しており、各ドメインのデータに対して必要な、固有の機能を実装しています。

また、皆様が実装したモデルのベンチマークにも使うことができます。

さらなる詳細は以下をご覧ください。

  • Image Datasets

  • Text Datasets

  • Audio Datasets


Datasetの読み込み¶

TorchVisionからFashion-MNISTをロードする例を紹介します。

Fashion-MNISTは、60,000個の訓練データと10,000個のテストデータから構成された、Zalandoの記事画像のデータセットです。

各サンプルは、28×28のグレースケール画像と、10クラスのうちの1つのラベルから構成されています。

FashionMNIST Datasetを読み込む際には、以下のパラメータを使用します。

  • root :訓練/テストデータが格納されているパスを指定
  • train :訓練データまたはテストデータセットを指定
  • download=True:root にデータが存在しない場合は、インターネットからデータをダウンロードを指定
  • transform と target_transform:特徴量とラベルの変換を指定
In [ ]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Processing...
Done!
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

データセットの反復処理と可視化¶

Datasetの特定indexを指定する際には、リスト操作と同様に、training_data[index]と記載します。

matplotlibを使用し、訓練データのいくつかのサンプルを可視化しましょう。

In [ ]:
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.savefig('fashionmnist_samples.png')
plt.close()

カスタムデータセットの作成(やらなくてもOK)¶

自分でカスタムしたDatasetクラスを作る際には、 __init__、__len__、__getitem__の3つの関数は必ず実装する必要があります。

これらの関数の実装を確認します。

FashionMNISTの画像データをimg_dirフォルダに、ラベルはCSVファイルannotations_fileとして保存します。

これから、各関数がどのような操作を行っているのか詳細に確認します。

In [ ]:
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        sample = {"image": image, "label": label}
        return sample

__init__

__init__関数はDatasetオブジェクトがインスタンス化される際に1度だけ実行されます。

画像、アノテーションファイル、そしてそれらに対する変換処理(transforms:次のセクションで解説します)の初期設定を行います。


ここで、labels.csvファイルは以下のような内容となっています。

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
In [ ]:
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__len__

__len__関数はデータセットのサンプル数を返す関数です。

In [ ]:
def __len__(self):
    return len(self.img_labels)

__getitem__

__getitem__関数は指定されたidxに対応するサンプルをデータセットから読み込んで返す関数です。

indexに基づいて、画像ファイルのパスを特定し、read_imageを使用して画像ファイルをテンソルに変換します。

加えて、self.img_labelsから対応するラベルを抜き出します。

そしてtransform functionsを必要に応じて画像およびラベルに適用し、最終的にPythonの辞書型変数で画像とラベルを返します。

In [ ]:
def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    sample = {"image": image, "label": label}
    return sample

DataLoaderの使用方法¶

Datasetを使用することで1つのサンプルの、データとラベルを取り出せます。

ですが、モデルの訓練時にはミニバッチ("minibatches")単位でデータを扱いたく、また各epochでデータはシャッフルされて欲しいです(訓練データへの過学習を防ぐ目的です)。

加えて、Pythonの multiprocessingを使用し、複数データの取り出しを高速化したいところです。

DataLoaderは上記に示した複雑な処理を簡単に実行できるようにしてくれるAPIとなります。

In [ ]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoaderを用いた繰り返し処理¶

データセットを Dataloader に読み込ませ、必要に応じてデータセットを反復処理することができます。

以下の各反復処理ではtrain_features と train_labelsのミニバッチを返します(それぞれ、64個のサンプルで構成されるミニバッチです)。

今回shuffle=Trueと指定しているので、データセットのデータを全て取り出したら、データの順番はシャッフルされます。


さらなるデータ読み込み操作の詳細については、こちらのSamplersをご覧ください。

In [ ]:
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.savefig('train_first_label.png')
plt.close()
print(f"Label: {label}")
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 2

さらなる詳細¶

以下のページも参考ください。

  • torch.utils.data API

以上。