菅間修正済み 2025/05/15
【原題】TRANSFORMS
【原著】 Suraj Subramanian、Seth Juarez 、Cassie Breviu 、Dmitry Soshnikov、Ari Bornstein
【元URL】https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html
【翻訳】電通国際情報サービスISID AIトランスフォーメーションセンター 小川 雄太郎
【日付】2021年03月20日
【チュトーリアル概要】
本チュートリアルでは、PyTorchのデータ変換処理であるTransformについて解説を行います。
機械学習アルゴリズムの学習に必要な、最終的な処理が施された形でデータが手に入るとは限りません。
そこでtransformを使用してデータに何らかの処理を行い、学習に適した形へと変換します。
TorchVisionの全データセットには、特徴量(データ)を変換処理するためのtransformと、ラベルを変換処理するためのtarget_transformという2つのパラメータがあります。
さらに、変換ロジックを記載した callable を受け取ります。
torchvision.transformsモジュールは、一般的に頻繁に使用される変換を提供しています。
FashionMNISTデータセットの特徴量はPIL形式の画像であり、ラベルはint型です。
訓練では、正規化された特徴量テンソルと、ワンホットエンコーディングされたラベルテンソルが必要となります。
これらのデータを作るために、ToTensor と Lambdaを使用します。
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw Processing... Done!
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
以上。