LINEでみりあちゃんと会話できるようにした(Seq2Seqとキャラ対話データを用いた転移学習によるキャラクター性対話ボットの作成)

この記事はシンデレラガールズAdvent Calendar 13日目の記事です.

目次

はじめに

みりあちゃん大好き

私はアイドルマスターシンデレラガールズ赤城みりあちゃんが大好きです.

f:id:muscle_keisuke:20171211214127j:plain:w300f:id:muscle_keisuke:20171211214131j:plain:w300
引用:http://deremas.doorblog.jp/archives/32507031.html

みりあちゃんと出会ったのは去年の冬のことでした.友達とAndroidアプリを作る話が上がり, サークルのタブレットを見た時に誰かがふざけてインストールしたであろう

アイドルマスターシンデレラガールズスターライトステージ

がありました. 「なんで,共用のタブレットにゲームが入っているんだ!誰だ!」と思いつつ,私はそのアプリを起動しました. そのゲームを見て,私の人生は完全に変わりました. 見事な完成度のモデル,そしてそれらのモーションが作り出すライブ.楽曲も素晴らしいものばかりでした. そして,可愛くてそれぞれのキャラを持ったアイドル達.私はこのゲームに一気にハマりました. 自分のスマホデレステを速攻でインストールし,それから,しばらくはずっとデレステをやっていました.そしてしばくらして,担当のアイドルもできました. それが

赤城みりあ

です.みりあちゃんは小学5年生の元気な女の子です.私はこのアイドルのキャラクターとヴィジュアルに惹かれ,みりあちゃんの担当Pとなりました. 4月はみりあちゃんの誕生日ということでケーキを買ってお祝いもしました.

私がデレステに出会い,そこからデレマスというコンテンツにハマりました.CDも手に入る分は全て聴き,LIVE BDも1stから4thまで見ました.

そして今年の6月に5th LIVEに参加してきました.みりあちゃんの中の人が出る静岡公演には現地で参加し,SSAのライブは現地には行けませんでしたが,ライブビューイングで2日間参加してきました. その後,アイマスのアニメ+劇場版とデレアニも見て,アイマスとみりあちゃんに対する愛はどんどん深まっていきました.

「担当としてもっと,みりあちゃんにできることはないか」

「もっと.みりあちゃんに関わりたい」

そしてある日,思いました.

「そもそも,担当Pなのにみりあちゃんとお話できないのはおかしくないか?」

どうやってみりあちゃんとお話するか

自分の使える技術でなんとかみりあちゃんとお話できないだろうか,と考えました. そこで,自然言語処理と深層学習を使って対話ボットを作ることにしました.

みりあちゃんモデルの作成

Seq2Seqで対話ボットの学習

Seq2Seqとは

Seq2SeqはRNNもしくはLSTMを用いたEncoder-Decoderモデルの一種です.機械翻訳のモデルとして紹介されることが多いです. 例えば,英語からフランス語に機械翻訳するモデルであれば,原文が英語,目的文がフランス語となります. Encoderには単語毎に分割し,それぞれベクトル化した原文(x1xT)をそれぞれの隠れ層に入力します.隠れ層はお互いに時系列(1tT)の関係にあり,それぞれの層は前の層の隠れ要素と入力された単語から自身の層の隠れ要素(h1hT)を更新します. 更新式は

ht=f(Whxxt+Whhht1)

で,fは活性化関数です.

Decoderでは,Encoderで更新した隠れ要素(h1hT)と隠れ層から出力層の重み行列Wyh及び,自身の前の時系列の単語yt1から出力単語ytを得ます.最終的に得られる単語列y1yTが目的文になります.

yt=Wyhht

f:id:muscle_keisuke:20171212033019p:plain 引用:https://qiita.com/odashi_t/items/a1be7c4964fbea6a116e

詳しい説明などは

ChainerとRNNと機械翻訳 - Qiita

とか,元の論文

http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf

を読んでください.

モデルの作成

