動く変分混合ガウス分布(導出編)- 動く PRML シリーズ(2)

やりたいこと

動く PRML シリーズ、第2回は変分混合ガウス分布 (variational Bayesian Gaussian mixture model, VB-GMM) です。
はじめに、前回の繰り返しになりますが、反復繰り返し型の機械学習アルゴリズムを理解するためには、大きく分けて二つのステップがあることを再確認します。一つ目はもちろん、

  • 更新式を導出すること。

反復アルゴリズムの理論的性質はすべて更新式の形に反映されています。従って、この更新式を自力で導出することはとても勉強になります。

そして、二つ目は、

  • イテレーションの内容をグラフにプロットし、実際の挙動を体感すること。

更新式には確かに必要なこと全てが記されていますし、熟練の研究者であれば、更新式の形から、その実際の挙動をある程度予測することが可能です。


しかし、それは楽譜だけを見てオーケストラを聴き取るようなものではないでしょうか。そのような力を手に入れる最良の訓練は、もちろん、楽譜と実際の音楽を何度も繰り返し聴き比べることです。

そこで本記事では、変分混合ガウス分布 (VB-GMM) の導出から始め、繰り返しごとに結果をプロットするプログラムを python で実装し、初期値応答、局所解頑健性、収束性能などを体感することを目標とします。

生成モデル

VB-GMM では、GMM と同じく、K 個のクラスから N 個のデータ点が生成されていると考えます。つまり、この分布からデータを生成したければ、K 個の面があるサイコロを N 回振って、k の目が出た回数だけ k 番目の正規分布からデータ点をサンプルすればよいのです。簡単ですね。
尤度関数は GMM と同じく、

  
  
となります。
さらに、変分ベイズ法はベイズ推定アルゴリズムなので、パラメータの事前分布 を導入する必要があります。

ベイズ推定の位置付け

変分ベイズ法は EM アルゴリズムに基づく繰り返し最適化法のひとつですが、EM アルゴリズムが適用できる問題には大きく分けて最尤推定、MAP 推定、ベイズ推定の三種類があり、それぞれ推定の目標が異なります。
観測変数を X、モデルパラメータを とすると、最尤推定

  

を、MAP 推定は

  

を、ベイズ推定は

  

を推定することに相当しています。MAP 推定やベイズ推定では、生成分布 の他に事前分布 が必要なことが、ベイズの定理から分かります。

事前分布

先に述べた通り、変分ベイズ法では事前分布が必要なため、これを導入します。事前分布として、閉形式で更新式が導出できることが保証されている共役事前分布

  
  

を使用します。

変分事後分布

さてベイズ推定では、観測変数で条件付けられた潜在変数の事後分布 を推定するわけですが、様々な理由から、これをこのまま閉形式で導出することは出来ません。
そこで、 をいくつかの変分事後分布 の積で近似し、変分事後分布の積 と真の事後分布 の KL ダイバージェンスが最小となるように各分布を更新していきます。変分法の一般的な導出は PRML に譲りますが、大事なことは、

  • 共役事前分布を導入すれば、変分事後分布は事前分布と同じ形の指数型分布を使って書ける。例えば、正規分布なら正規分布、ディリクレ分布ならディリクレ分布というように。
  • さらに、 の変分事後分布は、全ての確率変数の同時分布 の対数尤度の に関する期待値を用いて計算できる。

という二点です。変分混合ガウス分布では、様々な理由*1から、 の形に分解します。

変分 E ステップ (VB-E Step)

変分推論の更新式はひたすら機械的に導出することができます。
まず、Z は多項分布にしたがうので、Z の変分事後分布も多項分布の形で書くことができます。ここでは、多項分布のパラメータを とおきます。

  

Z の変分事後分布の対数形は、Z 以外の潜在変数に関する完全同時分布の期待値として書けるので、

  
  
  

となります。従って
  
となります。

ちなみに具体的な値はというと、例えば
  
となります。ここで、 は事後ハイパーパラメータ(後述)であり、 はディガンマ関数です。 が常に 1 より小さいことを考えると、この期待値は通常 -3 とか -5 とかいう値になります。このことを覚えておくとデバッグの際に重宝するでしょう。


