「画像分類タスクに対する転移学習の方法」¶

菅間修正済み 2025/05/15

【原題】TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL

【原著】asank Chilamkurthy

【元URL】https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

【翻訳】電通国際情報サービスISID AIトランスフォーメーションセンター 御手洗 拓真

【日付】2020年11月16日

【チュトーリアル概要】

CNNを使用した画像分類モデルに対して、転移学習を実施する方法を解説します。


本チュートリアルでは、画像分類用の畳み込みニューラルネットワーク(以下、ConvNetと記載)に対して、転移学習を使用して訓練する方法を学びます。

転移学習自体の詳細については、cs231nコースのメモをご参照ください。

(日本語訳注:スタンフォード大学が公開している画像認識をテーマとしたコースの資料)

以下に、上記メモの内容を引用します。

    実際には、(ランダムな初期値を使用して)畳み込みネットワーク全体をゼロから訓練するケースは非常に稀です。
    なぜなら、ネットワークをゼロから訓練するために十分なサイズのデータセットを用意できるケースがほとんど無いからです。
    ゼロからネットワークを訓練する代わりに、一般的には次のような対応をします。
    まず、非常に大規模なデータセット(例えばImageNetなどは120万枚の画像を1000のカテゴリに分けて収録しています)でConvNetを事前学習します。
    そしてこの訓練済みのConvNetを初期値、もしくは特徴量抽出器として実際のタスクで活用します。

上記の引用で紹介されていた転移学習の活用方法は、以下の2通りとなります。

  • ConvNetをファインチューニングする :
    ランダムな値の代わりに訓練済みのパラメータを、訓練するネットワークの初期値として利用します。
    例えば、imagenet1000 dataset で訓練したネットワークをこの用途に使うことができます。
    訓練済みのパラメータを初期値として使う点以外は、通常通りにネットワークを訓練します。

  • ConvNetを特徴量抽出器として使う :
    最後の全結合層を除いて訓練済みネットワークの重みを固定します。
    次に最後の全結合層のみをランダムな重みを持つ新たなものに置き換えます。
    そして、この最終層だけを訓練します。

In [ ]:
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy


データの読み込み¶

データの読み込みにはtorchvisionとtorch.utils.dataパッケージを使用します。


本チュートリアルでは、課題例としてアリとハチの画像を分類するモデルを訓練します。

アリとハチ、それぞれについて約120枚の訓練用画像があり、各クラスの評価用画像は75枚あります。

通常、このデータセットの画像枚数はゼロからConvNetを訓練してモデルの性能を汎化させるには不十分です。

しかし今回は転移学習(ファインチューニング)を使用するため、この量のデータセットでも効率的にモデルを汎化させることができると考えられます。

なお、このデータセットはImageNetデータセットのごく一部(サブセット)です。

注意:

データはこちらからダウンロードして、カレントディレクトリに解凍してください。


(日本語訳注:日本語版チュートリアルでは、データセットをダウンロードするコードを以下に実装しています)


In [ ]:

まず,ターミナルで以下のコマンドを実行してデータをダウンロード・解凍

wget -P ./data https://download.pytorch.org/tutorial/hymenoptera_data.zip

unzip ./data/hymenoptera_data.zip -d ./data

--2020-12-10 05:17:43--  https://download.pytorch.org/tutorial/hymenoptera_data.zip
Resolving download.pytorch.org (download.pytorch.org)... 13.32.204.93, 13.32.204.65, 13.32.204.34, ...
Connecting to download.pytorch.org (download.pytorch.org)|13.32.204.93|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 47286322 (45M) [application/zip]
Saving to: ‘./data/hymenoptera_data.zip’

hymenoptera_data.zi 100%[===================>]  45.10M   119MB/s    in 0.4s    

2020-12-10 05:17:44 (119 MB/s) - ‘./data/hymenoptera_data.zip’ saved [47286322/47286322]

