Transformed Weight-normalized Complement Naive Bayes(TWCNB)についての実験と結果

自作ソフトにテキストの多クラス分類機能を組み込みたくて、調べてみたら Complement Naive Bayes(CNB、補集合ナイーブベイズ)というアルゴリズムが最近の流行のようで、これを検証してみることにしました。

元論文 を一通り読んでから検証を進めていきました。実装される際は目を通すことをオススメします。

使用したコーパスは以下のようなもの

想定する用途に合わせて、それなりにクラス間でデータの量にばらつきがあります。

クラス ファイル数 サイズ
A 832 121MB
B 491 182MB
C 449 59MB
D 312 111MB
E 298 26MB
F 245 67MB
G 234 73MB
H 210 33MB
I 123 33MB
J 63 3MB
K 62 14MB
L 47 6MB
M 47 5MB

ひとまず、シンプルなナイーブベイズ

集合知プログラミング を参考に実装しました。

#SQL定義
conn.execute("""
    CREATE TABLE IF NOT EXISTS class_word_count(
    class TEXT NOT NULL,
    word TEXT NOT NULL,
    count LONG NOT NULL,
    UNIQUE(class, word)
    );
    """)
conn.execute("""
    CREATE TABLE IF NOT EXISTS class_count(
    class TEXT PRIMARY KEY NOT NULL,
    count LONG NOT NULL
    );
    """)

#テスト
def getClassWordCount(c, word): #あるクラスにある単語が出現する回数を返す
def getClassDocCount(c): #あるクラスに属する文書の数を返す

results = []
for targetc in classes:
    p = 1.0
    for d in data:
        word, count = d
        prob = float(getClassWordCount(targetc, word)) / getClassDocCount(targetc)
        p *= ( 重み * 仮確率 + 全てのクラスでのwordの出現回数 * prob ) / ( 全てのクラスでのwordの出現回数 + 重み )
    results.append((p, targetc))



y軸は、クラスごとにテストデータを分類し、実際のクラスが平均何位であるか。1に近いほど良い。
x軸は左から順にA、B、C...と続く。


大きいクラスへの偏りが大きく、ほとんど意味を成していません。クラス間の差が増える([:30] -> [:-10])と、偏りが大きくなっているのがわかります。
…ダメです。スパム判定などの用途にはギリギリ使えるかなーという精度でした(非常に素朴な実装なので、こんなコードで実用になりそう、というのはむしろスゴい)。


いろいろとダメな原因はあるのですが、

TWCNBで補正しているものは、

  • テスト時、単語の出現頻度(multinomial naive bayes)
  • 訓練時、単語の出現頻度(TF transform)
  • 訓練時、単語の珍しさ(IDF transform)
  • 訓練時、文書の長さ(document length normalization)
  • テスト時、クラスごとの訓練データ量(complement naive bayes)
  • テスト時、クラスごとの結果の偏り(weight normalization)

です。それぞれ簡単に説明すると、

multinomial naive bayes

上記のコードでは文書中の、単語の出現回数が使われず捨てられています。これを使うと精度が上がるよーというのがMNBです。

TF transform

ただ、上のコードを改変してcountを使用するようにしても、思うように精度は上がりません。
各クラスの特徴をうまく拾う必要があり、そのために TF-IDF を導入します。
TF(term frequency、単語の出現頻度)については、頻出語の影響を低減して、実際の単語の確率分布(べき分布)に近づけるために、出現回数の対数を取って簡単に近似しています。
この実装の基点はここになります。


IDF transform

IDFはinverse document frequency、逆出現頻度です。要するに、ある単語が特定の文書にだけ出現するなら、その単語のスコアが高くなるアルゴリズムです。
IDFも対数を取り、TFに掛けます。



document length normalization

クラスの特徴として長い文書に偏っている場合など、精度に大きな影響があるので、これを補正する必要があります。
具体的には、訓練時に、1文書中の単語のウェイトが合計で1になるように正規化します。



ここまでを訓練の際に行います。以下はテスト時です。

complement naive bayes

「あるクラスに、ある単語が含まれる場合」の補集合は「あるクラス以外に、ある単語が含まれる場合」であることを利用し、MNBからさらに精度を上げます。
これの何がいいかというと、「あるクラス」よりも「あるクラス以外」のほうが圧倒的に大きい(今回のコーパスでは最大約70倍)という所です。
大きい方が誤差が減り、つまり精度が上がります。



weight normalization

ここまでの対策でも、やはりクラス間で偏りがあります。大きいクラスのほうが有利で、これはテスト文書での未知語が、小さいクラスで多いことが原因です。
これに対応する方法は力業で、訓練データがテストデータと近似する(しないとダメなんだけど)と仮定し、全訓練データを1つのテストデータとして各クラスに与え、その結果を使って補正する、というものです。
計算にかなり時間はかかりますが、確かに効果的です。



最尤推定。正規化した単語ウェイトに単語の出現回数を掛け、足していき、クラスウェイトを求めます。


結果