については書く気すら起こりません…。というか、覚えていません。PRML を見て関数として実装して、その中身については忘れてしまうのが吉でしょう。

変分 M ステップ (VB-M Step)

自然な分解

E-Step では明示しませんでしたが、変分事後分布を導出する際には、更新しようとしている変数のみを残し、不要な項をどんどん定数項に押し込んでいきます。次に述べる変分 M ステップでは の変分事後分布を計算するわけですが、これをちゃんと計算してやると、なんと

  

の形で書けることが分かります。そこで、文献によっては始めからこの形の分解を与えているものもあります。この導出はかなり文章を食うので省略します。

クラス混合比の更新

モクモクと計算していきます。変分事後分布を事後ハイパーパラメータ (posterior hyperparameter) を用いて と書き、混合比の対数事前確率

  
を用いると、

  
  
  
  
となります。ここで、 です。

さらに、 と比較することで、クラス混合比に対する更新式

  

を得ます。some cheat なんて言われることもありますが、変分事後分布の計算は基本的に係数合わせだけで出来るので、微分してゼロを解く必要はありません。らくちんですね。

クラス平均と精度の更新

簡単なんですが、手間はかかります。手で導出する時は正規分布のハイパーパラメータを先に計算して、後でウィシャート分布のパラメータに集中するのが吉だと思います。導出は、例によって一次元で行います。


一次元版の正規ウィシャート分布は、変数変換により正規ガンマ分布で書くことができるため、事前分布と事後分布を
  
  
と書きます。事前分布の対数確率は

  

と書けるので、
  
  
  
  
となります。


係数を比べることにより、
  
  
  
  
を得ます。


更に、ガンマ分布とウィシャート分布は
  
と変数変換できるため、
  
  
を得ます。多次元ウィシャート分布の場合は、
  
となることが知られています。(PRML を参照のこと。)

初期化

通常は負担率の初期値を k-means 法によって与えますが、一様乱数から始めたほうが推論特性が見やすいため、負担率を一様乱数からサンプルして推論を始めます。

実装

長くなりましたが、これでようやく実装に必要な式を手に入れることが出来ました。それでは実装編へどうぞ。(鋭意執筆中…)

*1:一般的な分解法については Winn, Bishop の Variational Message Passing を参照のこと。基本的には、グラフィカルモデル上で潜在変数同士がリンクを持たないように分解しなければならない。

動く混合ガウス分布(実装編)- 動く PRML シリーズ(1)

こちらもどうぞ - 動く混合ガウス分布(導出編)

実装には python, SciPy と matplotlib を使います。
テストデータには Old Faithful 間欠泉データを使います。

データの読み込み

Old Faithful 間欠泉データを PRML のホームページからダウンロードし、作業ディレクトリに置きます。保存したデータは、SciPy の loadtxt 関数で読み込みます。

from scipy import loadtxt
def faithful():
    return loadtxt('faithful.txt')

適当に正規化します。

from scipy import sqrt
def faithful_norm():
    dat = faithful()
    dat[:, 0] -= dat[:, 0].mean()
    dat[:, 1] -= dat[:, 1].mean()
    dat[:, 0] /= sqrt(dat[:, 0].var())
    dat[:, 1] /= sqrt(dat[:, 1].var())
    return dat

matplotlib の初期化

matplotlib には interaction モードという動作モードがあり、この状態にするとグラフをリアルタイムでプロットすることが出来ます。

import pylab
pylab.ion() # interactive on

次に、レンタリング用のウィンドウを作ります。

from matplotlib import pyplot as pl
figsize = [8, 6] # ウィンドウの大きさ
fig = pl.figure(figsize = figsize)

作成したウィンドウに、グラフ表示用の領域を追加します。

ax = fig.add_subplot(111)

これも関数にまとめておきます。

def init_figure():
    pylab.ion()
    figsize = [8, 6]
    fig = pl.figure(figsize = figsize)
    ax = fig.add_subplot(111)
    return ax

EM アルゴリズムの実装

次に、EM アルゴリズムを実装します。
推定に使うクラス数と繰り返し回数を適当に変更できるようにします。
ついでにデータも読み込みます。