Archive:  ./data/hymenoptera_data.zip
   creating: ./data/hymenoptera_data/
   creating: ./data/hymenoptera_data/train/
   creating: ./data/hymenoptera_data/train/ants/
  inflating: ./data/hymenoptera_data/train/ants/0013035.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1030023514_aad5c608f9.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1095476100_3906d8afde.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1099452230_d1949d3250.jpg  
  inflating: ./data/hymenoptera_data/train/ants/116570827_e9c126745d.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1225872729_6f0856588f.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1262877379_64fcada201.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1269756697_0bce92cdab.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1286984635_5119e80de1.jpg  
  inflating: ./data/hymenoptera_data/train/ants/132478121_2a430adea2.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1360291657_dc248c5eea.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1368913450_e146e2fb6d.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1473187633_63ccaacea6.jpg  
  inflating: ./data/hymenoptera_data/train/ants/148715752_302c84f5a4.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1489674356_09d48dde0a.jpg  
  inflating: ./data/hymenoptera_data/train/ants/149244013_c529578289.jpg  
  inflating: ./data/hymenoptera_data/train/ants/150801003_3390b73135.jpg  
  inflating: ./data/hymenoptera_data/train/ants/150801171_cd86f17ed8.jpg  
  inflating: ./data/hymenoptera_data/train/ants/154124431_65460430f2.jpg  
  inflating: ./data/hymenoptera_data/train/ants/162603798_40b51f1654.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1660097129_384bf54490.jpg  
  inflating: ./data/hymenoptera_data/train/ants/167890289_dd5ba923f3.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1693954099_46d4c20605.jpg  
  inflating: ./data/hymenoptera_data/train/ants/175998972.jpg  
  inflating: ./data/hymenoptera_data/train/ants/178538489_bec7649292.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1804095607_0341701e1c.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1808777855_2a895621d7.jpg  
  inflating: ./data/hymenoptera_data/train/ants/188552436_605cc9b36b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1917341202_d00a7f9af5.jpg  
  inflating: ./data/hymenoptera_data/train/ants/1924473702_daa9aacdbe.jpg  
  inflating: ./data/hymenoptera_data/train/ants/196057951_63bf063b92.jpg  
  inflating: ./data/hymenoptera_data/train/ants/196757565_326437f5fe.jpg  
  inflating: ./data/hymenoptera_data/train/ants/201558278_fe4caecc76.jpg  
  inflating: ./data/hymenoptera_data/train/ants/201790779_527f4c0168.jpg  
  inflating: ./data/hymenoptera_data/train/ants/2019439677_2db655d361.jpg  
  inflating: ./data/hymenoptera_data/train/ants/207947948_3ab29d7207.jpg  
  inflating: ./data/hymenoptera_data/train/ants/20935278_9190345f6b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/224655713_3956f7d39a.jpg  
  inflating: ./data/hymenoptera_data/train/ants/2265824718_2c96f485da.jpg  
  inflating: ./data/hymenoptera_data/train/ants/2265825502_fff99cfd2d.jpg  
  inflating: ./data/hymenoptera_data/train/ants/226951206_d6bf946504.jpg  
  inflating: ./data/hymenoptera_data/train/ants/2278278459_6b99605e50.jpg  
  inflating: ./data/hymenoptera_data/train/ants/2288450226_a6e96e8fdf.jpg  
  inflating: ./data/hymenoptera_data/train/ants/2288481644_83ff7e4572.jpg  
  inflating: ./data/hymenoptera_data/train/ants/2292213964_ca51ce4bef.jpg  
  inflating: ./data/hymenoptera_data/train/ants/24335309_c5ea483bb8.jpg  
  inflating: ./data/hymenoptera_data/train/ants/245647475_9523dfd13e.jpg  
  inflating: ./data/hymenoptera_data/train/ants/255434217_1b2b3fe0a4.jpg  
  inflating: ./data/hymenoptera_data/train/ants/258217966_d9d90d18d3.jpg  
  inflating: ./data/hymenoptera_data/train/ants/275429470_b2d7d9290b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/28847243_e79fe052cd.jpg  
  inflating: ./data/hymenoptera_data/train/ants/318052216_84dff3f98a.jpg  
  inflating: ./data/hymenoptera_data/train/ants/334167043_cbd1adaeb9.jpg  
  inflating: ./data/hymenoptera_data/train/ants/339670531_94b75ae47a.jpg  
  inflating: ./data/hymenoptera_data/train/ants/342438950_a3da61deab.jpg  
  inflating: ./data/hymenoptera_data/train/ants/36439863_0bec9f554f.jpg  
  inflating: ./data/hymenoptera_data/train/ants/374435068_7eee412ec4.jpg  
  inflating: ./data/hymenoptera_data/train/ants/382971067_0bfd33afe0.jpg  
  inflating: ./data/hymenoptera_data/train/ants/384191229_5779cf591b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/386190770_672743c9a7.jpg  
  inflating: ./data/hymenoptera_data/train/ants/392382602_1b7bed32fa.jpg  
  inflating: ./data/hymenoptera_data/train/ants/403746349_71384f5b58.jpg  
  inflating: ./data/hymenoptera_data/train/ants/408393566_b5b694119b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/424119020_6d57481dab.jpg  
  inflating: ./data/hymenoptera_data/train/ants/424873399_47658a91fb.jpg  
  inflating: ./data/hymenoptera_data/train/ants/450057712_771b3bfc91.jpg  
  inflating: ./data/hymenoptera_data/train/ants/45472593_bfd624f8dc.jpg  
  inflating: ./data/hymenoptera_data/train/ants/459694881_ac657d3187.jpg  
  inflating: ./data/hymenoptera_data/train/ants/460372577_f2f6a8c9fc.jpg  
  inflating: ./data/hymenoptera_data/train/ants/460874319_0a45ab4d05.jpg  
  inflating: ./data/hymenoptera_data/train/ants/466430434_4000737de9.jpg  
  inflating: ./data/hymenoptera_data/train/ants/470127037_513711fd21.jpg  
  inflating: ./data/hymenoptera_data/train/ants/474806473_ca6caab245.jpg  
  inflating: ./data/hymenoptera_data/train/ants/475961153_b8c13fd405.jpg  
  inflating: ./data/hymenoptera_data/train/ants/484293231_e53cfc0c89.jpg  
  inflating: ./data/hymenoptera_data/train/ants/49375974_e28ba6f17e.jpg  
  inflating: ./data/hymenoptera_data/train/ants/506249802_207cd979b4.jpg  
  inflating: ./data/hymenoptera_data/train/ants/506249836_717b73f540.jpg  
  inflating: ./data/hymenoptera_data/train/ants/512164029_c0a66b8498.jpg  
  inflating: ./data/hymenoptera_data/train/ants/512863248_43c8ce579b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/518773929_734dbc5ff4.jpg  
  inflating: ./data/hymenoptera_data/train/ants/522163566_fec115ca66.jpg  
  inflating: ./data/hymenoptera_data/train/ants/522415432_2218f34bf8.jpg  
  inflating: ./data/hymenoptera_data/train/ants/531979952_bde12b3bc0.jpg  
  inflating: ./data/hymenoptera_data/train/ants/533848102_70a85ad6dd.jpg  
  inflating: ./data/hymenoptera_data/train/ants/535522953_308353a07c.jpg  
  inflating: ./data/hymenoptera_data/train/ants/540889389_48bb588b21.jpg  
  inflating: ./data/hymenoptera_data/train/ants/541630764_dbd285d63c.jpg  
  inflating: ./data/hymenoptera_data/train/ants/543417860_b14237f569.jpg  
  inflating: ./data/hymenoptera_data/train/ants/560966032_988f4d7bc4.jpg  
  inflating: ./data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg  
  inflating: ./data/hymenoptera_data/train/ants/6240329_72c01e663e.jpg  
  inflating: ./data/hymenoptera_data/train/ants/6240338_93729615ec.jpg  
  inflating: ./data/hymenoptera_data/train/ants/649026570_e58656104b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/662541407_ff8db781e7.jpg  
  inflating: ./data/hymenoptera_data/train/ants/67270775_e9fdf77e9d.jpg  
  inflating: ./data/hymenoptera_data/train/ants/6743948_2b8c096dda.jpg  
  inflating: ./data/hymenoptera_data/train/ants/684133190_35b62c0c1d.jpg  
  inflating: ./data/hymenoptera_data/train/ants/69639610_95e0de17aa.jpg  
  inflating: ./data/hymenoptera_data/train/ants/707895295_009cf23188.jpg  
  inflating: ./data/hymenoptera_data/train/ants/7759525_1363d24e88.jpg  
  inflating: ./data/hymenoptera_data/train/ants/795000156_a9900a4a71.jpg  
  inflating: ./data/hymenoptera_data/train/ants/822537660_caf4ba5514.jpg  
  inflating: ./data/hymenoptera_data/train/ants/82852639_52b7f7f5e3.jpg  
  inflating: ./data/hymenoptera_data/train/ants/841049277_b28e58ad05.jpg  
  inflating: ./data/hymenoptera_data/train/ants/886401651_f878e888cd.jpg  
  inflating: ./data/hymenoptera_data/train/ants/892108839_f1aad4ca46.jpg  
  inflating: ./data/hymenoptera_data/train/ants/938946700_ca1c669085.jpg  
  inflating: ./data/hymenoptera_data/train/ants/957233405_25c1d1187b.jpg  
  inflating: ./data/hymenoptera_data/train/ants/9715481_b3cb4114ff.jpg  
  inflating: ./data/hymenoptera_data/train/ants/998118368_6ac1d91f81.jpg  
  inflating: ./data/hymenoptera_data/train/ants/ant photos.jpg  
  inflating: ./data/hymenoptera_data/train/ants/Ant_1.jpg  
  inflating: ./data/hymenoptera_data/train/ants/army-ants-red-picture.jpg  
  inflating: ./data/hymenoptera_data/train/ants/formica.jpeg  
  inflating: ./data/hymenoptera_data/train/ants/hormiga_co_por.jpg  
  inflating: ./data/hymenoptera_data/train/ants/imageNotFound.gif  
  inflating: ./data/hymenoptera_data/train/ants/kurokusa.jpg  
  inflating: ./data/hymenoptera_data/train/ants/MehdiabadiAnt2_600.jpg  
  inflating: ./data/hymenoptera_data/train/ants/Nepenthes_rafflesiana_ant.jpg  
  inflating: ./data/hymenoptera_data/train/ants/swiss-army-ant.jpg  
  inflating: ./data/hymenoptera_data/train/ants/termite-vs-ant.jpg  
  inflating: ./data/hymenoptera_data/train/ants/trap-jaw-ant-insect-bg.jpg  
  inflating: ./data/hymenoptera_data/train/ants/VietnameseAntMimicSpider.jpg  
   creating: ./data/hymenoptera_data/train/bees/
  inflating: ./data/hymenoptera_data/train/bees/1092977343_cb42b38d62.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1093831624_fb5fbe2308.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1097045929_1753d1c765.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1232245714_f862fbe385.jpg  
  inflating: ./data/hymenoptera_data/train/bees/129236073_0985e91c7d.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1295655112_7813f37d21.jpg  
  inflating: ./data/hymenoptera_data/train/bees/132511197_0b86ad0fff.jpg  
  inflating: ./data/hymenoptera_data/train/bees/132826773_dbbcb117b9.jpg  
  inflating: ./data/hymenoptera_data/train/bees/150013791_969d9a968b.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1508176360_2972117c9d.jpg  
  inflating: ./data/hymenoptera_data/train/bees/154600396_53e1252e52.jpg  
  inflating: ./data/hymenoptera_data/train/bees/16838648_415acd9e3f.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1691282715_0addfdf5e8.jpg  
  inflating: ./data/hymenoptera_data/train/bees/17209602_fe5a5a746f.jpg  
  inflating: ./data/hymenoptera_data/train/bees/174142798_e5ad6d76e0.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1799726602_8580867f71.jpg  
  inflating: ./data/hymenoptera_data/train/bees/1807583459_4fe92b3133.jpg  
  inflating: ./data/hymenoptera_data/train/bees/196430254_46bd129ae7.jpg  
  inflating: ./data/hymenoptera_data/train/bees/196658222_3fffd79c67.jpg  
  inflating: ./data/hymenoptera_data/train/bees/198508668_97d818b6c4.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2031225713_50ed499635.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2037437624_2d7bce461f.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2053200300_8911ef438a.jpg  
  inflating: ./data/hymenoptera_data/train/bees/205835650_e6f2614bee.jpg  
  inflating: ./data/hymenoptera_data/train/bees/208702903_42fb4d9748.jpg  
  inflating: ./data/hymenoptera_data/train/bees/21399619_3e61e5bb6f.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2227611847_ec72d40403.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2321139806_d73d899e66.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2330918208_8074770c20.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2345177635_caf07159b3.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2358061370_9daabbd9ac.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2364597044_3c3e3fc391.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2384149906_2cd8b0b699.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2397446847_04ef3cd3e1.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2405441001_b06c36fa72.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2445215254_51698ff797.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2452236943_255bfd9e58.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2467959963_a7831e9ff0.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2470492904_837e97800d.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2477324698_3d4b1b1cab.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2477349551_e75c97cf4d.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2486729079_62df0920be.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2486746709_c43cec0e42.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2493379287_4100e1dacc.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2495722465_879acf9d85.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2528444139_fa728b0f5b.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2538361678_9da84b77e3.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2551813042_8a070aeb2b.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2580598377_a4caecdb54.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2601176055_8464e6aa71.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2610833167_79bf0bcae5.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2610838525_fe8e3cae47.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2617161745_fa3ebe85b4.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2625499656_e3415e374d.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2634617358_f32fd16bea.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2638074627_6b3ae746a0.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2645107662_b73a8595cc.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2651621464_a2fa8722eb.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2652877533_a564830cbf.jpg  
  inflating: ./data/hymenoptera_data/train/bees/266644509_d30bb16a1b.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2683605182_9d2a0c66cf.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2704348794_eb5d5178c2.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2707440199_cd170bd512.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2710368626_cb42882dc8.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2722592222_258d473e17.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2728759455_ce9bb8cd7a.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2756397428_1d82a08807.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2765347790_da6cf6cb40.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2781170484_5d61835d63.jpg  
  inflating: ./data/hymenoptera_data/train/bees/279113587_b4843db199.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2792000093_e8ae0718cf.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2801728106_833798c909.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2822388965_f6dca2a275.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2861002136_52c7c6f708.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2908916142_a7ac8b57a8.jpg  
  inflating: ./data/hymenoptera_data/train/bees/29494643_e3410f0d37.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2959730355_416a18c63c.jpg  
  inflating: ./data/hymenoptera_data/train/bees/2962405283_22718d9617.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3006264892_30e9cced70.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3030189811_01d095b793.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3030772428_8578335616.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3044402684_3853071a87.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3074585407_9854eb3153.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3079610310_ac2d0ae7bc.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3090975720_71f12e6de4.jpg  
  inflating: ./data/hymenoptera_data/train/bees/3100226504_c0d4f1e3f1.jpg  
  inflating: ./data/hymenoptera_data/train/bees/342758693_c56b89b6b6.jpg  
  inflating: ./data/hymenoptera_data/train/bees/354167719_22dca13752.jpg  
  inflating: ./data/hymenoptera_data/train/bees/359928878_b3b418c728.jpg  
  inflating: ./data/hymenoptera_data/train/bees/365759866_b15700c59b.jpg  
  inflating: ./data/hymenoptera_data/train/bees/36900412_92b81831ad.jpg  
  inflating: ./data/hymenoptera_data/train/bees/39672681_1302d204d1.jpg  
  inflating: ./data/hymenoptera_data/train/bees/39747887_42df2855ee.jpg  
  inflating: ./data/hymenoptera_data/train/bees/421515404_e87569fd8b.jpg  
  inflating: ./data/hymenoptera_data/train/bees/444532809_9e931e2279.jpg  
  inflating: ./data/hymenoptera_data/train/bees/446296270_d9e8b93ecf.jpg  
  inflating: ./data/hymenoptera_data/train/bees/452462677_7be43af8ff.jpg  
  inflating: ./data/hymenoptera_data/train/bees/452462695_40a4e5b559.jpg  
  inflating: ./data/hymenoptera_data/train/bees/457457145_5f86eb7e9c.jpg  
  inflating: ./data/hymenoptera_data/train/bees/465133211_80e0c27f60.jpg  
  inflating: ./data/hymenoptera_data/train/bees/469333327_358ba8fe8a.jpg  
  inflating: ./data/hymenoptera_data/train/bees/472288710_2abee16fa0.jpg  
  inflating: ./data/hymenoptera_data/train/bees/473618094_8ffdcab215.jpg  
  inflating: ./data/hymenoptera_data/train/bees/476347960_52edd72b06.jpg  
  inflating: ./data/hymenoptera_data/train/bees/478701318_bbd5e557b8.jpg  
  inflating: ./data/hymenoptera_data/train/bees/507288830_f46e8d4cb2.jpg  
  inflating: ./data/hymenoptera_data/train/bees/509247772_2db2d01374.jpg  
  inflating: ./data/hymenoptera_data/train/bees/513545352_fd3e7c7c5d.jpg  
  inflating: ./data/hymenoptera_data/train/bees/522104315_5d3cb2758e.jpg  
  inflating: ./data/hymenoptera_data/train/bees/537309131_532bfa59ea.jpg  
  inflating: ./data/hymenoptera_data/train/bees/586041248_3032e277a9.jpg  
  inflating: ./data/hymenoptera_data/train/bees/760526046_547e8b381f.jpg  
  inflating: ./data/hymenoptera_data/train/bees/760568592_45a52c847f.jpg  
  inflating: ./data/hymenoptera_data/train/bees/774440991_63a4aa0cbe.jpg  
  inflating: ./data/hymenoptera_data/train/bees/85112639_6e860b0469.jpg  
  inflating: ./data/hymenoptera_data/train/bees/873076652_eb098dab2d.jpg  
  inflating: ./data/hymenoptera_data/train/bees/90179376_abc234e5f4.jpg  
  inflating: ./data/hymenoptera_data/train/bees/92663402_37f379e57a.jpg  
  inflating: ./data/hymenoptera_data/train/bees/95238259_98470c5b10.jpg  
  inflating: ./data/hymenoptera_data/train/bees/969455125_58c797ef17.jpg  
  inflating: ./data/hymenoptera_data/train/bees/98391118_bdb1e80cce.jpg  
   creating: ./data/hymenoptera_data/val/
   creating: ./data/hymenoptera_data/val/ants/
  inflating: ./data/hymenoptera_data/val/ants/10308379_1b6c72e180.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1053149811_f62a3410d3.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1073564163_225a64f170.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1119630822_cd325ea21a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1124525276_816a07c17f.jpg  
  inflating: ./data/hymenoptera_data/val/ants/11381045_b352a47d8c.jpg  
  inflating: ./data/hymenoptera_data/val/ants/119785936_dd428e40c3.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1247887232_edcb61246c.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1262751255_c56c042b7b.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1337725712_2eb53cd742.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1358854066_5ad8015f7f.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1440002809_b268d9a66a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/147542264_79506478c2.jpg  
  inflating: ./data/hymenoptera_data/val/ants/152286280_411648ec27.jpg  
  inflating: ./data/hymenoptera_data/val/ants/153320619_2aeb5fa0ee.jpg  
  inflating: ./data/hymenoptera_data/val/ants/153783656_85f9c3ac70.jpg  
  inflating: ./data/hymenoptera_data/val/ants/157401988_d0564a9d02.jpg  
  inflating: ./data/hymenoptera_data/val/ants/159515240_d5981e20d1.jpg  
  inflating: ./data/hymenoptera_data/val/ants/161076144_124db762d6.jpg  
  inflating: ./data/hymenoptera_data/val/ants/161292361_c16e0bf57a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/170652283_ecdaff5d1a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/17081114_79b9a27724.jpg  
  inflating: ./data/hymenoptera_data/val/ants/172772109_d0a8e15fb0.jpg  
  inflating: ./data/hymenoptera_data/val/ants/1743840368_b5ccda82b7.jpg  
  inflating: ./data/hymenoptera_data/val/ants/181942028_961261ef48.jpg  
  inflating: ./data/hymenoptera_data/val/ants/183260961_64ab754c97.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2039585088_c6f47c592e.jpg  
  inflating: ./data/hymenoptera_data/val/ants/205398178_c395c5e460.jpg  
  inflating: ./data/hymenoptera_data/val/ants/208072188_f293096296.jpg  
  inflating: ./data/hymenoptera_data/val/ants/209615353_eeb38ba204.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2104709400_8831b4fc6f.jpg  
  inflating: ./data/hymenoptera_data/val/ants/212100470_b485e7b7b9.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2127908701_d49dc83c97.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2191997003_379df31291.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2211974567_ee4606b493.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2219621907_47bc7cc6b0.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2238242353_52c82441df.jpg  
  inflating: ./data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg  
  inflating: ./data/hymenoptera_data/val/ants/239161491_86ac23b0a3.jpg  
  inflating: ./data/hymenoptera_data/val/ants/263615709_cfb28f6b8e.jpg  
  inflating: ./data/hymenoptera_data/val/ants/308196310_1db5ffa01b.jpg  
  inflating: ./data/hymenoptera_data/val/ants/319494379_648fb5a1c6.jpg  
  inflating: ./data/hymenoptera_data/val/ants/35558229_1fa4608a7a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/412436937_4c2378efc2.jpg  
  inflating: ./data/hymenoptera_data/val/ants/436944325_d4925a38c7.jpg  
  inflating: ./data/hymenoptera_data/val/ants/445356866_6cb3289067.jpg  
  inflating: ./data/hymenoptera_data/val/ants/459442412_412fecf3fe.jpg  
  inflating: ./data/hymenoptera_data/val/ants/470127071_8b8ee2bd74.jpg  
  inflating: ./data/hymenoptera_data/val/ants/477437164_bc3e6e594a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/488272201_c5aa281348.jpg  
  inflating: ./data/hymenoptera_data/val/ants/502717153_3e4865621a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/518746016_bcc28f8b5b.jpg  
  inflating: ./data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg  
  inflating: ./data/hymenoptera_data/val/ants/562589509_7e55469b97.jpg  
  inflating: ./data/hymenoptera_data/val/ants/57264437_a19006872f.jpg  
  inflating: ./data/hymenoptera_data/val/ants/573151833_ebbc274b77.jpg  
  inflating: ./data/hymenoptera_data/val/ants/649407494_9b6bc4949f.jpg  
  inflating: ./data/hymenoptera_data/val/ants/751649788_78dd7d16ce.jpg  
  inflating: ./data/hymenoptera_data/val/ants/768870506_8f115d3d37.jpg  
  inflating: ./data/hymenoptera_data/val/ants/800px-Meat_eater_ant_qeen_excavating_hole.jpg  
  inflating: ./data/hymenoptera_data/val/ants/8124241_36b290d372.jpg  
  inflating: ./data/hymenoptera_data/val/ants/8398478_50ef10c47a.jpg  
  inflating: ./data/hymenoptera_data/val/ants/854534770_31f6156383.jpg  
  inflating: ./data/hymenoptera_data/val/ants/892676922_4ab37dce07.jpg  
  inflating: ./data/hymenoptera_data/val/ants/94999827_36895faade.jpg  
  inflating: ./data/hymenoptera_data/val/ants/Ant-1818.jpg  
  inflating: ./data/hymenoptera_data/val/ants/ants-devouring-remains-of-large-dead-insect-on-red-tile-in-Stellenbosch-South-Africa-closeup-1-DHD.jpg  
  inflating: ./data/hymenoptera_data/val/ants/desert_ant.jpg  
  inflating: ./data/hymenoptera_data/val/ants/F.pergan.28(f).jpg  
  inflating: ./data/hymenoptera_data/val/ants/Hormiga.jpg  
   creating: ./data/hymenoptera_data/val/bees/
  inflating: ./data/hymenoptera_data/val/bees/1032546534_06907fe3b3.jpg  
  inflating: ./data/hymenoptera_data/val/bees/10870992_eebeeb3a12.jpg  
  inflating: ./data/hymenoptera_data/val/bees/1181173278_23c36fac71.jpg  
  inflating: ./data/hymenoptera_data/val/bees/1297972485_33266a18d9.jpg  
  inflating: ./data/hymenoptera_data/val/bees/1328423762_f7a88a8451.jpg  
  inflating: ./data/hymenoptera_data/val/bees/1355974687_1341c1face.jpg  
  inflating: ./data/hymenoptera_data/val/bees/144098310_a4176fd54d.jpg  
  inflating: ./data/hymenoptera_data/val/bees/1486120850_490388f84b.jpg  
  inflating: ./data/hymenoptera_data/val/bees/149973093_da3c446268.jpg  
  inflating: ./data/hymenoptera_data/val/bees/151594775_ee7dc17b60.jpg  
  inflating: ./data/hymenoptera_data/val/bees/151603988_2c6f7d14c7.jpg  
  inflating: ./data/hymenoptera_data/val/bees/1519368889_4270261ee3.jpg  
  inflating: ./data/hymenoptera_data/val/bees/152789693_220b003452.jpg  
  inflating: ./data/hymenoptera_data/val/bees/177677657_a38c97e572.jpg  
  inflating: ./data/hymenoptera_data/val/bees/1799729694_0c40101071.jpg  
  inflating: ./data/hymenoptera_data/val/bees/181171681_c5a1a82ded.jpg  
  inflating: ./data/hymenoptera_data/val/bees/187130242_4593a4c610.jpg  
  inflating: ./data/hymenoptera_data/val/bees/203868383_0fcbb48278.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2060668999_e11edb10d0.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2086294791_6f3789d8a6.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2103637821_8d26ee6b90.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2104135106_a65eede1de.jpg  
  inflating: ./data/hymenoptera_data/val/bees/215512424_687e1e0821.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2173503984_9c6aaaa7e2.jpg  
  inflating: ./data/hymenoptera_data/val/bees/220376539_20567395d8.jpg  
  inflating: ./data/hymenoptera_data/val/bees/224841383_d050f5f510.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2321144482_f3785ba7b2.jpg  
  inflating: ./data/hymenoptera_data/val/bees/238161922_55fa9a76ae.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2407809945_fb525ef54d.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2415414155_1916f03b42.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2438480600_40a1249879.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2444778727_4b781ac424.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2457841282_7867f16639.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2470492902_3572c90f75.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2478216347_535c8fe6d7.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2501530886_e20952b97d.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2506114833_90a41c5267.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2509402554_31821cb0b6.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2525379273_dcb26a516d.jpg  
  inflating: ./data/hymenoptera_data/val/bees/26589803_5ba7000313.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2668391343_45e272cd07.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2670536155_c170f49cd0.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2685605303_9eed79d59d.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2702408468_d9ed795f4f.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2709775832_85b4b50a57.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2717418782_bd83307d9f.jpg  
  inflating: ./data/hymenoptera_data/val/bees/272986700_d4d4bf8c4b.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2741763055_9a7bb00802.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2745389517_250a397f31.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2751836205_6f7b5eff30.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2782079948_8d4e94a826.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2809496124_5f25b5946a.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2815838190_0a9889d995.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2841437312_789699c740.jpg  
  inflating: ./data/hymenoptera_data/val/bees/2883093452_7e3a1eb53f.jpg  
  inflating: ./data/hymenoptera_data/val/bees/290082189_f66cb80bfc.jpg  
  inflating: ./data/hymenoptera_data/val/bees/296565463_d07a7bed96.jpg  
  inflating: ./data/hymenoptera_data/val/bees/3077452620_548c79fda0.jpg  
  inflating: ./data/hymenoptera_data/val/bees/348291597_ee836fbb1a.jpg  
  inflating: ./data/hymenoptera_data/val/bees/350436573_41f4ecb6c8.jpg  
  inflating: ./data/hymenoptera_data/val/bees/353266603_d3eac7e9a0.jpg  
  inflating: ./data/hymenoptera_data/val/bees/372228424_16da1f8884.jpg  
  inflating: ./data/hymenoptera_data/val/bees/400262091_701c00031c.jpg  
  inflating: ./data/hymenoptera_data/val/bees/416144384_961c326481.jpg  
  inflating: ./data/hymenoptera_data/val/bees/44105569_16720a960c.jpg  
  inflating: ./data/hymenoptera_data/val/bees/456097971_860949c4fc.jpg  
  inflating: ./data/hymenoptera_data/val/bees/464594019_1b24a28bb1.jpg  
  inflating: ./data/hymenoptera_data/val/bees/485743562_d8cc6b8f73.jpg  
  inflating: ./data/hymenoptera_data/val/bees/540976476_844950623f.jpg  
  inflating: ./data/hymenoptera_data/val/bees/54736755_c057723f64.jpg  
  inflating: ./data/hymenoptera_data/val/bees/57459255_752774f1b2.jpg  
  inflating: ./data/hymenoptera_data/val/bees/576452297_897023f002.jpg  
  inflating: ./data/hymenoptera_data/val/bees/586474709_ae436da045.jpg  
  inflating: ./data/hymenoptera_data/val/bees/590318879_68cf112861.jpg  
  inflating: ./data/hymenoptera_data/val/bees/59798110_2b6a3c8031.jpg  
  inflating: ./data/hymenoptera_data/val/bees/603709866_a97c7cfc72.jpg  
  inflating: ./data/hymenoptera_data/val/bees/603711658_4c8cd2201e.jpg  
  inflating: ./data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg  
  inflating: ./data/hymenoptera_data/val/bees/6a00d8341c630a53ef00e553d0beb18834-800wi.jpg  
  inflating: ./data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg  
  inflating: ./data/hymenoptera_data/val/bees/759745145_e8bc776ec8.jpg  
  inflating: ./data/hymenoptera_data/val/bees/936182217_c4caa5222d.jpg  
  inflating: ./data/hymenoptera_data/val/bees/abeja.jpg  