TMNB


sはスムージングパラメータ。

TCNB


同上。


NBより精度が上がっています。TMNB・TCNB共に、結果がスムージングパラメータに大きく影響されています。より良い結果のためには、最適なスムージングパラメータを探す必要があります。そのためのAPNBCというアルゴリズムが論文中で紹介されていますが、今回は検証しませんでした。以下のグラフで明らかであるように、この論文のキモであるウェイトの正規化で、この問題がほぼ解決しているためです。

TWMNB


sはスムージングパラメータ。

TWCNB


同上。


TWMNBはクラスごとにやや精度に差が。TWCNBは均一な結果になりました。MNBよりCNBのほうが、結果にばらつきが無いことがわかります。
Aは他と比べ特徴の薄い、特殊なクラスなので精度が悪くなっています。B〜Mについても、クラス間で語彙が深く重なっている場合があり、(GとJなど)やや影響が見られます。
スムージングパラメータの影響も、TMNB・TCNBほどには受けていません。


なお、論文には



とありますが、今回検証した限りでは



とした方が精度が出ました。僕がどこかで勘違いしている可能性も否定できませんが…。

TWMNB, TWCNBの実装

訓練
conn.execute("""
    CREATE TABLE IF NOT EXISTS word_doc_count (
        word TEXT PRIMARY KEY NOT NULL,
        doc_count INTEGER NOT NULL
    )
    """)
conn.execute("""
    CREATE TABLE IF NOT EXISTS class_word_weight (
        class TEXT NOT NULL,
        word TEXT NOT NULL,
        weight REAL NOT NULL,
        UNIQUE(class, word)
    )
    """)
conn.execute("""
    CREATE TABLE IF NOT EXISTS word_count (
        word TEXT NOT NULL,
        count INTEGER NOT NULL
    )
    """)

#IDF計算のためにまずword_doc_countを格納する
for c in classes:
    files = getFiles(c)
    for f in files:
        words = getWords(f)
        data = toDataArray(words)
        #word_doc_countをインクリメント
        bulkIncrWordDocCount(data)
        totalDocCount += 1

#TF-IDFを計算し、格納する
for c in classes:
    files = getFiles(c)
    for f in files:
        words = getWords(f)
        data = toDataArray(words)
        weights = []
        weightTotal = 0
        for d in data:
            word, count = d
            docCount = getWordDocCount(word)
            #TF transform
            tf = getTF(count)
            #IDF transform
            idf = getIDF(docCount, totalDocCount)
            weight = tf * idf
            weights.append((weight, word))
            weightTotal += weight ** 2
        weightTotal = math.sqrt(weightTotal)
        #document length normallization
        nWeights = []
        nWeightTotal = 0
        for d in weights:
            weight, word = d
            weight /= weightTotal
            nWeights.append((weight, word))
        #word_countを加算
        bulkAddWordCount(data)
        #class_word_weightを加算
        bulkAddClassWordWeight(c, nWeights)
conn.commit()
TWMNB
#weight normalizationに使うclass weightの計算
cw = {}
for c in classes:
    cw[c] = 0
    cur = conn.execute("SELECT DISTINCT word FROM word_doc_count")
    for row in cur:
        word = row[0]
        count = getWordCount(word)
        weight = getClassWordWeight(c, word)
        #cw[c] += math.log((weight + s) / (getClassWeight(c) + s_all))
        cw[c] += float(count) * math.log((weight + s) / (getClassWeight(c) + s_all))

#各クラスのスコアを求める
results = []
for targetc in classes:
    weight = 0
    for d in data:
        word, count = d
        w = math.log( (getClassWordWeight(targetc, word) + s ) / (getClassWeight(c) + s_all ) )
        w *= count
        #weight normalization
        w /= cw[targetc]
        weight += w
    results.append((weight, targetc))
TWCNB
#weight normalizationに使うclass weightの計算
cw = {}
for c in classes:
    cw[c] = 0
    denominator = 0
    for ec in classes:
        if ec != c:
            denominator += getClassWeight(ec)
    cur = conn.execute("SELECT DISTINCT word FROM word_doc_count")
    for row in cur:
        word = row[0]
        count = getWordCount(word)
        numerator = 0
        for ec in classes:
            if ec != c:
                numerator += getClassWordWeight(ec, word)
        #cw[c] += math.log( (numerator + s) / (denominator + s_all) )
        cw[c] += float(count) * math.log( (numerator + s) / (denominator + s_all) )

#各クラスのスコアを求める
results = []
for targetc in classes:
    denominator = 0
    for ec in classes:
        if targetc != ec: #complement
            denominator += getClassWeight(ec)
    weight = 0
    for d in data:
        word, count = d
        numerator = 0
        for ec in classes:
            if targetc != ec: #complement
                numerator += getClassWordWeight(ec, word)
        w = math.log( (numerator + s) / (denominator + s_all) )
        w *= count
        #weight normalization
        w /= cw[targetc]
        weight += w
    results.append((-weight, targetc))