lmori's Library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub lmorinn/library

:heavy_check_mark: Point Set Range Frequency
(data-structure/wavelet-matrix/query/PointSetRangeFreq.hpp)

概要

todo

計算量

todo

Depends on

Verified with

Code

#include "../WaveletMatrixTemplate.hpp"

template <class S>
class PointSetRangeFreq {
   private:
    unordered_map<S, vector<unsigned>> m;
    unordered_map<S, unsigned> cnt;
    vector<vector<S>> q;
    unordered_map<S, bool> printq;
    unsigned n;

    unsigned set_query;
    unsigned output_query;
    vector<S> prev;

   public:
    PointSetRangeFreq(const vector<S> &v, unsigned query) {
        n = v.size();
        prev.resize(n);

        for (unsigned i = 0; i < n; i++) {
            m[v[i]].emplace_back(i);
            cnt[v[i]]++;
            prev[i] = v[i];
        }
        set_query = 0;
        output_query = 0;

        q = vector<vector<S>>(query, vector<S>(4));
    }

    void set(unsigned pos, S val) {
        unsigned idx = set_query + output_query;
        q[idx][0] = 0;
        q[idx][1] = pos;
        q[idx][2] = prev[pos];
        q[idx][3] = val;

        m[prev[pos]].emplace_back(pos + n);
        m[val].emplace_back(pos);
        set_query++;
        prev[pos] = val;
    }

    void range_freq(unsigned l, unsigned r, S val) {
        unsigned idx = set_query + output_query;
        q[idx][0] = 1;
        q[idx][1] = l;
        q[idx][2] = r;
        q[idx][3] = val;
        printq[val] = true;
        output_query++;
    }

    vector<unsigned> build() {
        for (int i = 0; i < set_query + output_query; i++) {
            if (q[i][0] == 0 and printq.contains(q[i][3])) {
                m[q[i][3]].emplace_back(q[i][1]);
                if (printq.contains(q[i][2])) {
                    m[q[i][2]].emplace_back(unsigned(q[i][1]) + n);
                }
            }
        }
        unordered_map<S, WaveletMatrix<unsigned>> wm;
        for (const pair<S, vector<unsigned>> &p : m) {
            if (!printq.contains(p.first)) continue;
            wm.emplace(make_pair(p.first, WaveletMatrix<unsigned>(p.second)));
        }
        vector<unsigned> ret(output_query);
        int idx = 0;
        for (int i = 0; i < set_query + output_query; i++) {
            S com = q[i][0];
            if (com == 0) {
                cnt[q[i][2]]++;
                cnt[q[i][3]]++;
            } else {
                ret[idx] = wm[q[i][3]].range_freq(0, cnt[q[i][3]], q[i][1], q[i][2]);
                ret[idx] -= wm[q[i][3]].range_freq(0, cnt[q[i][3]], n + q[i][1], n + q[i][2]);
                idx++;
            }
        }
        return ret;
    }
};
#line 1 "data-structure/wavelet-matrix/WaveletMatrixTemplate.hpp"
struct BitVector {
  unsigned sz;
  unsigned blocksize;
  vector<unsigned> bit, sum;

  BitVector() {}

  BitVector(unsigned siz) {
    sz = siz;
    blocksize = (sz + 31) >> 5;
    bit.assign(blocksize, 0U);
    sum.assign(blocksize, 0U);
  }

  inline void set(int k) { bit[k >> 5] |= 1U << (k & 31); }

  inline void build() {
    sum[0] = 0U;
    for (int i = 1; i < blocksize; i++) {
      sum[i] = sum[i - 1] + __builtin_popcount(bit[i - 1]);
    }
  }

  inline bool access(unsigned k) {
    return (bool((bit[k >> 5] >> (k & 31)) & 1));
  }

  inline int rank(int k) {
    return (sum[k >> 5] + __builtin_popcount(bit[k >> 5] & ((1U << (k & 31)) - 1)));
  }
};

template <class T>
class WaveletMatrix {
 private:
  unsigned n;
  unsigned bitsize;
  vector<BitVector> b;
  vector<unsigned> zero;
  vector<T> cmp;
  T MI, MA;

  inline unsigned compress(const T &x) {
    return lower_bound(cmp.begin(), cmp.end(), x) - begin(cmp);
  }

 public:
  // コンストラクタ
  WaveletMatrix() {}
  WaveletMatrix(const vector<T> &v) {
    MI = numeric_limits<T>::min();
    MA = numeric_limits<T>::max();
    n = v.size();
    cmp = v;
    sort(cmp.begin(), cmp.end());
    cmp.erase(unique(cmp.begin(), cmp.end()), cmp.end());
    vector<unsigned> compressed(n);
    vector<unsigned> tmpc(n);
    unsigned size_mx = v.size();
    for (unsigned i = 0; i < n; i++) {
      compressed[i] = compress(v[i]);
    }
    bitsize = bit_width(cmp.size());
    b.resize(bitsize);
    zero.assign(bitsize, 0);
    int cur = 0;

    for (unsigned i = 0; i < bitsize; i++) {
      b[i] = BitVector(n + 1);
      cur = 0;
      for (unsigned j = 0; j < n; j++) {
        if (compressed[j] & (1U << (bitsize - i - 1))) {
          b[i].set(j);
        } else {
          zero[i]++;
          tmpc[cur] = compressed[j];
          cur++;
        }
      }
      b[i].build();

      for (unsigned j = 0; j < n; j++) {
        if (compressed[j] & (1U << (bitsize - i - 1))) {
          tmpc[cur] = compressed[j];
          cur++;
        }
      }
      swap(tmpc, compressed);
    }
  }

