Synthesizing Tabular Data using Generative Adversarial Networks (preprint) 読んだ
[1811.11264] Synthesizing Tabular Data using Generative Adversarial Networks]
GAN を使って表形式のデータを生成する論文は既に読んだわけですが,その発展形. 著者らによる実装も公開されており(DAI-Lab/TGAN: Generative adversarial training for synthesizing tabular data),実装を試した人もいる(テーブルデータ向けのGAN(TGAN)で、titanicのデータを増やす - u++の備忘録).
前述した tableGAN との違いは CNN を用いずに LSTM を用いていること,交差エントロピーを用いるのではなく KL divergence を使って周辺分布を学習していることの二点.
データ変換
データが 個の連続値の変数 と 個の離散値の変数 で構成されているとし,各行の各列についてそれぞれが連続値の変数なのか,離散値の変数なのかを区別して話を進める.
連続値の変数について
多くの場合連続値の変数は多峰 (multimodal) である.なのでそのまま表現せず,次のような手続きを踏む.
- に変換する
- それぞれの変数について混合数 のGMM (Gaussian Mixture Model) を学習し,平均 および標準偏差 を得る
- 番目の変数の 列目の値 が GMM の各要素から得られる確率 を得る
- を とする.この時 である.その後 を ] に clip する
一言で言えば連続値の変数を 個の正規分布でクラスタリングし,一番当てはまりが良い分布に関する情報を持つ.論文では とし,もし単峰の変数だったとしても 個の正規分布に対する重みがゼロになるから構わないとしている.この手続きの結果,連続値の変数 を と の 次元で表現する.
離散値の変数について
- 列目の離散値の全要素を として離散値 を one-hot encoding して とする
- の各次元 に なノイズを加える
- を確率に正規化する
これら二種類の処理により, 次元のデータは 次元に変換される.
また,これから説明する GAN は上記の を生成するわけですが,本来の値に戻すには次のように変換すればいい.
- 連続値
- 離散値
生成
Generator には LSTM を使う.LSTM を使う理由は we use LSTM with attention in order to generate data column by column.
としか書かれていないが,気持ちを汲み取ると各変数間の相関などを陽に考慮したいからだと思う.
LSTM の出力を として hidden vector を求め,更に として各変数を出力する.その後, ステップの LSTM に を渡す.連続値の場合は を得,次に を得る.また,離散値の場合は ステップにはそのまま渡さずに ] として渡す( は 次元の embedding).
Discriminator には mini-batch discrimination vector 入りの MLP を用いる(Generator が LSTM なのだから Discriminator も LSTM で良かったのではないか).
通常の GAN の損失関数に加え, Generator 側の損失関数について連続値変数 に関する KL divergence と離散値の変数そのものの KL divergence を追加することで学習が安定するらしい.
実験
評価は三種類.
- 学習を生成したデータ,予測対象を元データとした時にどの程度精度を保つことができるかの Machine learning efficacy
- 前回の論文で model compatibility と呼んでいたもの
- 「変数間の相関が保存されているか」の検証として,連続変数を離散化して変数間の normalized mutual information を計算し描画
- 「真のデータにどれほど近いか」の検証として,学習データとテストデータまたは生成データ全対の距離のヒストグラムを描画