CS Academy Round #29: Root Change

問題

https://csacademy.com/contest/round-29/task/root-change/

ツリーの各頂点 i について、 i を根とした時に切断してもツリーの高さが変化しない辺の数を出力してください。

解法

editorial のコード通りです。いわゆる全方位木dpと呼ばれる手法です。
切断しても高さが変わらない辺の数は、切断すると高さが変わる辺の数を求めれば良いので、後者を求めることにします。

まず、頂点0を根とした時の、切断すると高さが変わる辺の数を求めます。0 を根とした時の頂点 v の親を p とすると、この操作によって「v を根とした時の、切断すると高さの変わる辺の数(p 方向以外)」が求まります。これを down[v] とします。


f:id:kenkoooo:20170524125244j:plain

あとはこの図からエスパーしていただければ幸いです。

コード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.NoSuchElementException;

public class Main {

  class Height {

    int height, unremovableEdgeNum;

    Height(int height, int unremovableEdgeNum) {
      this.height = height;
      this.unremovableEdgeNum = unremovableEdgeNum;
    }

    void merge(Height h) {
      if (h.height > height) {
        height = h.height;
        unremovableEdgeNum = h.unremovableEdgeNum;
      } else if (h.height == height) {
        unremovableEdgeNum = 0;
      }
    }

    Height goUp() {
      return new Height(height + 1, unremovableEdgeNum + 1);
    }
  }

  private ArrayList<Integer>[] tree;
  private int[] ans;
  private Height[] down, up;
  // down[u] := 0 を根としたツリー上で葉の方向に見た時の、高さと、切ったら高さが変わる辺の数
  // up[u] := 0 を根としたツリー上で根の方向に見た時の、高さと、切ったら高さが変わる辺の数

  private void solve(FastScanner in, PrintWriter out) {
    int N = in.nextInt();
    tree = new ArrayList[N];
    for (int i = 0; i < N; i++) {
      tree[i] = new ArrayList<>();
    }

    for (int i = 0; i < N - 1; i++) {
      int a = in.nextInt() - 1;
      int b = in.nextInt() - 1;
      tree[a].add(b);
      tree[b].add(a);
    }

    ans = new int[N];
    down = new Height[N];
    for (int i = 0; i < N; i++) {
      down[i] = new Height(0, 0);
    }
    up = new Height[N];
    for (int i = 0; i < N; i++) {
      up[i] = new Height(0, 0);
    }

    downDfs(0, -1);
    dfs(0, -1);
    for (int a : ans) {
      out.println(N - 1 - a);
    }
  }

  private void downDfs(int v, int p) {
    down[v] = new Height(0, 0);
    for (int u : tree[v]) {
      if (u == p) {
        continue;
      }
      downDfs(u, v);
      down[v].merge(down[u].goUp());
    }
  }

  private void dfs(int v, int p) {
    Height suffix = new Height(0, 0);
    for (int i = tree[v].size() - 1; i >= 0; i--) {
      int u = tree[v].get(i);
      if (u == p) {
        continue;
      }
      up[u] = new Height(suffix.height, suffix.unremovableEdgeNum);

      //頂点vから頂点uの方向に見た時の情報をマージしておく
      suffix.merge(down[u].goUp());
    }

    //この時点で up[u] には、頂点vを介して suffix に含まれる頂点の方向に見た時の、切手はいけない辺の数が入っている

    Height prefix = new Height(0, 0);
    for (int u : tree[v]) {
      if (u == p) {
        continue;
      }

      // v->p
      Height h = new Height(up[v].height, up[v].unremovableEdgeNum);
      // v->prefix
      h.merge(prefix);
      // v->suffix
      h.merge(up[u]);

      //u->v->...
      up[u] = h.goUp();

      dfs(u, v);

      //頂点v から頂点uの方向に見た時の情報をマージしておく
      // v->u
      prefix.merge(down[u].goUp());
    }

    //v->p
    prefix.merge(up[v]);
    ans[v] = prefix.unremovableEdgeNum;
  }

  public static void main(String[] args) {
    PrintWriter out = new PrintWriter(System.out);
    new Main().solve(new FastScanner(), out);
    out.close();
  }

  private static class FastScanner {

    private final InputStream in = System.in;
    private final byte[] buffer = new byte[1024];
    private int ptr = 0;
    private int bufferLength = 0;

    private boolean hasNextByte() {
      if (ptr < bufferLength) {
        return true;
      } else {
        ptr = 0;
        try {
          bufferLength = in.read(buffer);
        } catch (IOException e) {
          e.printStackTrace();
        }
        if (bufferLength <= 0) {
          return false;
        }
      }
      return true;
    }

    private int readByte() {
      if (hasNextByte()) {
        return buffer[ptr++];
      } else {
        return -1;
      }
    }

    private static boolean isPrintableChar(int c) {
      return 33 <= c && c <= 126;
    }

    private void skipUnprintable() {
      while (hasNextByte() && !isPrintableChar(buffer[ptr])) {
        ptr++;
      }
    }

    boolean hasNext() {
      skipUnprintable();
      return hasNextByte();
    }

