セグメントツリーは大変有益なデータ構造なので, 汎用的に作った.
- セグメントツリーのノードの値型Tをとる
- 子ノードから親ノードを計算する方法をOpとして抽象化する(型Op)
- 今回はRMQとSum(prefix sum)の実装をした.
- queryは, logNで得られたノードからOpを使ってさらに答えを計算する. これは問題によっては使えないこともあるので, 範囲クエリを分解するsplitを提供することにした. 返り値vectorの大きさはlogNなので, 分離することのオーバーヘッドは無視出来る.
- もしこれで間に合わないのであればその場でオーダーメイドする.
#include <bits/stdc++.h> using namespace std; template<typename T> class RMQ { public: T operator()(T left, T right) { return min(left, right); } T init() { return 1 << 20; } }; template<typename T> class Sum { public: T operator()(T left, T right) { return left + right; } T init() { return 0; } }; struct RL { int ind, i, j, low, high; }; vector<int> do_split(int i, int j, int low, int high){ vector<int> result; queue<RL> q; // bfs q.push({0, i, j, low, high}); while (!q.empty()) { RL t = q.front(); q.pop(); // cout << t.ind << ":" << t.i << " " << t.j << " " << t.low << " " << t.high << endl; if (t.i == t.low && t.j == t.high) { // perfect match result.push_back(t.ind); } else if (t.i < t.j && // non-zero query range t.low+1 != t.high) { // can split more int mid = (t.low + t.high) / 2; q.push({t.ind*2+1, t.i, min(t.j, mid), t.low, mid}); q.push({t.ind*2+2, max(t.i, mid), t.j, mid, t.high}); } } return result; } template<typename T, typename Op> class SegTree { public: int n; vector<T> data; T t0; SegTree(){}; SegTree(int n_) { int n = n_; int m = 1; while (m < n) { m = m << 1; } data = vector<T>(m*2); int i; for (i=0; i<data.size(); i++) { data[i] = Op().init(); } } void update(int i, T x) { int k = (data.size() / 2) + i - 1; data[k] = x; while (k>0) { k = (k-1) / 2; T t = Op()(data[k*2+1], data[k*2+2]); data[k] = t; } } // [i, j) vector<int> split(int i, int j){ return do_split(i, j, 0, data.size()/2); } T operator[](int i){ return data[i]; } T query(int i, int j){ T x = Op().init(); for(int e: split(i, j)){ x = Op()(x, data[e]); } return x; } void p() { int i; for(i=0; i<data.size(); i++){ cout << data[i] << " "; } cout<<endl; } }; typedef SegTree<int, RMQ<int> > ST; // typedef SegTree<int, Sum<int> > ST; ST st; void test(int i, int j) { cout<<i<<","<<j<<":"; for(int e: st.split(i, j)) { cout << e; } cout << "=>" << st.query(i,j); cout<<endl; } int main() { st = ST(3); st.update(0, 3); st.update(1, 2); st.p(); test(0,1); test(0,3); test(0,4); test(2,3); test(2,4); test(3,4); return 0; }