菅間修正済み 2025/05/15
【原題】SAVE AND LOAD THE MODEL
【原著】 Suraj Subramanian、Seth Juarez 、Cassie Breviu 、Dmitry Soshnikov、Ari Bornstein
【元URL】https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html
【翻訳】電通国際情報サービスISID AIトランスフォーメーションセンター 小川 雄太郎
【日付】2021年03月20日
【チュトーリアル概要】
本チュートリアルでは、PyTorchでモデルを保存する方法、および保存したモデルのロードについて解説します。
本チュートリアルでは、モデルの状態を継続させるために、モデルの保存する方法とモデルを読み込み推論を実行する方法について解説します。
import torch
import torch.onnx as onnx
import torchvision.models as models
PyTorchのモデルは学習したパラメータを内部に状態辞書(state_dict)として保持しています。
これらのパラメータの値は torch.save を使用することで、永続化させることができます。
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))
モデルの重みを読み込むためには、予め同じモデルの形をしたインスタンスを用意します。
そしてそのインスタンスに対してload_state_dict()メソッドを使用し、パラメータの値を読み込みます。
model = models.vgg16() # pretrained=Trueを引数に入れていないので、デフォルトのランダムな値になっています
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
【注意】
ドロップアウトやバッチノーマライゼーションレイヤーをevaluationモードに切り替えるために、推論前には model.eval()を実行することを忘れないようにしてください。
これを忘れると、推論結果が正確ではなくなります。
モデルの重みをロードする場合は、先にモデルのインスタンスを用意する必要があります。
モデルクラスの構造も一緒に保存したい場合もあるかと思います。
その際は保存時に、model.state_dict()ではなくmodelを渡します。
torch.save(model, 'model.pth')
モデルをロードするには、以下のように記載します。
model = torch.load('model.pth')
【注意】
上記の方法はPythonのpickleモジュールをモデルのシリアライズに使用します。
そのため、モデルのロード時に実際のクラス定義が利用可能である必要があります。
【日本語訳注】
上記の表現は理解が少し難しいのですが、言いたいことは、モデルのモジュールに独自クラスを定義して使用している場合、torch.loadを実行する前に、その独自クラスをimportするか宣言するかして、使用可能な状態にしておく必要があります、という意味です。
でないと、load時に不明なクラスを使用することになり読み込みエラーとなります。
PyTorchはONNX形式でのモデル出力もサポートしています。
しかしPyTorchの計算グラフは動的に生成されるため、出力処理では計算グラフを一度実行して作成してから、ONNXモデルを生成する必要があります。
すなわち、実際に一度データを流してみる必要があります。
そのため、テスト用の適切なテンソルサイズの入力変数を用意し、モデル出力の処理に渡す必要があります。
以下ではダミーのゼロテンソルを適切なサイズで作成して使用しています。
input_image = torch.zeros((1,3,224,224))
onnx.export(model, input_image, 'model.onnx')
# 日本語訳注:このセルを実行するとmodel.onnxというファイルが生成されます
ONNXモデルを使用することで、異なるプラットフォームや異なるプログラミング言語でディープラーニングモデルの推論を実行させるなど、様々なことが可能です。
さらなる詳細については、こちらのONNX tutorialをご覧ください。
おつかれまさです! これでPyTorch beginner tutorialは完了です。
再度目次ページを見たり、次の「8. クイックスタート」を見て、内容を振り返ってみてください。
本チュートリアルシリーズが、PyTorchでディープラーニングを始める際のお役に立てれば幸いです。幸運を祈ります。
以上。