    public String next() {
      if (!hasNext()) {
        throw new NoSuchElementException();
      }
      StringBuilder sb = new StringBuilder();
      int b = readByte();
      while (isPrintableChar(b)) {
        sb.appendCodePoint(b);
        b = readByte();
      }
      return sb.toString();
    }

    public int loadChar(char[] buf) {
      if (!hasNext()) {
        throw new NoSuchElementException();
      }
      int pos = 0;
      int b = readByte();
      while (isPrintableChar(b)) {
        buf[pos] = (char) b;
        b = readByte();
        pos++;
      }
      return pos;
    }

    long nextLong() {
      if (!hasNext()) {
        throw new NoSuchElementException();
      }
      long n = 0;
      boolean minus = false;
      int b = readByte();
      if (b == '-') {
        minus = true;
        b = readByte();
      }
      if (b < '0' || '9' < b) {
        throw new NumberFormatException();
      }
      while (true) {
        if ('0' <= b && b <= '9') {
          n *= 10;
          n += b - '0';
        } else if (b == -1 || !isPrintableChar(b)) {
          return minus ? -n : n;
        } else {
          throw new NumberFormatException();
        }
        b = readByte();
      }
    }

    double nextDouble() {
      return Double.parseDouble(next());
    }

    double[] nextDoubleArray(int n) {
      double[] array = new double[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextDouble();
      }
      return array;
    }

    double[][] nextDoubleMap(int n, int m) {
      double[][] map = new double[n][];
      for (int i = 0; i < n; i++) {
        map[i] = nextDoubleArray(m);
      }
      return map;
    }

    public int nextInt() {
      return (int) nextLong();
    }

    public int[] nextIntArray(int n) {
      int[] array = new int[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextInt();
      }
      return array;
    }

    public long[] nextLongArray(int n) {
      long[] array = new long[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextLong();
      }
      return array;
    }

    public String[] nextStringArray(int n) {
      String[] array = new String[n];
      for (int i = 0; i < n; i++) {
        array[i] = next();
      }
      return array;
    }

    public char[][] nextCharMap(int n) {
      char[][] array = new char[n][];
      for (int i = 0; i < n; i++) {
        array[i] = next().toCharArray();
      }
      return array;
    }

    public int[][] nextIntMap(int n, int m) {
      int[][] map = new int[n][];
      for (int i = 0; i < n; i++) {
        map[i] = nextIntArray(m);
      }
      return map;
    }
  }
}

ABC062/ARC074

Rust でやってみました。標準入出力は他の人のを拝借してきたけど、 println マクロは出力ごとに flush しているようなので注意が必要そう。もっときれいに書けるようになりたい。

A - Grouping

http://abc062.contest.atcoder.jp/tasks/abc062_a

fn next() -> String {
    let mut buffer = String::new();
    std::io::stdin().read_line(&mut buffer).ok();
    return buffer;
}

fn main() {
    let s = next();
    let a: Vec<&str> = s.trim().split(' ').collect();
    let x: i32 = a[0].parse().unwrap();
    let y: i32 = a[1].parse().unwrap();

    let g1 = vec![1, 3, 5, 7, 8, 10, 12];
    let g2 = vec![4, 6, 9, 11];
    let g3 = vec![2];
    if g1.contains(&x) && g1.contains(&y) {
        println!("Yes");
    } else if g2.contains(&x) && g2.contains(&y) {
        println!("Yes");
    } else if g3.contains(&x) && g3.contains(&y) {
        println!("Yes");
    } else {
        println!("No");
    }
}

B - Picture Frame

http://abc062.contest.atcoder.jp/tasks/abc062_b

use std::fmt::Debug;
use std::io;
use std::io::{Read, Stdin};
use std::str;
use std::str::FromStr;

fn main() {
    let mut sc = Scanner::new();
    let h = sc.parse::<usize>();
    let w = sc.parse::<usize>();
    let rows = (0..h).map(|_| sc.parse::<String>()).collect::<Vec<_>>();
    for _ in 0..w + 2 {
        print!("#");
    }
    println!("");
    for row in rows {
        println!("#{}#", row);
    }
    for _ in 0..w + 2 {
        print!("#");
    }
    println!("");
}


struct Scanner {
    stdin: Stdin,
    buf: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner {
            stdin: io::stdin(),
            buf: Vec::with_capacity(256),
        }
    }

    fn parse<T: FromStr>(&mut self) -> T
        where <T as FromStr>::Err: Debug
    {
        self.buf.clear();
        let mut it = self.stdin.lock().bytes();
        let mut c = it.next().unwrap().unwrap();
        while c == ' ' as u8 || c == '\n' as u8 {
            c = it.next().unwrap().unwrap();
        }
        while !(c == ' ' as u8 || c == '\n' as u8) {
            self.buf.push(c);
            c = it.next().unwrap().unwrap();
        }
        str::from_utf8(&self.buf).unwrap().parse::<T>().unwrap()
    }
}

C - Chocolate Bar

http://abc062.contest.atcoder.jp/tasks/arc074_a

use std::fmt::Debug;
use std::io;
use std::io::{Read, Stdin};
use std::str;
use std::str::FromStr;
use std::cmp;

