現実世界では、陰性例に比べて陽性例が少なすぎるという偏ったデータセットによく出会います。通常の機械学習手法では、このようなデータセットを正しく学習するのは困難です。
しかし、異常検知の手法を使うことで、偏ったデータセットに対する分類タスクを解決できる場合があります。ここでは偏ったデータセットの問題点を解説したあと、Pythonで異常検知を実行する方法を説明します。その後、異常検知が失敗するケースと代替案についても紹介します。
記事中に登場するコードの完全版は、GithubリポジトリにJupyter Notebookとしてアップロードしています。
機械学習におけるデータの偏り
機械学習を適用するためのデータセットの中に、特定のラベルを持つデータのみが多く含まれている状況がよく見られます。
ここでは、このような「データの偏り」の具体例と、それが引き起こす問題について解説します。
「偏ったデータセット」の意味と具体例
機械学習で用いられるデータセットには、「データ」と「ラベル」が含まれます。病気の予測の場合には、年齢・性別や血液検査の結果などが「データ」、陽性(病気あり)や陰性(病気なし)といった診断が「ラベル」になります。統計学の言葉を使うと、「データ」は独立変数、「ラベル」は従属変数に相当します。
基本的に、機械学習を行う際にはラベルの内訳(陽性と陰性の割合)のバランスが取れていることが理想とされています。しかし現実世界では、なかなかそうはいきません。
たとえば、全人口の10%が罹っている病気Aがあるとします。これを、「Aの有病率は10%である」と表現します。
一般的に、有病率10%は非常に高い数値です。しかしそれでも、病気の人を1人に対して正常な人が9人存在することになります。よって、病気Aのデータセットを自然に作成した場合、ラベルの内訳が陽性:陰性=1:9と偏った状態になります。
偏ったデータセットの問題点
使用するデータセットに偏りがある場合、機械学習が失敗し、精度が低下しやすくなります。その理由は次のように考えることができます。
たとえばラベルが陽性:陰性=1:9である場合、どのようなデータの入力に対しても「陰性」と出力するだけで精度90%を達成することができます。ここから機械学習のパラメータを更新して陽性例に対する予測性能を高めようとしても、それに伴い陰性例に対する性能が低下してしまう場合、全体の精度に対して、後者は前者の9倍影響を与えます。
以上の理由から、偏ったデータセットを使用すると、機械学習モデルが少数派を無視する現象が頻発します。その結果、機械学習が失敗してしまうのです。
偏ったデータセットに異常検知を適用する
異常検知という手法を用いると、偏ったデータセットに対しても適切な分類をできる場合があります。
ここでは、異常検知の原理を説明し、Pythonによる実行例を示します。
偏ったデータセットを再現する
誰でも利用できるデータセットとして、ここではPythonのscikit-learn
ライブラリに付属している「乳がんデータセット」を使用します。
from sklearn.datasets import load_breast_cancer
df_bc = load_breast_cancer()
このデータセットには、良性腫瘍357例・悪性腫瘍212例と、それぞれに対応する30種のパラメータ(変数:腫瘍半径など)が含まれます。ここでは212例ある悪性腫瘍から30例のみを抽出します。また、パラメータ30種は多すぎるため、冒頭の4つ(mean radius
, mean texture
, mean perimeter
, mean area
)のみを使用することにします。
import numpy as np
n_malignant = 30
df_bc["target"] = (df_bc["target"] - 1) * -1 # 良性を0、悪性を1にする
ind_bc = np.hstack([np.where(df_bc["target"] == 0)[0], np.where(df_bc["target"] == 1)[0][:n_malignant]])
data_bc = df_bc["data"][ind_bc, :4]
data_bc = (data_bc - data_bc.mean()) / data_bc.std() # パラメータの値を平均0、標準偏差1になるようにスケールする
以上により、良性357例・悪性30例からなる偏ったデータセットを作成することができました。
偏ったデータセットでの学習失敗例
このデータセットをロジスティック回帰を用いて学習してみます。このとき、評価関数としてはクロスエントロピー誤差を使用するのが一般的ですが、ここでは後の説明のため、平均二乗誤差(MSELoss
)を使っています。
from torch import nn
mse_loss = nn.MSELoss()
model_mse = train(LogiReg, mse_loss, data_bc, target_bc, seed=0)
上図はロジスティック回帰の結果を表示したものです。十分な回数データを参照して分類モデルを訓練しましたが、このモデルは良性と悪性をほとんど区別できていないことがわかります。データの良性・悪性にかかわらず、モデルの予測値は常に0(=良性)付近となっています。
異常検知の原理とマハラノビス距離
異常検知においては、データセットのうち、多数派のラベルを持つデータのみを使用します。つまり、乳がんデータセットの場合は「良性」のデータのみを使います。
同じ「良性」の中でもデータにバラツキが存在しています。しかし、その範囲はある一定の分布内に収まっているはずです。一方、「悪性」のデータはこの範囲から大きく外れたところに存在していると考えられます。異常検知は、以上のことを想定した手法です。
良性データの範囲は、データの次元と同じ多次元正規分布(平均: \(\boldsymbol{\mu}\), 分散共分散行列: \(\boldsymbol{\Sigma\}\))によってモデル化することができます。マハラノビス距離
$$D_M=\sqrt{(\mathbf{x}-\boldsymbol{\mu})^\top\boldsymbol{\Sigma}^{-1}(\mathbf{x}-\boldsymbol{\mu})}$$
は、データの分布(分散共分散)を考慮し、データ点 \(\mathbf{x}\) が平均からどれだけ離れているかを評価する指標です。これを利用して、データ点のマハラノビス距離が一定の閾値より大きい場合に「悪性」であると判断します。
異常検知をPythonで実行する
Pythonを使用すると、乳がんデータセ