Codeforces Round #364 Div2 E. Connecting Universities
解法
ある辺を見た時、ペアの両端がそれぞれ辺の反対側にある方が良い。この考えに基づいて、頂点 v と p を結ぶ辺を見た時に、v 側に d 個の大学、p 側に u 個の大学があるとすると、各ペアの片方の大学は d 個の中に、もう片方の大学は u この中に、出来るだけ含まれている方が良い。よって辺 v-p を見た時、min(d, u) 個のペアの両端が、それぞれ v 側と p 側にあるようにする。すると辺 v-p を通るケーブルは min(d, u) 本である。これをすべての辺について見てやれば良い。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.ArrayList; import java.util.NoSuchElementException; public class E { private static class Task { ArrayList<Integer>[] graph; boolean[] univ; long ans = 0; int K; int dfs(int v, int p) { int down = 0;// v と p を結ぶ辺について、v 側にある univ の数 if (univ[v]) down++; for (int u : graph[v]) if (u != p) down += dfs(u, v); int up = 2 * K - down;// v と p を結ぶ辺について、p 側にある univ の数 if (p >= 0) ans += Math.min(up, down);//min(up, down) 本のケーブルが v と p を結ぶ辺を通る return down; } void solve(FastScanner in, PrintWriter out) { int N = in.nextInt(); K = in.nextInt(); univ = new boolean[N]; for (int i = 0; i < 2 * K; i++) { int x = in.nextInt() - 1; univ[x] = true; } graph = new ArrayList[N]; for (int i = 0; i < N; i++) graph[i] = new ArrayList<>(); for (int i = 0; i < N - 1; i++) { int x = in.nextInt() - 1; int y = in.nextInt() - 1; graph[x].add(y); graph[y].add(x); } dfs(0, -1); out.println(ans); } } // Template public static void main(String[] args) { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) { array[i] = nextInt(); } return array; } } }