fn calc(h: i64, w: i64) -> i64 {
    let mut min: i64 = w;
    for h1 in 1..(h + 1) {
        let s1: i64 = h1 * w;
        let h2: i64 = h - h1;
        let w1: i64 = w / 2;
        let w2: i64 = w - w1;

        let s2: i64 = w1 * h2;
        let s3: i64 = w2 * h2;

        let mut max = cmp::max((s1 - s2).abs(), (s1 - s3).abs());
        max = cmp::max(max, (s2 - s3).abs());

        min = cmp::min(min, max);
    }
    min
}

fn main() {
    let mut sc = Scanner::new();
    let h = sc.parse::<i64>();
    let w = sc.parse::<i64>();
    if h % 3 == 0 || w % 3 == 0 {
        println!("0");
        return;
    }

    let ans = cmp::min(calc(h, w), calc(w, h));
    println!("{}", ans);
}


struct Scanner {
    stdin: Stdin,
    buf: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner {
            stdin: io::stdin(),
            buf: Vec::with_capacity(256),
        }
    }

    fn parse<T: FromStr>(&mut self) -> T
        where <T as FromStr>::Err: Debug
    {
        self.buf.clear();
        let mut it = self.stdin.lock().bytes();
        let mut c = it.next().unwrap().unwrap();
        while c == ' ' as u8 || c == '\n' as u8 {
            c = it.next().unwrap().unwrap();
        }
        while !(c == ' ' as u8 || c == '\n' as u8) {
            self.buf.push(c);
            c = it.next().unwrap().unwrap();
        }
        str::from_utf8(&self.buf).unwrap().parse::<T>().unwrap()
    }
}

D - 3N Numbers

http://arc074.contest.atcoder.jp/tasks/arc074_b

use std::fmt::Debug;
use std::io;
use std::io::{Read, Stdin};
use std::str;
use std::str::FromStr;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::usize;
use std::cmp;


fn main() {
    let mut sc = Scanner::new();
    let n = sc.parse::<usize>();
    let mut a: Vec<i64> = vec![0; 3*n];
    for i in 0..(3 * n) {
        a[i] = sc.parse::<i64>();
    }

    let mut heap: BinaryHeap<i64> = BinaryHeap::new();
    let mut sum: i64 = 0;
    let mut prefix_max: Vec<i64> = vec![0;3*n];
    for i in 0..(2 * n) {
        if heap.len() == n {
            let min = -heap.pop().unwrap();
            sum -= min;
            let max = cmp::max(min, a[i]);
            sum += max;
            heap.push(-max);
        } else {
            heap.push(-a[i]);
            sum += a[i];
        }

        if heap.len() == n {
            prefix_max[i] = sum;
        }
    }

    heap.clear();
    sum = 0;
    let mut suffix_min = vec![0;3*n];
    for i in (n..(3 * n)).rev() {
        if heap.len() == n {
            let max = heap.pop().unwrap();
            sum -= max;
            let min = cmp::min(max, a[i]);
            sum += min;
            heap.push(min);
        } else {
            heap.push(a[i]);
            sum += a[i];
        }

        if heap.len() == n {
            suffix_min[i] = sum;
        }
    }

    let mut ans: i64 = -1000000000000000000;
    for i in n..(2 * n + 1) {
        let p = prefix_max[i - 1] as i64;
        let s = suffix_min[i] as i64;
        // println!("{} {}", p, s);
        ans = cmp::max(ans, p - s);
    }
    println!("{}", ans);
}


struct Scanner {
    stdin: Stdin,
    buf: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner {
            stdin: io::stdin(),
            buf: Vec::with_capacity(256),
        }
    }

    fn parse<T: FromStr>(&mut self) -> T
        where <T as FromStr>::Err: Debug
    {
        self.buf.clear();
        let mut it = self.stdin.lock().bytes();
        let mut c = it.next().unwrap().unwrap();
        while c == ' ' as u8 || c == '\n' as u8 {
            c = it.next().unwrap().unwrap();
        }
        while !(c == ' ' as u8 || c == '\n' as u8) {
            self.buf.push(c);
            c = it.next().unwrap().unwrap();
        }
        str::from_utf8(&self.buf).unwrap().parse::<T>().unwrap()
    }
}

E - RGB Sequence

http://arc074.contest.atcoder.jp/tasks/arc074_c

use std::fmt::Debug;
use std::io;
use std::io::{Read, Stdin};
use std::str;
use std::str::FromStr;
use std::usize;
use std::cmp;

struct Rule {
    left: usize,
    x: usize,
}

fn main() {
    let mut sc = Scanner::new();
    let n = sc.parse::<usize>();
    let m = sc.parse::<usize>();

    let mut rules: Vec<Vec<Rule>> = Vec::new();
    for _ in 0..(n + 1) {
        rules.push(Vec::new());
    }
    for _ in 0..m {
        let l = sc.parse::<usize>();
        let r = sc.parse::<usize>();
        let x = sc.parse::<usize>();
        rules[r].push(Rule { left: l, x: x });
    }

    let mut dp = vec![vec![vec![0;(n+1)];(n+1)];(n+1)];
    dp[0][0][0] = 1;
    let modulo = 1000000007;
    for r in 0..n {
        for g in 0..n {
            for b in 0..n {
                let k = cmp::max(cmp::max(r, g), b);
                if check(&rules, k + 1, g, b) {
                    dp[k + 1][g][b] += dp[r][g][b];
                    dp[k + 1][g][b] %= modulo;
                }
                if check(&rules, r, k + 1, b) {
                    dp[r][k + 1][b] += dp[r][g][b];
                    dp[r][k + 1][b] %= modulo;
                }
                if check(&rules, r, g, k + 1) {
                    dp[r][g][k + 1] += dp[r][g][b];
                    dp[r][g][k + 1] %= modulo;
                }
            }
        }
    }

    let mut ans = 0;
    for i in 0..n {
        for j in 0..n {
            ans += dp[i][j][n];
            ans %= modulo;
            ans += dp[i][n][j];
            ans %= modulo;
            ans += dp[n][i][j];
            ans %= modulo;
        }
    }
    println!("{}", ans);
}

