陽性例が陰性例に比べて著しく少ないデータセットを「偏ったデータセット」、または不均衡データセットといいます。
この記事では、分類タスクに対してBalanced Loss、回帰タスクに対してCentered Lossを誤差関数として使用することを提案し、これらにより、偏りのあるデータセットに対して精度の高い機械学習が行えることを示します。
記事中に登場するコードの完全版は、GithubリポジトリにJupyter Notebookとしてアップロードしています。
分類タスクにおけるデータの偏りを修正する
偏ったデータセット、または不均衡データセットについての定義や説明は、こちらの記事を参照してください。
ここでは、自由に利用できるデータを利用して偏ったデータセットを再現し、分類タスクを正常に行うための方法としてBalanced Lossを導入することを提案します。
乳がんデータセットから偏ったデータを生成
Pythonのscikit-learn
ライブラリに付属している「乳がんデータセット」を使用します。このデータセットには、良性腫瘍357例・悪性腫瘍212例と、それぞれに対応する30種のパラメータ(変数:腫瘍半径など)が含まれます。
from sklearn.datasets import load_breast_cancer
df_bc = load_breast_cancer()
ここでは212例ある悪性腫瘍から30例のみを抽出します。また、パラメータ30種は多すぎるため、冒頭の4つ(mean radius
, mean texture
, mean perimeter
, mean area
)のみを使用することにします。
import numpy as np
n_malignant = 30
df_bc["target"] = (df_bc["target"] - 1) * -1 # 良性を0、悪性を1にする
ind_bc = np.hstack([np.where(df_bc["target"] == 0)[0], np.where(df_bc["target"] == 1)[0][:n_malignant]])
data_bc = df_bc["data"][ind_bc, :4]
data_bc = (data_bc - data_bc.mean()) / data_bc.std() # パラメータの値を平均0、標準偏差1になるようにスケールする
これで、良性357例・悪性30例からなる偏ったデータセットを作成することができました。
平均二乗誤差で機械学習した場合の失敗例(分類)
良性・悪性腫瘍の分類にはロジスティック回帰を使用します。分類タスクの誤差関数としては、クロスエントロピー誤差を使用するのが一般的ですが、のちの説明のため、ここでは平均二乗誤差MSELoss
を使って学習を行います。
from torch import nn
class LogiReg(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.l1 = nn.Linear(in_dim, out_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.l1(x)
return self.sigmoid(x)
mse_loss = nn.MSELoss()
model_mse = train(LogiReg, mse_loss, data_bc, target_bc, seed=0)
学習の結果を以下に表示します。
十分なイテレーションで学習を行いましたが、良性と悪性をほとんど区別することができていません。予測値はともに0付近であり、ラベルに関わらず「良性」であると予測されています。つまり、多数派である良性データが学習結果全体を引っ張り、少数派の悪性データは無視されていることがわかります。
Balanced Lossの導入による性能改善
\(i\) 番目のデータに対応するラベルと予測値をそれぞれ \(y_i, \tilde{y}_i\) とすると、平均二乗誤差は以下の式で計算することができます。
$$MSE=\frac{1}{N}\sum_{i=0}^{N-1}(y_i-\tilde{y}_i)^2$$
この式では、すべてのデータについての二乗誤差をそのまま足し合わせています。ここに重みづけを導入したものが、以下の式で定義されるBalanced Lossです。
$$BL=\frac{|Z|}{|Z_0|}\sum_{i\in Z_0}(y_i-\tilde{y}_i)^2+\frac{|Z|}{|Z_1|}\sum_{i\in Z_1}(y_i-\tilde{y}_i)^2$$
ここで、 \(Z_0, Z_1\) はそれぞれ良性・悪性ラベルの付いたデータIDの集合、 \(Z\) は全データIDの集合をあらわします。また、 \(|A|\) で集合 \(A\) の要素数を示すことにします。Balanced Lossでは、実際のラベルと予測値の二乗誤差を足し合わせる際に、 \(Z_0\) 由来のデータは \(|Z|/|Z_0|\) 倍、 \(Z_1\) 由来のデータは \(|Z|/|Z_1|\) 倍します。つまり、より数の少ないデータに対する誤差の方が、より重く評価されることになります。
Balanced Lossを誤差関数として使用し、再度ロジスティック回帰を実行します。
import torch
class BalancedLoss(nn.Module):
def __init__(self, y):
super().__init__()
n = len(y)
n0 = sum(y == 0)
n1 = n - n0
self.a = n / (2 * n0)
self.b = n / (2 * n1)
def forward(self, outputs, targets):
tmp = torch.square(outputs - targets)
tmp[targets == 0] *= self.a
tmp[targets == 1] *= self.b
return tmp.sum()
balanced_loss = BalancedLoss(target_bc)
model_balanced = train(LogiReg, balanced_loss, data_bc, target_bc, seed=1)
学習結果は以下の通りです。