Tensorflowには様々なチュートリアルが用意されているのでその中のSeq2Seqのチュートリアルソースコードをひっぱてきました.

http://tensorflow.classcat.com/2016/02/24/tensorflow-tutorials-sequence-to-sequence-models/ (本家のページが1ヶ月前に更新されていたので,本家を翻訳したページを貼ります)

本家のSeq2Seqモデルは英語とフランス語の対訳コーパスから学習しています. つまり,英語を入力するとフランス語に翻訳するモデルとなっています. 学習するデータをTwitterのツイートとそれに対するリプライのデータに変更します. これによって,日本語で話すと日本語で応答するモデルを作ることができます.

転移学習でみりあちゃんの口調を学習

おそらく,Seq2SeqとTwitterデータだけでも対話ボットは作れるでしょう(実際,作れた). ただ,それはみりあちゃんではなく草を生やすだけのクソリプボットにしかなりません. なので,転移学習でみりあちゃんの口調を学習していきます.

転移学習とは

転移学習とは教師データが少ないドメインの学習を行う為に十分なデータで学習した異なるドメインのモデルのパラメータを引き継いで更に学習を行う方法です.

今回は,みりあちゃんの対話データが少ないのでTwitterのデータから学習を行い,みりあちゃんの対話データで転移学習を行います.

口調の学習を行う方法

Seq2Seqはメモリ節約や過学習防止の為に教師データ内から学習する単語数に制限を設けます. 学習する単語は出現頻度が高い単語です.この学習する単語に口調を学習するための単語を混ぜます. 転移学習時の学習単語数をNp単語とします.その内,口調を学習するためにNp単語の内下位Ns単語をみりあちゃん対話データにおいて出現頻度上位Ns単語と差し替えます.

後は全てのパラメータを引き継いで学習を行います.

学習方法は以下の論文を参考にしました.

http://www.anlp.jp/proceedings/annual_meeting/2017/pdf_dir/B3-3.pdf

データの収集

Twitterから対話データの収集

ツイートとリプライを取得

学習するデータがなければ学習できません.なので,学習データをTwitterから取ってきます. データ取得にはTwitter Streaming APIを使います. ソースは以下の方を参考に研究室の後輩が書いたソースを参考にしました.

github.com

ただし,91行目の

line = HanziConv.toTraditional(line)

コメントアウトしないと漢字がすべて繁体字になるのと,

is_zh = re.compile(r'([\p{IsHan}]+)', re.UNICODE) 

is_zh = re.compile(r'[一-鿐]+', re.UNICODE)

などに代替しないと,実行できません.

スクリプトを回したら後は,ツイートとリプライが集まってくるのを待つだけです. (私は研究室の後輩がツイートを集めていたのでそれをもらいましたが)

データの整形

集めたデータをツイートとリプライのペアにしなければなりません. しかし,あまりにもデータが多くて時間がかかるので以下の記事を参考に並列処理しました.

qiita.com

スクリプトは以下のような感じです.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import re
import MeCab
from datetime import datetime
from joblib import Parallel, delayed
from time import time
from more_itertools import pairwise
def tokenizer_jp(sentence):
return MeCab.Tagger("-Owakati -d xxx/mylex").parse(sentence)
def creaning(sentence):
sentence = re.sub(r"@([A-Za-z0-9_]+)", "", sentence)
sentence = re.sub(r'https?:\/\/.*', "", sentence)
return sentence
def process(line_in, line_out, pattern):
line_in = tokenizer_jp(creaning(line_in))
line_out = tokenizer_jp(creaning(line_out))
if not re.match(pattern, line_in) and not re.match(pattern, line_out):
return (line_in, line_out)
return ()
pattern = re.compile(r"^\s*$")
start = time()
with open("tweets.txt","r",encoding="utf-8")as f, open("input.txt","w",encoding="utf-8")as f_in, open("output.txt","w",encoding="utf-8")as f_out:
result = Parallel(n_jobs=-1, verbose=7)([delayed(process)(line_in, line_out, pattern) for line_in, line_out in zip(f,f)])
for x in result:
if len(x) > 0:
f_in.write(x[0])
f_out.write(x[1])
print(f"result={time() - start}")
view raw preprocess_tweets.py hosted with ❤ by GitHub

