混合ガウス分布

Last update: 2018/8/22

混合ガウス分布モデルとその学習

データの分布が複数の山(モード)を持つ場合には、この分布のモデルとして、混合ガウス分布モデル(Gaussian mixture model) がしばしば用いられます。混合ガウス分布は、図1で示すように、複数のガウス分布の重ね合わせで表され、その確率密度関数は次式で与えられます。

図1:GMMの確率密度関数

ここで、Kは、重ね合わせるガウス分布の数、(μk.,Σk)はk番目のガウス分布の平均値ベクトルと共分散行列です。.πk は、混合係数(mixing coefficients)と呼ばれ、重ね合わせるガウス分布の重みを表します。混合係数は次式を満たします。

混合係数πk は、特定のデータが観測される前(事前)に、データがk番目のガウス分布からのサンプルである確率(事前確率)を表します。これを式で書くと、次式のようになります。

一方、特定のデータ x が観測されたあと(事後)に、そのデータがk番目のガウス分布からのサンプルである確率(事前確率)は、負担率(responsibility)と呼ばれ、次式で表されます。

この負担率を用いて、データ{xn|n=1,...,N}のクラスタリングを行うことができます。モデルのパラメタ

は、次式の対数尤度を最大化するように決定されます。

ここで、X={x1,..., xN}は、観測データのセットを表します。実際のパラメタの学習には、尤度を逐次的に最大化するEMアルゴリズムが用いられます。

サンプルプログラム1

ここでは、表1の3つの正規分布を混合した分布 p(x) からサンプルを取り、GMMによるクラスタリングを行います。図2は、データ x のヒストグラムです。 p(x) は、図1に表示されている分布であり、ヒストグラムも、図1と同様の分布になっています。

学習には、sklearn.mixture.gmmモジュールを用います。predict()メソッドは、事後確率(負担率)が最大となる k を返します。結果の見易さのため、サンプルのうち先頭の30個だけをクラスタリングしたのが、図3です。クラスラベル(推定したkの値)は、マーカの色と形で表されています。この図には、学習の結果得られた3つの正規分布の平均値 g.means_ も点線で示してあります。表1の結果と概ね一致しています。

表1:混合した正規分布

import numpy as np

import scipy.stats as st

import matplotlib.pyplot as plt

from sklearn import mixture

# パラメータ

K = 3 # 混合数

mu = [-3.0, 0.0, 4.0] # 平均値

std = [1.0, 1.5, 0.8] # 標準偏差

pi = [0.3, 0.4, 0.3] # 混合係数

N = 5000 # サンプル数

# 確率密度関数

x = np.linspace(-8.0, 8.0, num=200)

lineColor = ['b','g','r','k']

legend = ['p1','p2','p3','p']

pMix = np.zeros(len(x))

fig = plt.figure()

for k in range(K):

p = st.norm.pdf(x, loc=mu[k], scale=std[k])

plt.plot(x,pi[k]*p,lineColor[k],label=legend[k])

pMix += pi[k]*p

plt.plot(x,pMix,lineColor[3], label=legend[3])

plt.xlabel('x')

plt.legend()

# データの作成

x = np.array([])

for k in range(K):

Nk = int(N*pi[k])

xk = st.norm.rvs(loc=mu[k], scale=std[k], size=Nk) # 各ガウス分布からサンプリング

x = np.append(x,xk)

buf = np.random.choice(x,size=N, replace=False) # データをランダムに並び変え

x = buf

# ヒストグラム

bin = np.arange(-8.0,8.0,0.1)

fig = plt.figure()

plt.hist(x,bins=bin,width=0.1)

plt.xlabel('x')

# GMMの学習

# データXは2次元配列として与える。列数:データの次元、行数:データ数N

g = mixture.GaussianMixture(n_components=K)

X = np.array(x,ndmin=2).T

g.fit(X)

print('Mean:',g.means_)

print('Cov:',g.covariances_)

print('Weight:',g.weights_)

#クラスタリング

Np = 30 # 例として先頭の30個だけクラスタリング

y = x[0:Np]

Y = np.array(y,ndmin=2).T

classLabel = g.predict(Y)

# 結果の表示

lineColor = ['bo','gd','rs']

fig = plt.figure()

for k in range(K):

gy = np.array([])

gx = np.array([])

for n in range(Np):

if classLabel[n] == k:

gy = np.append(gy,y[n])

gx = np.append(gx,float(n))

plt.plot(gx,gy,lineColor[k])

plt.xlabel('Data number')

plt.ylabel('x')

(xmin,ymin) = plt.xlim()

for k in range(K):

plt.hlines(g.means_[k],xmin,ymin,linestyle='dashed',colors=lineColor[k][0])

図2:観測データのヒストグラム

図3:クラスタリングの結果

サンプルプログラム2

このプログラムでは、国語と数学の10人分の成績(2次元データ)を読み込んで、GMMを用いて2つのクラスに分類します。図4に示す散布図の点の色がクラスタリングの結果、等高線が個々のガウス分布の確率密度関数を表します。

GMM2D.py

import numpy as np

import matplotlib.pyplot as plt

import scipy.stats as st

from sklearn import mixture

# データの読み込み

fileName = '../Data/etc/examResult.csv'

X = np.loadtxt(fileName, delimiter=',',skiprows=1,usecols=(1,3))

Subject = ['Kokugo','Sugaku']

# GMMの学習

CovarianceType='full' # 共分散行列のタイプ

N = 2 # クラス数

gmm = mixture.GaussianMixture(n_components=N,covariance_type=CovarianceType)

gmm.fit(X)

classLabel = gmm.predict(X)

print('[GMM] Mean:',gmm.means_)

print('[GMM] Cov:',gmm.covariances_)

print('[GMM] Weight:',gmm.weights_)

print('[GMM] Class Label:', classLabel)

# 結果の表示

Colors = ('b','r','g')

gxmin,gxmax = (30,100)

gymin,gymax = (30,100)

fig = plt.figure(figsize=(6,6))

for k,(mean,cov,color) in enumerate(zip(

gmm.means_,gmm.covariances_,Colors)):

# 散布図

plt.scatter(X[classLabel == k,0], X[classLabel == k,1], color=color)

# 等高線

x, y = np.mgrid[gxmin:gxmax:1, gymin:gymax:1]

pos = np.empty(x.shape + (2,))

pos[:, :, 0] = x;

pos[:, :, 1] = y

rv = st.multivariate_normal(mean,cov)

P = rv.pdf(pos)

plt.contour(x, y, P, 5)

plt.xlim(gxmin,gxmax)

plt.ylim(gymin,gymax)

plt.xlabel(Subject[0])

plt.ylabel(Subject[1])

plt.show()

図4:2次元データのクラスタリング

参考文献

    1. C. M. Bishop, "Pattern recognition and machine learning," Springer, 2006

    2. http://scikit-learn.org/stable/modules/mixture.html#mixture