糞糞糞ネット弁慶

読んだ論文についてメモを書きます.趣味の話は http://repose.hatenablog.com

2つの GMM(Gaussian Mixture Model) の類似度を KL Divergence で測る

結論から言うと,結構面倒なのでサンプリングで近似すれば良い.死ぬほど精度が必要とかで無い限り, 後述する Variational Approximation を使えば良さそう.

目的

GMMは正規分布の重み付き和で表現される確率分布.二つのGMMの類似度を測る必要が生じたので調べていたら案外と面倒だったのでメモしておく.

KL Divergence

確率分布の類似度と言えばKLダイバージェンス.最近では正規分布間のKLダイバージェンスの導出 - 唯物是真 @Scaled_Wurm多変量(多次元)正規分布のKLダイバージェンスの求め方 - EchizenBlog-Zweiでも触れられている.
誰か実装できる形にまで書き下しているかと思ったら閉じた形では書けないとのこと.そしてこれをどうにか近似するというのは最近でも取り組まれている研究テーマであるらしい.
Approximating the Kullback Leibler Divergence Between Gaussian Mixture Models - IEEE Conference Publication
Lower and upper bounds for approximation of the Kullback-Leibler divergence between Gaussian Mixture Models - Infoscience
Artificial Intelligence Blog · Approximation of KL distance between mixtures of Gaussians
どうにか近似しようという話があったので一番簡単な Monte Carlo samplingを試す. (GMMを一つのgaussianで近似しようとかいう手法も紹介されてはいるが三件目のブログで this is bad, don’t do it, if you do do it,“ I told you so” とか言われていて面白い)

手法

手法は非常にシンプル.二つの正規分布の和で書かれた確率分布について

としてやる.あとはnを十分大きくすればになるという.
実装の方針としては

  • 混合比に従って正規分布を一つ選ぶ
  • 選ばれた番目の正規分布から乱数を一つサンプリング
  • それを使ってを計算
  • あとはこれをN回繰り返して平均を取る

実装する上で注意しなければならないのは,いくらKL Divergence がその定義上非負だからといって,乱数の種の初期化のタイミングを間違えた上で使う乱数が偏っているとf/gを計算する時に負の値が本当に大量に出てくるので念のため非負の判定を加える必要がある点だと思う.

実験

Rubyでの実装.実際GMMクラスでは分布の推定までをやっているがそれは長くなるので今回はパラメータだけを持つようにした.EMでGMMの推定までやるコードはGistに貼る
正規分布からのサンプリングに rubystats を使っているので gem install rubystats が必要. piのサンプリングの部分,ライブラリがあったはずと思ってGitHub - todesking/weighted_sample: Enumerable#weighted_sample_byを見たが整数の重みに対応していなかった.

# -*- coding: utf-8 -*-
require 'rubystats'
PROB_LIM = 10 ** -100

class GMM
  attr_reader :pi, :mu, :sigma
  def initialize(opt = {  })
    @k = opt[:k] || 2

    # initialize paramters
    @mu = opt[:mu]
    @sigma = opt[:sigma]
    @pi = opt[:pi]
  end
  
  def dist(x, mu, sigma)
    prob = 1.0 / (2 * Math::PI * sigma) ** 0.5 * Math::exp(-(x - mu) ** 2 / (2 * sigma))
    prob < PROB_LIM ? PROB_LIM : prob
  end

  def prob(new_x)
    sum = 0.0
    @k.times do |k|
      sum += @pi[k] * dist(new_x, @mu[k], @sigma[k])
    end
    sum
  end
end

class KLDivergence
  def self.kl_divergence(gmm_1, gmm_2, iter = 10000)
    random_xs = [ ]
    # ここでまずは pi をサンプリングする前準備としてソート
    pi = gmm_1.pi.sort_by{|e| e.first}

    iter.times do |i|
      # select a_i
      prob_tables = [ ]
      pi.each do |elem|
        # elem #=> [pi_i, val]
        val = elem.last
        prob_tables.push val
        prob_tables[-1] += prob_tables[-2] if prob_tables.size > 1
      end

      # [0, 1] に正規化
      prob_tables.each_index do |i|
        prob_tables[i] /= prob_tables[-1]
      end

      # sampling
      select_a = 0
      v = rand
      if v > prob_tables[0]
          1.upto (prob_tables.size - 1) do |i|
          if prob_tables[i - 1] <= v && prob_tables[i] > v
            select_a = i
            break
          end
        end
      end

      # sampling N_{select_a}
      mean = gmm_1.mu[select_a]
      sigma = gmm_1.sigma[select_a]
      random_xs.push Rubystats::NormalDistribution.new(mean, sigma ** 0.5).rng
    end

    # average
    ret = 0.0
    random_xs.each do |x|
      ret += Math::log(gmm_1.prob(x) / gmm_2.prob(x))
    end

    ret > 0 ? ret / iter : 0.0
  end

  def self.sim(gmm_1, gmm_2, iter = 100000)
    1.0 / (1 + kl_divergence(gmm_1, gmm_2, iter))
  end
