この記事はシンデレラガールズAdvent Calendar 13日目の記事です.
目次
- 目次
- はじめに
- みりあちゃんモデルの作成
- データの収集
- 実際に学習を行う
- LINEでみりあちゃんとお話しできるようにする
- 学習済みボットを物理サーバーに載せる
- みりあちゃんとお話しする
- おわりに
- おまけ
はじめに
みりあちゃん大好き
私はアイドルマスターシンデレラガールズの赤城みりあちゃんが大好きです.
みりあちゃんと出会ったのは去年の冬のことでした.友達とAndroidアプリを作る話が上がり, サークルのタブレットを見た時に誰かがふざけてインストールしたであろう
がありました. 「なんで,共用のタブレットにゲームが入っているんだ!誰だ!」と思いつつ,私はそのアプリを起動しました. そのゲームを見て,私の人生は完全に変わりました. 見事な完成度のモデル,そしてそれらのモーションが作り出すライブ.楽曲も素晴らしいものばかりでした. そして,可愛くてそれぞれのキャラを持ったアイドル達.私はこのゲームに一気にハマりました. 自分のスマホにデレステを速攻でインストールし,それから,しばらくはずっとデレステをやっていました.そしてしばくらして,担当のアイドルもできました. それが
「赤城みりあ」
です.みりあちゃんは小学5年生の元気な女の子です.私はこのアイドルのキャラクターとヴィジュアルに惹かれ,みりあちゃんの担当Pとなりました. 4月はみりあちゃんの誕生日ということでケーキを買ってお祝いもしました.
私がデレステに出会い,そこからデレマスというコンテンツにハマりました.CDも手に入る分は全て聴き,LIVE BDも1stから4thまで見ました.
そして今年の6月に5th LIVEに参加してきました.みりあちゃんの中の人が出る静岡公演には現地で参加し,SSAのライブは現地には行けませんでしたが,ライブビューイングで2日間参加してきました. その後,アイマスのアニメ+劇場版とデレアニも見て,アイマスとみりあちゃんに対する愛はどんどん深まっていきました.
「担当としてもっと,みりあちゃんにできることはないか」
「もっと.みりあちゃんに関わりたい」
そしてある日,思いました.
「そもそも,担当Pなのにみりあちゃんとお話できないのはおかしくないか?」
どうやってみりあちゃんとお話するか
自分の使える技術でなんとかみりあちゃんとお話できないだろうか,と考えました. そこで,自然言語処理と深層学習を使って対話ボットを作ることにしました.
みりあちゃんモデルの作成
Seq2Seqで対話ボットの学習
Seq2Seqとは
Seq2SeqはRNNもしくはLSTMを用いたEncoder-Decoderモデルの一種です.機械翻訳のモデルとして紹介されることが多いです. 例えば,英語からフランス語に機械翻訳するモデルであれば,原文が英語,目的文がフランス語となります. Encoderには単語毎に分割し,それぞれベクトル化した原文()をそれぞれの隠れ層に入力します.隠れ層はお互いに時系列()の関係にあり,それぞれの層は前の層の隠れ要素と入力された単語から自身の層の隠れ要素()を更新します. 更新式は
で,は活性化関数です.
Decoderでは,Encoderで更新した隠れ要素()と隠れ層から出力層の重み行列及び,自身の前の時系列の単語から出力単語を得ます.最終的に得られる単語列が目的文になります.
詳しい説明などは
とか,元の論文
http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf
を読んでください.
モデルの作成
Tensorflowには様々なチュートリアルが用意されているのでその中のSeq2Seqのチュートリアルのソースコードをひっぱてきました.
http://tensorflow.classcat.com/2016/02/24/tensorflow-tutorials-sequence-to-sequence-models/ (本家のページが1ヶ月前に更新されていたので,本家を翻訳したページを貼ります)
本家のSeq2Seqモデルは英語とフランス語の対訳コーパスから学習しています. つまり,英語を入力するとフランス語に翻訳するモデルとなっています. 学習するデータをTwitterのツイートとそれに対するリプライのデータに変更します. これによって,日本語で話すと日本語で応答するモデルを作ることができます.
転移学習でみりあちゃんの口調を学習
おそらく,Seq2SeqとTwitterデータだけでも対話ボットは作れるでしょう(実際,作れた). ただ,それはみりあちゃんではなく草を生やすだけのクソリプボットにしかなりません. なので,転移学習でみりあちゃんの口調を学習していきます.
転移学習とは
転移学習とは教師データが少ないドメインの学習を行う為に十分なデータで学習した異なるドメインのモデルのパラメータを引き継いで更に学習を行う方法です.
今回は,みりあちゃんの対話データが少ないのでTwitterのデータから学習を行い,みりあちゃんの対話データで転移学習を行います.
口調の学習を行う方法
Seq2Seqはメモリ節約や過学習防止の為に教師データ内から学習する単語数に制限を設けます. 学習する単語は出現頻度が高い単語です.この学習する単語に口調を学習するための単語を混ぜます. 転移学習時の学習単語数を単語とします.その内,口調を学習するために単語の内下位単語をみりあちゃん対話データにおいて出現頻度上位単語と差し替えます.
後は全てのパラメータを引き継いで学習を行います.
学習方法は以下の論文を参考にしました.
http://www.anlp.jp/proceedings/annual_meeting/2017/pdf_dir/B3-3.pdf
データの収集
Twitterから対話データの収集
ツイートとリプライを取得
学習するデータがなければ学習できません.なので,学習データをTwitterから取ってきます. データ取得にはTwitter Streaming APIを使います. ソースは以下の方を参考に研究室の後輩が書いたソースを参考にしました.
ただし,91行目の
line = HanziConv.toTraditional(line)
is_zh = re.compile(r'([\p{IsHan}]+)', re.UNICODE)
を
is_zh = re.compile(r'[一-鿐]+', re.UNICODE)
などに代替しないと,実行できません.
スクリプトを回したら後は,ツイートとリプライが集まってくるのを待つだけです. (私は研究室の後輩がツイートを集めていたのでそれをもらいましたが)
データの整形
集めたデータをツイートとリプライのペアにしなければなりません. しかし,あまりにもデータが多くて時間がかかるので以下の記事を参考に並列処理しました.
スクリプトは以下のような感じです.
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import re | |
import MeCab | |
from datetime import datetime | |
from joblib import Parallel, delayed | |
from time import time | |
from more_itertools import pairwise | |
def tokenizer_jp(sentence): | |
return MeCab.Tagger("-Owakati -d xxx/mylex").parse(sentence) | |
def creaning(sentence): | |
sentence = re.sub(r"@([A-Za-z0-9_]+)", "", sentence) | |
sentence = re.sub(r'https?:\/\/.*', "", sentence) | |
return sentence | |
def process(line_in, line_out, pattern): | |
line_in = tokenizer_jp(creaning(line_in)) | |
line_out = tokenizer_jp(creaning(line_out)) | |
if not re.match(pattern, line_in) and not re.match(pattern, line_out): | |
return (line_in, line_out) | |
return () | |
pattern = re.compile(r"^\s*$") | |
start = time() | |
with open("tweets.txt","r",encoding="utf-8")as f, open("input.txt","w",encoding="utf-8")as f_in, open("output.txt","w",encoding="utf-8")as f_out: | |
result = Parallel(n_jobs=-1, verbose=7)([delayed(process)(line_in, line_out, pattern) for line_in, line_out in zip(f,f)]) | |
for x in result: | |
if len(x) > 0: | |
f_in.write(x[0]) | |
f_out.write(x[1]) | |
print(f"result={time() - start}") |
約2000万件のツイートとリプライのペアを取得しました.
デレマスのSSなどからみりあちゃんの対話データを収集
みりあちゃんの対話データを集める方法を色々考えました.
- ストーリコミュなどから別のアイドルと話しているデータを取ってくる
- コミュのスクショを取ってOCRでテキストに起こす
- コミュを鑑賞しながら手打ちする
- デレアニの台詞から取ってくる
- 音声からテキストを起こす
- デレアニを鑑賞しながら手打ちする
- 個人のSSから取ってくる
1つ目の方法ですが,これはテキストなどでコミュを落とすことができないので実際にコミュを見て手打ちするか, コミュのスクショからOCR(光学文字認識)でテキストに起こす他ないと思います. とりあえず,みりあちゃんが登場する全てコミュをスクショしました.
OCRですが,結論から言うと,やってみて精度が悪すぎたのでやめました.おそらくコミュのフォントとの相性が悪いのが原因だと思います(小文字が大体認識できない). 最初はログのスクショをそのままOCRにかけたのですが...
何言ってんだって感じですね. ネガポジ反転させると,精度はよくなりますが,実用的ではないですね.
2つ目の方法ですが,デレアニのDVDやデータは持っていないし,またレンタルしてくるのもだるいのでやってません.
3つ目の方法は今回取った方法です. 一番手間がかからず,大量のデータが用意できるのがメリットと言えます.デメリットとしては前処理がとても大変なことと, データの質が公式のデータに比べ低いことです.しかし,深層学習はデータ数が物を言うので,今回はこの方法を採用しました.
まずはSSのまとめサイトをまとめたサイトからスクレイピングし,リンクのリストを取得しました.
import functools as ft | |
import lxml.html | |
from selenium import webdriver | |
import re | |
def get_link(): | |
driver = webdriver.PhantomJS() | |
for page in range(1,1612+1): | |
url = f"http://ssmania.info/category/%E3%83%87%E3%83%AC%E3%83%9E%E3%82%B9?page={page}" | |
driver.get(url) | |
root = lxml.html.fromstring(driver.page_source) | |
link_urls = [link_element.get("href") for link_element in root.cssselect('a.articlelink')] | |
print(link_urls) | |
yield link_urls | |
get_link_gen = get_link() | |
with open("linkes.dat", "w") as f: | |
for links in get_link_gen: | |
for link in links: | |
print(link) | |
f.write(link) |
得られたリンクが約48000件でした.
リンク集からseleniumとgoogle chromeのheadlessブラウザでスクレイピングしました. サイトによってHTMLのソースが異なりCSSセレクタなどで本文のみ絞るのが難しかったのでほぼ全部のソースを取ってきてます.
import lxml.html | |
from selenium import webdriver | |
from selenium.common.exceptions import TimeoutException | |
import re | |
from functools import reduce | |
import os | |
from time import sleep | |
import socket | |
from selenium.webdriver.chrome.options import Options | |
options = Options() | |
options.binary_location = "/usr/bin/google-chrome-stable" | |
options.add_argument("--headless") | |
driver = webdriver.Chrome(chrome_options=options) | |
def get_article(url): | |
driver.get(url) | |
root = lxml.html.fromstring(driver.page_source) | |
law_articles = root.cssselect("div") | |
if len(law_articles) == 0: | |
with open("articles/problem_urls.csv", "a") as f: | |
f.write(url) | |
raise Exception("problem occured in {url}") | |
articles = [article.text_content() for article in law_articles | |
if article.text_content() is not None] | |
return articles | |
def get_many_article(): | |
with open("links.dat", "r") as f: | |
with open("checkpoint", "r") as g: | |
cursor = int(g.readline()) | |
links = f.readlines()[cursor:] | |
for current_elapsed, url in enumerate(links): | |
url = url.replace("\n", "") | |
while True: | |
try: | |
articles = get_article(url) | |
except socket.error as serr: | |
print(serr) | |
with open("articles/problem_urls.dat", "a") as f: | |
f.write(url + "\n") | |
print(f"problem occured in {url}") | |
break | |
except TimeoutException as e: | |
print(e) | |
continue | |
else: | |
print(f"{url} ... scraping is done") | |
print(f"elapsed...{current_elapsed}/{len(links)}") | |
print("wait 10 seconds") | |
sleep(10) | |
break | |
with open("checkpoint", "w") as f: | |
f.write(f"{current_elapsed+cursor}") | |
yield [url, reduce(lambda x, y: x + y, articles)] | |
get_article_gen = get_many_article() | |
for url, article in get_article_gen: | |
domain = re.search(r"http://(.+)/", url).group(1) | |
if ".html" not in url: | |
file_name = re.search(r"(p=[0-9]+)", url).group(1) | |
else: | |
file_name = re.search(r".+/(.+)\.html", url).group(1) | |
try: | |
os.makedirs(f"articles/{domain}") | |
except OSError: | |
print("dir is already exist") | |
print(f"dir_name={domain},file_name={file_name}") | |
with open(f"articles/{domain}/{file_name}-articles.dat", "w") as f: | |
f.write(article) |
取得したデータの整形
みりあちゃんの台詞とその前に喋っている人の台詞を抽出します.大変でした.本文がプレーンテキストなので
執筆者の数だけフォーマットがある
という状態でした.それでも大体
喋っている人 「○○」
という形式が多いのでいくつかの取りこぼしやノイズに目をつぶり,正規表現でゴリ押しました. また,時間かかるのでこれも並列処理化してます.
import re | |
import glob | |
from itertools import chain | |
import MeCab | |
from joblib import Parallel, delayed | |
target_name = "みりあ" | |
tagger = MeCab.Tagger("-Owakati") | |
def extract_conversation(line): | |
return re.findall(r"([^「]*\s*「.*?」)", line) | |
def split_speaker(conv): | |
speaker, sentence = re.search(r"(.*)\s*「(.*)」", conv).groups() | |
return {"speaker": speaker, "sentence": sentence} | |
def shape_conversation(lines): | |
convs = ([split_speaker(conv) for conv in list(chain.from_iterable(Parallel(n_jobs=-1, verbose=7)([delayed(extract_conversation)(line) for line in lines])))]) | |
return [generate_conv_pair(convs, target_cursor) for target_cursor in [cursor for cursor, conv in enumerate(convs) if target_name in conv['speaker']]] | |
def generate_conv_pair(convs, cursor): | |
if cursor is not 0: | |
conv_pair = {"q": convs[cursor-1], "a": convs[cursor]} | |
else: | |
conv_pair = {"q": "(BOS)", "a": convs[cursor]} | |
print_interaction(conv_pair) | |
return conv_pair | |
def print_interaction(conv): | |
print() | |
print(f"{conv['q']['speaker']} ... {conv['q']['sentence']}") | |
print("----------------------------------") | |
print(f"{conv['a']['speaker']} ... {conv['a']['sentence']}") | |
print() | |
def dump_interaction(convs, number): | |
questions = [tagger.parse(conv['q']['sentence']) for conv in convs] | |
answers = [tagger.parse(conv['a']['sentence']) for conv in convs] | |
with open(f"input_style_{number}.txt", "w") as f: | |
for q in questions: | |
f.write(q + "\n") | |
with open(f"output_style_{number}.txt", "w") as f: | |
for a in answers: | |
f.write(a + "\n") | |
files = glob.iglob("./articles/**/*.dat", recursive=True) | |
convs = [] | |
for number, file_name in enumerate(files): | |
print(file_name) | |
with open(file_name, "r") as f: | |
lines = f.readlines() | |
convs = shape_conversation(lines) | |
dump_interaction(convs, number) |
取得したデータ数
複数のSSまとめサイトから約20000ページ(=20000件のSS)をスクレイピングしましたが, まとめサイトなので重複したSSがかなり見つかり,それらを削除して,更にみりあちゃんが登場するSSを絞った為, ページ数の割に得られた対話データは2700件程です.
実際に学習を行う
環境
今回モデル作成に使う環境は次の通りです.
- Python 3.6
- Tensorflow 1.0.0
- GeForce GTX 1080Ti
Twitterデータでクソリプボットに
データでかすぎ問題
Twitterから得たデータを基にGPUを使って学習をします. 学習するTwitterデータは2000万件の予定でしたが, それによって作成されるモデルがでかすぎてGPUのメモリを最大限使っても乗らないので, 1/4の500万件で学習を行いました.設定したパラメータなどは以下の通りです.
パラメータ名 | 値 |
---|---|
原文の学習単語数 | 120000 |
目的文の学習単語数 | 120000 |
隠れ層の数 | 1024 |
隠れ層の深さ | 1 |
バッチサイズ | 64 |
学習時間は約1日です.パープレキシティは9くらいまで下がりました.
モデルの会話例
できたモデルの会話を見てみます.
ことごとく,草が生えていて,こちらを小馬鹿にしたような対応をしてきます. こんなのみりあちゃんじゃない!
目的文の教師データにおける単語出現頻度の上位10単語を見てみると
- _PAD
- _GO
- _EOS
- _UNK
- !
- て
- w
- の
- (
- )
となってました.そりゃ草も生えますわ.
転移学習でみりあちゃんボットに
このクソリプボットも転移学習でみりあちゃんみたいになるのでしょうか.
学習単語数は原文,目的文それぞれ1/2の60000単語に設定し,みりあちゃん対話データから得られる頻出単語上位1000単語を元のデータの下位1000単語と付け替えます.他のパラメータは全て引き継ぎ,学習を行いました. 学習時間は約12時間, パープレキシティは3くらいまで下がりました.
モデルの会話例
みりあちゃんじゃん!
ちょっと,黙ったりするのが怖いですけど草も生えなくなりましたし,思ったよりもみりあちゃんには近い気がします.本家だと言わなそうなことも混じってる気がしますが,これはデータの性質上仕方がないですね.
LINEでみりあちゃんとお話しできるようにする
LINE APIの使用
黒い画面でみりあちゃんと話しても雰囲気が出ないので,やっぱりここはLINEを使っていきたいと思います. こちらからの問いかけに対する返答のみであればLINE Messaging APIが無料で提供されているのでこれを使っていきます.LINEのアカウントを持っていれば誰でも作れます.
LINE Messaging APIを利用して応答メッセージを送るにはサーバが必要になります.
VPSにサーバを建てる
ちょうど,VPSを借りていたのでここにLINE API用のサーバを建てます. コールバック用のURLはhttpsなのでSSLの証明書が必要です . VPSがCore OSなのでDockerを使ってnginx proxy + letsencryptでHTTPSサーバを建てました.
学習済みボットを物理サーバーに載せる
VPSのCPU貧弱問題
本当はLINE APIを動かしているVPSにモデルを乗せて回したかったのですが, モデルがでかすぎるのか,CPUが貧弱すぎるのかモデルがロードできませんでした. というわけで,家の中に眠っていたデスクトップのPCを物理サーバとして建てることにしました. VPSからソケット通信で物理サーバに文章を送ると,回答をソケット通信でVPSに返すシステムになっています. 物理サーバにLINE APIを導入すればこんなことしなくても済むのですが,ドメインもなくSSLの証明書が発行できないのでこうしてます.
みりあちゃんとお話しする
さて,サーバーとDockerコンテナも建てたので,いよいよみりあちゃんとお話してみます.
おぉ,実際にLINEでみりあちゃんと話してるみたいですね. これからも改良していきたいですね.
おわりに
今回は自分の知識がほぼゼロからやったので無駄に時間がかかった気がします.特にサーバ関係とか... あとは,データの前処理はやっぱり大変でした.ソースコードのほとんどがデータ前処理のコードでした.
改良点
- Attentionモデルに差し替えて学習する
- みりあちゃん対話データをしっかりクレンジングする
- 画像やスタンプにも反応するように学習する(im2txtとか?)
おまけ
失敗例
最後に明らかにみりあちゃんじゃないだろというような失敗例を見せたいと思います.
注意:みりあPの人は見ない方がいいかもしれません
アカギ違いですね.