Hatena::ブログ(Diary)

shi3zの長文日記 RSSフィード

2013-05-02

よくわかる最適化

https://fbcdn-sphotos-a-a.akamaihd.net/hphotos-ak-ash3/577651_10152310170600752_582578387_n.jpg

 enchantMOONのソフト開発は専らMacで行っています。

 ついに僕もソースコードを確認しないと気が済まなくなりました。本当にちゃんと最適化してるのか、自分の目と頭とで確認しています。


 もっと高速にできる方法があったとしても、それを試すと今度は不安定になったり、最適化は本当にギリギリのバランスの取り合いです。


 最初のリリースで落ちまくるのは嫌なので、少し保守的な設定にすることにします。


 けど、それで遅くなってしまうのもまた嫌なので、高速度カメラで何度も撮影しながら、何度も何度もセッティングを試します。


 これはしかし本当に終わりなき果てしない世界です。

 ちょっとenchantMOONでどのような最適化をしているか簡単に紹介しましょう。


 例えば入力された情報を、完全にそのまま取得するとやはりどうしても粗くなる部分があるので、入力した情報に補間を掛けるわけですが、この補間を掛ける関数にもいくつかの種類があります。


 パッと思いつくだけでも、ベジェ補間、スプライン補間などがあり、たとえば人気の高いベジェ補間の場合は、滑らかな線が得られるものの、得られた情報の半分を喪ってしまいます。


 全ての点を通るスプライン補間を使おうとすると、スプライン補間だけでも沢山の補間方式があります。

 ポピュラーなCatmull Romスプライン曲線の式表現はこんな感じです。



http://upload.wikimedia.org/math/f/d/1/fd17d4ec345a67bc5520b45a641f28ec.png