def estimate(num_class = 3, num_iter = 100):
    x = faithful_norm()
    num = len(x) # データ点の数。
    ndim = 2 # データの次元数。Old Faithful は二次元なので 2。

初期化

EM アルゴリズムを実際に動作させるためには、初期値を決めてやる必要があります。
初期値は推定精度にかかわるので、通常は k-means 法などで初期化します、が、良い初期値から始めるとグラフが動かず、面白くないので今回はわざとひどい初期値を与えます。

from scipy import eye, float64, ones, rand, zeros

def estimate(...):
    ...
    pi = ones(num_class) / num_class # 均一割り当て
    mu = (rand(num_class * ndim) * 2 - 1).reshape([num_class, ndim]) # 一様乱数
    var = zeros([num_class, ndim, ndim], dtype = float64)
    var[:] = eye(2) # 単位行列

グラフのウィンドウも作ります。

def estimate(...):
    ...
    ax = init_figure()

M-Step の実装

E-Step はややこしいので、先に M-Step を実装します。
導出編に書いた通り、クラス重み、平均、分散を更新します。

def estimate(...):
    ...
    for iiter in xrange(num_iter):
        response = gmm_response(x, pi, mu, var) # これが E-Step
        Nk = response.sum(0) # クラスの実効観測数を計算する
        pi = Nk / num
        for k in xrange(num_class):
            mu[k] = (x * response[:, k, None]).sum(0) / response[:, k].sum()
            for i in xrange(ndim):
                for j in xrange(ndim):
                    var[k, i, j] = (response[:, k] * (x[:, i] - mu[k, i]) * (x[:, j] - mu[k, j])).sum() / response[:, k].sum()

        preview_stage(ax, x, pi, mu, var) # 毎回グラフを表示する

E-Step の実装

E-Step では正規分布の確率密度を計算する必要がありますが、この値は exp(x**2) に比例してぐいぐい大きくなったり小さくなったりするので、普通に実装するとすぐにオーバーフローします。そのため普通はそういうのを気にする人は(うちとか)、対数領域で計算を済ませ、必要に応じて普通の値に直します。まあ Old Faithful ぐらいなら大丈夫ですケド。

確率密度関数ぐらいは書きましょう。

from scipy import dot, log
from scipy import pi as mpi
from scipy.linalg import det, inv

def logpdf(x, mu, var):
    num = x.shape[0]
    ndim = x.shape[1]
    ln2pi = log(2 * mpi)

    prec = inv(var)
    term = zeros(num, dtype = float64)
    term -= (ndim * ln2pi + log(det(var))) / 2
    for n in xrange(num):
        diff = x[n] - mu
        term[n] -= dot(diff, dot(prec, diff)) / 2
    return term

K 個のクラスでまとめて対数確率密度を計算する関数を書きます。

def gmm_logpdf(x, pi, mu, var):
    num = x.shape[0]
    ncl = pi.shape[0]

    lpdf = zeros([num, ncl], dtype = float64)
    for k in xrange(ncl):
        lpdf[:, k] = log(pi[k]) + logpdf(x, mu[k], var[k])

    return lpdf

M-Step の実装

負担率は対数では困るので、対数領域で正規化した後、普通の値に直します。

from scipy import exp, maximum
from scipy.maxentropy import logsumexp # このパッケージは何なんでしょうか

def gmm_response(x, pi, mu, var):
    num = len(x)
    ncl = pi.shape[0]
    lpdf = gmm_logpdf(x, pi, mu, var)
    response = zeros([num, ncl], dtype = float64)

    for i in xrange(num):
        lr = lpdf[i] - logsumexp(lpdf[i]) # exp(lr) のデータ点ごとの総和が 1 になるようにする
        response[i] = exp(lr)

    response = maximum(response, 1e-10) # ゼロめんどくさいです
    response /= response.sum(1)[:, None] # 無理やりずらしたので総和を正規化し直します
    return response

レンタリング

動く GMM 最大の山場はレンタリングです。なぜって、PRML に載ってるみたいに確率分布の等高線を引くためには、共分散行列の固有値を求めないといけないからです。なかなか面倒です。
とりあえず簡単なところをプロットしてしまいます。

