動く混合ガウス分布(実装編)- 動く PRML シリーズ(1)

こちらもどうぞ - 動く混合ガウス分布(導出編)

実装には python, SciPy と matplotlib を使います。
テストデータには Old Faithful 間欠泉データを使います。

データの読み込み

Old Faithful 間欠泉データを PRML のホームページからダウンロードし、作業ディレクトリに置きます。保存したデータは、SciPy の loadtxt 関数で読み込みます。

from scipy import loadtxt
def faithful():
    return loadtxt('faithful.txt')

適当に正規化します。

from scipy import sqrt
def faithful_norm():
    dat = faithful()
    dat[:, 0] -= dat[:, 0].mean()
    dat[:, 1] -= dat[:, 1].mean()
    dat[:, 0] /= sqrt(dat[:, 0].var())
    dat[:, 1] /= sqrt(dat[:, 1].var())
    return dat

matplotlib の初期化

matplotlib には interaction モードという動作モードがあり、この状態にするとグラフをリアルタイムでプロットすることが出来ます。

import pylab
pylab.ion() # interactive on

次に、レンタリング用のウィンドウを作ります。

from matplotlib import pyplot as pl
figsize = [8, 6] # ウィンドウの大きさ
fig = pl.figure(figsize = figsize)

作成したウィンドウに、グラフ表示用の領域を追加します。

ax = fig.add_subplot(111)

これも関数にまとめておきます。

def init_figure():
    pylab.ion()
    figsize = [8, 6]
    fig = pl.figure(figsize = figsize)
    ax = fig.add_subplot(111)
    return ax

EM アルゴリズムの実装

次に、EM アルゴリズムを実装します。
推定に使うクラス数と繰り返し回数を適当に変更できるようにします。
ついでにデータも読み込みます。

def estimate(num_class = 3, num_iter = 100):
    x = faithful_norm()
    num = len(x) # データ点の数。
    ndim = 2 # データの次元数。Old Faithful は二次元なので 2。

初期化

EM アルゴリズムを実際に動作させるためには、初期値を決めてやる必要があります。
初期値は推定精度にかかわるので、通常は k-means 法などで初期化します、が、良い初期値から始めるとグラフが動かず、面白くないので今回はわざとひどい初期値を与えます。

from scipy import eye, float64, ones, rand, zeros

def estimate(...):
    ...
    pi = ones(num_class) / num_class # 均一割り当て
    mu = (rand(num_class * ndim) * 2 - 1).reshape([num_class, ndim]) # 一様乱数
    var = zeros([num_class, ndim, ndim], dtype = float64)
    var[:] = eye(2) # 単位行列

グラフのウィンドウも作ります。

def estimate(...):
    ...
    ax = init_figure()

M-Step の実装

E-Step はややこしいので、先に M-Step を実装します。
導出編に書いた通り、クラス重み、平均、分散を更新します。

def estimate(...):
    ...
    for iiter in xrange(num_iter):
        response = gmm_response(x, pi, mu, var) # これが E-Step
        Nk = response.sum(0) # クラスの実効観測数を計算する
        pi = Nk / num
        for k in xrange(num_class):
            mu[k] = (x * response[:, k, None]).sum(0) / response[:, k].sum()
            for i in xrange(ndim):
                for j in xrange(ndim):
                    var[k, i, j] = (response[:, k] * (x[:, i] - mu[k, i]) * (x[:, j] - mu[k, j])).sum() / response[:, k].sum()

        preview_stage(ax, x, pi, mu, var) # 毎回グラフを表示する

E-Step の実装

E-Step では正規分布の確率密度を計算する必要がありますが、この値は exp(x**2) に比例してぐいぐい大きくなったり小さくなったりするので、普通に実装するとすぐにオーバーフローします。そのため普通はそういうのを気にする人は(うちとか)、対数領域で計算を済ませ、必要に応じて普通の値に直します。まあ Old Faithful ぐらいなら大丈夫ですケド。

確率密度関数ぐらいは書きましょう。

from scipy import dot, log
from scipy import pi as mpi
from scipy.linalg import det, inv

def logpdf(x, mu, var):
    num = x.shape[0]
    ndim = x.shape[1]
    ln2pi = log(2 * mpi)

    prec = inv(var)
    term = zeros(num, dtype = float64)
    term -= (ndim * ln2pi + log(det(var))) / 2
    for n in xrange(num):
        diff = x[n] - mu
        term[n] -= dot(diff, dot(prec, diff)) / 2
    return term

K 個のクラスでまとめて対数確率密度を計算する関数を書きます。

