動く変分混合ガウス分布(実装編)- 動く PRML シリーズ(2)

こちらもどうぞ - 動く変分混合ガウス分布(導出編)

実装には python, SciPy と matplotlib を使います。
テストデータには Old Faithful 間欠泉データを使います。
また、データの読み込み、プロットは混合ガウス分布の際に実装したものを再利用しますので、こちらからダウンロードしておいてください。

必要な関数の読み込み

はじめに、必要な関数を読み込みます。

from gmm import faithful_norm, init_figure, preview_stage
from scipy import arange, array, exp, eye, float64, log, maximum, ones, outer, pi, rand, zeros
from scipy.linalg import det, inv
from scipy.maxentropy import logsumexp
from scipy.special import digamma

主な関数は以下の通り。

  • faithful_norm … データの読み込み
  • init_figure … グラフウィンドウの作成
  • preview_stage … 混合ガウス分布のレンタリング
  • arange … 非負数列 [0, 1, 2, …] の作成
  • eye … 単位行列の作成
  • outer … テンソル の計算
  • logsumexp … の計算
  • digamma … ディガンマ関数

関数定義、データ読み込み

前回と同じく、estimate 関数を実行するとデモが見られるようにしておきます。

def estimate(num_class = 6, num_iter = 100, alpha0 = None, beta0 = None,
             m0 = None, W0 = None, nu0 = None):
    # Data
    x = faithful_norm()
    num, ndim = x.shape

    # Graph
    ax = init_figure()

今回は変分ベイズ法を使用するため、クラス数は多めに指定しておくと自動的に最適化されます。今回は 6 クラスを指定しています。繰り返し回数は 100 回としました。
クラス負担率、平均、精度のハイパーパラメータ は、関数呼び出しの際に自由に値を設定できるようにしておきます。明示されていない時 (None の時) は後でデフォルト値を代入します。

ハイパーパラメータの初期化

None の場合のみデフォルト値に差し替える関数を作ります。

def fill_param(param, default):
    if (param == None):
        return default
    else:
        return param

次に、ハイパーパラメータの値を設定します。デフォルト値は無情報事前分布 (non-informative prior) と呼ばれる値で、これは推論に恣意的な事前情報を持ち込まないことを意味しています。(alpha0 = 1000 とかで動かしてみるとよく分かります。)
ただし、精度に無情報事前分布を使うと不安定になりがちなので、精度 1,有効観測数 1 の事前情報を与えます。

# Hyperparameters
alpha0 = fill_param(alpha0, ones(num_class) * 1e-3)
beta0 = fill_param(beta0, 1e-3)
m0 = fill_param(m0, zeros(ndim))
W0 = fill_param(W0, eye(2))
nu0 = fill_param(nu0, 1)
inv_W0 = inv(W0)

負担率の初期化

今回は変分 E ステップで初期化を行うので、負担率を一様乱数からサンプリングします。

# Initial VB-E Step
r_nk = rand(num, num_class)
r_nk /= r_nk.sum(1)[:, None]

変分 M ステップ

負担率が手元にあるので、この値を使ってハイパーパラメータを更新していきます。

クラス混合比の更新

  

alpha = alpha0 + Nk
平均、精度の更新

変分ベイズ法の更新式は、基本的に最尤推定の結果と事前分布の重み付き平均になるので、はじめに最尤推定の場合の結果を計算します。
  

def calc_xbar(x, r_nk):
    num, ndim = x.shape
    num, num_class = r_nk.shape
    ret = zeros([num_class, ndim], dtype = float64)

    for k in xrange(num_class):
        clres = r_nk[:, k]
        for i in xrange(ndim):
            ret[k, i] = (clres * x[:, i]).sum()
        ret[k, :] /= clres.sum()

    return ret

  
python の特性上、N でループを回すより K で回したほうが速いので、次の式を使って計算します。
  

def calc_S(x, xbar, r_nk):
    num, ndim = x.shape
    num, num_class = r_nk.shape
    ret = zeros([num_class, ndim, ndim], dtype = float64)

    for k in xrange(num_class):
        clres = r_nk[:, k]
        for i in xrange(ndim):
            diff_i = x[:, i] - xbar[k, i]
            for j in xrange(ndim):
                diff_j = x[:, j] - xbar[k, j]
                ret[k, i, j] = (clres * diff_i * diff_j).sum()
        ret[k] /= clres.sum()
  • 平均の更新

  

