RUPC 2016 Day3 D: Complex Oracle (O(N(logN)^2) 解法)

これの高速化ver
kenkoooo.hatenablog.com

解法

BIT で抜けた数を持っておき、「今残っている数の中でmid番目」を当てる二分探索をする。

コード

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template <typename T>
std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
  if (!v.empty()) {
    out << '[';
    std::copy(v.begin(), v.end(), std::ostream_iterator<T>(out, ", "));
    out << "\b\b]";
  }
  return out;
}
template <class T, class U>
void chmin(T &t, U f) {
  if (t > f) t = f;
}
template <class T, class U>
void chmax(T &t, U f) {
  if (t < f) t = f;
}

template <typename T>
class FenwickTree {
 private:
  int N;
  vector<T> dat;

 public:
  FenwickTree(int N) : N(N) { dat.assign(N, 0); }

  void add(int k, T val) {
    for (int x = k; x < N; x |= x + 1) {
      dat[x] += val;
    }
  }

  // [0, k)
  T sum(int k) {
    if (k >= N) k = N - 1;
    T ret = 0;
    for (int x = k - 1; x >= 0; x = (x & (x + 1)) - 1) {
      ret += dat[x];
    }
    return ret;
  }

  // [l, r)
  T sum(int l, int r) { return sum(r) - sum(l); }

  T get(int k) {
    assert(0 <= k && k < N);
    return sum(k + 1) - sum(k);
  }
};

// Testcase generator
vector<int> test() {
  int N = 10;
  vector<int> v(N), ret(N, 0);
  for (int i = 0; i < N; ++i) v[i] = i + 1;

  std::random_device rd;
  std::mt19937 g(rd());
  shuffle(v.begin(), v.end(), g);
  for (int i = 0; i < N; ++i) {
    for (int j = i + 1; j < N; ++j) {
      if (v[i] > v[j]) ret[i]++;
    }
  }
  return ret;
}

int main() {
  cin.tie(0);
  ios::sync_with_stdio(false);

  int N;
  cin >> N;
  vector<int> point(N, 0);

  printf("? %d %d\n", 1, N);
  fflush(stdout);
  ll prev;
  cin >> prev;

  for (int i = 1; i < N; ++i) {
    printf("? %d %d\n", (i + 1), N);
    fflush(stdout);

    ll resp;
    cin >> resp;
    ll diff = prev - resp;
    point[i - 1] = diff;
    prev = resp;
  }

  FenwickTree<int> bit(N);
  printf("!");
  for (int i = 0; i < N; ++i) {
    int high = N * 2, low = 0;
    while (high > low) {
      int mid = (high + low) / 2;
      int v = mid - bit.sum(mid + 1);
      if (v < point[i]) {
        low = mid + 1;
      } else {
        high = mid;
      }
    }

    printf(" %d", low + 1);
    bit.add(low, 1);
  }
  printf("\n");
  fflush(stdout);
}