def gmm_logpdf(x, pi, mu, var):
    num = x.shape[0]
    ncl = pi.shape[0]

    lpdf = zeros([num, ncl], dtype = float64)
    for k in xrange(ncl):
        lpdf[:, k] = log(pi[k]) + logpdf(x, mu[k], var[k])

    return lpdf

M-Step の実装

負担率は対数では困るので、対数領域で正規化した後、普通の値に直します。

from scipy import exp, maximum
from scipy.maxentropy import logsumexp # このパッケージは何なんでしょうか

def gmm_response(x, pi, mu, var):
    num = len(x)
    ncl = pi.shape[0]
    lpdf = gmm_logpdf(x, pi, mu, var)
    response = zeros([num, ncl], dtype = float64)

    for i in xrange(num):
        lr = lpdf[i] - logsumexp(lpdf[i]) # exp(lr) のデータ点ごとの総和が 1 になるようにする
        response[i] = exp(lr)

    response = maximum(response, 1e-10) # ゼロめんどくさいです
    response /= response.sum(1)[:, None] # 無理やりずらしたので総和を正規化し直します
    return response

レンタリング

動く GMM 最大の山場はレンタリングです。なぜって、PRML に載ってるみたいに確率分布の等高線を引くためには、共分散行列の固有値を求めないといけないからです。なかなか面倒です。
とりあえず簡単なところをプロットしてしまいます。

def preview_stage(ax, x, pi, mu, var):
    num_class = pi.shape[0]

    ax.clear() # まずグラフをクリアします
    ax.plot(x[:, 0], x[:, 1], '+') # 観測点をプロットします
    ax.plot(mu[:, 0], mu[:, 1], 'o') # 各クラスの平均をプロットします

各クラスの共分散行列から固有値を計算し、標準偏差に対応する 2 つのベクトルを求めます。これを単位円周上の座標と掛けて、単位円を線形変換し、各クラスの平均の値を足せば完成です。PRML の付録とか見るのがよいと思います。

from scipy import cos, linspace, sin
from scipy.linalg import eigh

def unit_ring(ndiv = 20):
    angles = linspace(0, 2 * mpi, ndiv)
    ring = zeros([ndiv, 2], dtype = float64)
    ring[:, 0] = cos(angles)
    ring[:, 1] = sin(angles)

    return ring

def calc_unit(cov):
    assert cov.ndim == 2
    (vals, vecs) = eigh(cov)
    vals = sqrt(vals)
    
    buf = zeros([2, 2], dtype = float64)
    buf[0] = vals[0] * vecs[0]
    buf[1] = vals[1] * vecs[1]

    return buf

def make_ring(cov, ndiv = 20):
    ring = unit_ring(ndiv)
    units = calc_unit(cov)
    buf = zeros([ndiv, 2], dtype = float64)
    for i in xrange(ndiv):
        buf[i, :] = units[0, :] * ring[i, 0] + units[1, :] * ring[i, 1]

    return buf

標準偏差の大きさで単位円を引きます。クラス負担率もなんとなく表示できるようにします。
プロットが終わったら、x 軸と y 軸の範囲を設定して、draw 関数でグラフを更新します。

    weight = pi * 2 # 見やすい大きさにする

    for k in xrange(num_class):
        rbuf = make_ring(var[k], ndiv = 50)
        ax.plot(rbuf[:, 0] + mu[k, 0], rbuf[:, 1] + mu[k, 1], 'b')
        ax.plot(rbuf[:, 0] * weight[k] + mu[k, 0], rbuf[:, 1] * weight[k] + mu[k, 1], '0.8')
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    ax.figure.canvas.draw()
    ax.figure.canvas.draw()

ソースコード

できあがり!

動かすときは

python -m pdb gmm.py

とかするとよいです。-m pdb (デバッグモード) にしないと推定終了時にウィンドウを閉じちゃいます。

# -*- coding: utf-8 -*-
import pylab
from matplotlib import pyplot as pl
from scipy import cos, dot, exp, eye, float64, linspace, loadtxt, log, maximum
from scipy import ones, rand, sqrt, sin, zeros
from scipy import pi as mpi
from scipy.linalg import det, eigh, inv
from scipy.maxentropy import logsumexp

# ========== ========== ========== ==========
#
#   Old Faithful
#
# ========== ========== ========== ==========
def faithful():
    return loadtxt('faithful.txt')

def faithful_norm():
    dat = faithful()
    dat[:, 0] -= dat[:, 0].mean()
    dat[:, 1] -= dat[:, 1].mean()
    dat[:, 0] /= sqrt(dat[:, 0].var())
    dat[:, 1] /= sqrt(dat[:, 1].var())
    return dat

