強化学習と呼ばれる機械学習の一分野があります。機械学習というと、入力に対して正解の出力を当てる教師あり学習が話題になることが多いですが、強化学習では明示的に正解が与えられません。ある行動を試してみて、それに対して得られる「報酬」から自分でどのような行動が良い結果をもたらすのかを判断して、より良い行動を学習するアルゴリズムになっています。
強化学習にはチェスやリバーシなどといったボードゲームのAIやロボットの行動学習などの応用例があります。この前話題になったDeep Q Network、通称DQNも強化学習の一種です。応用例が面白いにも関わらず、PRMLなどの主要な機械学習の教科書では強化学習を扱わないことが多いので、いま強化学習だけの参考書を買って勉強しています。
- 作者: Richard S.Sutton,Andrew G.Barto,三上貞芳,皆川雅章
- 出版社/メーカー: 森北出版
- 発売日: 2000/12/01
- メディア: 単行本(ソフトカバー)
- 購入: 5人 クリック: 76回
- この商品を含むブログ (29件) を見る
ちなみにこの本の原書は無料公開されています。「ギャンブラーの問題」に対応するのは Example 4.3: Gambler's Problemです。
問題設定
考え方
詳細は本を読んでもらうのが分かりやすいのですが、簡単に説明しておきます。
- 所持金sのときの勝率を表す関数
を用意する。(
は状態価値関数と呼ぶ。配列で代用する)
- 所持金が0ドルの場合は勝率0、所持金が100ドルの場合は勝率1に設定する。
- 所持金が1ドル〜99ドルの場合の勝率を次のように更新する
ドル賭けると、次のターンの勝率は
の確率で
に、
の確率で勝率は
になる。
を次のターンの勝率の期待値
の最大値に更新する。
- 更新を繰り返し、変化量が十分小さくなったら真の勝率が得られたと判断して更新を止める。
- 各所持金に対して、次のターンの勝率の期待値が最も高くなる金額を賭けるようにする。(所持金sの時の最適戦略を方策関数
で表す。
こうすることで、ギャンブラーは最適な金額を賭けられるようになります。
ソースコード
#include <iostream> #include <cmath> #include <algorithm> using namespace std; int main(void){ double V[101]; // 状態価値関数 int pi[101]; // 方策 const double p = 0.4; //表が出る確率 // 状態価値関数の初期化 for(int s=0; s<100; ++s) V[s] = 0; V[100] = 1.0; const double theta = 1e-5; // ループ終了のしきい値 double delta = 1.0; // 最大変更量 // 状態価値関数の更新 while(delta >= theta){ delta = 0.0; for(int s=1; s<100; ++s){ double V_old = V[s]; double cand = 0.0; // 可能な掛け金ごとに勝率を調べる for(int bet=1; bet<=min(s, 100-s); ++bet){ double tmp = p*V[s+bet] + (1.0-p)*V[s-bet]; cand = max(tmp, cand); } V[s] = cand; delta = max(delta, abs(V_old-V[s])); // 変更量のチェック } //状態価値関数の表示 for(int i=0; i<101; i++){ cout << V[i] << ", "; } cout << endl << endl; } // 最適方策の更新 double threshold = 1e-5; for(int s=1; s<100; ++s){ double cand = 0; double tmp; for(int bet=1; bet<=min(s, 100-s); ++bet){ tmp = p*V[s+bet]+(1-p)*V[s-bet]; if(tmp > cand + threshold){ cand = tmp; pi[s] = bet; } } } // 最適方策の表示 for(int i=1; i<100; i++){ cout << pi[i] << ", "; } cout << endl << endl; }
結果
まず各賭け金における勝率のグラフを示します。(プロットはPythonでやりました。ソースは後述)
更新を繰り返すと収束していきます。所持金額が増えるほど勝率が増えるという妥当な結果になっているのが分かります。
つぎに、所持金ごとの最適な賭け金の額です。
とても面白い結果になっています。50ドルの時には全額を賭けるのが良いのに対し、49ドルや51ドルの時には1ドルだけ賭けるのが最も良いというのは不思議です。
「51ドルの時に1ドルだけ賭けるのが最適になっているのはなぜか?」というのは演習問題になっていたのですが、僕はいくら考えても分かりませんでした。解答にあったことを要約すると
- 表が出る確率は0.4なのでコイントスの回数が増えるほどギャンブラーにとって基本的には不利になる
- 50ドル持っているときに全額賭けると勝利に必要なコイントス回数が1回で済むためこれが最適戦略になる。
- 51ドル持っているときに1ドル賭けて負けても、50ドルという勝率0.4で勝てる状態になるだけで済む。
- 51ドル持っているときに1ドル賭けて勝てば所持金が増える。所持金が75ドルを超えれば、25ドル賭けることで勝率0.4で勝利し、負けても50ドル以上手元にあるのでまだ勝率は0.4という有利な状態になる。
- 51ドルの時に49ドル賭けると勝率は0.4だが、1ドルの賭けなら負けても勝率0.4、勝てば所持金が増えてノーリスクでの所持金75ドル超えに挑むことができる。
- よって1ドル賭けるのが最適な戦略である
自分では思いつきませんでしたが、賢い人がいるものです。そして、その賢い戦略を発見できるアルゴリズムの優秀さも分かります。
いろいろ実験する
さて、ソースコードを真面目に読んだ人が仮にいたとすると、最適行動を決めるところでthresholdという怪しい変数で最適行動の更新に制限をかけていることに気づいたと思います。
実は、本に書いてある条件の通りにいろいろ試してもどうしてもさっきのような綺麗なグラフになりませんでした。諦めてホームページで公開されているソースコードを見たところ、更新するときに今までの行動よりもthreshold(このプログラムでは0.00001)だけ大きくないと更新しないという処理が追加されていました。
このthresholdを0にした結果がこれです。
さっきのきれいなグラフとはだいぶ変わってしまいました。51ドルの場合の戦略は不変ですが、49ドルなどの場合の戦略が変更されています。0.00001未満の差なのであまり気にすることはないかもしれませんが、この方がほんの少し有利な戦略になっているはずです。
以下thresholdは0.00001とします。
次に胴元がイカサマをして表が出る確率が0.2だったときのグラフを示します。
確率が0.4のときのグラフと変わりません。コイントスをすればするほど不利になるという条件が変わらないことが原因だと思われます。
次は胴元が確率設定を間違えてしまい、表が出る確率が0.6になった場合のグラフを示します。
コイントスの回数が多いほど有利になるので、1ドルずつ賭けてなるべくコイントスの回数を増やす戦略を取るのが有効なようです。
最後に一般的な設定として表が出る確率が0.5, threshold=0の時のグラフを示します。
50ドル以下の時は全額賭ける、50ドル以上の場合は75ドルを狙うか、100ドルぴったりになるような金額を賭けるという戦略が有効なようです。
このグラフをよく見るとところどころ変な凹みがあることに気が付きます。もしかしたらまだ収束していないのかもしれないと思って、終了条件thetaを小さくして答えの変化がなくなるまで回したグラフがこれです。
ループ回数を増やすと、基本的に1ドル賭けるのが有利・50ドルなどの特異な場所では少し多く賭けるのが良いという結論になるようです。終了条件を1e-5からさらに減らしたことによる結果なので、どちらの戦略をとっても勝率には大きな違いはないでしょう。微小な終了条件の違いで全く違う戦略を取ることになったのでこのアルゴリズムによる戦略決定は結構不安定なのかもしれません。
まとめ
強化学習でギャンブラーがコイントスゲームに挑んだときにいくら賭けるべきかという問題を解いてみました。得られた結論は面白いものでしたが、条件を少し変えるだけで大きく変わってしまう不安定な解である可能性が捨て切れませんでした。
もう少し読み進めて何か良い方法があったらまた実装してみようと思います。
おまけ
プロットに使ったプログラムを置いておきます。
プログラムの出力値をコピペしている、手抜きなプログラムですがご容赦ください。
状態価値関数のプロット
import pylab import argparse import json x = range(0, 101) repeat1 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.64, 0.784, 0.784, 0.784, 0.784, 0.784, 0.784, 0.8704, 0.8704, 0.8704, 0.92224, 0.92224, 0.953344, 1] repeat2 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.256, 0.256, 0.256, 0.256, 0.256, 0.256, 0.3136, 0.3136, 0.3136, 0.34816, 0.34816, 0.368896, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4096, 0.4096, 0.4096, 0.44416, 0.464896, 0.477338, 0.496, 0.496, 0.496, 0.50176, 0.50176, 0.534938, 0.5536, 0.5536, 0.557056, 0.58816, 0.590234, 0.61014, 0.64, 0.64, 0.64, 0.64, 0.647834, 0.666496, 0.686403, 0.6976, 0.701056, 0.720963, 0.73216, 0.752896, 0.766084, 0.784, 0.784, 0.799898, 0.81856, 0.832578, 0.851738, 0.8704, 0.879939, 0.899547, 0.92224, 0.939728, 0.963837, 1] repeat3 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.064, 0.064, 0.064, 0.064, 0.064, 0.064, 0.1024, 0.1024, 0.1024, 0.12544, 0.12544, 0.139264, 0.16, 0.16, 0.16, 0.16, 0.16384, 0.177664, 0.190935, 0.1984, 0.200704, 0.213975, 0.22144, 0.235264, 0.244056, 0.256, 0.256, 0.266598, 0.27904, 0.288385, 0.301158, 0.3136, 0.319959, 0.333031, 0.34816, 0.359819, 0.375891, 0.4, 0.4, 0.4, 0.4, 0.40384, 0.417664, 0.430935, 0.4384, 0.440704, 0.453975, 0.46144, 0.475264, 0.484056, 0.496, 0.496, 0.506598, 0.51904, 0.528385, 0.541158, 0.5536, 0.559959, 0.573031, 0.58816, 0.599819, 0.615891, 0.64, 0.64, 0.643133, 0.658561, 0.664422, 0.676864, 0.690434, 0.6976, 0.711424, 0.724695, 0.735975, 0.752896, 0.769535, 0.784, 0.795137, 0.806118, 0.81856, 0.834817, 0.851738, 0.8704, 0.883671, 0.90089, 0.92224, 0.940534, 0.96432, 1] end = [0, 0.00206467, 0.00516393, 0.00922521, 0.0129101, 0.0173852, 0.0230635, 0.027814, 0.0322754, 0.0376844, 0.0434633, 0.0503542, 0.0576592, 0.0652388, 0.0695351, 0.0744311, 0.0806884, 0.0866106, 0.0942125, 0.103143, 0.108659, 0.115966, 0.125886, 0.13358, 0.144148, 0.16, 0.163098, 0.167746, 0.173838, 0.179365, 0.186078, 0.194595, 0.201721, 0.208413, 0.216527, 0.225195, 0.235532, 0.246489, 0.257859, 0.264303, 0.271647, 0.281033, 0.289916, 0.301319, 0.314715, 0.322988, 0.33395, 0.348829, 0.36037, 0.376222, 0.4, 0.403098, 0.407746, 0.413838, 0.419365, 0.426078, 0.434595, 0.441721, 0.448413, 0.456527, 0.465195, 0.475532, 0.486489, 0.497859, 0.504303, 0.511647, 0.521033, 0.529916, 0.541319, 0.554715, 0.562988, 0.57395, 0.588829, 0.60037, 0.616222, 0.64, 0.644648, 0.651619, 0.660757, 0.669048, 0.679117, 0.691893, 0.702582, 0.71262, 0.724791, 0.737793, 0.753298, 0.769733, 0.786789, 0.796454, 0.80747, 0.821549, 0.834875, 0.851979, 0.872073, 0.884482, 0.900925, 0.923244, 0.940555, 0.964333, 1] pylab.xlabel("money") pylab.ylabel("value") pylab.plot(x, repeat1, label="1st repeat") pylab.plot(x, repeat2, label="2nd repeat") pylab.plot(x, repeat3, label="3rd repeat") pylab.plot(x, end, label="last repeat") pylab.legend(loc = 'lower right') pylab.show()
最適方策関数のプロット
import pylab import argparse import json x = range(1, 100) y = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 25, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 50, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 25, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] pylab.ylim([0,50]) pylab.xlabel("money") pylab.ylabel("bet") pylab.bar(x, y) pylab.show()