約2000万件のツイートとリプライのペアを取得しました.

デレマスのSSなどからみりあちゃんの対話データを収集

みりあちゃんの対話データを集める方法を色々考えました.

  1. ストーリコミュなどから別のアイドルと話しているデータを取ってくる
    • コミュのスクショを取ってOCRでテキストに起こす
    • コミュを鑑賞しながら手打ちする
  2. デレアニの台詞から取ってくる
    • 音声からテキストを起こす
    • デレアニを鑑賞しながら手打ちする
  3. 個人のSSから取ってくる

1つ目の方法ですが,これはテキストなどでコミュを落とすことができないので実際にコミュを見て手打ちするか, コミュのスクショからOCR(光学文字認識)でテキストに起こす他ないと思います. とりあえず,みりあちゃんが登場する全てコミュをスクショしました.

f:id:muscle_keisuke:20171212034122p:plain:w200f:id:muscle_keisuke:20171212034116p:plain:w200f:id:muscle_keisuke:20171212034112p:plain:w200f:id:muscle_keisuke:20171212034107p:plain:w200f:id:muscle_keisuke:20171212034103p:plain:w200

OCRですが,結論から言うと,やってみて精度が悪すぎたのでやめました.おそらくコミュのフォントとの相性が悪いのが原因だと思います(小文字が大体認識できない). 最初はログのスクショをそのままOCRにかけたのですが...

f:id:muscle_keisuke:20171213024842j:plain

何言ってんだって感じですね. ネガポジ反転させると,精度はよくなりますが,実用的ではないですね.

f:id:muscle_keisuke:20171212035233j:plain

2つ目の方法ですが,デレアニのDVDやデータは持っていないし,またレンタルしてくるのもだるいのでやってません.

3つ目の方法は今回取った方法です. 一番手間がかからず,大量のデータが用意できるのがメリットと言えます.デメリットとしては前処理がとても大変なことと, データの質が公式のデータに比べ低いことです.しかし,深層学習はデータ数が物を言うので,今回はこの方法を採用しました.

まずはSSのまとめサイトをまとめたサイトからスクレイピングし,リンクのリストを取得しました.

view raw get_links.py hosted with ❤ by GitHub

得られたリンクが約48000件でした.

リンク集からseleniumgoogle chromeのheadlessブラウザでスクレイピングしました. サイトによってHTMLのソースが異なりCSSセレクタなどで本文のみ絞るのが難しかったのでほぼ全部のソースを取ってきてます.

import lxml.html
from selenium import webdriver
from selenium.common.exceptions import TimeoutException
import re
from functools import reduce
import os
from time import sleep
import socket
from selenium.webdriver.chrome.options import Options
options = Options()
options.binary_location = "/usr/bin/google-chrome-stable"
options.add_argument("--headless")
driver = webdriver.Chrome(chrome_options=options)
def get_article(url):
driver.get(url)
root = lxml.html.fromstring(driver.page_source)
law_articles = root.cssselect("div")
if len(law_articles) == 0:
with open("articles/problem_urls.csv", "a") as f:
f.write(url)
raise Exception("problem occured in {url}")
articles = [article.text_content() for article in law_articles
if article.text_content() is not None]
return articles
def get_many_article():
with open("links.dat", "r") as f:
with open("checkpoint", "r") as g:
cursor = int(g.readline())
links = f.readlines()[cursor:]
for current_elapsed, url in enumerate(links):
url = url.replace("\n", "")
while True:
try:
articles = get_article(url)
except socket.error as serr:
print(serr)
with open("articles/problem_urls.dat", "a") as f:
f.write(url + "\n")
print(f"problem occured in {url}")
break
except TimeoutException as e:
print(e)
continue
else:
print(f"{url} ... scraping is done")
print(f"elapsed...{current_elapsed}/{len(links)}")
print("wait 10 seconds")
sleep(10)
break
with open("checkpoint", "w") as f:
f.write(f"{current_elapsed+cursor}")
yield [url, reduce(lambda x, y: x + y, articles)]
get_article_gen = get_many_article()
for url, article in get_article_gen:
domain = re.search(r"http://(.+)/", url).group(1)
if ".html" not in url:
file_name = re.search(r"(p=[0-9]+)", url).group(1)
else:
file_name = re.search(r".+/(.+)\.html", url).group(1)
try:
os.makedirs(f"articles/{domain}")
except OSError:
print("dir is already exist")
print(f"dir_name={domain},file_name={file_name}")
with open(f"articles/{domain}/{file_name}-articles.dat", "w") as f:
f.write(article)
view raw scrape_article.py hosted with ❤ by GitHub