def calc_m(xbar, Nk, m0, beta0, beta):
    num_class, ndim = xbar.shape
    ret = zeros([num_class, ndim], dtype = float64)

    for k in xrange(num_class):
        ret[k] = (beta0 * m0 + Nk[k] * xbar[k]) / beta[k]

    return ret
  • 精度の更新

  

def calc_W(xbar, Sk, Nk, m0, beta0, inv_W0):
    num_class, ndim = xbar.shape
    ret = zeros([num_class, ndim, ndim], dtype = float64)

    for k in xrange(num_class):
        ret[k] = inv_W0 + Nk[k] * Sk[k]
        fact = beta0 * Nk[k] / (beta0 + Nk[k])
        diff = xbar[k] - m0
        for i in xrange(ndim):
            for j in xrange(ndim):
                term = diff[i] * diff[j]
                ret[k, i, j] += fact * term
        ret[k] = inv(ret[k])

    return ret

最後に、ここまでの関数をつなぎ合わせて変分 M ステップを組みます。

for iiter in xrange(num_iter):
    # VB-M Step (Dirichlet)
    Nk = r_nk.sum(0)
    alpha = alpha0 + Nk

    # VB-M Step (Normal-Wishart)
    xbar = calc_xbar(x, r_nk)
    Sk = calc_S(x, xbar, r_nk)
    beta = beta0 + Nk
    mk = calc_m(xbar, Nk, m0, beta0, beta)
    W = calc_W(xbar, Sk, Nk, m0, beta0, inv_W0)
    nu = nu0 + Nk

    # VB-E Step
    ...
    # Visualization
    ...

変分 E ステップ

変分 E ステップの実装では、様々な確率変数の期待値を計算する必要があるので、これを順番に実装していきます。プロットに必要な期待値の計算も行います。モクモクと作ります。(PRML を参照のこと。)

ディリクレ分布

  

def expect_pi(alpha):
    return alpha / alpha.sum()

  

def expect_lpi(alpha):
    return digamma(alpha) - digamma(alpha.sum())
ウィシャート分布

  

def expect_llambda(W, nu):
    ndim = W.shape[0]
    arr = float64(nu - arange(ndim)) / 2
    return digamma(arr).sum() + ndim * ln2 + log(det(W))

  

def expect_lambda(W, nu):
    return W * nu
正規分布

  
ここでは、二次形式 を計算する関数を実装し、これを活用します。

def quad(A, x):
    num, ndim = x.shape
    ret = zeros(num, dtype = float64)

    for i in xrange(ndim):
        for j in xrange(ndim):
            ret += A[i, j] * x[:, i] * x[:, j]

    return ret

def expect_quad(x, m, beta, W, nu):
    ndim = x.shape[1]
    return ndim / beta + nu * quad(W, x - m[None, :])

  

def expect_log(x, m, beta, W, nu):
    ndim = x.shape[1]

    ex_llambda = expect_llambda(W, nu)
    ex_quad = expect_quad(x, m, beta, W, nu)
    return (ex_llambda - ndim * ln2pi - ex_quad) / 2
負担率の更新

導出編に示した通り、対数負担率は
  
という形で更新できるので、これを実装します。

ex_lpi = expect_lpi(alpha)
ex_log = zeros([num, num_class], dtype = float64)
for k in xrange(num_class):
    ex_log[:, k] = expect_log(x, mk[k], beta[k], W[k], nu[k])
lrho = ex_lpi[None, :] + ex_log

さらに、データ点ごとの負担率の総和が 1 になるように正規化します。

def normalize_response(lrho):
    num, num_class = lrho.shape
    ret = zeros([num, num_class], dtype = float64)

    for i in xrange(num):
        lr = lrho[i] - logsumexp(lrho[i])
        ret[i] = exp(lr)

    ret = maximum(ret, 1e-10) # ゼロよけ
    ret /= ret.sum(1)[:, None]
    return ret

可視化

