自作ソフトにテキストの多クラス分類機能を組み込みたくて、調べてみたら Complement Naive Bayes(CNB、補集合ナイーブベイズ)というアルゴリズムが最近の流行のようで、これを検証してみることにしました。
元論文 を一通り読んでから検証を進めていきました。実装される際は目を通すことをオススメします。
使用したコーパスは以下のようなもの
想定する用途に合わせて、それなりにクラス間でデータの量にばらつきがあります。
クラス | ファイル数 | サイズ |
A | 832 | 121MB |
B | 491 | 182MB |
C | 449 | 59MB |
D | 312 | 111MB |
E | 298 | 26MB |
F | 245 | 67MB |
G | 234 | 73MB |
H | 210 | 33MB |
I | 123 | 33MB |
J | 63 | 3MB |
K | 62 | 14MB |
L | 47 | 6MB |
M | 47 | 5MB |
ひとまず、シンプルなナイーブベイズを
集合知プログラミング を参考に実装しました。
#SQL定義 conn.execute(""" CREATE TABLE IF NOT EXISTS class_word_count( class TEXT NOT NULL, word TEXT NOT NULL, count LONG NOT NULL, UNIQUE(class, word) ); """) conn.execute(""" CREATE TABLE IF NOT EXISTS class_count( class TEXT PRIMARY KEY NOT NULL, count LONG NOT NULL ); """) #テスト def getClassWordCount(c, word): #あるクラスにある単語が出現する回数を返す def getClassDocCount(c): #あるクラスに属する文書の数を返す results = [] for targetc in classes: p = 1.0 for d in data: word, count = d prob = float(getClassWordCount(targetc, word)) / getClassDocCount(targetc) p *= ( 重み * 仮確率 + 全てのクラスでのwordの出現回数 * prob ) / ( 全てのクラスでのwordの出現回数 + 重み ) results.append((p, targetc))
y軸は、クラスごとにテストデータを分類し、実際のクラスが平均何位であるか。1に近いほど良い。
x軸は左から順にA、B、C...と続く。
大きいクラスへの偏りが大きく、ほとんど意味を成していません。クラス間の差が増える([:30] -> [:-10])と、偏りが大きくなっているのがわかります。
…ダメです。スパム判定などの用途にはギリギリ使えるかなーという精度でした(非常に素朴な実装なので、こんなコードで実用になりそう、というのはむしろスゴい)。
いろいろとダメな原因はあるのですが、
TWCNBで補正しているものは、
- テスト時、単語の出現頻度(multinomial naive bayes)
- 訓練時、単語の出現頻度(TF transform)
- 訓練時、単語の珍しさ(IDF transform)
- 訓練時、文書の長さ(document length normalization)
- テスト時、クラスごとの訓練データ量(complement naive bayes)
- テスト時、クラスごとの結果の偏り(weight normalization)
です。それぞれ簡単に説明すると、
multinomial naive bayes
上記のコードでは文書中の、単語の出現回数が使われず捨てられています。これを使うと精度が上がるよーというのがMNBです。
TF transform
ただ、上のコードを改変してcountを使用するようにしても、思うように精度は上がりません。
各クラスの特徴をうまく拾う必要があり、そのために TF-IDF を導入します。
TF(term frequency、単語の出現頻度)については、頻出語の影響を低減して、実際の単語の確率分布(べき分布)に近づけるために、出現回数の対数を取って簡単に近似しています。
この実装の基点はここになります。
IDF transform
IDFはinverse document frequency、逆出現頻度です。要するに、ある単語が特定の文書にだけ出現するなら、その単語のスコアが高くなるアルゴリズムです。
IDFも対数を取り、TFに掛けます。
document length normalization
クラスの特徴として長い文書に偏っている場合など、精度に大きな影響があるので、これを補正する必要があります。
具体的には、訓練時に、1文書中の単語のウェイトが合計で1になるように正規化します。
ここまでを訓練の際に行います。以下はテスト時です。
complement naive bayes
「あるクラスに、ある単語が含まれる場合」の補集合は「あるクラス以外に、ある単語が含まれる場合」であることを利用し、MNBからさらに精度を上げます。
これの何がいいかというと、「あるクラス」よりも「あるクラス以外」のほうが圧倒的に大きい(今回のコーパスでは最大約70倍)という所です。
大きい方が誤差が減り、つまり精度が上がります。
weight normalization
ここまでの対策でも、やはりクラス間で偏りがあります。大きいクラスのほうが有利で、これはテスト文書での未知語が、小さいクラスで多いことが原因です。
これに対応する方法は力業で、訓練データがテストデータと近似する(しないとダメなんだけど)と仮定し、全訓練データを1つのテストデータとして各クラスに与え、その結果を使って補正する、というものです。
計算にかなり時間はかかりますが、確かに効果的です。
最尤推定。正規化した単語ウェイトに単語の出現回数を掛け、足していき、クラスウェイトを求めます。
結果
TMNB
sはスムージングパラメータ。
TCNB
同上。
NBより精度が上がっています。TMNB・TCNB共に、結果がスムージングパラメータに大きく影響されています。より良い結果のためには、最適なスムージングパラメータを探す必要があります。そのためのAPNBCというアルゴリズムが論文中で紹介されていますが、今回は検証しませんでした。以下のグラフで明らかであるように、この論文のキモであるウェイトの正規化で、この問題がほぼ解決しているためです。
TWMNB
sはスムージングパラメータ。
TWCNB
同上。
TWMNBはクラスごとにやや精度に差が。TWCNBは均一な結果になりました。MNBよりCNBのほうが、結果にばらつきが無いことがわかります。
Aは他と比べ特徴の薄い、特殊なクラスなので精度が悪くなっています。B〜Mについても、クラス間で語彙が深く重なっている場合があり、(GとJなど)やや影響が見られます。
スムージングパラメータの影響も、TMNB・TCNBほどには受けていません。
なお、論文には
とありますが、今回検証した限りでは
とした方が精度が出ました。僕がどこかで勘違いしている可能性も否定できませんが…。
TWMNB, TWCNBの実装
訓練
conn.execute(""" CREATE TABLE IF NOT EXISTS word_doc_count ( word TEXT PRIMARY KEY NOT NULL, doc_count INTEGER NOT NULL ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS class_word_weight ( class TEXT NOT NULL, word TEXT NOT NULL, weight REAL NOT NULL, UNIQUE(class, word) ) """) conn.execute(""" CREATE TABLE IF NOT EXISTS word_count ( word TEXT NOT NULL, count INTEGER NOT NULL ) """) #IDF計算のためにまずword_doc_countを格納する for c in classes: files = getFiles(c) for f in files: words = getWords(f) data = toDataArray(words) #word_doc_countをインクリメント bulkIncrWordDocCount(data) totalDocCount += 1 #TF-IDFを計算し、格納する for c in classes: files = getFiles(c) for f in files: words = getWords(f) data = toDataArray(words) weights = [] weightTotal = 0 for d in data: word, count = d docCount = getWordDocCount(word) #TF transform tf = getTF(count) #IDF transform idf = getIDF(docCount, totalDocCount) weight = tf * idf weights.append((weight, word)) weightTotal += weight ** 2 weightTotal = math.sqrt(weightTotal) #document length normallization nWeights = [] nWeightTotal = 0 for d in weights: weight, word = d weight /= weightTotal nWeights.append((weight, word)) #word_countを加算 bulkAddWordCount(data) #class_word_weightを加算 bulkAddClassWordWeight(c, nWeights) conn.commit()
TWMNB
#weight normalizationに使うclass weightの計算 cw = {} for c in classes: cw[c] = 0 cur = conn.execute("SELECT DISTINCT word FROM word_doc_count") for row in cur: word = row[0] count = getWordCount(word) weight = getClassWordWeight(c, word) #cw[c] += math.log((weight + s) / (getClassWeight(c) + s_all)) cw[c] += float(count) * math.log((weight + s) / (getClassWeight(c) + s_all)) #各クラスのスコアを求める results = [] for targetc in classes: weight = 0 for d in data: word, count = d w = math.log( (getClassWordWeight(targetc, word) + s ) / (getClassWeight(c) + s_all ) ) w *= count #weight normalization w /= cw[targetc] weight += w results.append((weight, targetc))
TWCNB
#weight normalizationに使うclass weightの計算 cw = {} for c in classes: cw[c] = 0 denominator = 0 for ec in classes: if ec != c: denominator += getClassWeight(ec) cur = conn.execute("SELECT DISTINCT word FROM word_doc_count") for row in cur: word = row[0] count = getWordCount(word) numerator = 0 for ec in classes: if ec != c: numerator += getClassWordWeight(ec, word) #cw[c] += math.log( (numerator + s) / (denominator + s_all) ) cw[c] += float(count) * math.log( (numerator + s) / (denominator + s_all) ) #各クラスのスコアを求める results = [] for targetc in classes: denominator = 0 for ec in classes: if targetc != ec: #complement denominator += getClassWeight(ec) weight = 0 for d in data: word, count = d numerator = 0 for ec in classes: if targetc != ec: #complement numerator += getClassWordWeight(ec, word) w = math.log( (numerator + s) / (denominator + s_all) ) w *= count #weight normalization w /= cw[targetc] weight += w results.append((-weight, targetc))