1

この記事は最終更新日から1年以上が経過しています。

更新日

投稿日 2021年09月18日

class で jit を使用する方法

python高速化のためのjitをclassで使用したいとき
jitclassより@staticmethod を使ったほうがいいのではないかという個人メモ。

jitclass の使用例

jitclassを使えばクラスでもjitを使用することができる。

import numpy as np
import matplotlib.pyplot as plt
from numba import jit, f8
from numba.experimental import jitclass
import time
spec = [
    ('arr', f8[:,:,:]),               
    ('new_arr', f8[:,:,:]),        
]

@jitclass(spec)
class class_jit():
    def __init__(self,arr):
        self.arr = arr
        self.new_arr=np.zeros_like(arr, dtype=np.float64)
    def func(self):
        shape=self.arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    self.new_arr[x,y,z]=1e6*x+1e3*y+z
        return self.new_arr
    def plot(self):
        plt.imshow(self.new_arr[0])

if __name__=='__main__':
    n=256
    arr=np.zeros([n]*3)
    st=time.time()
    obj1=class_jit(arr)
    obj1.func()
    print(f'"class jit " elapsed_time {time.time()-st:2f}') 
"class jit " elapsed_time 0.512740

jitclass で困ったこと

jitclassではjit の対応していないメンバ関数を使用することができない。
例えばクラスのほかのメンバ関数でmatplotを使おうとすると

obj1.plot()

plt.imshowのところでエラーを返される。

TypingError: - Resolution failure for literal arguments:
Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'imshow' of type Module

@staticmethodを使って解決

jitを使用したいメンバ関数を@staticmethodで静的関数として、@jitを使用したら上手くいった。インスタンス変数を参照したい場合は、ラッパーにすればよい。

class class_jit2():
    def __init__(self):
        pass

    def func(self, arr): #wrapper
        self.arr=arr
        self.new_arr=self.__func(arr)

    @staticmethod
    @jit(f8[:,:,:](f8[:,:,:])) #型指定
    def __func(arr):
        new_arr=np.zeros_like(arr, dtype=np.float64)
        shape=arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    new_arr[x,y,z]=1e6*x+1e3*y+z
        return new_arr

    def plot(self):
        plt.imshow(self.new_arr[0])

if __name__=='__main__':
    n=256
    arr=np.zeros([n]*3)        
    st=time.time()
    obj2=class_jit2()
    obj2.func(arr)
    print(f'"class jit staticmethod " elapsed_time {time.time()-st:2f}') 
    obj2.plot()

これなら、jitmatplotも同一のクラスで使用できる。

"class jit staticmethod " elapsed_time 0.044842

image.png

申し訳程度の速度比較

import numpy as np
import matplotlib.pyplot as plt
from numba import jit, f8
from numba.experimental import jitclass
import time


def func(arr): #jitなし
    new_arr=np.zeros_like(arr)
    shape=arr.shape
    for x in range(shape[0]):
        for y in range(shape[1]):
            for z in range(shape[2]):
                new_arr[x,y,z]=1e6*x+1e3*y+z
    return new_arr

@jit(nopython=True) #jit 型指定なし
def func_jit(arr):
    new_arr=np.zeros_like(arr)
    shape=arr.shape
    for x in range(shape[0]):
        for y in range(shape[1]):
            for z in range(shape[2]):
                new_arr[x,y,z]=1e6*x+1e3*y+z   
    return new_arr

@jit(f8[:,:,:](f8[:,:,:])) #jit 型指定
def func_jit2(arr):
    new_arr=np.zeros_like(arr)
    shape=arr.shape
    for x in range(shape[0]):
        for y in range(shape[1]):
            for z in range(shape[2]):
                new_arr[x,y,z]=1e6*x+1e3*y+z   
    return new_arr


spec = [
    ('arr', f8[:,:,:]),               # a simple scalar field
    ('new_arr', f8[:,:,:]),          # an array field
]

@jitclass(spec)
class class_jit():
    def __init__(self,arr):
        self.arr = arr
        self.new_arr=np.zeros_like(arr, dtype=np.float64)
    def func(self):
        shape=self.arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    self.new_arr[x,y,z]=1e6*x+1e3*y+z
        return self.new_arr

class class_jit2():
    def __init__(self):
        pass

    def func(self, arr):
        self.arr=arr
        self.new_arr=self.__func(arr)

    @staticmethod
    @jit(f8[:,:,:](f8[:,:,:]))
    def __func(arr):
        new_arr=np.zeros_like(arr, dtype=np.float64)
        shape=arr.shape
        for x in range(shape[0]):
            for y in range(shape[1]):
                for z in range(shape[2]):
                    new_arr[x,y,z]=1e6*x+1e3*y+z
        return new_arr

if __name__=='__main__':

    n=256
    arr=np.zeros([n]*3)

    st=time.time()
    func(arr)
    print(f'"w/o_jit" elapsed_time {time.time()-st:2f}') 

    st=time.time()
    func_jit(arr)
    print(f'"jit + w/o type sepc elapsed_time {time.time()-st:2f}') 

    st=time.time()
    func_jit2(arr)
    print(f'"jit + type spec" elapsed_time {time.time()-st:2f}') 

    st=time.time()
    obj1=class_jit(arr)
    obj1.func()
    print(f'"class jit " elapsed_time {time.time()-st:2f}') 

    st=time.time()
    obj2=class_jit2()
    obj2.func(arr)
    print(f'"class jit staticmethod " elapsed_time {time.time()-st:2f}') 

"w/o_jit" elapsed_time 3.964147
"jit + w/o type sepc elapsed_time 0.163591
"jit + type spec" elapsed_time 0.047845
"class jit " elapsed_time 0.431932
"class jit staticmethod " elapsed_time 0.043882

@jitつけるだけで24倍
型指定すると83倍の高速化
(jitclass の速度がいまいちなのは謎。jitclassの型指定がうまくいってない...?)

新規登録して、もっと便利にQiitaを使ってみよう

  1. あなたにマッチした記事をお届けします
  2. 便利な情報をあとで効率的に読み返せます
ログインすると使える機能について

コメント

この記事にコメントはありません。
あなたもコメントしてみませんか :)
新規登録
すでにアカウントを持っている方はログイン
1