Out[ ]:

In [ ]:
# 訓練データ用のデータ拡張と正規化
# 検証データ用には正規化のみ実施
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

画像の可視化¶

データ拡張(もしくはデータオーギュメンテーションと呼ぶ:Data Augmentation)を理解するために、いくつかの訓練画像を可視化してみましょう。

In [ ]:
def imshow(inp, title=None, filename='a.png'):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # プロット図が更新されるように少しだけ一時停止
    plt.savefig(filename)


# 訓練データのバッチを取得する
for inputs, classes in dataloaders['train']:
    break

# バッチからグリッドを作成する
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes], filename='data_samples.png')

モデルの訓練¶

それではモデルの訓練用に、一般的な関数実装しましょう。

以下のサンプルコードでは、次の内容を実装しています。

  • 学習率のスケジューリング
  • ベストモデルの保存

サンプルコードのschedulerという変数は、学習率をスケジュールするオブジェクト(学習率を特定のタイミングで変更する)です。

torch.optim.lr_schedulerクラスからインスタンス化して作られています。

In [ ]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 各エポックには訓練フェーズと検証フェーズがあります
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # モデルを訓練モードに設定します
            else:
                model.eval()   # モードを評価するモデルを設定します

            running_loss = 0.0
            running_corrects = 0

            # データをイレテートします
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # パラメータの勾配をゼロにします
                optimizer.zero_grad()

                # 順伝播
                # 訓練の時だけ、履歴を保持します
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # 訓練の時だけ逆伝播+オプティマイズを行います
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 損失を計算します
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # モデルをディープ・コピーします
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # ベストモデルの重みをロードします
    model.load_state_dict(best_model_wts)
    return model