  // get v[k]
  T access(unsigned k) {
    unsigned res = 0;
    unsigned cur = k;
    for (unsigned i = 0; i < bitsize; i++) {
      if (b[i].access(cur)) {
        res |= (1U << (bitsize - i - 1));
        cur = zero[i] + b[i].rank(cur);
      } else {
        cur -= b[i].rank(cur);
      }
    }
    return cmp[res];
  }

  // v[l,r) の中でk番目(1-origin)に小さい値を返す
  T kth_smallest(unsigned l, unsigned r, unsigned k) {
    unsigned res = 0;
    unsigned rank1_l, rank1_r, num0;
    for (unsigned i = 0; i < bitsize; i++) {
      rank1_l = b[i].rank(l);
      rank1_r = b[i].rank(r);
      num0 = r - l - (rank1_r - rank1_l);
      if (num0 < k) {
        res |= (1U << (bitsize - i - 1));
        l = zero[i] + rank1_l;
        r = zero[i] + rank1_r;
        k -= num0;
      } else {
        l -= rank1_l;
        r -= rank1_r;
      }
    }
    return cmp[res];
  }

  // v[l,r) の中でk番目(1-origin)に大きい値を返す
  T kth_largest(unsigned l, unsigned r, unsigned k) {
    return kth_smallest(l, r, r - l - k + 1);
  }

  // v[l,r) の中で[mink,maxk)に入る値の個数を返す
  unsigned range_freq(int vl, int vr, T mink, T maxk) {
    int D = compress(mink);
    int U = compress(maxk);
    unsigned res = 0;
    auto dfs = [&](auto &rec, int d, int L, int R, int A, int B) -> void {
      if (U <= A or B <= D) return;
      if (D <= A and B <= U) {
        res += (R - L);
        return;
      }
      if (d == bitsize) {
        if (D <= A and A < U) {
          res += (R - L);
        }
        return;
      }
      int C = (A + B) / 2;
      int rank0_l = L - b[d].rank(L);
      int rank0_r = R - b[d].rank(R);
      int rank1_l = b[d].rank(L) + zero[d];
      int rank1_r = b[d].rank(R) + zero[d];
      rec(rec, d + 1, rank0_l, rank0_r, A, C);
      rec(rec, d + 1, rank1_l, rank1_r, C, B);
    };
    dfs(dfs, 0, vl, vr, 0, 1 << bitsize);
    return res;
  }

	// v[l,r)の中でval未満の要素のうち最大の値を返す
  T prev_value(unsigned l, unsigned r, T val) {
    int num = range_freq(l, r, MI, val);
    if (num == 0) {
      return MA;
    } else {
      return kth_smallest(l, r, num);
    }
  }

  // v[l,r)の中でvalより大きい要素のうち最小の値を返す
  T next_value(unsigned l, unsigned r, T val) {
    int num = range_freq(l, r, MI, val + 1);
    if (num == r - l) {
      return MI;
    } else {
      return kth_smallest(l, r, num + 1);
    }
  }
};
#line 2 "data-structure/wavelet-matrix/query/PointSetRangeFreq.hpp"

template <class S>
class PointSetRangeFreq {
   private:
    unordered_map<S, vector<unsigned>> m;
    unordered_map<S, unsigned> cnt;
    vector<vector<S>> q;
    unordered_map<S, bool> printq;
    unsigned n;

    unsigned set_query;
    unsigned output_query;
    vector<S> prev;

   public:
    PointSetRangeFreq(const vector<S> &v, unsigned query) {
        n = v.size();
        prev.resize(n);

        for (unsigned i = 0; i < n; i++) {
            m[v[i]].emplace_back(i);
            cnt[v[i]]++;
            prev[i] = v[i];
        }
        set_query = 0;
        output_query = 0;

        q = vector<vector<S>>(query, vector<S>(4));
    }

    void set(unsigned pos, S val) {
        unsigned idx = set_query + output_query;
        q[idx][0] = 0;
        q[idx][1] = pos;
        q[idx][2] = prev[pos];
        q[idx][3] = val;

        m[prev[pos]].emplace_back(pos + n);
        m[val].emplace_back(pos);
        set_query++;
        prev[pos] = val;
    }

    void range_freq(unsigned l, unsigned r, S val) {
        unsigned idx = set_query + output_query;
        q[idx][0] = 1;
        q[idx][1] = l;
        q[idx][2] = r;
        q[idx][3] = val;
        printq[val] = true;
        output_query++;
    }

    vector<unsigned> build() {
        for (int i = 0; i < set_query + output_query; i++) {
            if (q[i][0] == 0 and printq.contains(q[i][3])) {
                m[q[i][3]].emplace_back(q[i][1]);
                if (printq.contains(q[i][2])) {
                    m[q[i][2]].emplace_back(unsigned(q[i][1]) + n);
                }
            }
        }
        unordered_map<S, WaveletMatrix<unsigned>> wm;
        for (const pair<S, vector<unsigned>> &p : m) {
            if (!printq.contains(p.first)) continue;
            wm.emplace(make_pair(p.first, WaveletMatrix<unsigned>(p.second)));
        }
        vector<unsigned> ret(output_query);
        int idx = 0;
        for (int i = 0; i < set_query + output_query; i++) {
            S com = q[i][0];
            if (com == 0) {
                cnt[q[i][2]]++;
                cnt[q[i][3]]++;
            } else {
                ret[idx] = wm[q[i][3]].range_freq(0, cnt[q[i][3]], q[i][1], q[i][2]);
                ret[idx] -= wm[q[i][3]].range_freq(0, cnt[q[i][3]], n + q[i][1], n + q[i][2]);
                idx++;
            }
        }
        return ret;
    }
};
Back to top page