fn check(rules: &Vec<Vec<Rule>>, r: usize, g: usize, b: usize) -> bool {
    let k = cmp::max(cmp::max(r, g), b);
    for rule in &rules[k] {
        let left = rule.left;
        let mut count = 0;
        if r >= left {
            count += 1;
        }
        if g >= left {
            count += 1;
        }
        if b >= left {
            count += 1;
        }
        if count != rule.x {
            return false;
        }
    }
    return true;
}

struct Scanner {
    stdin: Stdin,
    buf: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner {
            stdin: io::stdin(),
            buf: Vec::with_capacity(256),
        }
    }

    fn parse<T: FromStr>(&mut self) -> T
        where <T as FromStr>::Err: Debug
    {
        self.buf.clear();
        let mut it = self.stdin.lock().bytes();
        let mut c = it.next().unwrap().unwrap();
        while c == ' ' as u8 || c == '\n' as u8 {
            c = it.next().unwrap().unwrap();
        }
        while !(c == ' ' as u8 || c == '\n' as u8) {
            self.buf.push(c);
            c = it.next().unwrap().unwrap();
        }
        str::from_utf8(&self.buf).unwrap().parse::<T>().unwrap()
    }
}

F - Lotus Leaves

http://arc074.contest.atcoder.jp/tasks/arc074_d

use std::fmt::Debug;
use std::io;
use std::io::{Read, Stdin};
use std::str;
use std::str::FromStr;
use std::usize;
use std::cmp;
use std::collections::vec_deque::VecDeque;
use std::i64::MAX;

pub struct Edge {
    pub to: usize,
    pub rev: usize,
    pub cap: i64,
}

struct Dinitz {
    g: Vec<Vec<Edge>>,
    level: Vec<i32>,
    iter: Vec<usize>,
}

impl Dinitz {
    fn new(v: usize) -> Dinitz {
        let mut g: Vec<Vec<Edge>> = Vec::new();
        for _ in 0..v {
            g.push(Vec::new());
        }
        Dinitz {
            g: g,
            level: vec![0;v],
            iter: vec![0;v],
        }
    }

    fn add_edge(&mut self, from: usize, to: usize, cap: i64) {
        let to_len = self.g[to].len();
        let from_len = self.g[from].len();
        self.g[from].push(Edge {
                              to: to,
                              rev: to_len,
                              cap: cap,
                          });
        self.g[to].push(Edge {
                            to: from,
                            rev: from_len,
                            cap: 0,
                        });
    }

    fn dfs(&mut self, v: usize, t: usize, f: i64) -> i64 {
        if v == t {
            return f;
        }
        while self.iter[v] < self.g[v].len() {

            let (e_cap, e_to, e_rev);
            {
                let ref e = self.g[v][self.iter[v]];
                e_cap = e.cap;
                e_to = e.to;
                e_rev = e.rev;
            }
            if e_cap > 0 && self.level[v] < self.level[e_to] {
                let d = self.dfs(e_to, t, cmp::min(f, e_cap));
                if d > 0 {
                    {
                        let ref mut e = self.g[v][self.iter[v]];
                        e.cap -= d;
                    }
                    {
                        let ref mut rev_edge = self.g[e_to][e_rev];
                        rev_edge.cap += d;
                    }
                    return d;
                }
            }
            self.iter[v] += 1;
        }

        return 0;
    }

    fn bfs(&mut self, s: usize) {
        let v = self.level.len();
        self.level = vec![-1;v];
        self.level[s] = 0;
        let mut deque = VecDeque::new();
        deque.push_back(s);
        while !deque.is_empty() {
            let v = deque.pop_front().unwrap();
            for e in &self.g[v] {
                if e.cap > 0 && self.level[e.to] < 0 {
                    self.level[e.to] = self.level[v] + 1;
                    deque.push_back(e.to);
                }
            }
        }
    }

    fn max_flow(&mut self, s: usize, t: usize) -> i64 {
        let v = self.level.len();
        let mut flow: i64 = 0;
        loop {
            self.bfs(s);
            if self.level[t] < 0 {
                return flow;
            }
            self.iter = vec![0;v];
            loop {
                let f = self.dfs(s, t, MAX);
                if f == 0 {
                    break;
                }
                flow += f;
            }
        }
    }
}

