Multiclass Classificationとmacro/micro-Precision, Recall

2021/12/16

機械学習

t f B! P L

テキスト分類問題は自然言語処理でも比較的成果を出しやすいタスクです。
しかしながら微妙に問題のタイプがあるので、それぞれのタイプについて適切な評価指標を見ていきましょう。
問題によって、別の評価指標でも同じ値になったりします。バグじゃないよ。

Binary Classification

2値分類、2クラス分類。
スパム検出など(スパムかそうでないか)に使われる。

まず以下のConfusion matrix(混合行列)を覚える。

正解0 正解1
予測0 TP FP
予測1 FN TN

クラス0がPositive(陽性)、クラス1がNegative(陰性)とする。
スパムならクラス0でスパム、1でスパムじゃない。
(ここ逆じゃないかと思うと思うけど後半との整合性のためこう書いてる)

  • TP(True Positive): 正解0のデータに対して予測0を出力して正解
  • FP(False Positive): 正解1のデータに対して予測0を出力して不正解
  • FN(False Negative): 正解0のデータに対して予測1を出力して不正解
  • TN(True Negative): 正解1のデータに対して予測1を出力して正解

Posi/Negaは予測の陰陽そのまま。
True/Falseは正解クラスのTrue/Falseではなく、予測が正解しているかどうか。
なので、予測と正解が一致している対角成分のみがTrueになる。
これに従うと

  • TP(True Positive)は「True(成功した)Positive予測」
  • FP(False Positive)は「False(失敗した)Positive予測」
  • FN(False Negative)は「False(失敗した)Negative予測」
  • TN(True Negative)は「True(成功した)Negative予測」

になって覚えやすい。

各評価指標は混合行列に着目すると以下のように定義できる。

Accurary: 全事象(TP+FP+FN+TN)に対して、正解した事象(TP+TN)の割合

Accuracy=TP+TNTP+FP+FN+TNAccuracy = \frac{TP+TN}{TP+FP+FN+TN}

Precision: 予測がPositiveの事象(TP+FP)に対して、正解した事象(TP)の割合

Precision=TPTP+FPPrecision = \frac{TP}{TP+FP}

Recall: 正解がPositiveの事象(TP+FN)に対して、正解した事象(TP)の割合

Recall=TPTP+FNRecall = \frac{TP}{TP+FN}

F1-score: Precision(P)とRecall(R)の調和平均

F1=2PRP+R F_1 = \frac{2P*R}{P+R}

PとRは、TPなどを用いて表現できるので代入して式変形すると

F1=2TP2TP+FP+FNF_1 = \frac{2TP}{2TP+FP+FN}

2値分類では、これらの指標は特殊な場合を除いて同じ値にはなりません(例えば、偶然TP=TNの場合、AccuraryとF1は一緒の値になります)。

特にクラスラベルに偏りがある場合に、Accuraryだけで判断せずPrecision/Recallをしっかり見ようというのは基本のきとしてよく言われると思います。

Multiclass Classification

多クラス分類です。ニュース記事の分類などが有名ですね。

ここでは、1データポイントあたり正解クラスが1つであり、予測も1データポイントあたり1つの予測クラスを与えることします。

つまりデータ数がNであれば、混合行列の全事象数もN個になります。

クラスが3つ以上ある場合、Binary Classificationのように1つのクラスのみをPositiveとすることはできないので、すべてのクラスについて着目しているそのクラスをPositive、それ以外のクラスをNegativeとして、集計します。

その集計方法にmacroとmicroがあります。
これにより、Precision, Recall, F1が微妙に違う値になります(Accuracyは変わりません)。

混合行列に実際の数値を記入します。

正解0 正解1 正解2
予測0 5 1 1
予測1 2 6 3
予測2 1 2 6

Accuracy
(5 + 6 + 6) / (5 + 1 + 1 + 2 + 6 + 3 + 1 + 2 + 6) = 0.6296

macro

macroの場合、クラスごとにPrecision/Recallを計算してから、その結果を単純に平均します。

macro Precision
分子が対角成分i、分母が行iの行和になります。

  • クラス0: 5 / (5 + 1 + 1) = 0.7142
  • クラス1: 6 / (2 + 6 + 3) = 0.5454
  • クラス2: 6 / (1 + 2 + 6) = 0.6666

平均値 = 1/3 * (0.7142 + 0.5454 + 0.6666) = 0.6421

macro Recall

分子が対角成分i、分母が列iの列和になります。

  • クラス0: 5 / (5 + 2 + 1) = 0.625
  • クラス1: 6 / (1 + 6 + 2) = 0.6666
  • クラス2: 6 / (1 + 3 + 6) = 0.6

平均値 = 1/3 * (0.625 + 0.6666 + 0.6) = 0.6305

macro Precision/Recallは特殊なケースを除いてAccuracyとは一致しません。
F1は省略しましたが、F1も基本的には同じ値になりません。

上の式を考えればわかりますが、偶然、混合行列の行和、または列和がすべて同じ値だと、Accuracyに一致します。

行和が揃うのはモデルから出力される予測クラスの個数が、各クラスごとに同じになることを意味するので狙ってコントロールできるものでもないと思います。

一方で、列和が揃うのは、各クラスごとに同数のデータポイントが存在するケースなので、綺麗なデータセットだとこのシチュエーションはなくもないです。

micro

microの場合、実質的にクラスを区別せずPrecision/Recallを一発で計算します。

micro Precision

(5 + 6 + 6) / ((5 + 1 + 1) + (2 + 6 + 3) + (1 + 2 + 6)) = 0.6296

micro Recall

(5 + 6 + 6) / ((5 + 2 + 1) + (1 + 6 + 2) + (1 + 3 + 6)) = 0.6296

てこれAccuracyやんかい!!!

やっときました。
Multiclass Classificationのmicro Precision/RecallはAccuracyに一致するんですね。

分子が対角成分の和、分母が行和または列和の和になるので、分母は混合行列の全事象に他なりません。

F1はP=Rの場合、定義式よりP=R=Fになるので、これもAccuracyに一致します。

どれを使うか

情報量は混合行列>クラスごとのPrecision/Recall>macro/microのPrecision/Recallの順で大きいです。

多くの場合問題なのは、正解クラスのラベル、予測クラスのラベルに偏りがある場合です。

その場合、micro/macroレベルで見てしまうと、特定のクラスについて精度が出てない状態を見落とす可能性があります。

特にmacro集計の場合は、クラスごとのimblanceを考慮せず、各クラスの精度をすべて重み1で考慮してしまうので、あまり好きではないです。

クラスごとのPrecision/Recallであったり、出来れば混合行列で見て、このクラスとこのクラスは誤認識が多い、というような分析ができると良いと思います。

ラベル

QooQ