モデル予測値の可視化¶

適当な数枚の画像に対する予測結果を表示する、汎用的な関数を実装します。

In [ ]:
def visualize_model(model, num_images=6, filename_base='im'):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))

                imshow(inputs.cpu().data[j], filename=filename_base + '.png')

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return


        model.train(mode=was_training)

ConvNetをファインチューニングする方法¶

訓練済みモデルをロードし、最後の全結合層を新しいものに置き換えます。

In [ ]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# ここでは,各出力サンプルのサイズは2に設定されています
# なお、NN.Linear(num_ftrs, len(class_names))という書き方で一般化することもできます。
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# ネットワークのすべてのパラメータが最適化対象です
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 7エポックごとに学習率を1/10ずつ減衰させます
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))

訓練と評価¶

訓練と評価にかかる時間はCPU環境では15~25分くらいで、GPU環境では1分未満です。

(日本語訳注:本ノートブックはGoogle ColaboratoryのGPU設定で保存しています)

In [ ]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
Epoch 0/24
----------
train Loss: 0.6579 Acc: 0.6598
val Loss: 0.3647 Acc: 0.8693

Epoch 1/24
----------
train Loss: 0.4436 Acc: 0.8115
val Loss: 0.2602 Acc: 0.8889

Epoch 2/24
----------
train Loss: 0.3758 Acc: 0.8525
val Loss: 0.3663 Acc: 0.8627

