Python
Numba
12
どのような問題がありますか?

この記事は最終更新日から3年以上が経過しています。

投稿日

更新日

PythonをNumbaでAOT (事前) コンパイルして高速化する

Pythonは遅いと言われていますが、Pythonを高速化する方法はたくさんあります。その一つである Numba で、AOT(事前)コンパイルして高速化する方法を紹介します。

Numba は、Python のコードに@jitのデコレータを追加するだけで簡単に高速化できる JIT コンパイラーとして知られていますが、AOT コンパイルも可能です。AOT を使うメリットは起動時間が速くなることです。Python の場合は、起動時間を気にするようなケースでは使うことは少ないので、AOT を使う機会はあまりないと思いますが、例えば、Cloud Functions や AWS Lambda のサーバーレスクラウドを使う場合のように起動回数が多い場合にはかなり有効な手段になります。

AOT で使うためのコード

通常のコンパイラー言語だとコンパイルコマンドを使ってソースコードを総てコンパイルするのですが、Numba の場合は、コード全体をコンパイルするのではなくて、高速化が必要な関数のみ、すなわち @cc.export デコレータが追加されている関数だけをコンパイルします。

Numba の AOT の使い方は、公式ドキュメントの 2.3. Ahead-of-Time compilation にあります。概要としては、次のようなコードを書いて、

my_module.py
from numba.pycc import CC
from numba import njit

cc = CC('my_module')

@cc.export('square', 'f8(f8)')
def square(a):
    return a ** 2

if __name__ == "__main__":
    cc.compile()

my_code.py を実行すると、CC('my_module') で設定したモジュール名、この場合で Linux の場合だと my_module.cpython-37m-x86_64-linux-gnu.so という名前のバイナリーファイルができます。後は、そのモジュールを普通に import して使うことができます。

main.py
from my_module import square

print(my_module.square(3))

ベンチマークの結果

以前のQiitaの記事のコードを以下のように、Numba, PyArrow を使って書き換えたものを使いました。データ件数は100万件です。書き換えたことで、以前は、CSVの読み込み以降の処理で約0.42sec かかっていたものが、約0.16secまで高速化することができました。csv の読み込みに PyArrow を使ったことで読み込みが並列化されたこと、Pandas でも文字列をカテゴリー型にしたことの影響が大きいです。

項目 CSVの読み込みを含む  CSVの読み込み後 起動時間を含む実行時間
Numba JIT
型指定無し
288 ms  175 ms 951 ms
Numba JIT
型指定有り
162 ms 49 ms 948 ms
Numba AOT    158 ms 49 ms 511 ms
C# LINQ 861 ms 213 ms 923 ms

これでわかることは、AOT にすると Numba のインポートとコンパイルの時間が不要になるので、約440ms ぐらい起動時間が削減されます。また、上記のJITの起動時間は、Numba がページキャッシュにのっている状態の起動時間なので、ディスクから読み込まないといけない場合には、自分のPCでは2秒ぐらい、GCP の VM の場合だと5秒ぐらいかかります。Cloud Functions や AWS Lambda を使う場合には、AOT が有効な手段になるのがよくわかると思います。

また、JIT の場合は、型指定をするかどうかではトータルの時間は約950msでかわらないということです。一方で、プログラムの中の関数の実行時間の測定では差が出ています。型指定をしている場合は、インポート時にコンパイルが実行され、型指定をしていなければ実際に引数が渡された時点でコンパイルが実行されていると思われます。JITで使う場合には、型指定をする必要はなさそうです。

なお、C#については Numba を使った処理が結構速いということを示すために参考に載せています。

ベンチマークの実行環境及びコードについて

CPU: Intel Core i7-7700 CPU @ 3.60GHz
Python 3.7、.net core 2.2

使用したデータを作成するコードは、以下のページにあります。
サンプルデータ

Python のコード

main.py
import numpy as np
import pandas as pd
from pyarrow import csv
import my_module


def groupby(table):
    df = table.to_pandas(strings_to_categorical=True)
    a_codes = np.array(df['a'].cat.codes, dtype='int64')
    z = np.zeros(len(df['a'].cat.categories), dtype='int64')
    my_module.groupby_core(a_codes, df.x.values, df.y.values, z)
    df_groupby_a = pd.DataFrame({'a': df['a'].cat.categories, 'z': z})
    df_groupby_a.to_json('groupby_a.json', orient='records')


def main():
    table = csv.read_csv('data/test.csv')
    groupby(table)


if __name__=="__main__":
    main()
my_module.py
rom numba import njit
from numba.pycc import CC

cc = CC('my_module')

@cc.export('groupby_core', 'void(i8[:],f8[:],f8[:],i8[:])')
@njit
def groupby_core(a, x, y, z):
    for i in range(len(a)):
        if x[i] > 0:
            z[a[i]] += int(x[i] * y[i] + 0.0000001)
        else:
            z[a[i]] += int(x[i] * y[i] - 0.0000001)

if __name__ == "__main__":
    cc.compile()

C# のコード

Program.cs
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Newtonsoft.Json;

namespace csharp_test
{
    class Program
    {
        static void Main(string[] args)
        {
            var linqTest = new LinqTest();
            linqTest.Load();
            linqTest.ByLinq();
        }
    }

    class LinqTest{
        private List<TestData> TestDataList { get; set;}

        internal void Load()
        {
            TestDataList = File.ReadLines("../data/test.csv")
                .AsParallel()
                .Skip(1)
                .Select(line =>
                {
                    var columns = line.Split(',');
                    return new TestData
                    {
                        a = columns[0],
                        b = columns[1],
                        x = double.Parse(columns[2]),
                        y = double.Parse(columns[3])
                    };
                }).ToList();
        }

        internal void ByLinq()
        {
            var data = TestDataList
                .GroupBy(
                    d => d.a,
                    d => MultiplyToInt(d.x, d.y))
                .Select(g => new { a = g.Key, z = g.Sum(d => d) });

            File.WriteAllText("result.json", JsonConvert.SerializeObject(data));
        }
    }

    public class TestData
    {
        public string a { get; set; }
        public string b { get; set; }
        public double x { get; set; }
        public double y { get; set; }
    }
}
ユーザー登録して、Qiitaをもっと便利に使ってみませんか。
  1. あなたにマッチした記事をお届けします
    ユーザーやタグをフォローすることで、あなたが興味を持つ技術分野の情報をまとめてキャッチアップできます
  2. 便利な情報をあとで効率的に読み返せます
    気に入った記事を「ストック」することで、あとからすぐに検索できます
ユーザー登録ログイン
yniji

コメント

この記事にコメントはありません。
あなたもコメントしてみませんか :)
ユーザー登録
すでにアカウントを持っている方はログイン
記事投稿イベント開催中
Go強化月間~開発する上で知っておくべき知見を共有しよう~
~
エンジニア夏休み企画!~自由研究や読書感想文を発表しよう~
~
12
どのような問題がありますか?
ユーザー登録して、Qiitaをもっと便利に使ってみませんか

この機能を利用するにはログインする必要があります。ログインするとさらに下記の機能が使えます。

  1. ユーザーやタグのフォロー機能であなたにマッチした記事をお届け
  2. ストック機能で便利な情報を後から効率的に読み返せる
ユーザー登録ログイン
ストックするカテゴリー