def preview_stage(ax, x, pi, mu, var):
    num_class = pi.shape[0]

    ax.clear() # まずグラフをクリアします
    ax.plot(x[:, 0], x[:, 1], '+') # 観測点をプロットします
    ax.plot(mu[:, 0], mu[:, 1], 'o') # 各クラスの平均をプロットします

各クラスの共分散行列から固有値を計算し、標準偏差に対応する 2 つのベクトルを求めます。これを単位円周上の座標と掛けて、単位円を線形変換し、各クラスの平均の値を足せば完成です。PRML の付録とか見るのがよいと思います。

from scipy import cos, linspace, sin
from scipy.linalg import eigh

def unit_ring(ndiv = 20):
    angles = linspace(0, 2 * mpi, ndiv)
    ring = zeros([ndiv, 2], dtype = float64)
    ring[:, 0] = cos(angles)
    ring[:, 1] = sin(angles)

    return ring

def calc_unit(cov):
    assert cov.ndim == 2
    (vals, vecs) = eigh(cov)
    vals = sqrt(vals)
    
    buf = zeros([2, 2], dtype = float64)
    buf[0] = vals[0] * vecs[0]
    buf[1] = vals[1] * vecs[1]

    return buf

def make_ring(cov, ndiv = 20):
    ring = unit_ring(ndiv)
    units = calc_unit(cov)
    buf = zeros([ndiv, 2], dtype = float64)
    for i in xrange(ndiv):
        buf[i, :] = units[0, :] * ring[i, 0] + units[1, :] * ring[i, 1]

    return buf

標準偏差の大きさで単位円を引きます。クラス負担率もなんとなく表示できるようにします。
プロットが終わったら、x 軸と y 軸の範囲を設定して、draw 関数でグラフを更新します。

    weight = pi * 2 # 見やすい大きさにする

    for k in xrange(num_class):
        rbuf = make_ring(var[k], ndiv = 50)
        ax.plot(rbuf[:, 0] + mu[k, 0], rbuf[:, 1] + mu[k, 1], 'b')
        ax.plot(rbuf[:, 0] * weight[k] + mu[k, 0], rbuf[:, 1] * weight[k] + mu[k, 1], '0.8')
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    ax.figure.canvas.draw()
    ax.figure.canvas.draw()

ソースコード

できあがり!

動かすときは

python -m pdb gmm.py

とかするとよいです。-m pdb (デバッグモード) にしないと推定終了時にウィンドウを閉じちゃいます。

# -*- coding: utf-8 -*-
import pylab
from matplotlib import pyplot as pl
from scipy import cos, dot, exp, eye, float64, linspace, loadtxt, log, maximum
from scipy import ones, rand, sqrt, sin, zeros
from scipy import pi as mpi
from scipy.linalg import det, eigh, inv
from scipy.maxentropy import logsumexp

# ========== ========== ========== ==========
#
#   Old Faithful
#
# ========== ========== ========== ==========
def faithful():
    return loadtxt('faithful.txt')

def faithful_norm():
    dat = faithful()
    dat[:, 0] -= dat[:, 0].mean()
    dat[:, 1] -= dat[:, 1].mean()
    dat[:, 0] /= sqrt(dat[:, 0].var())
    dat[:, 1] /= sqrt(dat[:, 1].var())
    return dat

# ========== ========== ========== ==========
#
#   Figure
#
# ========== ========== ========== ==========
def init_figure():
    pylab.ion()
    figsize = [8, 6]
    fig = pl.figure(figsize = figsize)
    ax = fig.add_subplot(111)
    return ax

# ========== ========== ========== ==========
#
#   GMM
#
# ========== ========== ========== ==========
def logpdf(x, mu, var):
    num = x.shape[0]
    ndim = x.shape[1]
    ln2pi = log(2 * mpi)

    prec = inv(var)
    term = zeros(num, dtype = float64)
    term -= (ndim * ln2pi + log(det(var))) / 2
    for n in xrange(num):
        diff = x[n] - mu
        term[n] -= dot(diff, dot(prec, diff)) / 2
    return term