Epoch 3/24
----------
train Loss: 0.4801 Acc: 0.7746
val Loss: 0.8678 Acc: 0.7190

Epoch 4/24
----------
train Loss: 0.4152 Acc: 0.8484
val Loss: 0.3685 Acc: 0.8758

Epoch 5/24
----------
train Loss: 0.5817 Acc: 0.7992
val Loss: 0.5917 Acc: 0.7974

Epoch 6/24
----------
train Loss: 0.4582 Acc: 0.8443
val Loss: 0.3075 Acc: 0.9150

Epoch 7/24
----------
train Loss: 0.3692 Acc: 0.8443
val Loss: 0.2574 Acc: 0.9216

Epoch 8/24
----------
train Loss: 0.3042 Acc: 0.8648
val Loss: 0.2448 Acc: 0.9085

Epoch 9/24
----------
train Loss: 0.2838 Acc: 0.8811
val Loss: 0.2392 Acc: 0.9281

Epoch 10/24
----------
train Loss: 0.2779 Acc: 0.8811
val Loss: 0.2316 Acc: 0.9216

Epoch 11/24
----------
train Loss: 0.4327 Acc: 0.8361
val Loss: 0.2830 Acc: 0.8954

Epoch 12/24
----------
train Loss: 0.3345 Acc: 0.8566
val Loss: 0.2187 Acc: 0.9150

