廿TT

譬如水怙牛過窓櫺 頭角四蹄都過了 因甚麼尾巴過不得

グループドデータの非負値行列因子分解

今日の川柳

ていねいなしらない人の顔写真

モチベーション

たとえばこういう表がある。

userGender userAgeBracket (direct) anlyznews.com b.hatena.ne.jp bing facebook.com google langstat.hatenablog.com m.facebook.com matome.naver.jp search.fenrir-inc.com search.smt.docomo sp-web.search.auone.jp t.co yahoo
female 18-24 241 0 0 370 0 2546 0 0 13 0 13 0 34 579
female 25-34 370 0 0 236 0 2896 0 0 17 0 24 21 47 687
female 35-44 352 0 0 54 0 1190 0 0 17 0 31 25 36 577
female 45-54 122 0 0 24 0 347 0 0 0 0 0 0 13 190
female 55-64 52 0 0 0 0 135 0 0 0 0 0 0 0 91
female 65+ 30 0 0 0 0 62 0 0 0 0 0 0 0 62
male 18-24 859 0 15 1084 0 9714 0 0 0 0 14 0 328 1265
male 25-34 1107 11 63 583 15 11186 17 28 15 0 20 31 800 1476
male 35-44 791 0 39 226 0 5957 0 25 10 12 24 14 511 1355
male 45-54 255 0 0 108 0 1736 0 12 0 0 0 12 157 643
male 55-64 104 0 0 71 0 683 0 0 0 0 0 0 13 327
male 65+ 57 0 0 57 0 234 0 0 0 0 0 0 0 248
view raw gadata.csv hosted with ❤ by GitHub
gist.github.com

表の左のほうにユーザー層の情報、右の方にユーザー層ごとのブログへのアクセス経路が書かれている。

どのユーザー層がどの経路を好むか知りたいとする。

そこでトピックモデルとしてポアソン分布を使った非負値行列因子分解を考える。
トピックモデルシリーズ 6 GaP (Gamma-Poisson Model) - StatModeling Memorandum などを参照。)

アクセス経路が単語に対応する。

ユーザー層の情報を捨てて、行列を分解してしまうのはおもしろくない。

ユーザー層の情報を説明変数として、ユーザー層ごとにトピックの構成が変わるようなモデルにしたい。

トピックごとの単語の出現しやすさは説明変数によらず一定とする。

モデル

観測値を行列の積 XVH で近似することを目指します。

Y: 観測された分解したい行列(N行K列)
X: 観測された説明変数(N行J列)
V: パラメータ(J行L列)
H: パラメータ(L行K列)

X の中身は0、1のダミー変数でユーザー層を表したものです。

また、W=XVH とします。

L はトピック数です。分析者が決めます。

行列W、Y、X、V、Hの要素は小文字のw、y、x、v、xに下付きの添字を付けて表します。

観測モデルは以下です。

yn,kPoisson(wn,k)

事前分布として以下を仮定します。
vn,lGamma(αv,βv)
hl,jGamma(αh,βh)

事前分布を入れないと0の多い行列は途中計算でNaNになっちゃうことが多いです。

以下の潜在変数を考えます。
yn,k=ljsn,j,l,k
sn,j,l,kPoisson(xn,jvj,lhl,k)

ポアソン分布の和はポアソン分布に従うという性質を使いました。

変分ベイズ

ベイズ推論による機械学習入門』に出てくる非負値行列因子分解とほぼ同じ計算なので詳しくはそちらを参照してください。

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)

本を買うのがいやな人は論文をみてください。

Bayesian Inference for Nonnegative Matrix Factorisation Models

変分ベイズの更新式を導出するのはじめてなので、間違ってたらすみません。

事後分布が以下のように分解できるとして近似します。

q(S,V,H)=q(S)q(V)q(H)

a の期待値を取る操作を E(a) で表します。

変分ベイズの更新式は以下のようになります。

q(vj,l)=Gamma(α^v(j,l),β^v(j,l))