変分 E ステップ、変分 M ステップが組みあがったので、これをレンタリングします。
ただし、最尤推定や MAP 推定と違い、今手元には の値がありません。そこで、これらの値の期待値を計算してなんとなくそれっぽいグラフを作ることにします。

ex_pi = expect_pi(alpha)
ex_lambda = zeros([num_class, ndim, ndim], dtype = float64)
for k in xrange(num_class):
    ex_lambda[k] = expect_lambda(W[k], nu[k])
preview_stage(ax, x, ex_pi, mk, inverse_matrices(ex_lambda))

ソースコード

できあがり!
なんとなくこんな感じで動かすとホクホクします

$ python
Python 2.7.1 (r271:86832, Apr 12 2011, 16:16:18) 
[GCC 4.6.0 20110331 (Red Hat 4.6.0-2)] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import vbgmm
>>> vbgmm.estimate()
# -*- coding: utf-8 -*-
from gmm import faithful_norm, init_figure, preview_stage
from scipy import arange, array, exp, eye, float64, log, maximum, ones, outer, pi, rand, zeros
from scipy.linalg import det, inv
from scipy.maxentropy import logsumexp
from scipy.special import digamma

ln2 = log(2)
ln2pi = log(2 * pi)

# ========== ========== ========== ==========
#
#   Helpers
#
# ========== ========== ========== ==========
#
#   パラメータが null の場合,デフォルト値を返します.
#
def fill_param(param, default):
    if (param == None):
        return default
    else:
        return param

#
#   行列のリストから,対応する逆行列のリストを計算します.
#
def inverse_matrices(A):
    return array(map(inv, A))

# ========== ========== ========== ==========
#
#   Dirichlet Distribution
#
# ========== ========== ========== ==========
#
#   E[\pi_k] を計算します.
#
def expect_pi(alpha):
    return alpha / alpha.sum()

#
#   E[\ln \pi_k] を計算します.
#
def expect_lpi(alpha):
    return digamma(alpha) - digamma(alpha.sum())

# ========== ========== ========== ==========
#
#   Wishart Distribution
#
# ========== ========== ========== ==========
#
#   E[\ln \Lambda] を計算します.
#
def expect_llambda(W, nu):
    ndim = W.shape[0]
    arr = float64(nu - arange(ndim)) / 2
    return digamma(arr).sum() + ndim * ln2 + log(det(W))

#
#   E[\Lambda] を計算します.
#
def expect_lambda(W, nu):
    return W * nu

# ========== ========== ========== ==========
#
#   Normal Distribution
#
# ========== ========== ========== ==========
#
#   行列 A とベクトルの配列 {x1, x2, …, xN} から
#   二次形式の配列 {A[x1], A[x2], …, A[xN]} を計算します.
#
def quad(A, x):
    num, ndim = x.shape
    ret = zeros(num, dtype = float64)

    for i in xrange(ndim):
        for j in xrange(ndim):
            ret += A[i, j] * x[:, i] * x[:, j]

    return ret

#
#   E[\lambda (x - \mu) ** 2] を計算します.
#
def expect_quad(x, m, beta, W, nu):
    ndim = x.shape[1]
    return ndim / beta + nu * quad(W, x - m[None, :])

#
#   E[\ln N(x|\pi, \mu, \Lambda)] を計算します.
#
def expect_log(x, m, beta, W, nu):
    ndim = x.shape[1]

    ex_llambda = expect_llambda(W, nu)
    ex_quad = expect_quad(x, m, beta, W, nu)
    return (ex_llambda - ndim * ln2pi - ex_quad) / 2

# ========== ========== ========== ==========
#
#   VB-E Step Helper
#
# ========== ========== ========== ==========
#
#   対数負担率を正規化します.
#
def normalize_response(lrho):
    num, num_class = lrho.shape
    ret = zeros([num, num_class], dtype = float64)

    for i in xrange(num):
        lr = lrho[i] - logsumexp(lrho[i])
        ret[i] = exp(lr)

    ret = maximum(ret, 1e-10)
    ret /= ret.sum(1)[:, None]
    return ret

