菅間修正済み 2025/05/15
【原題】OPTIMIZING MODEL PARAMETERS
【原著】 Suraj Subramanian、Seth Juarez 、Cassie Breviu 、Dmitry Soshnikov、Ari Bornstein
【元URL】https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
【翻訳】電通国際情報サービスISID AIトランスフォーメーションセンター 小川 雄太郎
【日付】2021年03月20日
【チュトーリアル概要】
本チュートリアルでは、オプティマイザー(Optimizer)を使用した、パラメータの最適化(≒学習)について解説を行います。
モデルとデータを用意できたので続いてはモデルを訓練、検証することで、データに対してモデルのパラメータを最適化し、テストを行います。
モデルの訓練は反復的なプロセスとなります。
各イテレーション(エポックと呼ばれます)で、モデルは出力を計算し、損失を求めます。そして各パラメータについて損失に対する偏微分の値を求めます。
その後、勾配降下法に基づいてパラメータを最適化します。
この最適化プロセスの流れについては、以下の動画も参考にご覧ください。
入門シリーズの「2. データセットとデータローダー」および、「4. モデル構築」からコードを再利用します。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw Processing... Done!
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
今回は、訓練用のハイパーパラメータとして以下の値を使用します。
learning_rate = 1e-3
batch_size = 64
epochs = 5
ハイパーパラメータを設定後、訓練で最適化のループを回すことで、モデルを最適化します。
最適化ループの1回のイテレーションは、エポック(epoch)と呼ばれます。
各エポックでは2種類のループから構成されます。
訓練ループ:データセットに対して訓練を実行し、パラメータを収束させます
検証 / テストループ:テストデータセットでモデルを評価し、性能が向上しているか確認します
訓練ループ内で使用される概念について、簡単に把握しておきましょう。
本チュートリアルの最後には、最適化ループの完全な実装を紹介します。
損失関数:Loss Function
データが与えられても、訓練されていないネットワークは正しい答えを出力しない可能性があります。
損失関数はモデルが推論した結果と、実際の正解との誤差の大きさを測定する関数です。訓練ではこの損失関数の値を小さくしていきます。
損失を計算するためには、入力データに対するモデルの推論結果を求め、その値と正解のラベルとの違いを比較します。
一般的な損失関数としては、回帰タスクではnn.MSELoss(Mean Square Error)、分類タスクではnn.NLLLoss(Negative Log Likelihood) が使用されます。
nn.CrossEntropyLossは、nn.LogSoftmax と nn.NLLLossを結合した損失関数となります。
モデルが出力するlogit値をnn.CrossEntropyLossに与えて正規化し、予測誤差を求めます。
# loss functionの初期化、定義
loss_fn = nn.CrossEntropyLoss()
最適化器:Optimizer
最適化は各訓練ステップにおいてモデルの誤差を小さくなるように、モデルパラメータを調整するプロセスです。
最適化アルゴリズム:Optimization algorithms
最適化アルゴリズムは、最適化プロセスの具体的な手続きです(本チュートリアルでは確率的勾配降下法:Stochastic Gradient Descentを使用します)。
最適化のロジックは全てoptimizerオブジェクト内に隠ぺいされます。
今回はSGD optimizerを使用します。ただし、最適化関数にはADAMやRMSPropなど、様々な種類があります。
詳細については、こちらを参照ください。
訓練したいモデルパラメータをoptimizerに登録し、合わせて学習率をハイパーパラメータとして渡すことで初期化を行います。
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
訓練ループ内で、最適化(optimization)は3つのステップから構成されます。
[1] optimizer.zero_grad()を実行し、モデルパラメータの勾配をリセットします。
勾配の計算は蓄積されていくので、毎イテレーション、明示的にリセットします。
[2] 続いて、loss.backwards()を実行し、バックプロパゲーションを実行します。
PyTorchは損失に対する各パラメータの偏微分の値(勾配)を求めます。
[3] 最後に、optimizer.step()を実行し、各パラメータの勾配を使用してパラメータの値を調整します。
最適化を実行するコードをループするtrain_loopと、テストデータに対してモデルの性能を評価するtest_loopを定義します。
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# 予測と損失の計算
pred = model(X)
loss = loss_fn(pred, y)
# バックプロパゲーション
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
損失関数とoptimizerを初期化し、それを train_loop と test_loop に渡します。
以下の実装例において、モデルの性能を向上させるために、epoch数は自由に変えてみてください。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1 ------------------------------- loss: 2.304568 [ 0/60000] loss: 2.301557 [ 6400/60000] loss: 2.291586 [12800/60000] loss: 2.284993 [19200/60000] loss: 2.272847 [25600/60000] loss: 2.264585 [32000/60000] loss: 2.278295 [38400/60000] loss: 2.260727 [44800/60000] loss: 2.262039 [51200/60000] loss: 2.237576 [57600/60000] Test Error: Accuracy: 34.4%, Avg loss: 0.035269 Epoch 2 ------------------------------- loss: 2.266718 [ 0/60000] loss: 2.257207 [ 6400/60000] loss: 2.242595 [12800/60000] loss: 2.240602 [19200/60000] loss: 2.190395 [25600/60000] loss: 2.192866 [32000/60000] loss: 2.221788 [38400/60000] loss: 2.184941 [44800/60000] loss: 2.187555 [51200/60000] loss: 2.159560 [57600/60000] Test Error: Accuracy: 36.9%, Avg loss: 0.033832 Epoch 3 ------------------------------- loss: 2.193260 [ 0/60000] loss: 2.161435 [ 6400/60000] loss: 2.118552 [12800/60000] loss: 2.133410 [19200/60000] loss: 2.033287 [25600/60000] loss: 2.051484 [32000/60000] loss: 2.107544 [38400/60000] loss: 2.034562 [44800/60000] loss: 2.046951 [51200/60000] loss: 2.031325 [57600/60000] Test Error: Accuracy: 38.3%, Avg loss: 0.031261 Epoch 4 ------------------------------- loss: 2.056879 [ 0/60000] loss: 1.990355 [ 6400/60000] loss: 1.911940 [12800/60000] loss: 1.963883 [19200/60000] loss: 1.805866 [25600/60000] loss: 1.858948 [32000/60000] loss: 1.958336 [38400/60000] loss: 1.850322 [44800/60000] loss: 1.877860 [51200/60000] loss: 1.900923 [57600/60000] Test Error: Accuracy: 46.7%, Avg loss: 0.028436 Epoch 5 ------------------------------- loss: 1.902853 [ 0/60000] loss: 1.806532 [ 6400/60000] loss: 1.703594 [12800/60000] loss: 1.804330 [19200/60000] loss: 1.610781 [25600/60000] loss: 1.699996 [32000/60000] loss: 1.834617 [38400/60000] loss: 1.711799 [44800/60000] loss: 1.740380 [51200/60000] loss: 1.810726 [57600/60000] Test Error: Accuracy: 53.5%, Avg loss: 0.026318 Epoch 6 ------------------------------- loss: 1.773794 [ 0/60000] loss: 1.668828 [ 6400/60000] loss: 1.544228 [12800/60000] loss: 1.693585 [19200/60000] loss: 1.476874 [25600/60000] loss: 1.586178 [32000/60000] loss: 1.739542 [38400/60000] loss: 1.612905 [44800/60000] loss: 1.635378 [51200/60000] loss: 1.740324 [57600/60000] Test Error: Accuracy: 56.8%, Avg loss: 0.024745 Epoch 7 ------------------------------- loss: 1.669946 [ 0/60000] loss: 1.569386 [ 6400/60000] loss: 1.422041 [12800/60000] loss: 1.609000 [19200/60000] loss: 1.379347 [25600/60000] loss: 1.500004 [32000/60000] loss: 1.665772 [38400/60000] loss: 1.538614 [44800/60000] loss: 1.554731 [51200/60000] loss: 1.684841 [57600/60000] Test Error: Accuracy: 58.1%, Avg loss: 0.023533 Epoch 8 ------------------------------- loss: 1.587800 [ 0/60000] loss: 1.496115 [ 6400/60000] loss: 1.325578 [12800/60000] loss: 1.539307 [19200/60000] loss: 1.307962 [25600/60000] loss: 1.434174 [32000/60000] loss: 1.609863 [38400/60000] loss: 1.483990 [44800/60000] loss: 1.491577 [51200/60000] loss: 1.642921 [57600/60000] Test Error: Accuracy: 59.1%, Avg loss: 0.022607 Epoch 9 ------------------------------- loss: 1.523828 [ 0/60000] loss: 1.441160 [ 6400/60000] loss: 1.250425 [12800/60000] loss: 1.483303 [19200/60000] loss: 1.255319 [25600/60000] loss: 1.382149 [32000/60000] loss: 1.566745 [38400/60000] loss: 1.441864 [44800/60000] loss: 1.443685 [51200/60000] loss: 1.610611 [57600/60000] Test Error: Accuracy: 59.6%, Avg loss: 0.021893 Epoch 10 ------------------------------- loss: 1.472739 [ 0/60000] loss: 1.398461 [ 6400/60000] loss: 1.192398 [12800/60000] loss: 1.438026 [19200/60000] loss: 1.215621 [25600/60000] loss: 1.340659 [32000/60000] loss: 1.531828 [38400/60000] loss: 1.409429 [44800/60000] loss: 1.405748 [51200/60000] loss: 1.584560 [57600/60000] Test Error: Accuracy: 60.3%, Avg loss: 0.021334 Done!
以上。