fn main() {
    let mut sc = Scanner::new();
    let h = sc.parse::<usize>();
    let w = sc.parse::<usize>();
    let n = h + w + 2;
    let source = h + w;
    let sink = source + 1;

    let mut dinitz = Dinitz::new(n);

    let mut si = 0;
    let mut sj = 0;
    let mut ti = 0;
    let mut tj = 0;
    let rows = (0..h).map(|_| sc.parse::<String>()).collect::<Vec<_>>();
    for i in 0..h {
        let row = rows[i].as_bytes();
        for j in 0..w {
            match row[j] {
                b'S' => {
                    si = i;
                    sj = j;
                    dinitz.add_edge(source, i, MAX);
                    dinitz.add_edge(source, h + j, MAX);
                }
                b'T' => {
                    ti = i;
                    tj = j;
                    dinitz.add_edge(i, sink, MAX);
                    dinitz.add_edge(h + j, sink, MAX);
                }
                b'o' => {
                    dinitz.add_edge(i, h + j, 1);
                    dinitz.add_edge(h + j, i, 1);
                }
                _ => {}
            }
        }
    }

    if si == ti || sj == tj {
        println!("-1");
        return;
    }

    let f = dinitz.max_flow(source, sink);
    println!("{}", f);
}

struct Scanner {
    stdin: Stdin,
    buf: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner {
            stdin: io::stdin(),
            buf: Vec::with_capacity(256),
        }
    }

    fn parse<T: FromStr>(&mut self) -> T
        where <T as FromStr>::Err: Debug
    {
        self.buf.clear();
        let mut it = self.stdin.lock().bytes();
        let mut c = it.next().unwrap().unwrap();
        while c == ' ' as u8 || c == '\n' as u8 {
            c = it.next().unwrap().unwrap();
        }
        while !(c == ' ' as u8 || c == '\n' as u8) {
            self.buf.push(c);
            c = it.next().unwrap().unwrap();
        }
        str::from_utf8(&self.buf).unwrap().parse::<T>().unwrap()
    }
}

AtCoder Beginner Contest 010 D - 浮気予防

問題

D: 浮気予防 - AtCoder Beginner Contest 010 | AtCoder

Rust で Dinitz を実装した

コード

use std::fmt::Debug;
use std::io;
use std::io::{Read, Stdin};
use std::str;
use std::str::FromStr;
use std::usize;
use std::cmp;
use std::collections::vec_deque::VecDeque;
use std::i64::MAX;

pub struct Edge {
    pub to: usize,
    pub rev: usize,
    pub cap: i64,
}

struct Dinitz {
    g: Vec<Vec<Edge>>,
    level: Vec<i32>,
    iter: Vec<usize>,
}

impl Dinitz {
    fn new(v: usize) -> Dinitz {
        let mut g: Vec<Vec<Edge>> = Vec::new();
        for _ in 0..v {
            g.push(Vec::new());
        }
        Dinitz {
            g: g,
            level: vec![0;v],
            iter: vec![0;v],
        }
    }

    fn add_edge(&mut self, from: usize, to: usize, cap: i64) {
        let to_len = self.g[to].len();
        let from_len = self.g[from].len();
        self.g[from].push(Edge {
                              to: to,
                              rev: to_len,
                              cap: cap,
                          });
        self.g[to].push(Edge {
                            to: from,
                            rev: from_len,
                            cap: 0,
                        });
    }

    fn dfs(&mut self, v: usize, t: usize, f: i64) -> i64 {
        if v == t {
            return f;
        }
        while self.iter[v] < self.g[v].len() {

            let (e_cap, e_to, e_rev);
            {
                let ref e = self.g[v][self.iter[v]];
                e_cap = e.cap;
                e_to = e.to;
                e_rev = e.rev;
            }
            if e_cap > 0 && self.level[v] < self.level[e_to] {
                let d = self.dfs(e_to, t, cmp::min(f, e_cap));
                if d > 0 {
                    {
                        let ref mut e = self.g[v][self.iter[v]];
                        e.cap -= d;
                    }
                    {
                        let ref mut rev_edge = self.g[e_to][e_rev];
                        rev_edge.cap += d;
                    }
                    return d;
                }
            }
            self.iter[v] += 1;
        }

        return 0;
    }

    fn bfs(&mut self, s: usize) {
        let v = self.level.len();
        self.level = vec![-1;v];
        self.level[s] = 0;
        let mut deque = VecDeque::new();
        deque.push_back(s);
        while !deque.is_empty() {
            let v = deque.pop_front().unwrap();
            for e in &self.g[v] {
                if e.cap > 0 && self.level[e.to] < 0 {
                    self.level[e.to] = self.level[v] + 1;
                    deque.push_back(e.to);
                }
            }
        }
    }

    fn max_flow(&mut self, s: usize, t: usize) -> i64 {
        let v = self.level.len();
        let mut flow: i64 = 0;
        loop {
            self.bfs(s);
            if self.level[t] < 0 {
                return flow;
            }
            self.iter = vec![0;v];
            loop {
                let f = self.dfs(s, t, MAX);
                if f == 0 {
                    break;
                }
                flow += f;
            }
        }
    }
}

