陽性例が陰性例に比べて著しく少ないデータセットを「偏ったデータセット」、または不均衡データセットといいます。
この記事では、分類タスクに対してBalanced Loss、回帰タスクに対してCentered Lossを誤差関数として使用することを提案し、これらにより、偏りのあるデータセットに対して精度の高い機械学習が行えることを示します。
記事中に登場するコードの完全版は、GithubリポジトリにJupyter Notebookとしてアップロードしています。
分類タスクにおけるデータの偏りを修正する
偏ったデータセット、または不均衡データセットについての定義や説明は、こちらの記事を参照してください。
ここでは、自由に利用できるデータを利用して偏ったデータセットを再現し、分類タスクを正常に行うための方法としてBalanced Lossを導入することを提案します。
乳がんデータセットから偏ったデータを生成
Pythonのscikit-learn
ライブラリに付属している「乳がんデータセット」を使用します。このデータセットには、良性腫瘍357例・悪性腫瘍212例と、それぞれに対応する30種のパラメータ(変数:腫瘍半径など)が含まれます。
from sklearn.datasets import load_breast_cancer
df_bc = load_breast_cancer()
ここでは212例ある悪性腫瘍から30例のみを抽出します。また、パラメータ30種は多すぎるため、冒頭の4つ(mean radius
, mean texture
, mean perimeter
, mean area
)のみを使用することにします。
import numpy as np
n_malignant = 30
df_bc["target"] = (df_bc["target"] - 1) * -1 # 良性を0、悪性を1にする
ind_bc = np.hstack([np.where(df_bc["target"] == 0)[0], np.where(df_bc["target"] == 1)[0][:n_malignant]])
data_bc = df_bc["data"][ind_bc, :4]
data_bc = (data_bc - data_bc.mean()) / data_bc.std() # パラメータの値を平均0、標準偏差1になるようにスケールする
これで、良性357例・悪性30例からなる偏ったデータセットを作成することができました。

