Codeforces Round #359 Div2 D. Kay and Snowflake
問題
頂点数 N の木が与えられる。頂点 v を根とする部分木を考えた時、その部分木に含まれる頂点の u うち、u を消去してできる残りの部分木が、いずれも元の部分木の半分以下のサイズになるような u を「v の centroid である」とする。centroid を聞く Q このクエリに答えよ。
解法
各頂点 v について、
- 子要素の数 (c とする)
- v の下流の部分木のうち最大のもののサイズ (m とする)
を求めておく。
すると、頂点 v が centroid と成り得る部分木は、サイズが (c+1)*2以上、m*2以下のものである事がわかる。あとは筋肉。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.NoSuchElementException; public class D { private static class Task { ArrayList<Integer>[] graph; ArrayDeque<Integer>[] deque; int N; int[] children;// children[v] := v の子要素の合計数 int[] centroid;// centroid[v] := v の centroid int[] maxPart;// maxPart[v] := v を消去した時、v の子の部分集合の中で最大のもの // v の子について色々求める DFS int dfs(int v) { for (int u : graph[v]) { int d = dfs(u); children[v] += d; maxPart[v] = Math.max(maxPart[v], d); } if (children[v] == 0 || maxPart[v] * 2 <= children[v] + 1) centroid[v] = v; return children[v] + 1; } // centroid を求める DFS void centroidDFS(int v) { int vSubtree = children[v] + 1;// v を根とする subtree のサイズ int from = maxPart[v] * 2;// v が centroid と成り得る subtree のサイズの下限 int to = vSubtree * 2;// v が centroid と成り得る subtree のサイズの上限 // v が centroid と成り得る subtree を処理 for (int i = from; i <= Math.min(to, N); i++) while (!deque[i].isEmpty()) { int u = deque[i].poll(); centroid[u] = v; } if (centroid[v] < 0) deque[vSubtree].add(v);// v の centroid が決まっていなければキューに入れておく for (int u : graph[v]) centroidDFS(u); } void solve(FastScanner in, PrintWriter out) throws Exception { N = in.nextInt(); int Q = in.nextInt(); graph = new ArrayList[N]; for (int i = 0; i < N; i++) graph[i] = new ArrayList<>(); for (int i = 1; i <= N - 1; i++) { int p = in.nextInt() - 1; graph[p].add(i); } children = new int[N]; centroid = new int[N]; Arrays.fill(centroid, -1); maxPart = new int[N]; deque = new ArrayDeque[N + 1]; for (int i = 0; i < N + 1; i++) deque[i] = new ArrayDeque<>(); dfs(0); centroidDFS(0); for (int q = 0; q < Q; q++) { int v = in.nextInt() - 1; out.println(centroid[v] + 1); } } } // Template public static void main(String[] args) throws Exception { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) { array[i] = nextInt(); } return array; } } }