def gmm_logpdf(x, pi, mu, var):
    num = x.shape[0]
    ncl = pi.shape[0]

    lpdf = zeros([num, ncl], dtype = float64)
    for k in xrange(ncl):
        lpdf[:, k] = log(pi[k]) + logpdf(x, mu[k], var[k])

    return lpdf

def gmm_response(x, pi, mu, var):
    num = len(x)
    ncl = pi.shape[0]
    lpdf = gmm_logpdf(x, pi, mu, var)
    response = zeros([num, ncl], dtype = float64)

    for i in xrange(num):
        lr = lpdf[i] - logsumexp(lpdf[i])
        response[i] = exp(lr)

    response = maximum(response, 1e-10)
    response /= response.sum(1)[:, None]
    return response

# ========== ========== ========== ==========
#
#   Rendering
#
# ========== ========== ========== ==========
def unit_ring(ndiv = 20):
    angles = linspace(0, 2 * mpi, ndiv)
    ring = zeros([ndiv, 2], dtype = float64)
    ring[:, 0] = cos(angles)
    ring[:, 1] = sin(angles)

    return ring

def calc_unit(cov):
    assert cov.ndim == 2
    (vals, vecs) = eigh(cov)
    vals = sqrt(vals)
    
    buf = zeros([2, 2], dtype = float64)
    buf[0] = vals[0] * vecs[0]
    buf[1] = vals[1] * vecs[1]

    return buf

def make_ring(cov, ndiv = 20):
    ring = unit_ring(ndiv)
    units = calc_unit(cov)
    buf = zeros([ndiv, 2], dtype = float64)
    for i in xrange(ndiv):
        buf[i, :] = units[0, :] * ring[i, 0] + units[1, :] * ring[i, 1]

    return buf

def preview_stage(ax, x, pi, mu, var):
    num_class = pi.shape[0]
    ax.clear()
    ax.plot(x[:, 0], x[:, 1], '+')
    ax.plot(mu[:, 0], mu[:, 1], 'o')

    weight = pi * 2
    for k in xrange(num_class):
        rbuf = make_ring(var[k], ndiv = 50)
        ax.plot(rbuf[:, 0] + mu[k, 0], rbuf[:, 1] + mu[k, 1], 'b')
        ax.plot(rbuf[:, 0] * weight[k] + mu[k, 0], rbuf[:, 1] * weight[k] + mu[k, 1], '0.8')
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    ax.figure.canvas.draw()
    ax.figure.canvas.draw()

# ========== ========== ========== ==========
#
#   Estimation
#
# ========== ========== ========== ==========
def estimate(num_class = 3, num_iter = 100):
    x = faithful_norm()
    ax = init_figure()
    num = len(x)
    ndim = 2

    pi = ones(num_class) / num_class
    mu = (rand(num_class * ndim) * 2 - 1).reshape([num_class, ndim])
    var = zeros([num_class, ndim, ndim], dtype = float64)
    var[:] = eye(2)

    for iiter in xrange(num_iter):
        response = gmm_response(x, pi, mu, var)
        Nk = response.sum(0)
        pi = Nk / num
        for k in xrange(num_class):
            mu[k] = (x * response[:, k, None]).sum(0) / response[:, k].sum()
            for i in xrange(ndim):
                for j in xrange(ndim):
                    var[k, i, j] = (response[:, k] * (x[:, i] - mu[k, i]) * (x[:, j] - mu[k, j])).sum() / response[:, k].sum()

        preview_stage(ax, x, pi, mu, var)

# ========== ========== ========== ==========
#
#   Main
#
# ========== ========== ========== ==========
if (__name__ == '__main__'):
    estimate()

動く混合ガウス分布(導出編)- 動く PRML シリーズ(1)

はじめに

混合ガウス分布 (Gaussian Mixture Model, GMM) は、多次元の特徴量を持つデータ点の集合を機械学習により分類するための重要な手法です。特に、GMM は応用範囲が広く、様々な手法の基礎となっているため、自ら更新式を導出するなどして特性をよく理解することが重要です。


