「Pytorch:ClassImbalance」の版間の差分

提供:classwiki
ナビゲーションに移動 検索に移動
107行目: 107行目:
         return loss
         return loss


ここで,yとtはそれぞれ,モデルの出力とGround truthです.また,torch.nn.functional.cross_entropy() は,Pytorchで用意してくれているクロスエントロピー損失を計算する関数です.この部分を,Pytorchの標準ライブラリに頼らない方法で書き直していきます.
ここで,yとtはそれぞれ,モデルの出力とGround truthです.また,torch.nn.functional.cross_entropy() は,Pytorchで用意してくれているクロスエントロピー損失を計算する関数です.この部分を,Pytorchのビルトイン関数に頼らない方法で書き直していきます.


注意:forward(self, y, t)のyはモデルが出力したものですが,ソフトマックス関数を適用する前のものです.また,tは正解クラスラベルを表すスカラであり,one-hotベクトルではありません.よって,torch.nn.functional.cross_entropy(y,t)では,実際には
注意:forward(self, y, t)のyはモデルが出力したものですが,ソフトマックス関数を適用する前のものです.また,tは正解クラスラベルを表すスカラであり,one-hotベクトルではありません.よって,torch.nn.functional.cross_entropy(y,t)では,実際には

2025年3月8日 (土) 13:53時点における版

クラスインバランス

学習データセット内で,クラスごとのデータ数に偏りがある状態のことを,クラスインバランスと表現します.今回は,インバランスなデータでも比較的精度良く学習させる方法について学びます.

なぜクラスインバランスが問題か?

極端な例で考えてみることにします.クラスAとBがあり,学習用データセットにおけるクラスAのデータ数は990であり,クラスBのデータ数は10であるとします.識別モデルの気持ちになって考えてみると,どのような入力データに対しても,それがクラスAであると推定しておけば,正解率は99%になります.よって,この学習用データセットの中だけで考えれば,どんな入力であってもクラスAであると推定することは,悪くない戦略であると言えます.しかし,常にクラスAが出力されるモデルであれば,実運用時には何の役にも立ちません.

本質的に同じ問題は,上記のような場面以外にも,色々なところで現れます.例えば,テニスの試合の映像で,ボールの位置をヒートマップにより推定することを考えてみましょう(参考:TrackNet).ボールが存在する領域は,画像全体に対してごくわずかです.ボールの位置を推定するモデルの気持ちになって考えてみると,ほとんどの領域にはボールは存在しないわけですから,全ての領域が真っ黒のヒートマップを出力する(言い換えると,どこにもボールは存在しないと主張する)ことは,合理的な戦略に思えます.しかし,これではボールを検出することはできません.

対策案

ここでは識別問題に絞って考えてみます.クラスインバランスへの対策として,よく使われるのは次の3つの方法です.

  1. データが豊富なクラスのみで学習
  2. データ拡張
  3. 各データへの重みづけ

上記1は,異常検知の問題で使用される方法であり,2クラス分類の問題に対して使える方法です.(多クラス分類に拡張できるのかどうかはわかりません.)異常検知の問題においては,正常データは多数得られるが,異常データは滅多に得られない,という想定がなされます.これは,クラスAには大量の正常データ,クラスBには僅かな異常データ,という2クラス分類問題そのものです.

しかし,単純に識別問題として解けば良いかというと,そう単純でもありません.これまでに見たことがない新しいタイプの異常が,今後現れる可能性があるからです.(犯罪者は,これまでに無かった新しい方法で,セキュリティシステムを欺こうとします.)よって,異常検知の分野では,A(正常)とB(異常)の既知データの間で線を引く問題(識別問題)ではなく,Aのデータだけを用いて学習を行い,Aの分布の内か外かという基準で,正常・異常を判断する問題が扱われます.

上記2は,データに対する摂動を加えたり,GANや拡散モデルなどの生成AIを用いて生成した画像を学習データに加えることで,数が少ないクラスのデータを拡張する方法です.

