PR

Pythonによる異常検知の実装―偏ったデータセットでの機械学習

AI便利帳
Sponsored

現実世界では、陰性例に比べて陽性例が少なすぎるという偏ったデータセットによく出会います。通常の機械学習手法では、このようなデータセットを正しく学習するのは困難です。

しかし、異常検知の手法を使うことで、偏ったデータセットに対する分類タスクを解決できる場合があります。ここでは偏ったデータセットの問題点を解説したあと、Pythonで異常検知を実行する方法を説明します。その後、異常検知が失敗するケースと代替案についても紹介します。

記事中に登場するコードの完全版は、GithubリポジトリにJupyter Notebookとしてアップロードしています。

機械学習におけるデータの偏り

機械学習を適用するためのデータセットの中に、特定のラベルを持つデータのみが多く含まれている状況がよく見られます。

ここでは、このような「データの偏り」の具体例と、それが引き起こす問題について解説します。

「偏ったデータセット」の意味と具体例

機械学習で用いられるデータセットには、「データ」と「ラベル」が含まれます。病気の予測の場合には、年齢・性別や血液検査の結果などが「データ」、陽性(病気あり)や陰性(病気なし)といった診断が「ラベル」になります。統計学の言葉を使うと、「データ」は独立変数、「ラベル」は従属変数に相当します。

基本的に、機械学習を行う際にはラベルの内訳(陽性と陰性の割合)のバランスが取れていることが理想とされています。しかし現実世界では、なかなかそうはいきません。

たとえば、全人口の10%が罹っている病気Aがあるとします。これを、「Aの有病率は10%である」と表現します。

一般的に、有病率10%は非常に高い数値です。しかしそれでも、病気の人を1人に対して正常な人が9人存在することになります。よって、病気Aのデータセットを自然に作成した場合、ラベルの内訳が陽性:陰性=1:9と偏った状態になります。

偏ったデータセットの問題点

使用するデータセットに偏りがある場合、機械学習が失敗し、精度が低下しやすくなります。その理由は次のように考えることができます。

たとえばラベルが陽性:陰性=1:9である場合、どのようなデータの入力に対しても「陰性」と出力するだけで精度90%を達成することができます。ここから機械学習のパラメータを更新して陽性例に対する予測性能を高めようとしても、それに伴い陰性例に対する性能が低下してしまう場合、全体の精度に対して、後者は前者の9倍影響を与えます。

以上の理由から、偏ったデータセットを使用すると、機械学習モデルが少数派を無視する現象が頻発します。その結果、機械学習が失敗してしまうのです。

偏ったデータセットに異常検知を適用する

異常検知という手法を用いると、偏ったデータセットに対しても適切な分類をできる場合があります。

ここでは、異常検知の原理を説明し、Pythonによる実行例を示します。

偏ったデータセットを再現する

誰でも利用できるデータセットとして、ここではPythonのscikit-learnライブラリに付属している「乳がんデータセット」を使用します。

from sklearn.datasets import load_breast_cancer

df_bc = load_breast_cancer()

このデータセットには、良性腫瘍357例・悪性腫瘍212例と、それぞれに対応する30種のパラメータ(変数:腫瘍半径など)が含まれます。ここでは212例ある悪性腫瘍から30例のみを抽出します。また、パラメータ30種は多すぎるため、冒頭の4つ(mean radius, mean texture, mean perimeter, mean area)のみを使用することにします。

import numpy as np

n_malignant = 30
df_bc["target"] = (df_bc["target"] - 1) * -1 # 良性を0、悪性を1にする
ind_bc = np.hstack([np.where(df_bc["target"] == 0)[0], np.where(df_bc["target"] == 1)[0][:n_malignant]])
data_bc = df_bc["data"][ind_bc, :4]
data_bc = (data_bc - data_bc.mean()) / data_bc.std() # パラメータの値を平均0、標準偏差1になるようにスケールする

以上により、良性357例・悪性30例からなる偏ったデータセットを作成することができました。