本記事ではこれに加え、GMM の長所と短所を視覚的に確認する方法を提案し、実装します。GMM の更新式は「どう動くか」の説明にはなりますが、「長所と短所は何か」を直接教えてはくれません。たとえば、

  • どんな初期値から、どんな収束値が得られるのか。(初期値に対する特性)
  • 変な初期値を与えるとどうなるのか。(局所解頑健性)
  • どれぐらいの速さで収束するのか。(収束性能)

といった疑問にこたえるためには、EM アルゴリズムの各イテレーションでの中間状態をグラフにプロットし、アニメーションとして見る必要があります。


本記事では、まず簡単に更新式を導出し、次に動く GMM の実装方法について詳しく説明します。実装編の最後に添付したソースコードを実際に動かしながら、本記事や PRML を読むのがオススメです。

生成モデル

N 個のデータ点で構成された観測データを考えます。個々のデータ点はそれぞれ、K 個あるクラスター(クラス)のいずれかに属していると仮定します。n 番目のデータ点が k 番目のクラスから生成されたことを、 であらわします。各データ点は多項分布で各クラスに割り当てられ、各クラス固有のガウス分布により観測値が生成されます。この観測モデルの尤度は、

  
  

と書けます。EM アルゴリズムの目標は、観測データにとって最も適切なモデルパラメータの組み合わせを、E-Step と M-Step という二種類の計算の繰り返しによって再帰的に推定することです。

E-Step

E-Step では、潜在クラス割り当ての期待値を、現在のパラメータの値に基づいて決定します。

  

M-Step

M-Step では、観測データ X と得られた期待値 E[Z] の同時分布を最大化するようにモデルパラメータを決定します。
以下では、 と書きます。

クラス平均と分散(精度)の更新

完全データの対数尤度は

  

となり、これを 微分した結果を 0 と置くことで、 に対する次の更新式

  
  

を得ます。同様に、分散(精度)で偏微分することで、分散に対する次の更新式
  
を得ます。

多次元の場合は省略しますが、PRML の結果を引用すると

  

となります。

クラス混合比の更新

クラス混合比には という制約があるので、ラグランジュの未定乗数法を使って更新します。

目的関数を
  
  
とすると、ラグランジュの未定乗数法より、制約付き極値問題は 連立方程式と同値になるので、
  
  
従って、
  
を得ます。

動く混合ガウス分布(実装編)

インラインアセンブラで sprintf を呼び出す

ここで使っているのは sprintf ではなく、sprintf_s ですけど。
sprintf_s(char*, size_t, const char*, ...) をインラインアセンブラで直接呼び出そうという話。
lea 命令を使えば文字列バッファのアドレスはいくらでも読み出せるのでらくちん。


void main()
{
const size_t bufferSize = 64;
char format[] = "%d";
char s[bufferSize];

__asm
{
{
push 270; /* param 4 */
lea eax, format;
push eax; /* param 3 */
push bufferSize; /* param 2 */
lea eax, s;
push eax; /* param 1 */
call dword ptr [sprintf_s];
add esp, 16;
}
{
lea eax, s;
push eax; /* param 1 */
call dword ptr [puts];
add esp, 4;
}
}
}

インラインアセンブラで printf を呼び出す(その2)

ローカル変数のアドレスを動的に解決して printf を呼び出す方法。
lea (Load Effective Address) 命令で実効アドレスを読み込めば良いらしい。掛け算の最適化命令だと思ってたけど、基本的な使い方はこっちなのかな。


void main()
{
char hello[] = "Hello\r\n";

__asm
{
lea eax, hello;
push eax;
call dword ptr [printf];
add esp, 4;
}
}

インラインアセンブラで printf を呼び出す

とても簡単なことなんだけど誰も書いてないのでメモ。


char hello[] = "Hello\r\n";

void main()
{
__asm
{
mov eax, offset hello;
push eax;
call dword ptr [printf];
pop ebx;
}
}

DLL 内の関数を呼ぶときには関数名を dword ptr [] で括ればいいらしい。
MessageBoxA もこれで呼べる。ただし MessageBox の場合は呼び出し規約が __stdcall なので pop ebx; は必要ない。