上記3は,数が少ないクラスのデータに対して大きな重みを与えることで,クラス間のバランスを取ろう,という方法です. 今回は,この方法を実践します.

Focal loss

分類問題でよく使われる損失関数は,クロスエントロピー損失です.これは,モデルの出力とGround Truthをそれぞれおよび,クラス数をとすると,次の式で書けます.


一方,Focal Lossは,インバランスなデータを扱う際に使われる損失関数であり,次の式で計算します.


ここで,はハイパーパラメータです.の時は,クロスエントロピー損失と等価になります.

このがどのような意味を持つかを考えてみます.このサイトの2つ目の図から,の値を変化させた場合に,損失関数の形状がどのように変化するか,を確認できます.の値が大きくなると,0付近ではより勾配が大きく,1付近ではより勾配が小さくなっていることがわかります. 例えば,クラスに属するあるデータについて, かつ の時,損失の値は既に十分に小さく,その付近の勾配も0に近い値になります.よって,そのデータはほとんど学習に寄与しません.一方で,付近では,より大きな勾配を持つことがわかります.

このように,の値が1に近いとき,すなわちモデルがそのデータについて十分な学習が既にできているとき,そのデータによる学習への寄与は小さくなります.反対に,の値が0に近いとき,すなわちモデルがそのデータについて十分に学習していないとき,そのデータはより強く学習に寄与します.これは,の値に関係なく,一貫してみられる傾向ですが,の値が大きいほど,学習が十分でないデータの寄与度合いが相対的に大きくなることがわかります.

インバランスデータでは,多数派のクラスのデータに対しては学習が進みやすく,少数派のクラスのデータに対しては,学習が進みにくい,と考えるのは自然です.よって,少数派により大きな重みを与えることができるFocal Lossが有効に働くと考えられます.

演習

Focal Lossの実装を行います.

準備

まず,作業用ディレクトリに移動した後に,以下のコマンドを実行して,プログラムとデータをダウンロードしてください.

wget https://vrl.sys.wakayama-u.ac.jp/class/pytorch_tutorial/exersise_classimbalance/exersise_classimbalance.py
wget -P ./data -r https://vrl.sys.wakayama-u.ac.jp/class/pytorch_tutorial/datasets/cifar-10-batches-py/

ダウンロードしたら,exersise_classimbalance.pyをそのまま実行してください.おそらく,テストデータに対する精度は93%くらいになったと思います.

クラスインバランスなデータの作成

今回は,CIFAR-10という10クラス分類のベンチマークデータセットを使います.これらのうち,2クラスのみを取り出して使います.そのための処理を書いてみましょう.

プログラム中にある以下の領域に,10クラスあるデータセットから,2クラス(クラス0と1,飛行機と車)のみを取り出す処理が既に書かれています.

###################### ここを完成させましょう 1 #######################

# テストデータのうち,ラベルが0のデータと1のデータのインデックス
indices = []
...
train_dataset = Subset(train_dataset, indices)

###################################################################

次の4行では,テストデータから,我々にとって関心がある2クラスのデータのインデックスを抽出して,Subsetを作成しています.

indices = []
for i in range(len(test_dataset)):
    if test_dataset.targets[i] < 2:
        indices.append(i)
test_dataset = Subset(test_dataset, indices)

学習用データについても,2つのクラスのデータを取り出していますが,まずクラス0のデータのみを,続いてクラス1のデータのみを,別々に取り出しています.これは,同じラベルのデータを揃えておくことで,後々クラスごとの学習データ数の比率を調整しやすくなるからです.

indices = []
for i in range(len(train_dataset)):
    if train_dataset.targets[i] == 0:
        indices.append(i)
for i in range(len(train_dataset)):
    if train_dataset.targets[i] == 1:
        indices.append(i)
train_dataset = Subset(train_dataset, indices)

現在,train_datasetにはクラス0と1のデータが5000個ずつ含まれています.ここで,クラス0のデータを100個だけ残して,残りを間引きましょう.最後の行を,次のように書き換えてください.