fn main() {
    let mut sc = Scanner::new();
    let n = sc.parse::<usize>();
    let g = sc.parse::<usize>();
    let e = sc.parse::<usize>();

    let v = n + 1;
    let mut dinitz = Dinitz::new(v);
    for _ in 0..g {
        let p = sc.parse::<usize>();
        dinitz.add_edge(p, n, 1);
    }
    for _ in 0..e {
        let a = sc.parse::<usize>();
        let b = sc.parse::<usize>();
        dinitz.add_edge(a, b, 1);
        dinitz.add_edge(b, a, 1);
    }

    let ans = dinitz.max_flow(0, n);
    println!("{}", ans);
}

struct Scanner {
    stdin: Stdin,
    buf: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner {
            stdin: io::stdin(),
            buf: Vec::with_capacity(256),
        }
    }

    fn parse<T: FromStr>(&mut self) -> T
        where <T as FromStr>::Err: Debug
    {
        self.buf.clear();
        let mut it = self.stdin.lock().bytes();
        let mut c = it.next().unwrap().unwrap();
        while c == ' ' as u8 || c == '\n' as u8 {
            c = it.next().unwrap().unwrap();
        }
        while !(c == ' ' as u8 || c == '\n' as u8) {
            self.buf.push(c);
            c = it.next().unwrap().unwrap();
        }
        str::from_utf8(&self.buf).unwrap().parse::<T>().unwrap()
    }
}

Codeforces Round #411 (Div. 1) C. Ice cream coloring

問題

http://codeforces.com/contest/804/problem/C

n 頂点のツリー T と m 種類のアイスクリームがあります。T の各頂点はアイスクリームの集合をもっています。あるアイスクリーム i をもつ頂点同士は、連結なサブグラフになっています。

m 頂点の無向重みなしグラフ G を作ります。アイスクリーム v とアイスクリーム u を集合に含む頂点が T 上に存在する時 G 上で頂点 v から頂点 u へ辺を張ります。

隣り合った頂点同士が同じ色にならないように G の頂点を塗り分ける時、必要な最小の色数と、G の各頂点の色を出力してください。

解法

T 上の各集合について、その集合に含まれるアイスクリーム同士には辺が張られるので、集合の中だけでは完全グラフになっています。つまり、G は図のようにクリーク同士がクリークによって連結されているグラフになるはずです。

f:id:kenkoooo:20170505134216j:plain

ここで重要なのが、「あるアイスクリーム i をもつ頂点同士は、連結なサブグラフになっている」という制約で、図でアイスクリーム 4 とアイスクリーム 6 を同時に集合に持つ頂点が T 上には存在すると仮定すると、T 内に閉路ができてしまうため、4 と 6 を同時にもつ頂点は存在しない事がわかります。

f:id:kenkoooo:20170505134314j:plain

そのため、クリーク内の塗り分けを DFS 的に決めていっても後から矛盾することがないため、順番に塗り分けていけば良いです。

f:id:kenkoooo:20170505134335j:plain

コード

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.NoSuchElementException;
import java.util.TreeSet;

public class C {

  class ColorManager {

    final int size;
    int cur = 0;
    final int[] used;
    int pos = 0;

    ColorManager(int size) {
      this.size = size;
      used = new int[size];
    }

    void fill() {
      cur++;
      pos = 0;
    }

    void use(int idx) {
      used[idx] = cur;
    }

    int next() {
      while (used[pos] == cur) {
        pos++;
      }

      int ret = pos;
      used[pos] = cur;
      return ret;
    }
  }

  private void solve(FastScanner in, PrintWriter out) {
    int N = in.nextInt();
    int M = in.nextInt();

    ArrayList<Integer>[] sets = new ArrayList[N];
    for (int i = 0; i < N; i++) {
      sets[i] = new ArrayList<>();
    }

    TreeSet<Integer>[] inv = new TreeSet[M];
    for (int i = 0; i < M; i++) {
      inv[i] = new TreeSet<>();
    }

    for (int i = 0; i < N; i++) {
      int s = in.nextInt();
      for (int j = 0; j < s; j++) {
        int x = in.nextInt() - 1;
        sets[i].add(x);
        inv[x].add(i);
      }
    }

    ArrayDeque<Integer> idxQueue = new ArrayDeque<>();
    boolean[] idxUsed = new boolean[N];
    boolean[] invSearched = new boolean[M];

    int[] color = new int[M];
    Arrays.fill(color, -1);

    int maxSize = 0;
    for (ArrayList<Integer> set : sets) {
      maxSize = Math.max(set.size(), maxSize);
    }

    ColorManager manager = new ColorManager(maxSize);
    for (int i = 0; i < N; i++) {
      if (idxUsed[i]) {
        continue;
      }
      idxUsed[i] = true;
      idxQueue.add(i);

      while (!idxQueue.isEmpty()) {

        manager.fill();
        int idx = idxQueue.poll();
        ArrayList<Integer> set = sets[idx];

        for (int s : set) {
          if (color[s] >= 0) {
            manager.use(color[s]);
          }
        }

        for (int s : set) {
          if (color[s] < 0) {
            color[s] = manager.next();
          }
        }

        for (int s : set) {
          if (invSearched[s]) {
            continue;
          }
          invSearched[s] = true;
          for (int nextIdx : inv[s]) {
            if (idxUsed[nextIdx]) {
              continue;
            }
            idxUsed[nextIdx] = true;
            idxQueue.add(nextIdx);
          }
        }
      }
    }

    int m = 0;
    for (int i = 0; i < M; i++) {
      if (color[i] < 0) {
        color[i] = 0;
      }
      m = Math.max(m, color[i] + 1);
    }

    out.println(m);
    for (int i = 0; i < M; i++) {
      if (i > 0) {
        out.print(" ");
      }
      out.print(color[i] + 1);
    }
    out.println();
  }

