菅間修正済み 2025/05/15
菅間修正済み2 2025/05/18
【原題】DATASETS & DATALOADERS
【原著】 Suraj Subramanian、Seth Juarez 、Cassie Breviu 、Dmitry Soshnikov、Ari Bornstein
【元URL】https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
【翻訳】電通国際情報サービスISID AIトランスフォーメーションセンター 小川 雄太郎
【日付】2021年03月20日
【チュトーリアル概要】
本チュートリアルでは、PyTorchでサンプルデータを扱う基本要素である、DatasetとDataLoaderについて解説を行います。
サンプルデータを処理するコードは複雑であり、メンテナンスも大変です。
データセットに関するコードは可読性とモジュール性を考慮し、モデルの訓練コードから切り離すのが理想的です。
PyTorchにはデータセットを扱う基本要素が2つあります。
torch.utils.data.DataLoaderと、torch.utils.data.Datasetです。
これらを活用することであらかじめ用意されたデータセットや自分で作成したデータを使用することができます。
Datasetにはサンプルとそれに対応するラベルが格納され、DataLoaderにはイテレート処理が可能なデータが格納されます。
DataLoaderは、サンプルを簡単に利用できるように、Datasetをイテレート処理可能なものへとラップします。
PyTorch domain librariesでは、多くのデータセット(FashionMNISTなど)を提供しています。
これらは torch.utils.data.Dataset を継承しており、各ドメインのデータに対して必要な、固有の機能を実装しています。
また、皆様が実装したモデルのベンチマークにも使うことができます。
さらなる詳細は以下をご覧ください。
TorchVisionからFashion-MNISTをロードする例を紹介します。
Fashion-MNISTは、60,000個の訓練データと10,000個のテストデータから構成された、Zalandoの記事画像のデータセットです。
各サンプルは、28×28のグレースケール画像と、10クラスのうちの1つのラベルから構成されています。
FashionMNIST Datasetを読み込む際には、以下のパラメータを使用します。
root :訓練/テストデータが格納されているパスを指定train :訓練データまたはテストデータセットを指定download=True:root にデータが存在しない場合は、インターネットからデータをダウンロードを指定transform と target_transform:特徴量とラベルの変換を指定import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw Processing... Done!
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Datasetの特定indexを指定する際には、リスト操作と同様に、training_data[index]と記載します。
matplotlibを使用し、訓練データのいくつかのサンプルを可視化しましょう。
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.savefig('fashionmnist_samples.png')
plt.close()
自分でカスタムしたDatasetクラスを作る際には、 __init__、__len__、__getitem__の3つの関数は必ず実装する必要があります。
これらの関数の実装を確認します。
FashionMNISTの画像データをimg_dirフォルダに、ラベルはCSVファイルannotations_fileとして保存します。
これから、各関数がどのような操作を行っているのか詳細に確認します。
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
sample = {"image": image, "label": label}
return sample
__init__
__init__関数はDatasetオブジェクトがインスタンス化される際に1度だけ実行されます。
画像、アノテーションファイル、そしてそれらに対する変換処理(transforms:次のセクションで解説します)の初期設定を行います。
ここで、labels.csvファイルは以下のような内容となっています。
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
__len__
__len__関数はデータセットのサンプル数を返す関数です。
def __len__(self):
return len(self.img_labels)
__getitem__
__getitem__関数は指定されたidxに対応するサンプルをデータセットから読み込んで返す関数です。
indexに基づいて、画像ファイルのパスを特定し、read_imageを使用して画像ファイルをテンソルに変換します。
加えて、self.img_labelsから対応するラベルを抜き出します。
そしてtransform functionsを必要に応じて画像およびラベルに適用し、最終的にPythonの辞書型変数で画像とラベルを返します。
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
sample = {"image": image, "label": label}
return sample
Datasetを使用することで1つのサンプルの、データとラベルを取り出せます。
ですが、モデルの訓練時にはミニバッチ("minibatches")単位でデータを扱いたく、また各epochでデータはシャッフルされて欲しいです(訓練データへの過学習を防ぐ目的です)。
加えて、Pythonの multiprocessingを使用し、複数データの取り出しを高速化したいところです。
DataLoaderは上記に示した複雑な処理を簡単に実行できるようにしてくれるAPIとなります。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
データセットを Dataloader に読み込ませ、必要に応じてデータセットを反復処理することができます。
以下の各反復処理ではtrain_features と train_labelsのミニバッチを返します(それぞれ、64個のサンプルで構成されるミニバッチです)。
今回shuffle=Trueと指定しているので、データセットのデータを全て取り出したら、データの順番はシャッフルされます。
さらなるデータ読み込み操作の詳細については、こちらのSamplersをご覧ください。
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.savefig('train_first_label.png')
plt.close()
print(f"Label: {label}")
Feature batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64])
Label: 2
以上。