α^v(j,l)=nkE(sn,l,k)+αv
β^v(j,l)=nkxn,jE(hl,k)+βv

q(hl,k)=Gamma(α^v(l,k),β^v(l,k))

α^h(l,k)=njE(sn,l,k)+αh
β^h(l,k)=njxn,jE(vj,l)+βh

q(sn,:,:,k)=Multinomial(yn,k,p^n,j,l,k)

p^n,j,l,kexpE(log(xn,jvj,lhl,k))

確率密度関数も確率分布とおなじ記号 Gamma で書きましたが、文脈で区別できると思います。

近似分布の更新に必要なガンマ分布の対数の期待値はディガンマ関数 ψ を用いて以下のように表わせます。

XGamma(a,b)
ならば
E(logX)=ψ(a)logb

R による実装

s を三次元配列として保存しないで、V や H の更新式に代入してしまうのがこつです。

NMFVB <-function(Y,X,L=2,alpha=1,beta=1,tol=1,maxit=5000,seed=1234){
  set.seed(seed)
  N <- nrow(Y)
  K <- ncol(Y)
  J <- ncol(X)
  EV <- EelV <- matrix(rgamma(J*L,shape=alpha,scale=beta),J,L)
  EH <- EelH <- matrix(rgamma(L*K,shape=alpha,scale=beta),L,K)
  for(iter in 2:maxit){
    den <- ((X %*% EelV)) %*% EelH
    Sh <- EelH * (t((X %*% EelV)) %*% (Y/den))
    Sv <- EelV * t(X) %*% (((Y/den)%*%t(EelH)))
    beta_H <- colSums(X%*%EV) + beta
    alpha_H <- alpha + Sh
    EH <- alpha_H/beta_H
    beta_V <- outer(colSums(X),rowSums(EH)) + beta
    alpha_V <- alpha + Sv
    EV <- alpha_V/beta_V
    EelH <-exp(digamma(alpha_H))/beta_H
    EelV <-exp(digamma(alpha_V))/beta_V
  }
  return(list(V=EV,H=EH,
              alpha_V=alpha_V,beta_V=beta_V,
              alpha_H=alpha_H,beta_H=beta_H))
}

結果

勘によりトピック数は L = 3 としました。

勘できめるのがいやな人は論文をみてください。

Bayesian Inference for Nonnegative Matrix Factorisation Models

ELBOを比べることでモデルを選択する方法が記載されています。

もとめた W=XVH(左)ともとの行列 Y(右)を並べてみます。

f:id:abrahamcow:20180705021358p:plain

まあまあ雰囲気を再現できているんじゃないでしょうか。

行列 V の値をみてみます。V1、V2、V3はそれぞれ V の1列目、2列目、3列目で潜在的なトピックを表します。

f:id:abrahamcow:20180705021620p:plain

18-24歳はV1が支配的です。歳を取るにつれてV2が多くなってきます。また女性は男性よりV2が多く他が少ないです。
V3は18-24歳にはほとんどなく、25歳以上になると増えますが歳を取るにつれて減っていきます。

V1、V2、V3に解釈を与えるために H の中身を見ます。H はトピックごとの単語の出現しやすさです。

f:id:abrahamcow:20180705022544p:plain

一応ガンマ分布の95%区間をエラーバーで重ねていますが、幅が狭くてほとんど見えません。

V1はGoogleが支配的ですがBingも多めです。調べものをしていてこのブログにたどり着く層でしょうか。若年層は具体的に知りたいことがあってこのブログに来る人が多いようです。
V2はGoogleが少なくYahooが多いのが特徴です。歳を取るにつれてYahooユーザーが多くなるみたいです。また、女性のほうがYahooユーザーが多そうです。
V3のV1、V2との違いはt.co(ツイッター)成分がそこそこある点です。SNS経由で来る人は18-24歳にはほとんどなく、25歳以上になると増え歳を取るにつれて減っていくみたいです。