  public static void main(String[] args) {
    PrintWriter out = new PrintWriter(System.out);
    new C().solve(new FastScanner(), out);
    out.close();
  }

  private static class FastScanner {

    private final InputStream in = System.in;
    private final byte[] buffer = new byte[1024];
    private int ptr = 0;
    private int bufferLength = 0;

    private boolean hasNextByte() {
      if (ptr < bufferLength) {
        return true;
      } else {
        ptr = 0;
        try {
          bufferLength = in.read(buffer);
        } catch (IOException e) {
          e.printStackTrace();
        }
        if (bufferLength <= 0) {
          return false;
        }
      }
      return true;
    }

    private int readByte() {
      if (hasNextByte()) {
        return buffer[ptr++];
      } else {
        return -1;
      }
    }

    private static boolean isPrintableChar(int c) {
      return 33 <= c && c <= 126;
    }

    private void skipUnprintable() {
      while (hasNextByte() && !isPrintableChar(buffer[ptr])) {
        ptr++;
      }
    }

    boolean hasNext() {
      skipUnprintable();
      return hasNextByte();
    }

    public String next() {
      if (!hasNext()) {
        throw new NoSuchElementException();
      }
      StringBuilder sb = new StringBuilder();
      int b = readByte();
      while (isPrintableChar(b)) {
        sb.appendCodePoint(b);
        b = readByte();
      }
      return sb.toString();
    }

    long nextLong() {
      if (!hasNext()) {
        throw new NoSuchElementException();
      }
      long n = 0;
      boolean minus = false;
      int b = readByte();
      if (b == '-') {
        minus = true;
        b = readByte();
      }
      if (b < '0' || '9' < b) {
        throw new NumberFormatException();
      }
      while (true) {
        if ('0' <= b && b <= '9') {
          n *= 10;
          n += b - '0';
        } else if (b == -1 || !isPrintableChar(b)) {
          return minus ? -n : n;
        } else {
          throw new NumberFormatException();
        }
        b = readByte();
      }
    }

    double nextDouble() {
      return Double.parseDouble(next());
    }

    double[] nextDoubleArray(int n) {
      double[] array = new double[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextDouble();
      }
      return array;
    }

    double[][] nextDoubleMap(int n, int m) {
      double[][] map = new double[n][];
      for (int i = 0; i < n; i++) {
        map[i] = nextDoubleArray(m);
      }
      return map;
    }

    public int nextInt() {
      return (int) nextLong();
    }

    public int[] nextIntArray(int n) {
      int[] array = new int[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextInt();
      }
      return array;
    }

    public long[] nextLongArray(int n) {
      long[] array = new long[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextLong();
      }
      return array;
    }

    public String[] nextStringArray(int n) {
      String[] array = new String[n];
      for (int i = 0; i < n; i++) {
        array[i] = next();
      }
      return array;
    }

    public char[][] nextCharMap(int n) {
      char[][] array = new char[n][];
      for (int i = 0; i < n; i++) {
        array[i] = next().toCharArray();
      }
      return array;
    }

    public int[][] nextIntMap(int n, int m) {
      int[][] map = new int[n][];
      for (int i = 0; i < n; i++) {
        map[i] = nextIntArray(m);
      }
      return map;
    }
  }
}

Google Cloud Platform の GPU インスタンスの Ubuntu 16.04 LTS に TensorFlow をインストール

自分用メモ。使っている GPUNVIDIA Tesla K80 を 1 枚。

# CUDAやドライバ諸々入れてもらう
sudo curl -O http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_8.0.61-1_amd64.deb
sudo dpkg -i ./cuda-repo-ubuntu1604_8.0.61-1_amd64.deb
sudo apt-get update
sudo apt-get install cuda -y
sudo reboot

# 動作確認
nvidia-smi

