CS Academy Round #29: Root Change
問題
https://csacademy.com/contest/round-29/task/root-change/
ツリーの各頂点 i について、 i を根とした時に切断してもツリーの高さが変化しない辺の数を出力してください。
解法
editorial のコード通りです。いわゆる全方位木dpと呼ばれる手法です。
切断しても高さが変わらない辺の数は、切断すると高さが変わる辺の数を求めれば良いので、後者を求めることにします。
まず、頂点0を根とした時の、切断すると高さが変わる辺の数を求めます。0 を根とした時の頂点 v の親を p とすると、この操作によって「v を根とした時の、切断すると高さの変わる辺の数(p 方向以外)」が求まります。これを down[v] とします。
あとはこの図からエスパーしていただければ幸いです。
コード
import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.util.ArrayList; import java.util.NoSuchElementException; public class Main { class Height { int height, unremovableEdgeNum; Height(int height, int unremovableEdgeNum) { this.height = height; this.unremovableEdgeNum = unremovableEdgeNum; } void merge(Height h) { if (h.height > height) { height = h.height; unremovableEdgeNum = h.unremovableEdgeNum; } else if (h.height == height) { unremovableEdgeNum = 0; } } Height goUp() { return new Height(height + 1, unremovableEdgeNum + 1); } } private ArrayList<Integer>[] tree; private int[] ans; private Height[] down, up; // down[u] := 0 を根としたツリー上で葉の方向に見た時の、高さと、切ったら高さが変わる辺の数 // up[u] := 0 を根としたツリー上で根の方向に見た時の、高さと、切ったら高さが変わる辺の数 private void solve(FastScanner in, PrintWriter out) { int N = in.nextInt(); tree = new ArrayList[N]; for (int i = 0; i < N; i++) { tree[i] = new ArrayList<>(); } for (int i = 0; i < N - 1; i++) { int a = in.nextInt() - 1; int b = in.nextInt() - 1; tree[a].add(b); tree[b].add(a); } ans = new int[N]; down = new Height[N]; for (int i = 0; i < N; i++) { down[i] = new Height(0, 0); } up = new Height[N]; for (int i = 0; i < N; i++) { up[i] = new Height(0, 0); } downDfs(0, -1); dfs(0, -1); for (int a : ans) { out.println(N - 1 - a); } } private void downDfs(int v, int p) { down[v] = new Height(0, 0); for (int u : tree[v]) { if (u == p) { continue; } downDfs(u, v); down[v].merge(down[u].goUp()); } } private void dfs(int v, int p) { Height suffix = new Height(0, 0); for (int i = tree[v].size() - 1; i >= 0; i--) { int u = tree[v].get(i); if (u == p) { continue; } up[u] = new Height(suffix.height, suffix.unremovableEdgeNum); //頂点vから頂点uの方向に見た時の情報をマージしておく suffix.merge(down[u].goUp()); } //この時点で up[u] には、頂点vを介して suffix に含まれる頂点の方向に見た時の、切手はいけない辺の数が入っている Height prefix = new Height(0, 0); for (int u : tree[v]) { if (u == p) { continue; } // v->p Height h = new Height(up[v].height, up[v].unremovableEdgeNum); // v->prefix h.merge(prefix); // v->suffix h.merge(up[u]); //u->v->... up[u] = h.goUp(); dfs(u, v); //頂点v から頂点uの方向に見た時の情報をマージしておく // v->u prefix.merge(down[u].goUp()); } //v->p prefix.merge(up[v]); ans[v] = prefix.unremovableEdgeNum; } public static void main(String[] args) { PrintWriter out = new PrintWriter(System.out); new Main().solve(new FastScanner(), out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) { return buffer[ptr++]; } else { return -1; } } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) { ptr++; } } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) { throw new NoSuchElementException(); } StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } public int loadChar(char[] buf) { if (!hasNext()) { throw new NoSuchElementException(); } int pos = 0; int b = readByte(); while (isPrintableChar(b)) { buf[pos] = (char) b; b = readByte(); pos++; } return pos; } long nextLong() { if (!hasNext()) { throw new NoSuchElementException(); } long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) { array[i] = nextInt(); } return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) { array[i] = nextLong(); } return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) { array[i] = next(); } return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) { array[i] = next().toCharArray(); } return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }