Python
画像処理
機械学習
バグ
データ分析

Pythonで機械学習やデータ解析やってる人はバグを無くすためにassert文を使うべき

背景

研究で画像を扱っていたとき、

  • 次元の違い
  • 正規化する範囲の違い

がバグに繋がるケースが数多くありました。

自分の中でルールを決めていたとしても、

  • 未来の自分が覚えてるかわからない
  • 他の人とルールが同じとは限らない
  • 極力バグを減らしたい

ので最近はassert文を多用しています。

assert文

Pythonの公式リファレンスによると

assert <条件式>

assert <条件式>, <エラーの時に表示したい文字列>

となります。

参考:
7.3. assert 文

用例

機械学習で何かしらの画像分類モデルを作ろうと思ったとき、入力画像を0~1に正規化する必要があるのに忘れていた、なんてことないでしょうか。

「値が0~1になってると思ってたけど、実際は0〜255だった〜〜!!」ってのはよく聞く話です。

assert文が1行だけでも入っていたら、これは防げます。

import numpy as np
from PIL import Image


img = np.array(Image.open("sample.png"))
img = img / 255.0

# 0~1に正規化されててほしい!(正規化されてなければ、エラーが出て処理が止まる)
assert (0.0 <= img).all() and (img <= 1.0).all()

# 画像が3次元(width, heigt, channels)なのを確かめたい場合
assert img.ndim == 3

# widthとheightは同じ値であってほしい! エラー文のカスタマイズはこんな感じ。
assert img.shape[0] == img.shape[1], "widthとheightの値は同じにしてね"

# チャンネル数はRGBの3チャンネルじゃなくてRGBAの4チャンネルであってほしい!
assert img.shape[2] == 4

まとめ

assert文多用しましょう!

友達にコード見てもらうときも、assert文があると必要条件が分かるのでスムーズに理解できそう。

追記

皆さん、多くのいいねありがとうございます!!
補足ですが、エラーは下記のように出力されます。

import numpy as np

tmp = np.arange(10).reshape(2,5)
print("tmp.shape=",tmp.shape)
assert tmp.shape == (2,6)

print("End")

出力

tmp.shape= (2, 5)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-3-b4bb1727ab0e> in <module>()
      3 tmp = np.arange(10).reshape(2,5)
      4 print("tmp.shape=",tmp.shape)
----> 5 assert tmp.shape == (2,6)
      6
      7 print("End")

AssertionError:

エラーが出た時点で実行が止まるので、安心です。