Epoch 13/24
----------
train Loss: 0.3179 Acc: 0.8607
val Loss: 0.2028 Acc: 0.9216

Epoch 14/24
----------
train Loss: 0.3340 Acc: 0.8484
val Loss: 0.2109 Acc: 0.9281

Epoch 15/24
----------
train Loss: 0.3628 Acc: 0.8238
val Loss: 0.2308 Acc: 0.9085

Epoch 16/24
----------
train Loss: 0.2692 Acc: 0.8566
val Loss: 0.2209 Acc: 0.9216

Epoch 17/24
----------
train Loss: 0.3819 Acc: 0.8361
val Loss: 0.2053 Acc: 0.9085

Epoch 18/24
----------
train Loss: 0.2927 Acc: 0.8770
val Loss: 0.2768 Acc: 0.9020

Epoch 19/24
----------
train Loss: 0.2942 Acc: 0.8566
val Loss: 0.2131 Acc: 0.9216

Epoch 20/24
----------
train Loss: 0.2187 Acc: 0.9139
val Loss: 0.2271 Acc: 0.9150

Epoch 21/24
----------
train Loss: 0.2382 Acc: 0.8852
val Loss: 0.2285 Acc: 0.9216

Epoch 22/24
----------
train Loss: 0.2750 Acc: 0.8484
val Loss: 0.2348 Acc: 0.9281