平均二乗誤差で機械学習した場合の失敗例(分類)
良性・悪性腫瘍の分類にはロジスティック回帰を使用します。分類タスクの誤差関数としては、クロスエントロピー誤差を使用するのが一般的ですが、のちの説明のため、ここでは平均二乗誤差MSELoss
を使って学習を行います。
from torch import nn
class LogiReg(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.l1 = nn.Linear(in_dim, out_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.l1(x)
return self.sigmoid(x)
mse_loss = nn.MSELoss()
model_mse = train(LogiReg, mse_loss, data_bc, target_bc, seed=0)
学習の結果を以下に表示します。

十分なイテレーションで学習を行いましたが、良性と悪性をほとんど区別することができていません。予測値はともに0付近であり、ラベルに関わらず「良性」であると予測されています。つまり、多数派である良性データが学習結果全体を引っ張り、少数派の悪性データは無視されていることがわかります。
Balanced Lossの導入による性能改善
\(i\) 番目のデータに対応するラベルと予測値をそれぞれ \(y_i, \tilde{y}_i\) とすると、平均二乗誤差は以下の式で計算することができます。
$$MSE=\frac{1}{N}\sum_{i=0}^{N-1}(y_i-\tilde{y}_i)^2$$
この式では、すべてのデータについての二乗誤差をそのまま足し合わせています。ここに重みづけを導入したものが、以下の式で定義されるBalanced Lossです。
$$BL=\frac{|Z|}{|Z_0|}\sum_{i\in Z_0}(y_i-\tilde{y}_i)^2+\frac{|Z|}{|Z_1|}\sum_{i\in Z_1}(y_i-\tilde{y}_i)^2$$
ここで、 \(Z_0, Z_1\) はそれぞれ良性・悪性ラベルの付いたデータIDの集合、 \(Z\) は全データIDの集合をあらわします。また、 \(|A|\) で集合 \(A\) の要素数を示すことにします。Balanced Lossでは、実際のラベルと予測値の二乗誤差を足し合わせる際に、 \(Z_0\) 由来のデータは \(|Z|/|Z_0|\) 倍、 \(Z_1\) 由来のデータは \(|Z|/|Z_1|\) 倍します。つまり、より数の少ないデータに対する誤差の方が、より重く評価されることになります。
Balanced Lossを誤差関数として使用し、再度ロジスティック回帰を実行します。
import torch
class BalancedLoss(nn.Module):
def __init__(self, y):
super().__init__()
n = len(y)
n0 = sum(y == 0)
n1 = n - n0
self.a = n / (2 * n0)
self.b = n / (2 * n1)
def forward(self, outputs, targets):
tmp = torch.square(outputs - targets)
tmp[targets == 0] *= self.a
tmp[targets == 1] *= self.b
return tmp.sum()
balanced_loss = BalancedLoss(target_bc)
model_balanced = train(LogiReg, balanced_loss, data_bc, target_bc, seed=1)
学習結果は以下の通りです。

閾値0.5付近を境界にして、良性・悪性のデータ集団を分離できていることがわかります。この結果は、ROC曲線からも見ることができます。

ROC-AUCが0.9と高く、偽陽性率20%で真陽性率80%を達成できることからも、精度の高さがうかがえます。
(なお、偏りのあるデータセットに対してはROC曲線よりもPR曲線で評価するべきとされていますが、ここでは客観的な基準を提示するためにROC曲線を使用しています。このあたりの議論に関しては、こちらの記事を参照してください)
回帰タスクにおけるデータの偏りを修正する
Balanced Lossを拡張した誤差関数として、回帰タスクに応用できるCentered Lossを考えることができます。
ここではCentered Lossを用いて、偏ったデータセットの回帰タスクをPythonで実装します。
糖尿病データセットから偏ったデータを生成
Pythonのscikit-learn
ライブラリに付属している「糖尿病データセット」を使用します。
from sklearn.datasets import load_diabetes
df_di = load_diabetes()
このデータセットには、数百例の患者の糖尿病の進行度とそれに対応するパラメータ(年齢、性別、その他検査値など)が含まれています。ここから以下のように症例を抽出します。
remove_rate = 10
target_di = df_di["target"]
target_di = (target_di - target_di.mean()) / target_di.std() # 進行度の値を平均0、標準偏差1になるようにスケールする
ind_di = np.hstack([np.where((target_di <= 0) & (target_di >= -1.1))[0], np.where(target_di > 0)[0][::remove_rate]]) # 進行度の高い症例数を1/10に減らす
つまり、ほとんどの症例が進行度100前後で、それよりも進行度の高い症例数は少なくなるよう、データセットを変形しました(結果をわかりやすくするため、同時に進行度が極端に低い症例も取り除いています)。

平均二乗誤差で機械学習した場合の失敗例(回帰)
隠れ層を1つ持つ2層ニューラルネットワークを使用して、各パラメータから糖尿病の進行度を回帰します。予測された進行度と実際の値の誤差は、分類タスクと同様に平均二乗誤差で評価します。このとき、機械学習の安定性を高めるため、糖尿病の進行度の値を平均0、標準偏差1になるようにスケールしています。
class TwoLayer(nn.Module):
def __init__(self, in_dim, out_dim, dim=10):
super().__init__()
self.l1 = nn.Linear(in_dim, dim)
self.relu = nn.LeakyReLU()
self.l2 = nn.Linear(dim, out_dim)
def forward(self, x):
x = self.l1(x)
x = self.relu(x)
return self.l2(x)
model_mse = train(TwoLayer, mse_loss, data_di, target_di, seed=2)
下図は、実際の糖尿病の進行度(真の値)にしたがってデータをソートし、真の値と予測値をプロットしたものです。

データID:180以降は進行度の高い症例となっていますが、ほとんど予測に反映されず、無視されていることがわかります。
Centered Lossの導入による性能改善
回帰タスクでは、以下のCentered Lossを誤差関数として使用することを提案します。
$$CL=\frac{S_p|Z|}{S_pR_n-S_nR_p}\sum_{i\in Z_n}(y_i-\tilde{y}_i)^2-\frac{S_n|Z|}{S_pR_n-S_nR_p}\sum_{i\in Z_p}(y_i-\tilde{y}_i)^2$$
ここで、 \(Z_n, Z_p\) はそれぞれ、実際の進行度が0未満または0以上となるデータIDの集合です。その他の文字については複雑になるため、補足Bの式 \((3), (4)\) を参照してください。
このCentered Loss関数では、スケールされたデータに対して0を中心に重みづけをするという操作を行っています。また、補足Bでは、Centered LossがBalanced Lossの自然な拡張であることの証明も行っています。
この誤差関数を使用して回帰モデルの訓練を行います。
class CenteredLoss(nn.Module):
def __init__(self, y):
super().__init__()
n = len(y)
tn = y[y < 0]
tp = y[y >= 0]
sn = sum(tn)
sp = sum(tp)
rn = sum(tn * tn)
rp = sum(tp * tp)
den = sp * rn - sn * rp
self.wn = sp * n / den
self.wp = -sn * n / den
def forward(self, outputs, targets):
tmp = torch.square(outputs - targets)
tmp[targets < 0] *= self.wn
tmp[targets >= 0] *= self.wp
return tmp.sum()
centered_loss = CenteredLoss(target_di)
model_centered = train(TwoLayer, centered_loss, data_di, target_di, seed=3)
その結果、値が大きいID:180以降のデータも反映して回帰を行うことができました(下図)。

偏ったデータに対する客観的な性能評価
分類タスクや回帰タスクの結果から、なんとなく機械学習はうまくいっているように見受けられます。
このことを客観的に評価するために、Balanced LossとCentered Lossの設計の背後にあるNullデコーダの基準を利用することができます。
Nullモデルで最低ラインを決める
Nullモデルとは、まったく学習を行わなくても達成できる最低限の機械学習モデルを意味します。たとえば、良性=0と悪性=1を判定する乳がんデータセットに対しては、常に「良性かもしれないし、悪性かもしれない」=0.5と出力することが、考えられる最低限の(無意味な)機械学習モデルになります。Nullモデルは、常に一定値を出力する機械学習モデルの中では、最も高い精度を出すモデルであるとします。
機械学習がうまくいった場合、少なくともこのNullモデルよりも高い精度を達成できているはずです。しかし、機械学習が失敗した場合には、Nullモデル以下の性能になることもあります。たとえば、良性データに対して常に「悪性」と予測し、悪性データに対して常に「良性」と判定してしまう場合です。これは無意味を通り越して有害なモデルです。
Balanced / Centered LossにおけるNullモデル
詳細は補足Bを参照してほしいのですが、Balanced Lossの設計にあたっては、常に0.5=1/2を出力する場合がNullモデルであり、その時の誤差が \(|Z|/2\) となるように条件を付けています。これは、常に一定値を出力するモデルの中では最も高い精度であるため、失敗例で見たような、すべての出力が0付近になるようなモデルの性能はNullモデル未満と評価されます。
一方、Balanced Lossにもとづいて学習を行った場合は、最終的な誤差が106.92と、Nullモデルの誤差 \(|Z|/2=193.50\) を大きく下回ったことから、正常に学習が行われたことが示された。
同様にCentered Lossは、常に0を出力する場合に、誤差 \(|Z|\) のNullモデルとなるように設計されている。Centered Lossにもとづいて学習した場合の結果も、誤差81.98で \(|Z|=204.00\) を大きく下回った。
このように、Balanced LossとCentered Lossは客観的な性能指標を提供する。
(補足A) 異常検知を用いた解決策
データに偏りがある場合の解決策としては、ほかにも異常検知を用いることができます。詳細は以下の記事を三h勝してください。
しかし、異常検知が成功するためには、各ラベルに対応するデータの分布が大きく離れている必要があります。この前提がみたされない場合には、やはりBalanced LossやCentered Lossを使った手法が主流になると考えられます。
(補足B) Centered LossがBalanced Lossの自然な拡張であることの証明
Balanced Lossの導出
平均二乗誤差に、ラベルごとの重みづけ \(w_0, w_1\) を導入し、
$$L=w_0\sum_{i\in Z_0}(y_i-\tilde{y}_i)^2+w_1\sum_{i\in Z_1}(y_i-\tilde{y}_i)^2$$
とおきます。また、つねに \(\tilde{y}_i=\alpha\) を出力する場合の誤差の値を \(L(\alpha)\) と書くことにします。
重み \(w_0, w_1\) の決定にあたり、Nullモデルの設定から、以下をみたすことを条件とします。
- \(L\left(\frac{1}{2}\right)=|Z|/2\)
- \(L\left(\frac{1}{2}\right)\) は最小値である
ここで、
$$L(\alpha)=w_0\sum_{i\in Z_0}(0-\alpha)^2+w_1\sum_{i\in Z_1}(1-\alpha)^2$$
$$=w_0\sum_{i\in Z_0}\alpha^2+w_1\sum_{i\in Z_1}(1-\alpha)^2=w_0|Z_0|\alpha^2+w_1|Z_1|(1-\alpha)^2$$
より、条件1.からは以下の関係式が得られます。
$$L\left(\frac{1}{2}\right)=\frac{1}{4}(w_0|Z_0|+w_1|Z_1|)=|Z|/2\tag{1}$$
条件2.については、 \(L(\alpha)\) の導関数が
$$\frac{dL}{d\alpha}(\alpha)=2\{w_0|Z_0|\alpha-w_1|Z_1|(1-\alpha)\}$$
と書けることから、 \(\alpha=0.5\) で極値をとる条件より
$$\frac{dL}{d\alpha}\left(\frac{1}{2}\right)=w_0|Z_0|-w_1|Z_1|=0\tag{2}$$
が得られます。 \((1), (2)\) を連立させて解くことで
$$w_0=\frac{|Z|}{|Z_0|},\quad w_1=\frac{|Z|}{|Z_1|}$$
であることがわかり、Balanced Lossが導かれます。
Centered Lossの導出
本文中の設定にとらわれず、一般的な誤差関数の形を導くことにします。
\(c\) を中心に、データセットを \(Z_n={i|y_i<c}, Z_p={i|y_i\geq c}\) に分割し、それぞれの予測誤差に対して \(w_n, w_p\) で重みづけを行います。
$$L=w_n\sum_{i\in Z_n}(y_i-\tilde{y}_i)^2+w_p\sum_{i\in Z_p}(y_i-\tilde{y}_i)^2$$
Balanced Lossのときと同様に、常に \(\tilde{y}_i=\alpha\) を出力するNullモデルを想定し、関数化します。
$$L(\alpha)=w_n\sum_{i\in Z_n}(y_i-\alpha)^2+w_p\sum_{i\in Z_p}(y_i-\alpha)^2$$
ここで、このあとの計算のために関数を展開します。結果が複雑になることを避けるため、
$$S_n=\sum_{i\in Z_n}y_i,\quad S_p=\sum_{i\in Z_p}y_i\tag{3}$$
$$R_n=\sum_{i\in Z_n}y_i^2,\quad R_p=\sum_{i\in Z_p}y_i^2\tag{4}$$
とおくことで、展開式は次のように書けます。
$$L(\alpha)=w_n(R_n-2\alpha S_n+|Z_n|\alpha^2)+w_p(R_p-2\alpha S_p+|Z_p|\alpha^2)\tag{5}$$
また、この関数の微分は次のとおりです。
$$\frac{dL}{d\alpha}(\alpha)=-2w_n(S_n+|Z_n|\alpha)-2w_p(S_p+|Z_p|\alpha)\tag{6}$$
ここで、Balanced Lossのときと同様、係数の決定のために
- \(L(c)=E\)
- \(L(c)\) は最小値である
という条件を設定しますが、ここではデータセットの分割点 \(c\) で最小値 \(E\) をとるものとして一般化しています。式 \((5)\) や \((6)\) に値を代入することで、それぞれの条件は
$$w_n(R_n-2cS_n+|Z_n|c^2)+w_p(R_p-2cS_p+|Z_p|c^2)=E$$
$$-2w_n(S_n+|Z_n|c)-2w_p(S_p+|Z_p|c)=0$$
と書けます。これらを連立させて解くことで
$$w_n=\frac{(S_p+|Z_p|c)E}{(S_p+|Z_p|c)(R_p-2cS_p+|Z_p|c^2)-(S_n+|Z_n|c)(R_n-2cS_n+|Z_n|c^2)}$$
$$w_p=-\frac{(S_n+|Z_n|c)E}{(S_p+|Z_p|c)(R_p-2cS_p+|Z_p|c^2)-(S_n+|Z_n|c)(R_n-2cS_n+|Z_n|c^2)}$$
が得られます。
ここで \(c=0, E=|Z|\) を代入すると、
$$w_n=\frac{S_p|Z|}{S_pR_n-S_nR_p},\quad w_p=-\frac{S_n|Z|}{S_pR_n-S_nR_p}$$
となり、Centered Lossの係数が導出できることがわかります。
\(c=0.5, E=|Z|/2\) とし、さらに乳がんデータセットにおいては \(S_n=R_n=0, S_p=R_p=|Z_1|\) が成り立つことを利用すると、Balanced Lossの係数を得ることができます。
$$w_0=\frac{|Z|}{|Z_0|},\quad w_1=-\frac{|Z|}{|Z_1|}$$
以上より、Balanced LossはCentered Lossの特殊な形であり、逆にCentered LossはBalanced Lossの自然な拡張であることが示されました。
もっと知りたいこと、感想を教えてください!