「Pytorch:ClassImbalance」の版間の差分
| 30行目: | 30行目: | ||
ここで,<math> \gamma \ge 0 </math>はハイパーパラメータです.<math> \gamma = 0 </math>の時は,Cross Entropy Lossと等価になります. | ここで,<math> \gamma \ge 0 </math>はハイパーパラメータです.<math> \gamma = 0 </math>の時は,Cross Entropy Lossと等価になります. | ||
この<math> \gamma </math>がどのような意味を持つかを考えてみます.[https://qiita.com/agatan/items/53fe8d21f2147b0ac982 このサイト]の2つ目の図から,<math> \gamma </math>の値を変化させた場合に,損失関数の形状がどのように変化するか,を確認できます.<math> \gamma </math> | この<math> \gamma </math>がどのような意味を持つかを考えてみます.[https://qiita.com/agatan/items/53fe8d21f2147b0ac982 このサイト]の2つ目の図から,<math> \gamma </math>の値を変化させた場合に,損失関数の形状がどのように変化するか,を確認できます.<math> \gamma </math>の値が大きくなると,0付近ではより勾配が大きく,1付近ではより勾配が小さくなっていることがわかります. | ||
例えば,クラス<math>c'</math>に属するあるデータについて,<math> \gamma = 5 </math> かつ <math>y_{ c' } = 0.8</math>の時,損失の値は既に十分に小さく,その付近の勾配も0に近い値になります.よって,そのデータはほとんど学習に寄与しません.一方で,<math>y_{ c' } = 0.1</math>付近では,より大きな勾配を持つことがわかります. | |||
このように,<math>y_{ c' }</math>の値が1に近いとき,すなわちモデルがそのデータについて十分な学習が既にできているとき,そのデータによる学習への寄与は小さくなります.反対に,<math>y_{ c' }</math>の値が0に近いとき,すなわちモデルがそのデータについて十分に学習していないとき,そのデータはより強く学習に寄与します.これは,<math>\gamma</math>の値に関係なく,一貫してみられる傾向ですが,<math>\gamma</math>の値が大きいほど,学習が十分でないデータの寄与度合いが相対的に大きくなることがわかります. | |||
インバランスデータでは,多数派のクラスのデータに対しては学習が進みやすく,少数派のクラスのデータに対しては,学習が進みにくい,と考えるのは自然です.よって,少数派により大きな重みを与えることができるFocal Lossが有効に働くと考えられます. | |||
= 演習 = | = 演習 = | ||
2025年3月6日 (木) 13:45時点における版
クラスインバランス
学習データセット内で,クラスごとのデータ数に偏りがある状態のことを,クラスインバランスと言います.今回は,インバランスなデータでも比較的精度良く学習させる方法について学びます.
なぜクラスインバランスが問題か?
極端な例で考えてみることにします.クラスAとBがあり,学習用データセットにおけるクラスAのデータ数は990であり,クラスBのデータ数は10であるとします.識別モデルの気持ちになって考えてみると,どのような入力データに対しても,それがクラスAであると推定しておけば,正解率は99%になります.よって,この学習用データセットの中だけで考えれば,どんな入力であってもクラスAであると推定することは,悪くない戦略であると言えます.しかし,常にクラスAが出力されるモデルであれば,実運用時には何の役にも立ちません.
本質的に同じ問題は,上記のような場面以外にも,色々なところで現れます.例えば,テニスの試合の映像で,ボールの位置をヒートマップにより推定することを考えてみましょう(参考:TrackNet).ボールが存在する領域は,画像全体に対してごくわずかです.ボールの位置を推定するモデルの気持ちになって考えてみると,ほとんどの領域にはボールは存在しないわけですから,全ての領域が真っ黒のヒートマップを出力する(言い換えると,どこにもボールは存在しないと主張する)ことは,合理的な戦略に思えます.しかし,これではボールを検出することはできません.
対策案
ここでは識別問題に絞って考えてみます.クラスインバランスへの対策として,よく使われるのは次の3つの方法です.
- データが豊富なクラスのみで学習
- データ拡張
- 各データへの重みづけ
上記1は,異常検知の問題で使用される方法であり,2クラス分類の問題に対して使える方法です.(多クラス分類に拡張できるのかどうかはわかりません.)異常検知の問題においては,正常データは多数得られるが,異常データは滅多に得られない,という想定がなされます.これは,クラスAには大量の正常データ,クラスBには僅かな異常データ,という2クラス分類問題そのものです.
しかし,単純に識別問題として解けば良いかというと,そう単純でもありません.これまでに見たことがない新しいタイプの異常が,今後現れる可能性があるからです.(犯罪者は,これまでに無かった新しい方法で,セキュリティシステムを欺こうとします.)よって,異常検知の分野では,A(正常)とB(異常)の既知データの間で線を引く問題(識別問題)ではなく,Aのデータだけを用いて学習を行い,Aの分布の内か外かという基準で,正常・異常を判断する問題が扱われます.
上記2は,データに対する摂動を加えたり,GANや拡散モデルなどの生成AIを用いて生成した画像を学習データに加えることで,数が少ないクラスのデータを拡張する方法です.
上記3は,数が少ないクラスのデータに対して大きな重みを与えることで,クラス間のバランスを取ろう,という方法です. 今回は,この方法を実践します.
Focal loss
分類問題でよく使われる損失関数は,Cross Entropy Lossです.これは,モデルの出力とGround Truthをそれぞれおよび,クラス数をとすると,次の式で書けます.
一方,Focal Lossは,インバランスなデータを扱う際に使われる損失関数であり,次の式で計算します.
ここで,はハイパーパラメータです.の時は,Cross Entropy Lossと等価になります.
このがどのような意味を持つかを考えてみます.このサイトの2つ目の図から,の値を変化させた場合に,損失関数の形状がどのように変化するか,を確認できます.の値が大きくなると,0付近ではより勾配が大きく,1付近ではより勾配が小さくなっていることがわかります. 例えば,クラスに属するあるデータについて, かつ の時,損失の値は既に十分に小さく,その付近の勾配も0に近い値になります.よって,そのデータはほとんど学習に寄与しません.一方で,付近では,より大きな勾配を持つことがわかります.
このように,の値が1に近いとき,すなわちモデルがそのデータについて十分な学習が既にできているとき,そのデータによる学習への寄与は小さくなります.反対に,の値が0に近いとき,すなわちモデルがそのデータについて十分に学習していないとき,そのデータはより強く学習に寄与します.これは,の値に関係なく,一貫してみられる傾向ですが,の値が大きいほど,学習が十分でないデータの寄与度合いが相対的に大きくなることがわかります.
インバランスデータでは,多数派のクラスのデータに対しては学習が進みやすく,少数派のクラスのデータに対しては,学習が進みにくい,と考えるのは自然です.よって,少数派により大きな重みを与えることができるFocal Lossが有効に働くと考えられます.
演習
まず,作業用ディレクトリに移動した後に,以下のコマンドを実行して,プログラムとデータをダウンロードしてください.
wget https://vrl.sys.wakayama-u.ac.jp/class/pytorch_tutorial/exersise_classimbalance/exersise_classimbalance.py wget -P ./data -r https://vrl.sys.wakayama-u.ac.jp/class/pytorch_tutorial/datasets/cifar-10-batches-py/