end

if __FILE__ == $0
  mu = {0 => 100, 1 => 10000}
  sigma = {0 => 3, 1 => 500}

  gmm_1 = GMM.new(mu: mu, sigma: sigma, pi: {0 => 0.7, 1 => 0.3})
  gmm_2 = GMM.new(mu: mu, sigma: sigma, pi: {0 => 0.3, 1 => 0.7})

  TOTAL = 100
  srand(0)
  puts "|*試行回数|*最小値|*最大値|*平均|*分散|"
  [10, 100, 1000, 10000].each do |iter|
    ary = [ ]
    TOTAL.times do
      ary.push KLDivergence.kl_divergence(gmm_1, gmm_2, iter)
    end
    mean = ary.inject(:+) / ary.size
    var = ary.inject(0.0){|s, v| s += (v - mean) ** 2} / ary.size
    puts  "|%d|%04f|%04f|%04f|%04f|" % [iter, ary.min, ary.max, mean, var]
  end
end

上記論文ではサンプリングの回数でかなり値がばらつくと言われているので実験を行った.
,二つの分布のKLダイバージェンスをサンプリング回数を変えながら求める.
サンプリング回数を10,100,1000,10000と変化させ,それぞれ100回ずつKL Divergence を計算した際の最小値,最大値,平均,分散を表にした.図は気力が尽きた.

試行回数 最小値 最大値 平均 分散
10 0.000000 0.847298 0.364338 0.047311
100 0.152514 0.559217 0.337733 0.006371
1000 0.282997 0.394841 0.339326 0.000578
10000 0.317228 0.356035 0.337225 0.000056
100000 0.333242 0.345376 0.338746 0.000005

結果がこの表.当然,サンプリング回数を増やすと安定はしてくるが,ではどの程度サンプリングすれば良いのか,それは求める精度はどの程度のものなのか,によってくるので非常に微妙に思える.大体10000回も回せば十分なのではないかとも思うけれども(最小値最大値ともに異常な値が出ていないように思える).

KL以外の尺度

Closed-Form Expression for Divergence of Gaussian Mixture Models | kittipatkampa
Cauchy-Schwarz Divergence なら閉じた形で書けるから問題無いという論文を発見.別段KLにこだわりは無いので最初はこちらを試そうとしたが,式を見て挫折した.

09/03 追記

Monte Carlo samplingで求めていたがそこそこ重く,ボトルネックになってしまった.
そこで,Approximating the Kullback Leibler Divergence Between Gaussian Mixture Models - IEEE Conference Publication でも触れられている Variational Approximation を試してみてその誤差が許容できるかどうかを確認する.

Variational Approximation

Variational Approximation は,二つの正規分布の和で書かれた確率分布について

として近似するもの.計算には通常の(閉じた形で求められる) gaussian の KL Divergence を使う.サンプリングが要らないので計算時間は早い.

実装

class KLDivergence
  # http://sucrose.hatenablog.com/entry/2013/07/20/190146
  # kld between two gaussians
  def self.kld(mu_1, sigma_1, mu_2, sigma_2)
    (Math::log(sigma_2 / sigma_1) + sigma_2 / sigma_1 +  (mu_1 - mu_2) ** 2 / sigma_2  - 1) * 0.5
  end
  
  def self.kl_divergence_by_variational_approximation(gmm_1, gmm_2)
    ret = 0.0
    size_1 = gmm_1.mu.size
    size_2 = gmm_2.mu.size
    0.upto (size_1 - 1) do |a|
      pi_a = gmm_1.pi[a]
      mu_a = gmm_1.mu[a]
      sigma_a = gmm_1.sigma[a]

      # numer
      numer = 0.0
      0.upto (size_1 - 1) do |i|
        mu = gmm_1.mu[i]
        sigma = gmm_1.sigma[i]
        pi = gmm_1.pi[i]
        numer += pi * Math::exp(-kld(mu_a, sigma_a, mu, sigma))
      end

      # denom
      denom = 0.0
      0.upto (size_2 - 1) do |i|
        mu = gmm_2.mu[i]
        sigma = gmm_2.sigma[i]
        pi = gmm_2.pi[i]
        denom += pi * Math::exp(-kld(mu_a, sigma_a, mu, sigma))
      end
      ret += pi_a * Math::log(numer / denom)
    end
    ret
  end
end

前回の実装を使いまわす形で追加で定義.結果は0.338919となった.
前回までの結果と比較するため,上の表に継ぎ足す形で貼る.

試行回数 最小値 最大値 平均 分散
10 0.000000 0.847298 0.364338 0.047311
100 0.152514 0.559217 0.337733 0.006371
1000 0.282997 0.394841 0.339326 0.000578
10000 0.317228 0.356035 0.337225 0.000056
100000 0.333242 0.345376 0.338746 0.000005
VA - - 値: 0.338919 10k回との差: 0.000173

これぐらいの精度なら実装を差し替えても十分かもしれない.