# ========== ========== ========== ==========
#
#   Figure
#
# ========== ========== ========== ==========
def init_figure():
    pylab.ion()
    figsize = [8, 6]
    fig = pl.figure(figsize = figsize)
    ax = fig.add_subplot(111)
    return ax

# ========== ========== ========== ==========
#
#   GMM
#
# ========== ========== ========== ==========
def logpdf(x, mu, var):
    num = x.shape[0]
    ndim = x.shape[1]
    ln2pi = log(2 * mpi)

    prec = inv(var)
    term = zeros(num, dtype = float64)
    term -= (ndim * ln2pi + log(det(var))) / 2
    for n in xrange(num):
        diff = x[n] - mu
        term[n] -= dot(diff, dot(prec, diff)) / 2
    return term

def gmm_logpdf(x, pi, mu, var):
    num = x.shape[0]
    ncl = pi.shape[0]

    lpdf = zeros([num, ncl], dtype = float64)
    for k in xrange(ncl):
        lpdf[:, k] = log(pi[k]) + logpdf(x, mu[k], var[k])

    return lpdf

def gmm_response(x, pi, mu, var):
    num = len(x)
    ncl = pi.shape[0]
    lpdf = gmm_logpdf(x, pi, mu, var)
    response = zeros([num, ncl], dtype = float64)

    for i in xrange(num):
        lr = lpdf[i] - logsumexp(lpdf[i])
        response[i] = exp(lr)

    response = maximum(response, 1e-10)
    response /= response.sum(1)[:, None]
    return response

# ========== ========== ========== ==========
#
#   Rendering
#
# ========== ========== ========== ==========
def unit_ring(ndiv = 20):
    angles = linspace(0, 2 * mpi, ndiv)
    ring = zeros([ndiv, 2], dtype = float64)
    ring[:, 0] = cos(angles)
    ring[:, 1] = sin(angles)

    return ring

def calc_unit(cov):
    assert cov.ndim == 2
    (vals, vecs) = eigh(cov)
    vals = sqrt(vals)
    
    buf = zeros([2, 2], dtype = float64)
    buf[0] = vals[0] * vecs[0]
    buf[1] = vals[1] * vecs[1]

    return buf

def make_ring(cov, ndiv = 20):
    ring = unit_ring(ndiv)
    units = calc_unit(cov)
    buf = zeros([ndiv, 2], dtype = float64)
    for i in xrange(ndiv):
        buf[i, :] = units[0, :] * ring[i, 0] + units[1, :] * ring[i, 1]

    return buf

def preview_stage(ax, x, pi, mu, var):
    num_class = pi.shape[0]
    ax.clear()
    ax.plot(x[:, 0], x[:, 1], '+')
    ax.plot(mu[:, 0], mu[:, 1], 'o')

    weight = pi * 2
    for k in xrange(num_class):
        rbuf = make_ring(var[k], ndiv = 50)
        ax.plot(rbuf[:, 0] + mu[k, 0], rbuf[:, 1] + mu[k, 1], 'b')
        ax.plot(rbuf[:, 0] * weight[k] + mu[k, 0], rbuf[:, 1] * weight[k] + mu[k, 1], '0.8')
    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    ax.figure.canvas.draw()
    ax.figure.canvas.draw()

# ========== ========== ========== ==========
#
#   Estimation
#
# ========== ========== ========== ==========
def estimate(num_class = 3, num_iter = 100):
    x = faithful_norm()
    ax = init_figure()
    num = len(x)
    ndim = 2

    pi = ones(num_class) / num_class
    mu = (rand(num_class * ndim) * 2 - 1).reshape([num_class, ndim])
    var = zeros([num_class, ndim, ndim], dtype = float64)
    var[:] = eye(2)

    for iiter in xrange(num_iter):
        response = gmm_response(x, pi, mu, var)
        Nk = response.sum(0)
        pi = Nk / num
        for k in xrange(num_class):
            mu[k] = (x * response[:, k, None]).sum(0) / response[:, k].sum()
            for i in xrange(ndim):
                for j in xrange(ndim):
                    var[k, i, j] = (response[:, k] * (x[:, i] - mu[k, i]) * (x[:, j] - mu[k, j])).sum() / response[:, k].sum()

        preview_stage(ax, x, pi, mu, var)

# ========== ========== ========== ==========
#
#   Main
#
# ========== ========== ========== ==========
if (__name__ == '__main__'):
    estimate()