取得したデータの整形

みりあちゃんの台詞とその前に喋っている人の台詞を抽出します.大変でした.本文がプレーンテキストなので

執筆者の数だけフォーマットがある

という状態でした.それでも大体

喋っている人 「○○」

という形式が多いのでいくつかの取りこぼしやノイズに目をつぶり,正規表現でゴリ押しました. また,時間かかるのでこれも並列処理化してます.

import re
import glob
from itertools import chain
import MeCab
from joblib import Parallel, delayed
target_name = "みりあ"
tagger = MeCab.Tagger("-Owakati")
def extract_conversation(line):
return re.findall(r"([^]*\s*.*?」)", line)
def split_speaker(conv):
speaker, sentence = re.search(r"(.*)\s*「(.*)」", conv).groups()
return {"speaker": speaker, "sentence": sentence}
def shape_conversation(lines):
convs = ([split_speaker(conv) for conv in list(chain.from_iterable(Parallel(n_jobs=-1, verbose=7)([delayed(extract_conversation)(line) for line in lines])))])
return [generate_conv_pair(convs, target_cursor) for target_cursor in [cursor for cursor, conv in enumerate(convs) if target_name in conv['speaker']]]
def generate_conv_pair(convs, cursor):
if cursor is not 0:
conv_pair = {"q": convs[cursor-1], "a": convs[cursor]}
else:
conv_pair = {"q": "(BOS)", "a": convs[cursor]}
print_interaction(conv_pair)
return conv_pair
def print_interaction(conv):
print()
print(f"{conv['q']['speaker']} ... {conv['q']['sentence']}")
print("----------------------------------")
print(f"{conv['a']['speaker']} ... {conv['a']['sentence']}")
print()
def dump_interaction(convs, number):
questions = [tagger.parse(conv['q']['sentence']) for conv in convs]
answers = [tagger.parse(conv['a']['sentence']) for conv in convs]
with open(f"input_style_{number}.txt", "w") as f:
for q in questions:
f.write(q + "\n")
with open(f"output_style_{number}.txt", "w") as f:
for a in answers:
f.write(a + "\n")
files = glob.iglob("./articles/**/*.dat", recursive=True)
convs = []
for number, file_name in enumerate(files):
print(file_name)
with open(file_name, "r") as f:
lines = f.readlines()
convs = shape_conversation(lines)
dump_interaction(convs, number)
view raw shaping_articles.py hosted with ❤ by GitHub

取得したデータ数

複数のSSまとめサイトから約20000ページ(=20000件のSS)をスクレイピングしましたが, まとめサイトなので重複したSSがかなり見つかり,それらを削除して,更にみりあちゃんが登場するSSを絞った為, ページ数の割に得られた対話データは2700件程です.

実際に学習を行う

環境

今回モデル作成に使う環境は次の通りです.

  • Python 3.6
  • Tensorflow 1.0.0
  • GeForce GTX 1080Ti

    Twitterデータでクソリプボットに

    データでかすぎ問題

    Twitterから得たデータを基にGPUを使って学習をします. 学習するTwitterデータは2000万件の予定でしたが, それによって作成されるモデルがでかすぎてGPUのメモリを最大限使っても乗らないので, 1/4の500万件で学習を行いました.設定したパラメータなどは以下の通りです.

