Python の新しい行列積演算子 (PEP 465: Matrix Multiplication Operator)

Python 3.5 で np.dot(A, B) を A @ B と書いて動かす話。

PEP 465 “PEP 465 matrix multiplication operator” が 2014 年 4 月 7 日に Accept されました 。これによって新しい binary operator として @ operator とその in-place 版である @= operator が新たに導入され、行列演算を簡潔に表現できるようになります。 CPython の default branch (3.5 の開発ブランチ) に実装が既にマージされています 。CPython 3.5 でリリースされる見込みですが、 numpy の行列表現のデータ構造でのサポートはまだ未完というステータスです。Issue は ENH: Implement matmul function · Issue #4464 · numpy/numpy · GitHub っぽい。待ち遠しいですね。

ところで ndarray には View Casting という機構 があり、 ndarray instance の View を subclass として定義できます。 View Casting で __matmul__ を暫定的に追加すれば @-operator を簡単に試すことが出来そうだと思いついたので、実際にやってみました。 例として linear regression の weight parameter を最小二乗法で求めてみます。

行列表現すると以下:

\hat{w} = (X^T X)^{-1} X^T y

行列積が出てきました。@-operator の出番です。View を定義して @-operator を試してみます:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
Python 3.5.0a0 (default, May 31 2014, 13:34:58)
[GCC 4.8.1] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>
>>> import numpy as np
>>> from numpy.linalg import inv
>>>
>>> class C(np.ndarray):
...   def __matmul__(self, other):
...     return self.dot(other)
...
>>> X = np.random.rand(4,3).view(C)
>>> y = np.random.rand(4).view(C)
>>>
>>> w = inv(X.T @ X) @ X.T @ y
>>> w
C([-0.15002047,  0.99972529,  0.32114514])

すばらしい。np.dot で書いて比較してみます。

1
2
3
>>> w = np.dot(np.dot(inv(np.dot(X.T, X)), X.T), y)
>>> w
C([-0.15002047,  0.99972529,  0.32114514])

もう慣れてしまいましたが、式が大きくなるとつらさが増します。 研究などでここがバグると非常につらい。わかりやすさは重要です。 抽象構文木も確認してみます。

1
2
3
>>> import ast
>>> ast.dump(ast.parse('inv(X.T @ X) @ X.T @ y'))
"Module(body=[Expr(value=BinOp(left=BinOp(left=Call(func=Name(id='inv', ctx=Load()), args=[BinOp(left=Attribute(value=Name(id='X', ctx=Load()), attr='T', ctx=Load()), op=MatMult(), right=Name(id='X', ctx=Load()))], keywords=[], starargs=None, kwargs=None), op=MatMult(), right=Attribute(value=Name(id='X', ctx=Load()), attr='T', ctx=Load())), op=MatMult(), right=Name(id='y', ctx=Load())))])"

見づらいので結果を手動で整形:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
>>> ast.dump(ast.parse('inv(X.T @ X) @ X.T @ y'))
"Module(body=
[
  Expr(
    value=BinOp(
      left=BinOp(
        left=Call(
          func=Name(id='inv', ctx=Load()),
          args=[
            BinOp(
              left=Attribute(
                value=Name(id='X', ctx=Load()),
                attr='T',
                ctx=Load()
              ),
              op=MatMult(),
              right=Name(id='X', ctx=Load()))
          ],
        keywords=[], starargs=None, kwargs=None
      ),
      op=MatMult(),
      right=Attribute(
        value=Name(id='X', ctx=Load()),
        attr='T',
        ctx=Load()
      )
    ),
    op=MatMult(),
    right=Name(id='y', ctx=Load())
  )
)
])"

BinOp(left=Xの転置, op=MatMult, right=X) など期待した抽象構文木が得られました(動作を確認しているのであたりまえですが)。 数値計算の主要パッケージはだいたい Python3 compatible なので、数値計算をする人たちはさっさと Python3 へ移行しているという印象があります。 PEP 465 は Python 3.5 以降の機能となるため、完全な Python3 への移行が加速するかもしれません。知らんけど。