偏ったデータセットでの学習失敗例

このデータセットをロジスティック回帰を用いて学習してみます。このとき、評価関数としてはクロスエントロピー誤差を使用するのが一般的ですが、ここでは後の説明のため、平均二乗誤差MSELoss)を使っています。

from torch import nn

mse_loss = nn.MSELoss()
model_mse = train(LogiReg, mse_loss, data_bc, target_bc, seed=0)

上図はロジスティック回帰の結果を表示したものです。十分な回数データを参照して分類モデルを訓練しましたが、このモデルは良性と悪性をほとんど区別できていないことがわかります。データの良性・悪性にかかわらず、モデルの予測値は常に0(=良性)付近となっています。

異常検知の原理とマハラノビス距離

異常検知においては、データセットのうち、多数派のラベルを持つデータのみを使用します。つまり、乳がんデータセットの場合は「良性」のデータのみを使います。

同じ「良性」の中でもデータにバラツキが存在しています。しかし、その範囲はある一定の分布内に収まっているはずです。一方、「悪性」のデータはこの範囲から大きく外れたところに存在していると考えられます。異常検知は、以上のことを想定した手法です。

良性データの範囲は、データの次元と同じ多次元正規分布(平均: \(\boldsymbol{\mu}\), 分散共分散行列: \(\boldsymbol{\Sigma\}\))によってモデル化することができます。マハラノビス距離

$$D_M=\sqrt{(\mathbf{x}-\boldsymbol{\mu})^\top\boldsymbol{\Sigma}^{-1}(\mathbf{x}-\boldsymbol{\mu})}$$

は、データの分布(分散共分散)を考慮し、データ点 \(\mathbf{x}\) が平均からどれだけ離れているかを評価する指標です。これを利用して、データ点のマハラノビス距離が一定の閾値より大きい場合に「悪性」であると判断します。

異常検知をPythonで実行する

Pythonを使用すると、乳がんデータセットのマハラノビス距離は以下のように計算することができます。

data_bc_benign = data_bc[target_bc[:, 0] == 0] # 良性データのみ抽出
xm = data_bc - data_bc_benign.mean(axis=0) # 良性データの平均と、全データの差
cov_inv = torch.inverse(data_bc_benign.T.cov()) # 良性データの分散共分散行列(逆行列)
mahal = (xm * torch.matmul(xm, cov_inv)).sum(axis=1).detach().numpy() # マハラノビス距離

この距離にもとづいて、閾値を変化させながらROC曲線を描画すると以下のようになりました。

ROC-AUCが0.97と非常に高い値を示しており、偽陽性率20%で真陽性率90%を達成していることから、異常検知の手法はこのデータセットに対して極めて有効であることがわかります。

(なお、偏りのあるデータセットに対してはROC曲線よりもPR曲線で評価するべきとされていますが、ここでは客観的な基準を提示するためにROC曲線を使用しています。このあたりの議論に関しては、こちらの記事を参照してください)

異常検知のメリット・デメリット

しかし、異常検知では対応できないケースもあるため注意が必要です。

異常検知は、正常例(多数派)の分布と異常例(少数派)の分布が離れていることを想定しています。乳がんデータセットでもそのような傾向が見られたため、今回は無事に分類を行うことができました。しかし、データの分布が以下のような場合には異常検知は成功しません。

これは具体的には、「悪性腫瘍の特徴は、大きさがXcm以上かつYcm以下であることで、それより大きいか小さいものは良性腫瘍」というような特徴があるような場合を想定しています。良性データしか参照しない異常検知では、フタを開けてみるまで悪性データのパラメータ分布が明らかではない点が問題となり得ます。

異常検知が失敗する場合には、通常の機械学習を実行する際に、サンプリング手法を工夫したり、特別な誤差関数を使用したりすることが考えられます。U-知能デバイス研究所では、Balanced LossとCentered Lossという誤差関数を活用することを推奨しています。

もっと知りたいこと、感想を教えてください!