RustでCUDAカーネルを書く

  • 14
    いいね
  • 0
    コメント

数値計算をする以上GPUを使用できないというのは現代では致命的だ。
この記事ではCUDA経由でGPUをRustから使う方法についてまとめる。
特に今回はRustでCUDAカーネルを記述することに挑戦する。

Write CUDA kernel in Rust

Rustは2016/12よりNVPTXへのコンパイルに対応している。NVPTXとはLLVMの一部で、nvidia GPU用のアセンブラのようなものである。CUDAで書かれたkernelは一旦NVPTXに変換されて最適化された後、GPUの命令に変換される。

二つのベクトルA, Bを足してベクトルCを作る演算を考えよう。CUDAで書くと以下のようになる:

__global__ void
vectorAdd(const float *A, const float *B, float *C, int numElements) {
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    if (i < numElements) {
        C[i] = A[i] + B[i];
    }
}

(CUDAの公式サンプルより拝借)
ベクトルのアドレスをもらい、block・threadのIDから自分の担当場所を決めてそこの加算を実行する。これは個々のthreadがする処理が記述してある。これを複数起動するのはカーネルを起動する側の役割である。これをRustで書いてみよう。

注意する点は2つ、__global__修飾子とthreadIdx.x等の組み込み値だ。
上記のCUDAコードがNVPTXにコンパイルされるとき、これらはそれぞれアノテーション特別なレジスタに変換される。RustのコードをNVPTXにコンパイルするにはこの2つを正しく指定する必要がある。RustコンパイラがNVPTXターゲットをサポートしたことで、これらを指定する方法が提供された。

まず__global__の方は簡単で、extern "ptx-kernel"修飾子を関数に付けることに相当する。

一方threadIdx.xの方は多少複雑である。NVPTXはLLVMの一部なので、同じように無限個のレジスタを持つ仮想マシン上で動作しているように記述される。threadIdx.xの値はある特殊なレジスタを読み取ることで得られるように定義されている。これはSIMD拡張用のレジスタ等と同じレベルでLLVMにターゲット固有機能として実装されている。LLVMの関数

declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()

がこのレジスタから値を読み取って返してくれる。
Rustからこの関数を呼べるようにするにはplatform-intrinsic機能を使う。残念ながらまだunstableである。これによりextern "platform-intrinsic"修飾された関数をLLVMの該当する関数に置き換えることができる。これはSIMD拡張のために開発されている機能である。
https://github.com/rust-lang/rust/issues/27731

このextern "platform-intrinsics"関数を通常のRust関数にマップしてくれいてるのがnvptx-builtinsで、これを使用するとついにRustでカーネルが書けるようになる。

#![feature(abi_ptx)]
#![no_std]

extern crate nvptx_builtins as intrinsics;

#[no_mangle]
pub unsafe extern "ptx-kernel"
fn add(a: *const f32, b: *const f32, c: *mut f32, n: usize) {
    let i = intrinsics::block_dim_x()
        .wrapping_mul(intrinsics::block_idx_x())
        .wrapping_add(intrinsics::thread_idx_x()) as isize;

    if (i as usize) < n {
        *c.offset(i) = *a.offset(i) + *b.offset(i);
    }
}

https://github.com/japaric/nvptx/blob/master/kernel/src/lib.rs
wrapping_*に目をつぶれば元のCUDAのコードとほぼ同じなのが見て取れるだろう。

コンパイル方法についてはjaparic/nvptxを参照してね(疲れたから略)

なお、NVPTXターゲットはx86やARMとかと同じ扱いなので、一つのcrateをNVPTXにコンパイルするという方式をとる。つまり、あるプロジェクトで必要なカーネルは本体のプロジェクトとは別のcrateとして扱う必要があるということだ。これは*-sysのようなcrateを分けて作るRustの文化ではさほど障害にはならないだろう。

以上でRustからPTXを生成することに成功した。しかしPTXは単独では実行することはできない。
これを実行するにはPTXをGPUバイナリにJITコンパイルしたうえでCUDA Driver APIに含まれるcuLaunchKernel関数を使うが、それについては次回の記事でまとめることにする。

References