# ========== ========== ========== ==========
#
#   VB-M Step Helper
#
# ========== ========== ========== ==========
#
#   負担率を用いてクラス平均を計算します.
#
def calc_xbar(x, r_nk):
    num, ndim = x.shape
    num, num_class = r_nk.shape
    ret = zeros([num_class, ndim], dtype = float64)

    for k in xrange(num_class):
        clres = r_nk[:, k]
        for i in xrange(ndim):
            ret[k, i] = (clres * x[:, i]).sum()
        ret[k, :] /= clres.sum()

    return ret

#
#   負担率を用いてクラス分散を計算します.
#
def calc_S(x, xbar, r_nk):
    num, ndim = x.shape
    num, num_class = r_nk.shape
    ret = zeros([num_class, ndim, ndim], dtype = float64)

    for k in xrange(num_class):
        clres = r_nk[:, k]
        for i in xrange(ndim):
            diff_i = x[:, i] - xbar[k, i]
            for j in xrange(ndim):
                diff_j = x[:, j] - xbar[k, j]
                ret[k, i, j] = (clres * diff_i * diff_j).sum()
        ret[k] /= clres.sum()

    return ret


#
#   平均のハイパーパラメータを計算します.
#
def calc_m(xbar, Nk, m0, beta0, beta):
    num_class, ndim = xbar.shape
    ret = zeros([num_class, ndim], dtype = float64)

    for k in xrange(num_class):
        ret[k] = (beta0 * m0 + Nk[k] * xbar[k]) / beta[k]

    return ret

#
#   精度のハイパーパラメータを計算します.
#
def calc_W(xbar, Sk, Nk, m0, beta0, inv_W0):
    num_class, ndim = xbar.shape
    ret = zeros([num_class, ndim, ndim], dtype = float64)

    for k in xrange(num_class):
        ret[k] = inv_W0 + Nk[k] * Sk[k]
        fact = beta0 * Nk[k] / (beta0 + Nk[k])
        diff = xbar[k] - m0
        for i in xrange(ndim):
            for j in xrange(ndim):
                term = diff[i] * diff[j]
                ret[k, i, j] += fact * term
        ret[k] = inv(ret[k])

    return ret

# ========== ========== ========== ==========
#
#   Program
#
# ========== ========== ========== ==========
def estimate(num_class = 6, num_iter = 100, alpha0 = None, beta0 = None,
             m0 = None, W0 = None, nu0 = None):
    # Data
    x = faithful_norm()
    num, ndim = x.shape

    # Graph
    ax = init_figure()

    # Hyperparameters
    alpha0 = fill_param(alpha0, ones(num_class) * 1e-3)
    beta0 = fill_param(beta0, 1e-3)
    m0 = fill_param(m0, zeros(ndim))
    W0 = fill_param(W0, eye(2))
    nu0 = fill_param(nu0, 1)
    inv_W0 = inv(W0)

    # Initial VB-E Step
    r_nk = rand(num, num_class)
    r_nk /= r_nk.sum(1)[:, None]

    # Execution
    for iiter in xrange(num_iter):
        # VB-M Step (Dirichlet)
        Nk = r_nk.sum(0)
        alpha = alpha0 + Nk

        # VB-M Step (Normal-Wishart)
        xbar = calc_xbar(x, r_nk)
        Sk = calc_S(x, xbar, r_nk)
        beta = beta0 + Nk
        mk = calc_m(xbar, Nk, m0, beta0, beta)
        W = calc_W(xbar, Sk, Nk, m0, beta0, inv_W0)
        nu = nu0 + Nk

        # VB-E Step
        ex_lpi = expect_lpi(alpha)
        ex_log = zeros([num, num_class], dtype = float64)
        for k in xrange(num_class):
            ex_log[:, k] = expect_log(x, mk[k], beta[k], W[k], nu[k])
        lrho = ex_lpi[None, :] + ex_log
        r_nk = normalize_response(lrho)

        # Visualization
        ex_pi = expect_pi(alpha)
        ex_lambda = zeros([num_class, ndim, ndim], dtype = float64)
        for k in xrange(num_class):
            ex_lambda[k] = expect_lambda(W[k], nu[k])
        preview_stage(ax, x, ex_pi, mk, inverse_matrices(ex_lambda))

if (__name__ == '__main__'):
    estimate()