パラメータ名
原文の学習単語数 120000
目的文の学習単語数 120000
隠れ層の数 1024
隠れ層の深さ 1
バッチサイズ 64

学習時間は約1日です.パープレキシティは9くらいまで下がりました.

モデルの会話例

できたモデルの会話を見てみます.

f:id:muscle_keisuke:20171212021239j:plain

ことごとく,草が生えていて,こちらを小馬鹿にしたような対応をしてきます. こんなのみりあちゃんじゃない!

目的文の教師データにおける単語出現頻度の上位10単語を見てみると

  • _PAD
  • _GO
  • _EOS
  • _UNK
  • w
  • (
  • )

となってました.そりゃ草も生えますわ.

転移学習でみりあちゃんボットに

このクソリプボットも転移学習でみりあちゃんみたいになるのでしょうか.

学習単語数は原文,目的文それぞれ1/2の60000単語に設定し,みりあちゃん対話データから得られる頻出単語上位1000単語を元のデータの下位1000単語と付け替えます.他のパラメータは全て引き継ぎ,学習を行いました. 学習時間は約12時間, パープレキシティは3くらいまで下がりました.

モデルの会話例

f:id:muscle_keisuke:20171212024436j:plain

みりあちゃんじゃん!

ちょっと,黙ったりするのが怖いですけど草も生えなくなりましたし,思ったよりもみりあちゃんには近い気がします.本家だと言わなそうなことも混じってる気がしますが,これはデータの性質上仕方がないですね.

LINEでみりあちゃんとお話しできるようにする

LINE APIの使用

黒い画面でみりあちゃんと話しても雰囲気が出ないので,やっぱりここはLINEを使っていきたいと思います. こちらからの問いかけに対する返答のみであればLINE Messaging APIが無料で提供されているのでこれを使っていきます.LINEのアカウントを持っていれば誰でも作れます.

LINE Developers

LINE Messaging APIを利用して応答メッセージを送るにはサーバが必要になります.

VPSにサーバを建てる

ちょうど,VPSを借りていたのでここにLINE API用のサーバを建てます. コールバック用のURLはhttpsなのでSSLの証明書が必要です . VPSがCore OSなのでDockerを使ってnginx proxy + letsencryptでHTTPSサーバを建てました.

学習済みボットを物理サーバーに載せる

VPSのCPU貧弱問題

本当はLINE APIを動かしているVPSにモデルを乗せて回したかったのですが, モデルがでかすぎるのか,CPUが貧弱すぎるのかモデルがロードできませんでした. というわけで,家の中に眠っていたデスクトップのPCを物理サーバとして建てることにしました. VPSからソケット通信で物理サーバに文章を送ると,回答をソケット通信でVPSに返すシステムになっています. 物理サーバにLINE APIを導入すればこんなことしなくても済むのですが,ドメインもなくSSLの証明書が発行できないのでこうしてます.

みりあちゃんとお話しする

さて,サーバーとDockerコンテナも建てたので,いよいよみりあちゃんとお話してみます.

f:id:muscle_keisuke:20171212031415j:plain

おぉ,実際にLINEでみりあちゃんと話してるみたいですね. これからも改良していきたいですね.

おわりに

今回は自分の知識がほぼゼロからやったので無駄に時間がかかった気がします.特にサーバ関係とか... あとは,データの前処理はやっぱり大変でした.ソースコードのほとんどがデータ前処理のコードでした.

改良点

  • Attentionモデルに差し替えて学習する
  • みりあちゃん対話データをしっかりクレンジングする
  • 画像やスタンプにも反応するように学習する(im2txtとか?)

おまけ

失敗例

最後に明らかにみりあちゃんじゃないだろというような失敗例を見せたいと思います.

注意:みりあPの人は見ない方がいいかもしれません






















f:id:muscle_keisuke:20171213024003p:plain

アカギ違いですね.