Epoch 23/24
----------
train Loss: 0.2746 Acc: 0.8730
val Loss: 0.2103 Acc: 0.9216

Epoch 24/24
----------
train Loss: 0.2915 Acc: 0.8730
val Loss: 0.2610 Acc: 0.9020

Training complete in 1m 37s
Best val Acc: 0.928105
In [ ]:
visualize_model(model_ft, filename_base='a')

Conv Netを特徴抽出器として使う方法¶

今回のケースでは、最後の全結合層を除くすべてのネットワークのパラメータを固定します。

パラメータを固定するためには、requires_grad = False を設定する必要があります。

これによって、backward()で勾配が計算されないようになります。


勾配計算の詳細はこちらのドキュメントをご覧ください

In [ ]:
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# 新しく構築されたモジュールのパラメータは、デフォルトでは requires_grad=True になっています。
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

 # ファインチューニングとは違い、最終層のパラメータのみが最適化されていることを確認してください
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# 7エポックごとに学習率を0.1倍ずつ減衰させる
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

訓練と評価¶

ファインチューニングの場合に比べると、訓練と評価にかかる時間はCPUを使う場合でも約半分になります。

ほとんどのネットワークで勾配を計算する必要がないために、このような訓練時間の短縮を見込むことができます。

ただし、順伝播の計算が不要になるわけではありません。


