「PyTorch入門 3. Transform」¶

菅間修正済み 2025/05/15

【原題】TRANSFORMS

【原著】 Suraj Subramanian、Seth Juarez 、Cassie Breviu 、Dmitry Soshnikov、Ari Bornstein

【元URL】https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html

【翻訳】電通国際情報サービスISID AIトランスフォーメーションセンター 小川 雄太郎

【日付】2021年03月20日

【チュトーリアル概要】

本チュートリアルでは、PyTorchのデータ変換処理であるTransformについて解説を行います。


Transforms¶

機械学習アルゴリズムの学習に必要な、最終的な処理が施された形でデータが手に入るとは限りません。

そこでtransformを使用してデータに何らかの処理を行い、学習に適した形へと変換します。

TorchVisionの全データセットには、特徴量(データ)を変換処理するためのtransformと、ラベルを変換処理するためのtarget_transformという2つのパラメータがあります。

さらに、変換ロジックを記載した callable を受け取ります。

torchvision.transformsモジュールは、一般的に頻繁に使用される変換を提供しています。

FashionMNISTデータセットの特徴量はPIL形式の画像であり、ラベルはint型です。

訓練では、正規化された特徴量テンソルと、ワンホットエンコーディングされたラベルテンソルが必要となります。

これらのデータを作るために、ToTensor と Lambdaを使用します。

In [ ]:
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Processing...
Done!
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

ToTensor()¶

ToTensorはPIL形式の画像、もしくはNumpyのndarrayを、FloatTensorに変換します。


加えて、画像の場合にはピクセルごとの値を [0., 1.]の範囲に変換します。

Lambda Transforms(やらなくてOK)¶

Lambda transformsは、ユーザーが定義した関数を実行するPython関数です。

本チュートリアルではint型のデータを、ワンホットエンコーディングしたテンソルへと変換しています。


最初に大きさ10のゼロテンソルを作成し(10はクラス数に対応)、scatter_ を用いて、ラベルyの値のindexのみ1のワンホットエンコーディングに変換しています。

In [ ]:
target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

さらなる詳細¶

以下のページも参考ください。

  • torchvision.transforms API

以上。