最後にRのコードをまとめて貼ります。

library(googleAnalyticsR)
library(tidyverse)
NMFVB <-function(Y,X,L=2,alpha=1,beta=1,tol=1,maxit=5000,seed=1234){
  set.seed(seed)
  N <- nrow(Y)
  K <- ncol(Y)
  J <- ncol(X)
  EV <- EelV <- matrix(rgamma(J*L,shape=alpha,scale=beta),J,L)
  EH <- EelH <- matrix(rgamma(L*K,shape=alpha,scale=beta),L,K)
  for(iter in 2:maxit){
    den <- ((X %*% EelV)) %*% EelH
    Sh <- EelH * (t((X %*% EelV)) %*% (Y/den))
    Sv <- EelV * t(X) %*% (((Y/den)%*%t(EelH)))
    beta_H <- colSums(X%*%EV) + beta
    alpha_H <- alpha + Sh
    EH <- alpha_H/beta_H
    beta_V <- outer(colSums(X),rowSums(EH)) + beta
    alpha_V <- alpha + Sv
    EV <- alpha_V/beta_V
    EelH <-exp(digamma(alpha_H))/beta_H
    EelV <-exp(digamma(alpha_V))/beta_V
  }
  return(list(V=EV,H=EH,
              alpha_V=alpha_V,beta_V=beta_V,
              alpha_H=alpha_H,beta_H=beta_H))
}

ga_auth()
account_list <- ga_account_list()
ga_id <- account_list$viewId[3]

gadata <-
  google_analytics(ga_id,
                   date_range = c("2018-01-01","2018-06-30"),
                   metrics = c("sessions"),
                   dimensions = c("source","userGender","userAgeBracket"))

gadata_w <-spread(gadata,source,sessions,fill=0)

gamat <-as.matrix(gadata_w[,-c(1:2)])

X1 <-model.matrix(~userGender-1,data=gadata_w)
X2 <-model.matrix(~userAgeBracket-1,data=gadata_w)
X <- gaX <-cbind(X1,X2)

out <-NMFVB(Y=gamat,X=gaX,L=3)

obsdf <-as.data.frame(gamat) %>% 
  set_names(1:ncol(gamat)) %>% 
  mutate(row=row_number()) %>% 
  gather(col,sessions,-row) %>% 
  mutate(col=as.integer(col),type="obs")

fitdf <- as.data.frame(X %*% out$V %*% out$H) %>% 
  set_names(1:ncol(gamat)) %>% 
  mutate(row=row_number()) %>% 
  gather(col,sessions,-row) %>% 
  mutate(col=as.integer(col),type="fit")

outdf <-bind_rows(obsdf,fitdf)

ggplot(outdf,aes(x=col,y=row,fill=sessions))+
  geom_tile()+
  facet_wrap(~type)

dfV <-as.data.frame(out$V) %>%
  rownames_to_column() %>%
  gather(key,value,-rowname)

ggplot(dfV,aes(x=rowname,y=value,fill=key))+
  geom_col(colour="black",position = "fill")+
  coord_flip()


CIHlower <- as.data.frame(qgamma(0.025,shape=out$alpha_H,rate=out$beta_H)) %>% 
  mutate(l=row_number()) %>%
  gather(source,lower,-l)


CIHupper <- as.data.frame(qgamma(0.975,shape=out$alpha_H,rate=out$beta_H)) %>% 
  mutate(l=row_number()) %>%
  gather(source,upper,-l)

Hdf <- as.data.frame(out$H) %>% 
  mutate(l=row_number()) %>% 
  gather(source,value,-l) %>% 
  left_join(CIHlower) %>% 
  left_join(CIHupper) %>% 
  group_by(l)

ggplot(Hdf,aes(x=source,y=value,ymin=lower,ymax=upper))+
  geom_col(fill="white",colour="black")+
  geom_errorbar(width=0.5)+
  facet_wrap(~l,scales="free_x")+
  coord_flip()