AtCoder Regular Contest 048 D: たこ焼き屋とQ人の高橋君

コード

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

struct HL {
  vector<vector<int>> g;
  vector<int> vertexID, depth, head, parent, size;
  int u, v, k;
  HL(const vector<vector<int>> &adj) {
    g = adj;
    vertexID.resize(adj.size());
    depth.resize(adj.size());
    head.resize(adj.size());
    parent.resize(adj.size());
    size.assign(adj.size(), 1);
    k = 0;

    dfs(0, -1);
    dfs2(0, -1);
  }

  void dfs(int curr, int prev) {
    head[curr] = curr;
    parent[curr] = prev;
    for (int &next : g[curr]) {
      if (next == prev) continue;
      depth[next] = depth[curr] + 1;
      dfs(next, curr);
      size[curr] += size[next];
      int &h = g[curr][0];
      if (h == prev || size[h] < size[next]) swap(h, next);
    }
  }

  void dfs2(int curr, int prev) {
    vertexID[curr] = k++;
    if (g[curr].size() > 1) head[g[curr][0]] = head[curr];
    for (int next : g[curr])
      if (next != prev) dfs2(next, curr);
  }

  void set(int x, int y) {
    u = x;
    v = y;
  }

  bool same(int u, int v) { return head[u] == head[v]; }

  pair<int, int> next() {
    if (depth[u] > depth[v]) swap(u, v);
    if (depth[head[u]] > depth[head[v]]) swap(u, v);
    int l = same(u, v) ? vertexID[u] : vertexID[head[v]];
    int r = vertexID[v] + 1;
    v = same(u, v) ? -1 : parent[head[v]];
    return {l, r};
  }

  bool hasNext() { return v != -1; }
};

struct LCA {
  LCA() {}
  vector<vector<int>> G;
  vector<vector<int>> parent;
  vector<int> depth;
  int root, logV;

  void dfs(int v, int p, int d) {
    parent[0][v] = p;
    depth[v] = d;
    for (auto &u : G[v])
      if (u != p) dfs(u, v, d + 1);
  }

  LCA(const vector<vector<int>> &adj) {
    int V = adj.size();
    root = 0;
    G = adj;
    depth.assign(V, 0);

    logV = 1;
    for (int i = 1; i <= V;) i *= 2, logV++;
    parent.assign(logV, vector<int>(V));

    dfs(root, -1, 0);

    for (int k = 0; k + 1 < logV; ++k)
      for (int v = 0; v < V; ++v)
        if (parent[k][v] < 0) {
          parent[k + 1][v] = -1;
        } else {
          parent[k + 1][v] = parent[k][parent[k][v]];
        }
  }

  int getLCA(int u, int v) {
    if (depth[u] > depth[v]) swap(u, v);
    for (int k = 0; k < logV; ++k)
      if ((depth[v] - depth[u]) >> k & 1) v = parent[k][v];
    if (u == v) return u;
    for (int k = logV - 1; k >= 0; --k)
      if (parent[k][u] != parent[k][v]) {
        u = parent[k][u];
        v = parent[k][v];
      }
    return parent[0][u];
  }

  int getLength(int u, int v) {
    int lca = getLCA(u, v);
    return depth[u] + depth[v] - depth[lca] * 2;
  }
};

struct RMQ {
  static const int N = 1 << 17;
  vector<long long> seg;
  RMQ() : seg(N * 2, (long long)1e18) {}

  void update(int k, long long value) {
    seg[k += N - 1] = value;
    while (k > 0) {
      k = (k - 1) / 2;
      seg[k] = min(seg[k * 2 + 1], seg[k * 2 + 2]);
    }
  }

  long long query(int a, int b, int k = 0, int l = 0, int r = N) {
    if (r <= a || b <= l) return 1e18;
    if (a <= l && r <= b) return seg[k];
    long long x = query(a, b, k * 2 + 1, l, (l + r) / 2);
    long long y = query(a, b, k * 2 + 2, (l + r) / 2, r);
    return min(x, y);
  }
};

int main() {
  cin.tie(0);
  ios::sync_with_stdio(false);

  int N, Q;
  cin >> N >> Q;
  vector<vector<int>> adj(N, vector<int>());
  for (int i = 0; i < N - 1; ++i) {
    int A, B;
    cin >> A >> B;
    A--, B--;
    adj[A].push_back(B);
    adj[B].push_back(A);
  }
  LCA lca(adj);
  HL hl(adj);
  RMQ rmq1, rmq2;

  string S;
  cin >> S;
  queue<pair<int, int>> que;
  for (int i = 0; i < N; ++i)
    if (S[i] == '1') que.push({i, 0});
  vector<int> cost(N, N);
  while (!que.empty()) {
    pair<int, int> p = que.front();
    que.pop();
    if (cost[p.first] <= p.second) continue;
    cost[p.first] = p.second;
    for (const auto &v : adj[p.first]) que.push({v, p.second + 1});
  }
  for (int i = 0; i < N; ++i) {
    rmq1.update(hl.vertexID[i], cost[i] * 3 - hl.depth[i]);
    rmq2.update(hl.vertexID[i], cost[i] * 3 + hl.depth[i]);
  }

  while (Q--) {
    int s, t;
    cin >> s >> t;
    s--, t--;
    int d = lca.getLength(s, t);
    int l = lca.getLCA(s, t);
    ll ans = d * 2;

    ll min1 = 1e9;
    hl.set(s, l);
    while (hl.hasNext()) {
      pair<int, int> p = hl.next();
      min1 = min(min1, rmq1.query(p.first, p.second));
    }
    ans = min(ans, d + hl.depth[s] + min1);

    ll min2 = 1e9;
    hl.set(t, l);
    while (hl.hasNext()) {
      pair<int, int> p = hl.next();
      min2 = min(min2, rmq2.query(p.first, p.second));
    }
    ans = min(ans, d + hl.depth[s] - hl.depth[l] * 2 + min2);

    cout << ans << endl;
  }
}