# cuDNNを落としてくる
# CUDA 8.0ならcuDNN 5.1以上らしい
gsutil cp gs://.../cudnn-8.0-linux-x64-v5.1.tgz .
tar xvf cudnn-8.0-linux-x64-v5.1.tgz
sudo cp -a cuda/lib64/* /usr/local/cuda-8.0/lib64/
sudo cp -a cuda/include/* /usr/local/cuda-8.0/include/
sudo ldconfig

# pipとかをインストール
sudo apt-get install python3-pip python3-dev

# localeがja_JP.UTF-8だとpipがちゃんと動かないのでen_US.UTF-8にする(クソ)
export LC_ALL="en_US.UTF-8"
export LC_CTYPE="en_US.UTF-8"
sudo dpkg-reconfigure locales

# GPU用のTensorFlowをインストール
sudo pip3 install tensorflow-gpu

TopCoder SRM 710 Div 1 Easy: ReverseMancala

解法

操作Aと操作Bは完全に対称な操作であることが分かる。操作Aをしまくって1箇所に集め、操作Bをしまくって復元すれば良い。

コード

import java.util.ArrayList;
import java.util.Collections;

public class ReverseMancala {

  private int N;

  private ArrayList<int[]> typeA(int[] array) {
    ArrayList<int[]> list = new ArrayList<>();
    while (true) {
      int start = N - 2;
      while (start >= 0 && array[start] == 0) {
        start--;
      }
      if (start < 0) {
        return list;
      }

      int end = (start + array[start]) % N;
      list.add(new int[]{start, end});
      int num = array[start];
      array[start] = 0;
      for (int j = 0; j < num; j++) {
        array[(start + 1 + j) % N]++;
      }
    }
  }

  public int[] findMoves(int[] start, int[] target) {
    this.N = start.length;

    ArrayList<int[]> a = typeA(start);
    ArrayList<int[]> b = typeA(target);
    ArrayList<Integer> ans = new ArrayList<>();
    for (int[] p : a) {
      ans.add(p[0]);
    }

    Collections.reverse(b);
    for (int[] p : b) {
      ans.add(p[1] + start.length);
    }

    int[] ret = new int[ans.size()];
    for (int i = 0; i < ret.length; i++) {
      ret[i] = ans.get(i);
    }
    return ret;
  }
}

Codeforces Round #402 (Div. 2) E. Bitwise Formula

# 問題

http://codeforces.com/contest/779/problem/E

Mビットからなる変数がN個与えられます。各変数は、代入か、異なる2つの変数の AND OR XOR のいずれかの結果です。すべての変数の合計値が最小になるような変数 '?' と最大になるような変数 '?' を求めてください。

# 解法

各bitについて0のときと1のときを計算し、大きい方を採用すると、結果として全体を最大にするビットになる。

# コード

import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Scanner;
import java.util.TreeMap;

public class E {

  enum OP {
    XOR,
    OR,
    AND
  }

  class Operation {

    final int key1;
    final int key2;
    final OP op;

    Operation(int key1, int key2, OP op) {
      this.key1 = key1;
      this.key2 = key2;
      this.op = op;
    }
  }

  Operation[] operations;
  String[] variables;
  int[] tmp;

  int dfs(int key, int value, int pos) {
    if (tmp[key] >= 0) {
      return tmp[key];
    }
    if (variables[key] != null) {
      return variables[key].charAt(pos) - '0';
    }
    if (operations[key] == null) {
      throw new IllegalArgumentException();
    }

    Operation operation = operations[key];
    int key1 = operation.key1;
    int v1 = key1 >= 0 ? dfs(key1, value, pos) : value;
    int key2 = operation.key2;
    int v2 = key2 >= 0 ? dfs(key2, value, pos) : value;

    switch (operation.op) {
      case AND:
        tmp[key] = v1 & v2;
        break;
      case OR:
        tmp[key] = v1 | v2;
        break;
      case XOR:
        tmp[key] = v1 ^ v2;
        break;
    }
    return tmp[key];
  }

  private void solve(Scanner in, PrintWriter out) {
    String[] nm = in.nextLine().split(" ");
    int N = Integer.parseInt(nm[0]);
    int M = Integer.parseInt(nm[1]);

    TreeMap<String, Integer> shrink = new TreeMap<>();
    shrink.put("?", -1);
    operations = new Operation[N];
    variables = new String[N];
    tmp = new int[N];

    String[][] formulas = new String[N][];

    for (int i = 0; i < N; i++) {
      String[] formula = in.nextLine().split(" ");
      String key = formula[0];
      shrink.put(key, i);
      formulas[i] = formula;
    }

    for (String[] formula : formulas) {
      String key = formula[0];
      int i = shrink.get(key);
      if (formula.length == 3) {
        variables[i] = formula[2];
      } else {
        int key1 = shrink.get(formula[2]);
        int key2 = shrink.get(formula[4]);
        OP op = OP.valueOf(formula[3]);
        operations[i] = new Operation(key1, key2, op);
      }
    }

    char[] max = new char[M];
    for (int pos = 0; pos < M; pos++) {
      Arrays.fill(tmp, -1);
      int result0 = 0;
      for (int i = 0; i < N; i++) {
        result0 += dfs(i, 0, pos);
      }

      Arrays.fill(tmp, -1);
      int result1 = 0;
      for (int i = 0; i < N; i++) {
        result1 += dfs(i, 1, pos);
      }

      if (result0 >= result1) {
        max[pos] = '0';
      } else {
        max[pos] = '1';
      }
    }
    char[] min = new char[M];
    for (int pos = 0; pos < M; pos++) {
      Arrays.fill(tmp, -1);
      int result0 = 0;
      for (int i = 0; i < N; i++) {
        result0 += dfs(i, 0, pos);
      }

      Arrays.fill(tmp, -1);
      int result1 = 0;
      for (int i = 0; i < N; i++) {
        result1 += dfs(i, 1, pos);
      }

      if (result0 <= result1) {
        min[pos] = '0';
      } else {
        min[pos] = '1';
      }
    }

    out.println(new String(min));
    out.println(new String(max));
  }

  public static void main(String[] args) {
    PrintWriter out = new PrintWriter(System.out);
    new E().solve(new Scanner(System.in), out);
    out.close();
  }
}