購読中です 読者をやめる 読者になる 読者になる

テストステ論

高テス協会会長が, テストステロンに関する情報をお届けします.

セグメントツリーを実装しました

セグメントツリーは大変有益なデータ構造なので, 汎用的に作った.

  • セグメントツリーのノードの値型Tをとる
  • 子ノードから親ノードを計算する方法をOpとして抽象化する(型Op)
  • 今回はRMQとSum(prefix sum)の実装をした.
  • queryは, logNで得られたノードからOpを使ってさらに答えを計算する. これは問題によっては使えないこともあるので, 範囲クエリを分解するsplitを提供することにした. 返り値vectorの大きさはlogNなので, 分離することのオーバーヘッドは無視出来る.
  • もしこれで間に合わないのであればその場でオーダーメイドする.
#include <bits/stdc++.h>
using namespace std;

template<typename T> class RMQ {
public:
        T operator()(T left, T right) { return min(left, right); }
        T init() { return 1 << 20; }
};

template<typename T> class Sum {
public:
        T operator()(T left, T right) { return left + right; }
        T init() { return 0; }
};

struct RL {
        int ind, i, j, low, high;
};
vector<int> do_split(int i, int j, int low, int high){
        vector<int> result;
        queue<RL> q; // bfs
        q.push({0, i, j, low, high});
        while (!q.empty()) {
                RL t = q.front(); q.pop();
                // cout << t.ind << ":" << t.i << " " << t.j << " " << t.low << " " << t.high << endl;
                if (t.i == t.low && t.j == t.high) { // perfect match
                        result.push_back(t.ind);
                } else if (t.i < t.j && // non-zero query range
                           t.low+1 != t.high) { // can split more
                        int mid = (t.low + t.high) / 2;
                        q.push({t.ind*2+1, t.i, min(t.j, mid), t.low, mid});
                        q.push({t.ind*2+2, max(t.i, mid), t.j, mid, t.high});
                }
        }
        return result;
}

template<typename T, typename Op> class SegTree {
public:
        int n;
        vector<T> data;
        T t0;
        SegTree(){};
        SegTree(int n_) {
                int n = n_;
                int m = 1;
                while (m < n) { m = m << 1; }
                data = vector<T>(m*2);
                int i;
                for (i=0; i<data.size(); i++) { data[i] = Op().init(); }
        }
        void update(int i, T x) {
                int k = (data.size() / 2) + i - 1;
                data[k] = x;
                while (k>0) {
                        k = (k-1) / 2;
                        T t = Op()(data[k*2+1], data[k*2+2]);
                        data[k] = t;
                }
        }
        // [i, j)
        vector<int> split(int i, int j){
                return do_split(i, j, 0, data.size()/2);
        }
        T operator[](int i){
                return data[i];
        }
        T query(int i, int j){
                T x = Op().init();
                for(int e: split(i, j)){
                        x = Op()(x, data[e]);
                }
                return x;
        }

        void p() {
                int i;
                for(i=0; i<data.size(); i++){
                        cout << data[i] << " ";
                }
                cout<<endl;
        }
};


typedef SegTree<int, RMQ<int> > ST;
// typedef SegTree<int, Sum<int> > ST;
ST st;

void test(int i, int j)
{
        cout<<i<<","<<j<<":";
        for(int e: st.split(i, j)) {
                cout << e;
        }
        cout << "=>" << st.query(i,j);
        cout<<endl;
}
int main()
{
        st = ST(3);
        st.update(0, 3);
        st.update(1, 2);
        st.p();
        test(0,1);
        test(0,3);
        test(0,4);
        test(2,3);
        test(2,4);
        test(3,4);
        return 0;
}