今回はAnnoyというライブラリを使って、Pythonで簡単に近似最近傍探索を行う方法について説明します。
近似最近傍探索は類似画像検索などに用いられる技術です。
類似画像検索は「特徴量抽出」と「特徴量の類似度計算」を組み合わせることで実現されます。CNNなどを使って得られた得られた特徴量を元に、特徴量の類似度計算を行い、特徴量が類似しているものを抽出することで検索を実現します。必ずしも同じ固有の物体を見つけるのではなく同じ色調や形状など似ている画像の検索が可能となります。今回は特徴量の類似度計算を行う近似最近傍探索について書きます。
Annoyについて
AnnoyはSpotifyが開発する近似最近傍探索ライブラリです。実装はC++ですが、PythonバインディングなのでPythonから簡単に使うことができ、なおかつ高速に動作します。比較記事によると数ある近似最近傍探索ライブラリの中でもかなり高速に動作するみたいです。
Spotifyでは音楽のレコメンド機能に用いられているそうです。
近似最近傍探索
最近傍探索は与えられた特徴量から最も近いものをみつけるアルゴリズムですが、データ量と特徴量の次元数によって計算量が増えていくため、計算コストが膨大になります。これでは実装に組み込むには難しいです。
そこで、近似最近傍探索は厳密に最近傍を探索するのではなく、近似的に最近傍を見つけることでエラーを許容する代わりに計算量を削減し、高速に探索を行います。
Annoyはこの近似最近傍探索の中でもANNというアルゴリズムを実装したライブラリになります。アルゴリズムの詳細についてはこの記事では割愛します。
AnnoyのAPI
サンプルを用いてAPIの使い方を見ていきます。ここではなんとなくAPIの使い方がわかればいいです。
from annoy import AnnoyIndex
import random
f = 40
t = AnnoyIndex(f) # 与える特徴量のベクトルの次元数を渡します
for i in xrange(1000):
v = [random.gauss(0, 1) for z in xrange(f)]
t.add_item(i, v) # モデルにデータをインデックスしていきます
t.build(10) # ビルドします。これ以降データのインデックスは行えません。
t.save('test.ann') # モデルを保存することも可能です。
u = AnnoyIndex(f)
u.load('test.ann') # モデルを読み込むことも可能です。
print(u.get_nns_by_item(0, 1000)) # インデックス0付近の1000個のデータの返します
手順
1. モデルの構築
model = AnnoyIndex(f, metric='angular')
扱う次元数(f)を与えてモデルを宣言します。次元数は1000以下だといいかんじになることが多いみたいです。1000以上だとあまりよくはないみたいです。metricにはangular、euclidean、 manhattan 、hammingなどの尺度を設定できます。
2. trainデータをモデルに追加する
model.add_item(i, v)
インデックス(i)とベクトル(v)をモデルに追加します。
3. モデルをビルドする
model.build(n_trees)
モデルをビルドします。これ以降モデルにデータを追加することはできません。n_treesは設定する必要があるパラメータです。これにより木の構造が決まり、探索の精度やビルドの速度が変化するので、いくつか試すといいです。値が大きくなると精度が上がる分、ビルドが遅いです。
4. モデルを保存と読み込み
model.save(file_name)
model.unload()
model.load(file_name)
ビルドしたモデルは保存したり、後から読み込んだりできます。
5. 検索
# ベクトルvを与えると、 近傍n個のアイテムを取り出せます。これをよく使うと思います。
# include_distancesはTrueで2地点間の距離を含めます。
model.get_nns_by_vector(v, n, search_k=-1,include_distances=False)
# イデンックスiの近傍n個のアイテムを取り出せます。
model.get_nns_by_item(i, n, search_k=-1, include_distances=False)
# インデックスiのデータを取り出せます。
model.get_item_vector(i)
# インデックスi,jのデータの距離を取り出せます。
model.get_distance(i, j)
# インデックスされているアイテム数を取り出せます
model.get_n_items()
search_kはデフォルトでn * n_treesで設定されます。search_kは検索時間に影響があります。大きい方が精度がいいかわりに検索速度が遅いので、いくつか試すといいと思います。
MNISTで近似最近傍探索
実際にMNISTに対してAnnoyを使ってみます。MNISTは手書き文字データセットです。機械学習では定番のデータセットなので使ったことがある人が多いと思います。
今回使用するライブラリです。Chainerはデータセット取得に便利なので使います。全てpipで入ります。
[packages]
jupyter = "*"
numpy = "*"
matplotlib = "*"
annoy = "*"
chainer = "*"
まずデータセットをダウンロードします。
%matplotlib inline
import matplotlib.pyplot as plt
from chainer.datasets import mnist
from chainer.datasets import split_dataset_random
train_val, test = mnist.get_mnist(withlabel=True, ndim=1)
train, valid = split_dataset_random(train_val, 50000, seed=0)
train[0][0].shape
# (784,)
モデルを構築し、trainデータを追加していきます。その後、ビルドします。
dim = 784
model = AnnoyIndex(dim)
for i in range(len(train)):
x, t = train[i]
model.add_item(i, x)
model.build(30)
model.save("mnist-30tree.ann")
# True
とりあえず、近傍1個取り出してみます。
import random
for i in range(10):
index =random.randint(0, len(test))
x, t = test[index]
predict_index = model.get_nns_by_vector(x, 1)
predict = train[predict_index][0]
print(f"最近傍ラベル:{predict[1]}")
print(f"正解ラベル :{t}")
print("-"*40)
# 最近傍ラベル:4
# 正解ラベル :4
# ----------------------------------------
# 最近傍ラベル:2
# 正解ラベル :2
# ----------------------------------------
# 最近傍ラベル:3
# 正解ラベル :3
# ----------------------------------------
# 最近傍ラベル:6
# 正解ラベル :6
# ----------------------------------------
# 最近傍ラベル:4
# 正解ラベル :4
# ----------------------------------------
# 最近傍ラベル:3
# 正解ラベル :3
# ----------------------------------------
# 最近傍ラベル:3
# 正解ラベル :3
# ----------------------------------------
# 最近傍ラベル:2
# 正解ラベル :2
# ----------------------------------------
# 最近傍ラベル:0
# 正解ラベル :0
# ----------------------------------------
# 最近傍ラベル:3
# 正解ラベル :3
# ----------------------------------------
最近傍の精度も見て見ます。本来は類似画像の検索が目的なので、特定の正解はないと思うのですが、一応ラベルをもとに判定してみます。
count = 0
for i in range(len(test)):
x, t = test[i]
predict_index = model.get_nns_by_vector(x, 1, search_k=100)
predict = train[predict_index][0]
if predict[1] == t:
count += 1
print(count / len(test))
# 0.9645
うまくいってると思うので、近傍5個を出力してみます。
import random
import numpy as np
rows_count = 10
columns_count = 6
images_count = rows_count * columns_count
axes = []
fig = plt.figure(figsize=(15, 15))
for i in range(10):
index =random.randint(0, len(test))
x, t = test[index]
answer_ax = fig.add_subplot(rows_count, columns_count, i*columns_count + 1)
answer_ax.imshow(x.reshape(28, 28), cmap="gray")
plt.axis("off")
answer_ax.set_title("Correct Answer")
predict_indexes = model.get_nns_by_vector(x, 5, search_k=150)
for j, predict_i in enumerate(predict_indexes):
predict_x, predict_t = train[predict_i]
ax = fig.add_subplot(rows_count, columns_count, i*columns_count + 2+ j)
ax.imshow(predict_x.reshape(28, 28), cmap="gray")
plt.axis("off")
fig.subplots_adjust(wspace=0.2, hspace=0.2)
plt.tight_layout()
plt.savefig("mnist-30-tree.png")
plt.show()
いい感じに似ている画像が表示されていると思います。
つぎはカラー画像で試して見ます。
cifar10で近似最近傍探索
カラー画像のデータセットには小規模ですが、cifar10を使います。今回は特徴量抽出を行わずに直接画像をAnnoyに入れるのであまり次元数が大きくならないようにcifar10を選びました。それでも次元数は3072(3×32×32)でかなり多いですが。
MNISTと同様の手順で試していきます。
train, test = chainer.datasets.get_cifar10()
print(train[0][0].shape)
# (3, 32, 32)
from annoy import AnnoyIndex
dim = 3 * 32 * 32
annoy_model = AnnoyIndex(dim)
for i in range(len(train)):
x, _ = train[i]
annoy_model.add_item(i, x.reshape(-1))
annoy_model.build(80)
annoy_model.save("cifar-80tree.ann")
近傍一個の精度を見ていきます。
count = 0
for i in range(len(test)):
x, t = test[i]
predict_index = annoy_model.get_nns_by_vector(x.reshape(-1), 1)
predict = train[predict_index]
if predict[1] == t:
count += 1
print(count / len(test))
# 0.3618
いまいちですね。類似画像を検索しているので画像を見てみないと、判断できないので画像を抽出します。MNISTと同様に近傍5個で試します。
import random
import numpy as np
rows_count = 10
columns_count = 6
images_count = rows_count * columns_count
axes = []
fig = plt.figure(figsize=(15, 15))
for i in range(10):
index =random.randint(0, len(test))
x, t = test[index]
answer_ax = fig.add_subplot(rows_count, columns_count, i*columns_count + 1)
answer_ax.imshow(x.transpose(1, 2, 0))
plt.axis("off")
answer_ax.set_title("Correct Answer")
predict_indexes = annoy_model.get_nns_by_vector(x.reshape(-1), 5, search_k=150)
for j, predict_i in enumerate(predict_indexes):
predict_x, predict_t = train[predict_i]
ax = fig.add_subplot(rows_count, columns_count, i*columns_count + 2+ j)
ax.imshow(predict_x.transpose(1, 2, 0))
plt.axis("off")
fig.subplots_adjust(wspace=0.2, hspace=0.2)
plt.tight_layout()
plt.savefig("cifar10-80tree.png")
plt.show()
なんとなく雰囲気の似ている画像が抽出できている気がします。次元数が3000を超えているので、CNNなどの特徴量抽出などで次元数を削減すればもっと良い結果がでると思います。
おわりに
Annoyは「Approximate Nearest Neighbors Oh Yeah」の略になるのですが、ゆるい名前のわりに強力なライブラリだと思います。
最近はプログラミングをPythonから始める人も多いみたいですが(自分もそうだった)、Pythonが使えれば他の言語に慣れていなくてもAnnoyを使って高速な近似最近傍探索を簡単に実装に組み込めます。近似最近傍探索の詳しいアルゴリズムを理解していなくても、簡単に使えるので役に立つ場面もあると思います。使ってみてください。
次はDeep Learningで特徴量抽出と組み合わせた記事を書こうと思います。