モチベーション
たとえばこういう表がある。
userGender | userAgeBracket | (direct) | anlyznews.com | b.hatena.ne.jp | bing | facebook.com | langstat.hatenablog.com | m.facebook.com | matome.naver.jp | search.fenrir-inc.com | search.smt.docomo | sp-web.search.auone.jp | t.co | yahoo | ||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
female | 18-24 | 241 | 0 | 0 | 370 | 0 | 2546 | 0 | 0 | 13 | 0 | 13 | 0 | 34 | 579 | |
female | 25-34 | 370 | 0 | 0 | 236 | 0 | 2896 | 0 | 0 | 17 | 0 | 24 | 21 | 47 | 687 | |
female | 35-44 | 352 | 0 | 0 | 54 | 0 | 1190 | 0 | 0 | 17 | 0 | 31 | 25 | 36 | 577 | |
female | 45-54 | 122 | 0 | 0 | 24 | 0 | 347 | 0 | 0 | 0 | 0 | 0 | 0 | 13 | 190 | |
female | 55-64 | 52 | 0 | 0 | 0 | 0 | 135 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 91 | |
female | 65+ | 30 | 0 | 0 | 0 | 0 | 62 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 62 | |
male | 18-24 | 859 | 0 | 15 | 1084 | 0 | 9714 | 0 | 0 | 0 | 0 | 14 | 0 | 328 | 1265 | |
male | 25-34 | 1107 | 11 | 63 | 583 | 15 | 11186 | 17 | 28 | 15 | 0 | 20 | 31 | 800 | 1476 | |
male | 35-44 | 791 | 0 | 39 | 226 | 0 | 5957 | 0 | 25 | 10 | 12 | 24 | 14 | 511 | 1355 | |
male | 45-54 | 255 | 0 | 0 | 108 | 0 | 1736 | 0 | 12 | 0 | 0 | 0 | 12 | 157 | 643 | |
male | 55-64 | 104 | 0 | 0 | 71 | 0 | 683 | 0 | 0 | 0 | 0 | 0 | 0 | 13 | 327 | |
male | 65+ | 57 | 0 | 0 | 57 | 0 | 234 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 248 |
表の左のほうにユーザー層の情報、右の方にユーザー層ごとのブログへのアクセス経路が書かれている。
どのユーザー層がどの経路を好むか知りたいとする。
そこでトピックモデルとしてポアソン分布を使った非負値行列因子分解を考える。
(トピックモデルシリーズ 6 GaP (Gamma-Poisson Model) - StatModeling Memorandum などを参照。)
アクセス経路が単語に対応する。
ユーザー層の情報を捨てて、行列を分解してしまうのはおもしろくない。
ユーザー層の情報を説明変数として、ユーザー層ごとにトピックの構成が変わるようなモデルにしたい。
トピックごとの単語の出現しやすさは説明変数によらず一定とする。
モデル
観測値を行列の積 で近似することを目指します。
Y: 観測された分解したい行列(N行K列)
X: 観測された説明変数(N行J列)
V: パラメータ(J行L列)
H: パラメータ(L行K列)
X の中身は0、1のダミー変数でユーザー層を表したものです。
また、 とします。
L はトピック数です。分析者が決めます。
行列W、Y、X、V、Hの要素は小文字のw、y、x、v、xに下付きの添字を付けて表します。
観測モデルは以下です。
事前分布として以下を仮定します。
事前分布を入れないと0の多い行列は途中計算でNaNになっちゃうことが多いです。
以下の潜在変数を考えます。
変分ベイズ
『ベイズ推論による機械学習入門』に出てくる非負値行列因子分解とほぼ同じ計算なので詳しくはそちらを参照してください。
機械学習スタートアップシリーズ ベイズ推論による機械学習入門 (KS情報科学専門書)
- 作者: 須山敦志,杉山将
- 出版社/メーカー: 講談社
- 発売日: 2017/10/21
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (1件) を見る
本を買うのがいやな人は論文をみてください。
Bayesian Inference for Nonnegative Matrix Factorisation Models
変分ベイズの更新式を導出するのはじめてなので、間違ってたらすみません。
事後分布が以下のように分解できるとして近似します。
の期待値を取る操作を で表します。
変分ベイズの更新式は以下のようになります。
確率密度関数も確率分布とおなじ記号 Gamma で書きましたが、文脈で区別できると思います。
近似分布の更新に必要なガンマ分布の対数の期待値はディガンマ関数 を用いて以下のように表わせます。
ならば
R による実装
s を三次元配列として保存しないで、V や H の更新式に代入してしまうのがこつです。
NMFVB <-function(Y,X,L=2,alpha=1,beta=1,tol=1,maxit=5000,seed=1234){ set.seed(seed) N <- nrow(Y) K <- ncol(Y) J <- ncol(X) EV <- EelV <- matrix(rgamma(J*L,shape=alpha,scale=beta),J,L) EH <- EelH <- matrix(rgamma(L*K,shape=alpha,scale=beta),L,K) for(iter in 2:maxit){ den <- ((X %*% EelV)) %*% EelH Sh <- EelH * (t((X %*% EelV)) %*% (Y/den)) Sv <- EelV * t(X) %*% (((Y/den)%*%t(EelH))) beta_H <- colSums(X%*%EV) + beta alpha_H <- alpha + Sh EH <- alpha_H/beta_H beta_V <- outer(colSums(X),rowSums(EH)) + beta alpha_V <- alpha + Sv EV <- alpha_V/beta_V EelH <-exp(digamma(alpha_H))/beta_H EelV <-exp(digamma(alpha_V))/beta_V } return(list(V=EV,H=EH, alpha_V=alpha_V,beta_V=beta_V, alpha_H=alpha_H,beta_H=beta_H)) }
結果
勘によりトピック数は L = 3 としました。
勘できめるのがいやな人は論文をみてください。
Bayesian Inference for Nonnegative Matrix Factorisation Models
ELBOを比べることでモデルを選択する方法が記載されています。
もとめた (左)ともとの行列 Y(右)を並べてみます。
まあまあ雰囲気を再現できているんじゃないでしょうか。
行列 V の値をみてみます。V1、V2、V3はそれぞれ V の1列目、2列目、3列目で潜在的なトピックを表します。
18-24歳はV1が支配的です。歳を取るにつれてV2が多くなってきます。また女性は男性よりV2が多く他が少ないです。
V3は18-24歳にはほとんどなく、25歳以上になると増えますが歳を取るにつれて減っていきます。
V1、V2、V3に解釈を与えるために H の中身を見ます。H はトピックごとの単語の出現しやすさです。
一応ガンマ分布の95%区間をエラーバーで重ねていますが、幅が狭くてほとんど見えません。
V1はGoogleが支配的ですがBingも多めです。調べものをしていてこのブログにたどり着く層でしょうか。若年層は具体的に知りたいことがあってこのブログに来る人が多いようです。
V2はGoogleが少なくYahooが多いのが特徴です。歳を取るにつれてYahooユーザーが多くなるみたいです。また、女性のほうがYahooユーザーが多そうです。
V3のV1、V2との違いはt.co(ツイッター)成分がそこそこある点です。SNS経由で来る人は18-24歳にはほとんどなく、25歳以上になると増え歳を取るにつれて減っていくみたいです。
最後にRのコードをまとめて貼ります。
library(googleAnalyticsR) library(tidyverse) NMFVB <-function(Y,X,L=2,alpha=1,beta=1,tol=1,maxit=5000,seed=1234){ set.seed(seed) N <- nrow(Y) K <- ncol(Y) J <- ncol(X) EV <- EelV <- matrix(rgamma(J*L,shape=alpha,scale=beta),J,L) EH <- EelH <- matrix(rgamma(L*K,shape=alpha,scale=beta),L,K) for(iter in 2:maxit){ den <- ((X %*% EelV)) %*% EelH Sh <- EelH * (t((X %*% EelV)) %*% (Y/den)) Sv <- EelV * t(X) %*% (((Y/den)%*%t(EelH))) beta_H <- colSums(X%*%EV) + beta alpha_H <- alpha + Sh EH <- alpha_H/beta_H beta_V <- outer(colSums(X),rowSums(EH)) + beta alpha_V <- alpha + Sv EV <- alpha_V/beta_V EelH <-exp(digamma(alpha_H))/beta_H EelV <-exp(digamma(alpha_V))/beta_V } return(list(V=EV,H=EH, alpha_V=alpha_V,beta_V=beta_V, alpha_H=alpha_H,beta_H=beta_H)) } ga_auth() account_list <- ga_account_list() ga_id <- account_list$viewId[3] gadata <- google_analytics(ga_id, date_range = c("2018-01-01","2018-06-30"), metrics = c("sessions"), dimensions = c("source","userGender","userAgeBracket")) gadata_w <-spread(gadata,source,sessions,fill=0) gamat <-as.matrix(gadata_w[,-c(1:2)]) X1 <-model.matrix(~userGender-1,data=gadata_w) X2 <-model.matrix(~userAgeBracket-1,data=gadata_w) X <- gaX <-cbind(X1,X2) out <-NMFVB(Y=gamat,X=gaX,L=3) obsdf <-as.data.frame(gamat) %>% set_names(1:ncol(gamat)) %>% mutate(row=row_number()) %>% gather(col,sessions,-row) %>% mutate(col=as.integer(col),type="obs") fitdf <- as.data.frame(X %*% out$V %*% out$H) %>% set_names(1:ncol(gamat)) %>% mutate(row=row_number()) %>% gather(col,sessions,-row) %>% mutate(col=as.integer(col),type="fit") outdf <-bind_rows(obsdf,fitdf) ggplot(outdf,aes(x=col,y=row,fill=sessions))+ geom_tile()+ facet_wrap(~type) dfV <-as.data.frame(out$V) %>% rownames_to_column() %>% gather(key,value,-rowname) ggplot(dfV,aes(x=rowname,y=value,fill=key))+ geom_col(colour="black",position = "fill")+ coord_flip() CIHlower <- as.data.frame(qgamma(0.025,shape=out$alpha_H,rate=out$beta_H)) %>% mutate(l=row_number()) %>% gather(source,lower,-l) CIHupper <- as.data.frame(qgamma(0.975,shape=out$alpha_H,rate=out$beta_H)) %>% mutate(l=row_number()) %>% gather(source,upper,-l) Hdf <- as.data.frame(out$H) %>% mutate(l=row_number()) %>% gather(source,value,-l) %>% left_join(CIHlower) %>% left_join(CIHupper) %>% group_by(l) ggplot(Hdf,aes(x=source,y=value,ymin=lower,ymax=upper))+ geom_col(fill="white",colour="black")+ geom_errorbar(width=0.5)+ facet_wrap(~l,scales="free_x")+ coord_flip()