PR

データの偏りを補正するための誤差関数の提案

AI便利帳
Sponsored

陽性例が陰性例に比べて著しく少ないデータセットを「偏ったデータセット」、または不均衡データセットといいます。

この記事では、分類タスクに対してBalanced Loss、回帰タスクに対してCentered Lossを誤差関数として使用することを提案し、これらにより、偏りのあるデータセットに対して精度の高い機械学習が行えることを示します。

記事中に登場するコードの完全版は、GithubリポジトリにJupyter Notebookとしてアップロードしています。

分類タスクにおけるデータの偏りを修正する

偏ったデータセット、または不均衡データセットについての定義や説明は、こちらの記事を参照してください。

ここでは、自由に利用できるデータを利用して偏ったデータセットを再現し、分類タスクを正常に行うための方法としてBalanced Lossを導入することを提案します。

乳がんデータセットから偏ったデータを生成

Pythonのscikit-learnライブラリに付属している「乳がんデータセット」を使用します。このデータセットには、良性腫瘍357例・悪性腫瘍212例と、それぞれに対応する30種のパラメータ(変数:腫瘍半径など)が含まれます。

from sklearn.datasets import load_breast_cancer

df_bc = load_breast_cancer()

ここでは212例ある悪性腫瘍から30例のみを抽出します。また、パラメータ30種は多すぎるため、冒頭の4つ(mean radius, mean texture, mean perimeter, mean area)のみを使用することにします。

import numpy as np

n_malignant = 30
df_bc["target"] = (df_bc["target"] - 1) * -1 # 良性を0、悪性を1にする
ind_bc = np.hstack([np.where(df_bc["target"] == 0)[0], np.where(df_bc["target"] == 1)[0][:n_malignant]])
data_bc = df_bc["data"][ind_bc, :4]
data_bc = (data_bc - data_bc.mean()) / data_bc.std() # パラメータの値を平均0、標準偏差1になるようにスケールする

これで、良性357例・悪性30例からなる偏ったデータセットを作成することができました。

平均二乗誤差で機械学習した場合の失敗例(分類)

良性・悪性腫瘍の分類にはロジスティック回帰を使用します。分類タスクの誤差関数としては、クロスエントロピー誤差を使用するのが一般的ですが、のちの説明のため、ここでは平均二乗誤差MSELossを使って学習を行います。

from torch import nn

class LogiReg(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.l1 = nn.Linear(in_dim, out_dim)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.l1(x)
        return self.sigmoid(x)

mse_loss = nn.MSELoss()
model_mse = train(LogiReg, mse_loss, data_bc, target_bc, seed=0)

学習の結果を以下に表示します。

十分なイテレーションで学習を行いましたが、良性と悪性をほとんど区別することができていません。予測値はともに0付近であり、ラベルに関わらず「良性」であると予測されています。つまり、多数派である良性データが学習結果全体を引っ張り、少数派の悪性データは無視されていることがわかります。

Balanced Lossの導入による性能改善

\(i\) 番目のデータに対応するラベルと予測値をそれぞれ \(y_i, \tilde{y}_i\) とすると、平均二乗誤差は以下の式で計算することができます。

$$MSE=\frac{1}{N}\sum_{i=0}^{N-1}(y_i-\tilde{y}_i)^2$$

この式では、すべてのデータについての二乗誤差をそのまま足し合わせています。ここに重みづけを導入したものが、以下の式で定義されるBalanced Lossです。

$$BL=\frac{|Z|}{|Z_0|}\sum_{i\in Z_0}(y_i-\tilde{y}_i)^2+\frac{|Z|}{|Z_1|}\sum_{i\in Z_1}(y_i-\tilde{y}_i)^2$$

ここで、 \(Z_0, Z_1\) はそれぞれ良性・悪性ラベルの付いたデータIDの集合、 \(Z\) は全データIDの集合をあらわします。また、 \(|A|\) で集合 \(A\) の要素数を示すことにします。Balanced Lossでは、実際のラベルと予測値の二乗誤差を足し合わせる際に、 \(Z_0\) 由来のデータは \(|Z|/|Z_0|\) 倍、 \(Z_1\) 由来のデータは \(|Z|/|Z_1|\) 倍します。つまり、より数の少ないデータに対する誤差の方が、より重く評価されることになります。

Balanced Lossを誤差関数として使用し、再度ロジスティック回帰を実行します。

import torch

class BalancedLoss(nn.Module):
    def __init__(self, y):
        super().__init__()
        n = len(y)
        n0 = sum(y == 0)
        n1 = n - n0
        self.a = n / (2 * n0)
        self.b = n / (2 * n1)
    
    def forward(self, outputs, targets):
        tmp = torch.square(outputs - targets)
        tmp[targets == 0] *= self.a
        tmp[targets == 1] *= self.b
        return tmp.sum()

balanced_loss = BalancedLoss(target_bc)
model_balanced = train(LogiReg, balanced_loss, data_bc, target_bc, seed=1)

学習結果は以下の通りです。