CodeChef Snackdown 2016 : Online Elimination Round - Alliances
問題
頂点数 N のツリーが与えられる。
K 個のギャングがいて、各ギャングはツリー上のある頂点集合を支配している。
クエリが Q 個くる。各クエリはある頂点 v と、ギャングのリストが含まれている。
リスト中のギャングが支配している頂点および、それらの頂点を結ぶパス上にある頂点を「支配されている頂点」とするとき、v から最も近い「支配されている頂点」への距離を求めよ。
解法
ALLIANCE - Editorial - CodeChef Discuss
支配されている頂点集合のLCAを、u とする。
v が u を根とするサブツリーに含まれていなければ、dist(v, u) が答えになる。
v が u を根とするサブツリーに含まれるとき、「v を含み、かつ、支配されている頂点を含まない最大のサブツリー」を探す。そのサブツリーの根を w とすると、dist(v, w)+1 が答えになる。
「v を含み、かつ、支配されている頂点を含まない最大のサブツリー」 をみつけるのはオイラーツアーしてBITで頑張るといける。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.NoSuchElementException; public class ALLIANCE { private static class Task { void solve(FastScanner in, PrintWriter out) { int N = in.nextInt(); ArrayList<ArrayList<Integer>> adj = new ArrayList<>(N); for (int i = 0; i < N; i++) adj.add(new ArrayList<>()); for (int i = 0; i < N - 1; i++) { int u = in.nextInt() - 1; int v = in.nextInt() - 1; adj.get(u).add(v); adj.get(v).add(u); } LCA lca = new LCA(adj); int gangGroupNum = in.nextInt(); int[][] gangs = new int[gangGroupNum][]; int[] gangGroupLCA = new int[gangGroupNum]; for (int gang = 0; gang < gangGroupNum; gang++) { int cityNum = in.nextInt(); gangs[gang] = new int[cityNum]; gangGroupLCA[gang] = -1; for (int j = 0; j < cityNum; j++) { int city = in.nextInt() - 1; gangs[gang][j] = city; gangGroupLCA[gang] = lca.getLCA(gangGroupLCA[gang], city); } Arrays.sort(gangs[gang]); } int Q = in.nextInt(); int[] queryCity = new int[Q]; ArrayList<Integer>[] queryGangs = new ArrayList[N]; for (int i = 0; i < N; i++) queryGangs[i] = new ArrayList<Integer>(); int[] queryGangLCA = new int[Q]; int[] queryAnswer = new int[Q]; for (int q = 0; q < Q; q++) { queryCity[q] = in.nextInt() - 1; int queryGangNum = in.nextInt(); queryGangLCA[q] = -1; for (; queryGangNum > 0; queryGangNum--) { int gang = in.nextInt() - 1; queryGangs[gang].add(q); queryGangLCA[q] = lca.getLCA(queryGangLCA[q], gangGroupLCA[gang]); } queryAnswer[q] = Integer.MAX_VALUE / 2; } FenwickTree bit = new FenwickTree(lca.eulerTourCnt); // per gang for (int gang = 0; gang < gangGroupNum; gang++) { // init gang for (int city : gangs[gang]) { bit.add(lca.preEuler[city], 1); bit.add(lca.postEuler[city], 1); } //answer queries for (int query : queryGangs[gang]) { int targetCity = queryCity[query]; int l = lca.getLCA(targetCity, gangGroupLCA[gang]); int answer; if (l == gangGroupLCA[gang]) { answer = 0; if (bit.sum(lca.preEuler[targetCity], lca.postEuler[targetCity] + 1) == 0) { int currentCity = targetCity; for (int k = lca.parent.length - 1; k >= 0; --k) { int nextCity = lca.parent[k][currentCity]; if (nextCity < 0) nextCity = 0; if (bit.sum(lca.preEuler[nextCity], lca.postEuler[nextCity] + 1) == 0) { currentCity = nextCity; } } answer = lca.getLength(currentCity, targetCity) + 1; } } else { answer = lca.depth[targetCity] - lca.depth[l] + Math.max(0, lca.depth[queryGangLCA[query]] - lca.depth[l]); } queryAnswer[query] = Math.min(queryAnswer[query], answer); } // Reset Fenwick Tree for (int city : gangs[gang]) { bit.add(lca.preEuler[city], -1); bit.add(lca.postEuler[city], -1); } } for (int i = 0; i < Q; i++) { out.println(queryAnswer[i]); } } } static class FenwickTree { int N; long[] data; FenwickTree(int N) { this.N = N + 1; data = new long[N + 1]; } void add(int k, long val) { for (int x = k; x < N; x |= x + 1) { data[x] += val; } } // [0, k) long sum(int k) { if (k >= N) k = N - 1; int ret = 0; for (int x = k - 1; x >= 0; x = (x & (x + 1)) - 1) { ret += data[x]; } return ret; } // [l, r) long sum(int l, int r) { return sum(r) - sum(l); } long get(int k) { assert (0 <= k && k < N); return sum(k + 1) - sum(k); } } static class LCA { ArrayList<ArrayList<Integer>> G; int[][] parent; int[] depth; int root, logV; int[] preEuler, postEuler; int eulerTourCnt = 0; void build(int root) { Arrays.fill(depth, -1); ArrayDeque<Integer> stack = new ArrayDeque<>(); stack.addFirst(root); parent[0][root] = -1; depth[root] = 0; while (!stack.isEmpty()) { int v = stack.peekFirst(); for (int u : G.get(v)) { if (depth[u] >= 0) continue; parent[0][u] = v; depth[u] = depth[v] + 1; stack.addFirst(u); } if (stack.peekFirst() == v) { stack.pollFirst(); if (preEuler[v] < 0) preEuler[v] = eulerTourCnt++; postEuler[v] = eulerTourCnt++; } else { preEuler[v] = eulerTourCnt++; } } } LCA(final ArrayList<ArrayList<Integer>> adj) { int V = adj.size(); root = 0; G = adj; depth = new int[V]; preEuler = new int[V]; Arrays.fill(preEuler, -1); postEuler = new int[V]; logV = 1; for (int i = 1; i <= V; ) { i *= 2; logV++; } parent = new int[logV][V]; build(root); for (int k = 0; k + 1 < logV; ++k) for (int v = 0; v < V; ++v) if (parent[k][v] < 0) { parent[k + 1][v] = -1; } else { parent[k + 1][v] = parent[k][parent[k][v]]; } } int getLCA(int u, int v) { if (u < 0) return v; if (v < 0) return u; if (depth[u] > depth[v]) { int tu = u; u = v; v = tu; } for (int k = 0; k < logV; ++k) if (((depth[v] - depth[u]) >> k & 1) != 0) v = parent[k][v]; if (u == v) return u; for (int k = logV - 1; k >= 0; --k) if (parent[k][u] != parent[k][v]) { u = parent[k][u]; v = parent[k][v]; } return parent[0][u]; } int getLength(int u, int v) { int lca = getLCA(u, v); return depth[u] + depth[v] - depth[lca] * 2; } } // Template public static void main(String[] args) throws InterruptedException { Task solver = new Task(); OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) { array[i] = nextInt(); } return array; } } }