(日本語訳注:上記説明では、勾配をrequires_grad = Falseにすると記載・説明されていますが、明示的にその操作をせず、代わりに、SGDでの訓練対象をmodel_conv.fc.parameters()として、最終層のみが訓練対象になるように設定されています。)

In [ ]:
model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
Epoch 0/24
----------
train Loss: 0.5942 Acc: 0.7049
val Loss: 0.2298 Acc: 0.9281

Epoch 1/24
----------
train Loss: 0.3885 Acc: 0.8238
val Loss: 0.2044 Acc: 0.9346

Epoch 2/24
----------
train Loss: 0.3899 Acc: 0.8238
val Loss: 0.2545 Acc: 0.9020

Epoch 3/24
----------
train Loss: 0.5229 Acc: 0.7582
val Loss: 0.2526 Acc: 0.8954

Epoch 4/24
----------
train Loss: 0.4906 Acc: 0.8115
val Loss: 0.1795 Acc: 0.9477

Epoch 5/24
----------
train Loss: 0.3060 Acc: 0.8852
val Loss: 0.2763 Acc: 0.8824

Epoch 6/24
----------
train Loss: 0.4984 Acc: 0.7377
val Loss: 0.2357 Acc: 0.8954

Epoch 7/24
----------
train Loss: 0.4188 Acc: 0.8033
val Loss: 0.1752 Acc: 0.9412

Epoch 8/24
----------
train Loss: 0.3272 Acc: 0.8648
val Loss: 0.2000 Acc: 0.9150

Epoch 9/24
----------
train Loss: 0.2764 Acc: 0.8934
val Loss: 0.1667 Acc: 0.9608

Epoch 10/24
----------
train Loss: 0.3596 Acc: 0.8402
val Loss: 0.2018 Acc: 0.9216

Epoch 11/24
----------
train Loss: 0.3567 Acc: 0.8730
val Loss: 0.1905 Acc: 0.9150

Epoch 12/24
----------
train Loss: 0.3074 Acc: 0.8525
val Loss: 0.1652 Acc: 0.9542

Epoch 13/24
----------
train Loss: 0.4013 Acc: 0.8279
val Loss: 0.1610 Acc: 0.9542

Epoch 14/24
----------
train Loss: 0.2848 Acc: 0.8730
val Loss: 0.1622 Acc: 0.9477

Epoch 15/24
----------
train Loss: 0.3529 Acc: 0.8648
val Loss: 0.1757 Acc: 0.9477

Epoch 16/24
----------
train Loss: 0.3388 Acc: 0.8279
val Loss: 0.1842 Acc: 0.9542

Epoch 17/24
----------
train Loss: 0.3896 Acc: 0.8443
val Loss: 0.1590 Acc: 0.9608

Epoch 18/24
----------
train Loss: 0.4212 Acc: 0.8279
val Loss: 0.2018 Acc: 0.9216

Epoch 19/24
----------
train Loss: 0.3185 Acc: 0.8402
val Loss: 0.1690 Acc: 0.9412

Epoch 20/24
----------
train Loss: 0.3104 Acc: 0.8525
val Loss: 0.1745 Acc: 0.9542

Epoch 21/24
----------
train Loss: 0.3066 Acc: 0.8607
val Loss: 0.1720 Acc: 0.9542

Epoch 22/24
----------
train Loss: 0.3096 Acc: 0.8648
val Loss: 0.1673 Acc: 0.9542

Epoch 23/24
----------
train Loss: 0.3565 Acc: 0.8238
val Loss: 0.1631 Acc: 0.9542

Epoch 24/24
----------
train Loss: 0.3169 Acc: 0.8689
val Loss: 0.1737 Acc: 0.9608

Training complete in 1m 23s
Best val Acc: 0.960784
In [ ]:
visualize_model(model_conv, filename_base='b')

さらなる学習のために¶

転移学習の応用についてさらに学びたい方は、画像認識のための量子を用いた転移学習チュートリアル(Quantized Transfer Learning for Computer Vision Tutorial)もチェックしてみてください。