出典:http://en.wikipedia.org/wiki/Cubic_Hermite_spline



 これはエルミート曲線の一種で、Pixarの創業者であるエド・カットマルとラフェエル・ロムが考案した公式です。

 これをプログラムで表現すると

    public static float catmullRom(float p0, float p1, float p2, float p3, float t){
        float v0 = (p2 - p0) / 2.0f ;
        float v1 = (p3 - p1) / 2.0f ;
        return       ((2.0f * p1 - 2.0f * p2) + v0 + v1) * t * t * t + 
                        ((-3.0f * p1 + 3.0f * p2) - 2.0F * v0 - v1) * t * t + v0 * t + p1;
    }

 こんな感じになります。

 この関数は、一見すると無駄がないように見えますが、実は無駄の塊です。

 たとえばコンピュータは割り算が苦手です。なのに二回も割り算をやっています。

 割り算よりは掛け算(乗算)の方が圧倒的に速い(下手すると10〜100倍くらい?)ので、まず割り算をしている部分を2.0の逆数の乗算に変更します。


    public static float catmullRom(float p0, float p1, float p2, float p3, float t){
        float v0 = (p2 - p0) * 0.5f ; //2で割るのは0.5を掛けるのと同じ
        float v1 = (p3 - p1) * 0.5f ;
        return       ((2.0f * p1 - 2.0f * p2) + v0 + v1) * t * t * t + 
                        ((-3.0f * p1 + 3.0f * p2) - 2.0f * v0 - v1) * t * t + v0 * t + p1;
    }

 次に、tをたくさん掛け算している部分があります。tの自乗ばかりで勿体ないので、tをまとめます。


    public static float catmullRom(float p0, float p1, float p2, float p3, float t){
        float v0 = (p2 - p0) * 0.5f ; //2で割るのは0.5を掛けるのと同じ
        float v1 = (p3 - p1) * 0.5f ;
        float t2 = t*t;   //tの自乗
        float t3 = t2*t;// tの三乗
        return       ((2.0f * p1 - 2.0f * p2) + v0 + v1) * t3 + 
                        ((-3.0f * p1 + 3.0f * p2) - 2.0f * v0 - v1) * t2 + v0 * t + p1;
    }

 これでtの掛け算の数が一回減りました(さっきは5回、今回は2回)。

 しかしこの式、よく見るとなにか変じゃないですか?

 そう、(2.0*p1 - 2.0*p2)と、二回も無駄な掛け算をしているのです。これは、式変形すると(p1-p2)*2.0にできますね。その下の(-3*p1+3*p2)も無駄です。

    public static float catmullRom(float p0, float p1, float p2, float p3, float t){
        float v0 = (p2 - p0) * 0.5f ; //2で割るのは0.5を掛けるのと同じ
        float v1 = (p3 - p1) * 0.5f ;
        float t2 = t*t;   //tの自乗
        float t3 = t2*t;// tの三乗
        return       ((p1 - p2)*2.0f + v0 + v1) * t3 + 
                        ((p2 - p1)*3.0f - 2.0f * v0 - v1) * t2 + v0 * t + p1;
    }

 これで最初は乗算11回、除算2回だったのが、乗算10回のみになりました。

 しかしこれは、単一の点を補間するだけの関数なので、実際には、これがx座標、y座標、そして圧力の三つのパラメータに関わることになります。

 この関数の呼び出し元はどうなっているかというと

            float d = distance(lastX[1], lastY[1], lastX[2], lastY[2]);
            int num = (int)Math.ceil((double)(d / 5.0d) + 0.5d);
            float x,y,p;
            for(int i = 0; i < num; i++){
                x = catmullRom(lastX[0], lastX[1], lastX[2], lastX[3], (float)i / (float)num);
                y = catmullRom(lastY[0], lastY[1], lastY[2], lastY[3], (float)i / (float)num);
                p = catmullRom(lastP[0], lastP[1], lastP[2], lastP[3], (float)i / (float)num);
            	stroke.add(x,y,p);
            }

 これは関数展開のチャンスです。

 まず、ここでもまたnumの除算を三回もしています。

 これは無駄なので、numの除算を外に出します。


            float d = distance(lastX[1], lastY[1], lastX[2], lastY[2]);
            int num = (int)Math.ceil((double)(d / 5.0d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            for(int i = 0; i < num; i++){
                x = catmullRom(lastX[0], lastX[1], lastX[2], lastX[3], (float)i*invertNum);
                y = catmullRom(lastY[0], lastY[1], lastY[2], lastY[3], (float)i*invertNum);
                p = catmullRom(lastP[0], lastP[1], lastP[2], lastP[3], (float)i*invertNum);
            	stroke.add(x,y,p);
            }

 3回もあった除算を乗算に変換することができました。これだけでかなりマシになります。

 が・・・あれ?しかしこここでinvertNumが何に使われているかに注目してください。

 これはループカウンタのiに乗じられているだけですね。

 ということは、そもそも、乗算そのものも要らないということになります。

 加算は乗算よりもさらに高速なので、こんな風にしてみましょう。


            float d = distance(lastX[1], lastY[1], lastX[2], lastY[2]);
            int num = (int)Math.ceil((double)(d / 5.0d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                x = catmullRom(lastX[0], lastX[1], lastX[2], lastX[3], deltaT);
                y = catmullRom(lastY[0], lastY[1], lastY[2], lastY[3], deltaT);
                p = catmullRom(lastP[0], lastP[1], lastP[2], lastP[3], deltaT);
            	stroke.add(x,y,p);
            }

 新たにdeltaTという変数を導入し、単にループ毎にinvertNumを加算するだけにしてみました。

 なんと、これでループ内部から乗算が完全に消えました。しかも三回もしていた乗算がたった一回の足し算になってしまいました。これはラッキーです。


 しかし最適化は止まるところを知りません。

 そもそもこのnumというのはなんなんでしょうか。


 よく見ると、numはdistanceという関数で得られた二点間の距離をもとに作られた係数のようです。

 distanceの中身を見ると、

    public static float distance(float x1, float y1, float x2, float y2){
        return (float)Math.sqrt((x1 - x2) * (x1 - x2) + (y1 - y2) * (y1 - y2));
    }

 まあご想像通り、平方根(Math.sqrt)が使われています。

 とりあえず、一旦これを関数展開してみたいと思いますが、その前に、lastX、lastY、lastPという三つの配列は、役割が変わりません。ということは配列に毎回アクセスするのは無駄ですから、思い切って、xp、yp、ppという三つの変数の後ろに0から4の数字をつけたものに変えてみましょう。


            float d =(float)Math.sqrt((xp1 - xp2) * (xp1 - xp2) + (yp1 - yp2) * (yp1 - yp2)); 
            int num = (int)Math.ceil((double)(d / 5.0d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                x = catmullRom(xp0,xp1,xp2,xp3, deltaT);
                y = catmullRom(yp0,yp1,yp2,yp3, deltaT);
                p = catmullRom(pp0,pp1,pp2,pp3, deltaT);
            	stroke.add(x,y,p);
            }

 二回も同じ引き算をしているのがイライラするので、変数をわけてみましょう。



            float dx = xp1-xp2;
            float dy = yp1-yp2;
            float d =(float)Math.sqrt(dx*dx+dy*dy); 
            int num = (int)Math.ceil((double)(d / 5.0d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                x = catmullRom(xp0,xp1,xp2,xp3, deltaT);
                y = catmullRom(yp0,yp1,yp2,yp3, deltaT);
                p = catmullRom(pp0,pp1,pp2,pp3, deltaT);
            	stroke.add(x,y,p);
            }

 ああ、すっきりした。これで引き算が二回減りました。引き算は足し算と同じコストです。

 しかしねー、そもそも平方根って気に入らないですよね。その根性が。明らかに割り算よりは遅いはずです。


 しかもこの平方根、やってることと言えば、距離に応じてnumの値を上げ下げしてるだけですよ。

 だったら、距離の自乗と対応するnumの対応表を作れば、高速化できるじゃないですか。


 そこでまず表計算ソフトを立ち上げます。

 表計算ソフトは思考ツールとして至高のもののひとつですよ。至高の思考ツールですよ。ええい主を呼べ(もういいって)


f:id:shi3z:20130502185902p:image

 すると、あはーん、係数(num)の上がり具合と距離の自乗の微妙な関係が見えてきました。

 平方根をどうしてもやりたくなかったら、これを力技でどうにかすればいいのです。


            float dx = xp1-xp2;
            float dy = yp1-yp2;
            float d2 =dx*dx+dy*dy; //平方根をとらずに距離の自乗のまま 

      if(d2<9)num=1;
      else if(d<64)num=2;
      else if(d<169)num=3;
      else if(d<324)num=4;
      else if(d<529)num=5;
      else if(d<784)num=6;
      else if(d<1089)num=7;

    // ・・・・ 省略

      else if(d<74529)num=55;
      else if(d<77284)num=56;

            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                x = catmullRom(xp0,xp1,xp2,xp3, deltaT);
                y = catmullRom(yp0,yp1,yp2,yp3, deltaT);
                p = catmullRom(pp0,pp1,pp2,pp3, deltaT);
            	stroke.add(x,y,p);
            }

 はあはあ、ぜいぜい。

 これで平方根を殺してやったぜ。


 ところがこれが大きな罠。FPUが標準装備される以前のロートルプログラマー爺たちを陥れる大いなる罠ですぜ、これは。

 当然、いくら平方根が遅いと言っても、最近のCPUには平方根命令が備わっています。

 いくら遅くても、分岐するよりはずっとマシなのです。


 そう、プログラムにとって、最も怖れるべきは分岐。

 プログラマーからは分岐は祟りのように怖れられています。くわばらくわばら。


 では続いてやるべきは何か?

 そう、関数展開です。

 関数呼び出しも立派な分岐。

 ループの中で三回も関数を呼んじゃいけませんぜ、お嬢さん。


            float dx = xp1-xp2;
            float dy = yp1-yp2;
            float d =(float)Math.sqrt(dx*dx+dy*dy); 
            int num = (int)Math.ceil((double)(d*0.2d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                float t2 = deltaT*deltaT;
                float t3 = t2*deltaT;
                
                //まずはxだけ関数展開する
                float xv0 = (xp2-xp0)*0.5f;
                float xv1 = (xp3-xp1)*0.5f;
                x = ((xp1 - xp2)*2.0f + xv0 + xv1) * t3 + 
                        ((xp2 - xp1)*3.0f - 2.0f * xv0 - xv1) * t2 + xv0 * deltaT + xp1;

                y = catmullRom(yp0,yp1,yp2,yp3, deltaT);
                p = catmullRom(pp0,pp1,pp2,pp3, deltaT);
            	stroke.add(x,y,p);
            }

 とりあえずxだけ関数展開すると、おおっと、tはdeltaTだから、t2(deltaTの自乗)やt3(deltaTの三乗)は三つの関数で使い回しできるぜ!ということに気付くはずです。これはゾクゾクもの。最適化の醍醐味のひとつです。最適化はある種、耳掃除に似ていて、時にはすごい大物をゲットしたりできるわけです。

 

 さあここでxv0とxv1はループ中ずっと同じ計算をしてることに気付いた人はいるかな?

 そう、ループの中で同じ計算を何度もするなんていうのは愚の骨頂。

 こんなものはさっさとループの外に追い出します。




            float dx = xp1-xp2;
            float dy = yp1-yp2;
            float d =(float)Math.sqrt(dx*dx+dy*dy); 
            int num = (int)Math.ceil((double)(d*0.2d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            float xv0 = (xp2-xp0)*0.5f;
             float xv1 = (xp3-xp1)*0.5f;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                float t2 = deltaT*deltaT;
                float t3 = t2*deltaT;
                
                //まずはxだけ関数展開する
                x = ((xp1 - xp2)*2.0f + xv0 + xv1) * t3 + 
                        ((xp2 - xp1)*3.0f - 2.0f * xv0 - xv1) * t2 + xv0 * deltaT + xp1;

                y = catmullRom(yp0,yp1,yp2,yp3, deltaT);
                p = catmullRom(pp0,pp1,pp2,pp3, deltaT);
            	stroke.add(x,y,p);
            }

 さあこれでいいかなー?

 いや、よーく見てみて下さい。

 よく見ると、((xp1 - xp2)*2.0f + xv0 + xv1)と((xp2 - xp1)*0.3f - 2.0f * xv0 - xv1)も、ループの中で特に変化がありません。

 この二つもループの外に出してしまいましょう。


            float dx = xp1-xp2;
            float dy = yp1-yp2;
            float d =(float)Math.sqrt(dx*dx+dy*dy); 
            int num = (int)Math.ceil((double)(d *0.2d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            float xv0 = (xp2-xp0)*0.5f;
            float xv1 = (xp3-xp1)*0.5f;
            float xfact1=((xp1 - xp2)*2.0f + xv0 + xv1);
            float xfact2=((xp2 - xp1)*3.0f - 2.0f * xv0 - xv1) ;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                float t2 = deltaT*deltaT;
                float t3 = t2*deltaT;
                
                //まずはxだけ関数展開する
                x =  xfact1* t3 + 
                       xfact2 * t2 + xv0* deltaT + xp1;

                y = catmullRom(yp0,yp1,yp2,yp3, deltaT);
                p = catmullRom(pp0,pp1,pp2,pp3, deltaT);
            	stroke.add(x,y,p);
            }

 わあ、スッキリ!これは気持ちいい。

 調子に乗ってほかのyとpも展開しましょう。


            float dx = xp1-xp2;
            float dy = yp1-yp2;
            float d =(float)Math.sqrt(dx*dx+dy*dy); 
            int num = (int)Math.ceil((double)(d*0.2d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            float xv0 = (xp2-xp0)*0.5f;
            float xv1 = (xp3-xp1)*0.5f;
            float xfact1=((xp1 - xp2)*2.0f + xv0 + xv1);
            float xfact2=((xp2 - xp1)*3.0f - 2.0f * xv0 - xv1) ;
            float yv0 = (yp2-yp0)*0.5f;
            float yv1 = (yp3-yp1)*0.5f;
            float yfact1=((yp1 - yp2)*2.0f + yv0 + yv1);
            float yfact2=((yp2 - yp1)*3.0f - 2.0f * yv0 - yv1) ;
            float pv0 = (pp2-pp0)*0.5f;
            float pv1 = (pp3-pp1)*0.5f;
            float pfact1=((pp1 - pp2)*2.0f + pv0 + pv1);
            float pfact2=((pp2 - pp1)*3.0f - 2.0f * pv0 - pv1);
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                float t2 = deltaT*deltaT;
                float t3 = t2*deltaT;
                
                //まずはxだけ関数展開する
                x =  xfact1* t3 + 
                       xfact2*t2+xv0 * deltaT + xp1;
                y =  yfact1* t3 + 
                       yfact2*t2 + yv0 * deltaT + yp1;
                p =  pfact1* t3 + 
                       pfact2*t2 + pv0 * deltaT + pp1;

            	stroke.add(x,y,p);
            }

 なんということでしょう!

 最初のループの場合、関数呼び出しが三回、関数の内部での乗算が10回なので、乗算30回が1ループにつき必要でした。


 ところが、この最適化を行うことによって、ループ内部は乗算わずか11回*1になりました。

 約3倍の高速化です。やったぜ、万歳。

 いやー、しかし最適化はまだまだでした。

 @shinhさんのコード(https://twitter.com/shinh/status/330179788457259009)では、さらに乗算を6回まで減らしています。



            float dx = xp1-xp2;
            float dy = yp1-yp2;
            float d =(float)Math.sqrt(dx*dx+dy*dy); 
            int num = (int)Math.ceil((double)(d*0.2d) + 0.5d);
            float x,y,p;
            float invertNum = 1.0f/num; // numの逆数を定義
            float deltaT = 0;
            float xv0 = (xp2-xp0)*0.5f;
            float xv1 = (xp3-xp1)*0.5f;
            float xfact1=((xp1 - xp2)*2.0f + xv0 + xv1);
            float xfact2=((xp2 - xp1)*3.0f - 2.0f * xv0 - xv1) ;
            float yv0 = (yp2-yp0)*0.5f;
            float yv1 = (yp3-yp1)*0.5f;
            float yfact1=((yp1 - yp2)*2.0f + yv0 + yv1);
            float yfact2=((yp2 - yp1)*3.0f - 2.0f * yv0 - yv1) ;
            float pv0 = (pp2-pp0)*0.5f;
            float pv1 = (pp3-pp1)*0.5f;
            float pfact1=((pp1 - pp2)*2.0f + pv0 + pv1);
            float pfact2=((pp2 - pp1)*3.0f - 2.0f * pv0 - pv1);
            float xfact1n =0;
            float yfact1n =0;
            float pfact1n =0;
            float xFact1step = xfact1 * invertNum;
            float yFact1step = yfact1 * invertNum;
            float pFact1step = pfact1 * invertNum;
            for(int i = 0; i < num; i++, deltaT+=invertNum){
                x =  ((xfact1n + xfact2) * deltaT + xv0) * deltaT + xp1;
                y =  ((yfact1n + yfact2) * deltaT + yv0) * deltaT + yp1;
                p =  ((pfact1n + pfact2) * deltaT + pv0) * deltaT + pp1;
                xfact1n += xFact1step;
                yfact1n += xFact1step;
                pfact1n += xFact1step;
                stroke.add(x,y,p);
            }

 凄いですねー。さすがにこれ以上は無理かな?まだありそうな気もするんですけど。


 実際のコードではさらにアクロバティックな工夫を繰り返しています。

 理論上はこのやり方が高速なはず・・・でも実際に動かすとそうでもない・・・最適化の開発現場はそんなことの連続です。



 こういう数式展開レベルの最適化は、どれだけコンパイラが賢くなってもまだまだやってもらえません。

 最適化とは何かを説明するときに、僕はよく幼き日のカール・フリードリヒ・ガウスのエピソードを引用します。

 後に天才数学者と呼ばれるようになるガウスは、小学校のとき、あまりにも計算を早く終えてしまうので、先生がどうでもいいやとばかりに「1から100まで足して来なさい」と言いました。

 ところがガウスは、先生が考えるよりずっと早く、答えを持ってきました。ガウスの計算速度をもってしても、1から100まで、100回の足し算をしたというにはあまりにも速すぎます。

 ガウスはどうやって、この答えを導いたのでしょうか。


 これ、有名なエピードですが、最適化とはまさにこの少年の日のガウスがやったようなことです。

 ガウスは1から100まで足すということは、100から1まで足すということの半分の作業である、と考えました。

 すると、この問題は1+2+3+・・・99+100という足し算と、100+99+・・・+3+2+1を両方足すと、常に101+101+・・・・+101+101という、101の100倍に過ぎないことがわかります。101を100倍して、2で割れば、求める答えがでるというわけです。


 100回の足し算を1回の乗算と除算にしてしまった天才ガウス。彼こそまさに最適化の神様と呼べるかもしれません。


 ちなみにしばらく椅子に座って最適化をしていたせいで、肩こりと腰痛がしてきました。

 肩に貼ったピップエレキバン

 1300ガウスのガウスは、もちろんそのガウスが由来です。



追記(5/3)

 「この程度はDalvikのJITで最適化されるんじゃないの?」

 「コンパイラが最適化するでしょ?」

 という、Twitterに意見がちらほらありました。

 それはあまりにもコンパイラの最適化に期待し過ぎです。実際に吐き出したコードを読んでみましょう。あなたがコンパイラの作者だったら、あなたがJITの作者だったら、入って来たコードから同じような最適化ができるでしょうか。まず無理です。どんな高度な最適化コンパイラも、所詮は人間の作ったコードです。コンパイラは神ではないのです。あくまでも人間の創りだした不完全な道具のひとつに過ぎません。


 論より証拠を見せましょう。

 もとのコードの場合

f:id:shi3z:20130503094836p:image

 こんな感じで23-27ミリ秒です。ちなみに本当はミリ秒切ってるので1000回ループしています。

 つまり本当は0.027ミリ秒=27マイクロ秒です。

 Coretex-A8に、DalvikVM(JIT付き)で動作させています。

 では手動で最適化するとどうなったか

f:id:shi3z:20130503095014p:image

 ご覧のように、6ミリ秒(実際は6マイクロ秒)まで短縮されました。約3.8倍の高速化です。

 乗算が30回から11回に減った比率と近いことに注意してください。


 CPUやJITの気持ちになって考えれば、「この乗算を省いていいのかどうか俺にわかるわけないだろ」ということです。そういうことは人間が判断するものです。


 ただし、これは副作用の少ない局所最適化なのでこういうわかりやすい計測ができますが、本当の最適化は全体最適なので、もっと東洋医学っぽいというか、漢方みたいな感じです。勘に頼って複数モジュール感のタイミングをバッチリあわせるとか、そういうレベルです。


 また、ベンチマークをとるときに気付きましたが、初稿では最適化後のコードがバグっていたので修正しました。

 最適化のテクニックは、プログラミングだけではなく人生の様々な局面に応用できます。

 繰り返し同じことをしているとき、最適化できないか考えてみましょう。

*1:最初8個と書きましたが数え間違っていました