train_dataset = Subset(train_dataset, indices[4900:])

これで,クラス0と1のデータ数の比は,1:50になりました.

プログラムをそのまま実行して,テストデータに対する精度を確かめてください.

Focal lossの実装

さて,ここからが本番です.Focal lossを実装します.プログラムの次の領域を見てください.

###################### ここを完成させましょう 2 #######################
CRITERION = nn.CrossEntropyLoss

###################################################################

CRITERION に,pytorchのクロスエントロピー損失のクラスであるnn.CrossEntropyLossが代入されています.さらに,以下の行において,クラスCRITERION(=クラスnn.CrossEntropyLoss)からインスタンスcriterionが作成されており,これが損失関数として学習に使用されています.

criterion = CRITERION()

ひとまずクロスエントロピー損失の自作

まず,クロスエントロピー損失を計算するクラスを自作してみましょう.ResNetの実装のところでもやっているように,nn.Moduleを継承したクラスを作ることで,実装できます, とりあえず,以下の通りに書いてください.

class CRITERION(nn.Module):
    def __init__(self):
        super(CRITERION, self).__init__()

    def forward(self, y, t):
        loss = torch.nn.functional.cross_entropy(y,t)
        return loss

ここで,yとtはそれぞれ,モデルの出力とGround truthです.また,torch.nn.functional.cross_entropy() は,Pytorchで用意してくれているクロスエントロピー損失を計算する関数です.この部分を,Pytorchのビルトイン関数に頼らない方法で書き直していきます.

注意:forward(self, y, t)のyはモデルが出力したものですが,ソフトマックス関数を適用する前のものです.また,tは正解クラスラベルを表すスカラであり,one-hotベクトルではありません.よって,torch.nn.functional.cross_entropy(y,t)では,実際には

  1. yにソフトマックス関数を適用
  2. tをone-hotベクトルに変換
  3. 上記の処理を経たyとtを用いてクロスエントロピー損失を計算

の3つの処理がなされています.我々も,その通りに実装します.

以下の部分を削除するかコメントアウトしてください.

loss = torch.nn.functional.cross_entropy(y,t)

では,tをone-hotベクトルに変換する処理を書いてください.これには,torch.nn.functional.one_shotを使います.

続いて,yにtorch.softmax()を適用します.ただし,yは単一のデータではなく,複数のデータを含むミニバッチのものであることに注意して,適切に引数を渡してください.

yとtからクロスエントロピー損失を計算する式を書いてください.ミニバッチで複数データを同時に処理しますが,各データについての損失の平均値を,forward関数の戻り値に設定してください.

プログラムを実行して,先ほどと同じくらいの精度になったらOKです.

Focal lossへの拡張

Focal lossでは,ハイパーパラメータ を使います.これを,今作成しているCRITERIONクラスに,メンバ変数(self.gamma)として持たせてやりましょう.コンストラクタ__init__()の引数にgammaを設定して,中の処理で

self.gamma = gamma

のように代入してやることで,CRITERIONクラスのどのメソッドからでも,self.gamma で参照できるようになります.

では,forward()を編集して,Focal lossを実装してください.gamma の値をいろいろ変えて実験してみてください.

コマンドライン引数

このままでは,gammaの値を変更するために一々プログラムを書き直しては保存する必要があります.これでは面倒なので,コマンドライン引数でgammaの値を変更できるようにしましょう.

プログラムの先頭付近に,以下の部分があります.

###################### ここを完成させましょう 3 #######################

###################################################################

ここに,以下のように書いてください.

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gamma', type=float)
args = parser.parse_args()

細かい解説は省きますが,上記のように書いてから

$ python exersise_classimbalance.py --gamma 3

のように実行すると,args.gamma に 3 が代入されます.CRITERIONクラスのインスタンス生成時に,引数gammaにこのargs.gamma を渡すようにすれば,一々プログラムを書き直さなくても,いろいろなgammaの値を試すことができます.

答え

答え