Rust で競技プログラミングをする時に使う高速な標準入力 Scanner

struct Scanner {
    ptr: usize,
    length: usize,
    buf: Vec<u8>,
    small_cache: Vec<u8>,
}

#[allow(dead_code)]
impl Scanner {
    fn new() -> Scanner {
        Scanner {
            ptr: 0,
            length: 0,
            buf: vec![0; 1024],
            small_cache: vec![0; 1024],
        }
    }

    fn load(&mut self) {
        use std::io::Read;
        let mut s = std::io::stdin();
        self.length = s.read(&mut self.buf).unwrap();
    }

    fn byte(&mut self) -> u8 {
        if self.ptr >= self.length {
            self.ptr = 0;
            self.load();
            if self.length == 0 {
                self.buf[0] = b'\n';
                self.length = 1;
            }
        }

        self.ptr += 1;
        return self.buf[self.ptr - 1];
    }

    fn is_space(b: u8) -> bool {
        b == b'\n' || b == b'\r' || b == b'\t' || b == b' '
    }

    fn read_vec<T>(&mut self, n: usize) -> Vec<T>
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Debug,
    {
        (0..n).map(|_| self.read()).collect()
    }

    fn usize_read(&mut self) -> usize {
        self.read()
    }

    fn read<T>(&mut self) -> T
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Debug,
    {
        let mut b = self.byte();
        while Scanner::is_space(b) {
            b = self.byte();
        }

        for pos in 0..self.small_cache.len() {
            self.small_cache[pos] = b;
            b = self.byte();
            if Scanner::is_space(b) {
                return String::from_utf8_lossy(&self.small_cache[0..(pos + 1)])
                    .parse()
                    .unwrap();
            }
        }

        let mut v = self.small_cache.clone();
        while !Scanner::is_space(b) {
            v.push(b);
            b = self.byte();
        }
        return String::from_utf8_lossy(&v).parse().unwrap();
    }
}

同じディレクトリに複数の CMakeLists.txt を置きたい時

github.com

.
├── CMakeLists.txt      # Root.cmake を include するだけ
├── main.cpp            # main()
├── Mods                #
│   ├── Mod1.cmake      # mod1.cpp => mod1 にする
│   ├── mod1.cpp        # 
│   ├── mod1.h          # 
│   ├── Mod2.cmake      # mod2.cpp => mod2 にする。実は Mod1 に依存している。
│   ├── mod2.cpp        #
│   └── mod2.h          #
└── Root.cmake          # mod1 と mod2 をビルドして ./a.out を固める

AGC 003 D - Anticube

解説

各 s_i を素因数分解する(方法は後で考える)。

まず、 s_i に 3 つ同じ素因数が含まれるとき、それらを取り除いてしまっても構わない。例えば  s_i = 3^4 5^2 のとき、 3^3 を取り除いて  s_i = 3 \dot 5^2 と考えて良い。このように各 s_i を素数の 3 乗で割れるだけ割っておく。すると、各 s_i は  s_i = (0 個以上の相異なる素数の積)^1 (0 個以上の相異なる素数の積)^2 と表せる。ここで前者を p_i 後者を q_i として  s_i = p_i^1 q_i^2とする。この s_i との積が立法数になる s_j は  s_j = p_i^2q_i^1 と表せるはずである。

なので、問題は  s_i = p_i^1 q_i^2 s_j = p_i^2q_i^1 の数をそれぞれ求め、多い方を採用するようにする。また、  s_i = 1^1 1^2 と表せるものは 1 つだけ採用できる。

最初に戻って s_i を素因数分解する方法を考える。 s_i \geq p^3 となるような全ての素数 p について調べたあとの s_i は素数の 3 乗を含まないので、以下の 3 つのいずれかである。

素数の2乗のとき、  s_i = 1^1 q_i^2 と表せる。そうでない時は  s_i = q_i^1 1^2 と表せる。このレベルまで分かれば十分なので、完全に素因数分解する必要はなく、実行時間制限に間に合う。

コード

use std::cmp;
use std::collections::{BTreeMap, BTreeSet};

fn get_primes(n: usize) -> Vec<u64> {
    let mut is_prime = vec![true; n + 1];
    is_prime[0] = false;
    is_prime[1] = false;
    for p in 2..(n + 1) {
        if !is_prime[p] {
            continue;
        }
        let mut cur = 2 * p;
        while cur <= n {
            is_prime[cur] = false;
            cur += p;
        }
    }

    is_prime
        .iter()
        .enumerate()
        .filter(|&(_, &is_prime)| is_prime)
        .map(|(p, _)| p as u64)
        .collect()
}

fn main() {
    let primes = get_primes(100000);
    let prime2 = primes
        .iter()
        .map(|&p| (p * p, p))
        .collect::<BTreeMap<u64, u64>>();
    let mut sc = Scanner::new();
    let n = sc.usize_read();

    let mut count = BTreeMap::new();

    for _ in 0..n {
        let mut s: u64 = sc.read();
        let mut s1 = 1;
        let mut s2 = 1;
        for &prime in &primes {
            let p2 = prime * prime;
            let p3 = p2 * prime;
            if p3 > s {
                break;
            }
            while s % p3 == 0 {
                s /= p3;
            }
            if s % p2 == 0 {
                s /= p2;
                s2 *= prime;
            } else if s % prime == 0 {
                s /= prime;
                s1 *= prime;
            }
        }

        match prime2.get(&s) {
            Some(&p) => s2 *= p,
            None => s1 *= s,
        }

        let &cur = count.get(&(s1, s2)).unwrap_or(&0);
        count.insert((s1, s2), cur + 1);
    }

    let mut ans = 0;
    let mut set = BTreeSet::new();
    set.insert((1, 1));
    for &(s1, s2) in count.keys() {
        if set.contains(&(s1, s2)) {
            continue;
        }
        let &count1 = count.get(&(s1, s2)).unwrap_or(&0);
        let &count2 = count.get(&(s2, s1)).unwrap_or(&0);
        ans += cmp::max(count1, count2);
        set.insert((s1, s2));
        set.insert((s2, s1));
    }

    if count.contains_key(&(1, 1)) {
        ans += 1;
    }
    println!("{}", ans);
}

struct Scanner {
    ptr: usize,
    length: usize,
    buf: Vec<u8>,
    small_cache: Vec<u8>,
}

#[allow(dead_code)]
impl Scanner {
    fn new() -> Scanner {
        Scanner {
            ptr: 0,
            length: 0,
            buf: vec![0; 1024],
            small_cache: vec![0; 1024],
        }
    }

    fn load(&mut self) {
        use std::io::Read;
        let mut s = std::io::stdin();
        self.length = s.read(&mut self.buf).unwrap();
    }

    fn byte(&mut self) -> u8 {
        if self.ptr >= self.length {
            self.ptr = 0;
            self.load();
            if self.length == 0 {
                self.buf[0] = b'\n';
                self.length = 1;
            }
        }

        self.ptr += 1;
        return self.buf[self.ptr - 1];
    }

    fn is_space(b: u8) -> bool {
        b == b'\n' || b == b'\r' || b == b'\t' || b == b' '
    }

    fn read_vec<T>(&mut self, n: usize) -> Vec<T>
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Debug,
    {
        (0..n).map(|_| self.read()).collect()
    }

    fn usize_read(&mut self) -> usize {
        self.read()
    }

    fn read<T>(&mut self) -> T
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Debug,
    {
        let mut b = self.byte();
        while Scanner::is_space(b) {
            b = self.byte();
        }

        for pos in 0..self.small_cache.len() {
            self.small_cache[pos] = b;
            b = self.byte();
            if Scanner::is_space(b) {
                return String::from_utf8_lossy(&self.small_cache[0..(pos + 1)])
                    .parse()
                    .unwrap();
            }
        }

        let mut v = self.small_cache.clone();
        while !Scanner::is_space(b) {
            v.push(b);
            b = self.byte();
        }
        return String::from_utf8_lossy(&v).parse().unwrap();
    }
}

ARC 072 D - Alice&Brown

解法

実験すると |x-y|<=1 の時に grundy 数が 0 になりそうな気がするので、結論ありきで帰納法で証明する。

コード

use std::collections::{BTreeMap, BTreeSet};
fn main() {
    let mut sc = Scanner::new();
    let x: u64 = sc.read();
    let y: u64 = sc.read();
    let (x, y) = if x > y { (y, x) } else { (x, y) };
    println!("{}", if y - x <= 1 { "Brown" } else { "Alice" });
    // exp();
}

fn exp() {
    let mut map = BTreeMap::new();
    map.insert((0, 0), 0);
    map.insert((0, 1), 0);
    map.insert((1, 1), 0);

    for i in 0..1000 {
        for j in i..1000 {
            grundy(i, j, &mut map);
        }
    }
}

fn grundy(x: usize, y: usize, map: &mut BTreeMap<(usize, usize), usize>) -> usize {
    let (x, y) = if x > y { (y, x) } else { (x, y) };

    if map.contains_key(&(x, y)) {
        return map[&(x, y)];
    }

    let mut grundy_nums = BTreeSet::new();
    for x_to_y in 1..(x / 2 + 1) {
        let next_x = x - x_to_y * 2;
        let next_y = y + x_to_y;
        grundy_nums.insert(grundy(next_x, next_y, map));
    }
    for y_to_x in 1..(y / 2 + 1) {
        let next_x = x + y_to_x;
        let next_y = y - y_to_x * 2;
        grundy_nums.insert(grundy(next_x, next_y, map));
    }

    let mut g = 0;
    loop {
        if !grundy_nums.contains(&g) {
            map.insert((x, y), g);
            if g == 0 {
                println!("{} {}", x, y);
            }
            return g;
        }
        g += 1;
    }
}

struct Scanner {
    ptr: usize,
    length: usize,
    buf: Vec<u8>,
    small_cache: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner {
            ptr: 0,
            length: 0,
            buf: vec![0; 1024],
            small_cache: vec![0; 1024],
        }
    }

    fn load(&mut self) {
        use std::io::Read;
        let mut s = std::io::stdin();
        self.length = s.read(&mut self.buf).unwrap();
    }

    fn byte(&mut self) -> u8 {
        if self.ptr >= self.length {
            self.ptr = 0;
            self.load();
            if self.length == 0 {
                self.buf[0] = b'\n';
                self.length = 1;
            }
        }

        self.ptr += 1;
        return self.buf[self.ptr - 1];
    }

    fn is_space(b: u8) -> bool {
        b == b'\n' || b == b'\r' || b == b'\t' || b == b' '
    }

    fn read<T>(&mut self) -> T
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Debug,
    {
        let mut b = self.byte();
        while Scanner::is_space(b) {
            b = self.byte();
        }

        for pos in 0..self.small_cache.len() {
            self.small_cache[pos] = b;
            b = self.byte();
            if Scanner::is_space(b) {
                return String::from_utf8_lossy(&self.small_cache[0..(pos + 1)])
                    .parse()
                    .unwrap();
            }
        }

        let mut v = self.small_cache.clone();
        while !Scanner::is_space(b) {
            v.push(b);
            b = self.byte();
        }
        return String::from_utf8_lossy(&v).parse().unwrap();
    }
}

CODE FESTIVAL 2016 Grand Final C - Cheating Nim

解法

Grundy 数の原理から、「各  a_i もしくは  a_i-1 の XOR の値を 0 にすることが出来るか?出来るなら  a_i-1 を使った回数は何回か?」という問題に言い換えることができます。

ここで、ある i について  a_i-1 を使うと XOR の値 x は  x \oplus a_i \oplus (a_i-1) になります。すなわち、現在の XOR の値と  a_i \oplus (a_i -1) の XOR をとった値になります。ところで  a_i \oplus (a_i-1) は必ず  2^k-1 の形になります。証明もできますが、実験でもわかります。なので、貪欲に上のビットから潰していくことで解けます。

コード

fn main() {
    let mut sc = Scanner::new();
    let n: usize = sc.read();
    let a: Vec<u32> = (0..n).map(|_| sc.read()).collect();


    let mut xor_sum = 0;
    for &a in &a {
        xor_sum ^= a;
    }

    let mut ans = 0;
    for bit in (0..30).rev() {
        if ((1 << bit) & xor_sum) == 0 { continue; }
        let x = (1 << (bit + 1)) - 1;
        for &a in &a {
            let y = a ^ (a - 1);
            if y == x {
                xor_sum ^= y;
                ans += 1;
                break;
            }
        }
    }

    if xor_sum != 0 {
        println!("-1");
    } else {
        println!("{}", ans);
    }
}

struct Scanner {
    ptr: usize,
    length: usize,
    buf: Vec<u8>,
    small_cache: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner { ptr: 0, length: 0, buf: vec![0; 1024], small_cache: vec![0; 1024] }
    }

    fn load(&mut self) {
        use std::io::Read;
        let mut s = std::io::stdin();
        self.length = s.read(&mut self.buf).unwrap();
    }

    fn byte(&mut self) -> u8 {
        if self.ptr >= self.length {
            self.ptr = 0;
            self.load();
            if self.length == 0 {
                self.buf[0] = b'\n';
                self.length = 1;
            }
        }

        self.ptr += 1;
        return self.buf[self.ptr - 1];
    }

    fn is_space(b: u8) -> bool { b == b'\n' || b == b'\r' || b == b'\t' || b == b' ' }

    fn read<T>(&mut self) -> T where T: std::str::FromStr, T::Err: std::fmt::Debug, {
        let mut b = self.byte();
        while Scanner::is_space(b) {
            b = self.byte();
        }

        for pos in 0..self.small_cache.len() {
            self.small_cache[pos] = b;
            b = self.byte();
            if Scanner::is_space(b) {
                return String::from_utf8_lossy(&self.small_cache[0..(pos + 1)]).parse().unwrap();
            }
        }

        let mut v = self.small_cache.clone();
        while !Scanner::is_space(b) {
            v.push(b);
            b = self.byte();
        }
        return String::from_utf8_lossy(&v).parse().unwrap();
    }
}

「みんなのプロコン2018」 D - XOR XorY

問題

D - XOR XorY

解法

解説を読んでも分からなかったので 「みんなのプロコン 2018」: D - XOR XorY · うさぎ小屋 を参考にした。

 a_i \oplus a_j \oplus X = A_{i,j} または  a_i \oplus a_j \oplus Y = A_{i,j} ということは  a_i \oplus a_j \oplus A_{i,j} \in \{X, Y\} なので  B_{i,j}=A_{i,j} \oplus X, Z= Y \oplus X とすると  a_i \oplus a_j \oplus B_{i,j} \in \{0,Z\} となる。以後、  a_i \oplus a_j \oplus B_{i,j} \in \{0,Z\} を満たす  \{a_I\} を数え上げることにする。

  • i=j でも条件を満たすため  a_i \oplus a_i \oplus B_{i,i} \in \{0,Z\} より  B_{i,i} \in \{0,Z\}
  • i=j で条件を満たすとき j=i でも条件を満たすので  a_i \oplus a_j \oplus B_{i,j} \in \{0,Z\} かつ  a_j \oplus a_i \oplus B_{j,i} \in \{0,Z\} より  B_{i,j} \oplus B_{j,i} \in \{0,Z\}
  •  a_0 を選んだとき  a_0 \oplus a_j \oplus B_{0,j} \in \{0,Z\} より  a_j \in \{ a_0 \oplus B_{0,j} ,a_0 \oplus B_{0,j} \oplus Z \} だから各 j について  a_0 \oplus B_{0,j} または  a_0 \oplus B_{0,j} \oplus Z となるものを  \{x\} から選んで [ex: a_j] とする。

コード

use std::cmp;

const MAX_A: usize = 2048;
const MOD: usize = 1000000007;

fn main() {
    let mut sc = Scanner::new();
    let n: usize = sc.read();
    let k: usize = sc.read();
    let x: usize = sc.read();
    let y: usize = sc.read();
    let c: Vec<usize> = (0..n).map(|_| sc.read()).collect();

    let mut a = vec![vec![0; k]; k];
    for i in 0..k {
        for j in 0..k {
            a[i][j] = sc.read::<usize>() ^ x;
        }
    }

    println!("{}", solve(k, x ^ y, &c, &a));
}

fn solve(k: usize, z: usize, x: &Vec<usize>, b: &Vec<Vec<usize>>) -> usize {
    let comb = Combination::new(MAX_A * 2, MOD);

    // check
    for i in 0..k {
        if b[i][i] != 0 && b[i][i] != z { return 0; }
    }
    for i in 0..k {
        for j in 0..k {
            let delta = b[i][j] ^ b[j][i];
            if delta != 0 && delta != z { return 0; }
        }
    }

    let mut count = vec![0; MAX_A + 1];
    let mut xor_counted = vec![0; MAX_A + 1];
    for xi in x {
        let xi = *xi;
        count[cmp::min(xi, xi ^ z)] += 1;
        if (xi ^ z) < xi { xor_counted[xi ^ z] += 1; }
    }

    let mut answer = 0;
    for a0 in 0..(MAX_A + 1) {
        if count[a0] == 0 { continue; }
        let mut used = vec![0; MAX_A + 1];
        used[a0] += 1;

        let mut can_construct = true;
        for j in 1..k {
            let aj = cmp::min(b[0][j] ^ a0, b[0][j] ^ a0 ^ z);

            if used[aj] == count[aj] {
                can_construct = false;
                break;
            }
            used[aj] += 1;
        }

        if !can_construct { continue; }

        let mut ans_for_a0 = 1;
        for a in 0..(MAX_A + 1) {
            if used[a] == 0 { continue; }

            let needed = used[a];
            let max_not_xor = cmp::min(count[a] - xor_counted[a], needed);
            let min_not_xor = if needed < xor_counted[a] { 0 } else { needed - xor_counted[a] };

            let mut combination_sum = 0;
            for choose in min_not_xor..(max_not_xor + 1) {
                combination_sum += comb.get(needed, choose);
                if combination_sum > MOD { combination_sum -= MOD; }
            }
            ans_for_a0 *= combination_sum;
            ans_for_a0 %= MOD;
        }
        answer += ans_for_a0;
        if answer > MOD { answer -= MOD; }
    }

    return answer;
}

pub struct Combination {
    fact: Vec<usize>,
    inv_fact: Vec<usize>,
    modulo: usize,
}

impl Combination {
    pub fn new(max: usize, modulo: usize) -> Combination {
        let mut inv = vec![0; max + 1];
        let mut fact = vec![0; max + 1];
        let mut inv_fact = vec![0; max + 1];
        inv[1] = 1;
        for i in 2..(max + 1) {
            inv[i] = inv[modulo % i] * (modulo - modulo / i) % modulo;
        }
        fact[0] = 1;
        inv_fact[0] = 1;
        for i in 0..max { fact[i + 1] = fact[i] * (i + 1) % modulo; }
        for i in 0..max {
            inv_fact[i + 1] = inv_fact[i] * inv[i + 1] % modulo;
        }
        Combination { fact: fact, inv_fact: inv_fact, modulo: modulo }
    }

    pub fn get(&self, x: usize, y: usize) -> usize {
        assert!(x >= y);
        self.fact[x] * self.inv_fact[y] % self.modulo * self.inv_fact[x - y] % self.modulo
    }
}

struct Scanner {
    ptr: usize,
    length: usize,
    buf: Vec<u8>,
    small_cache: Vec<u8>,
}

impl Scanner {
    fn new() -> Scanner {
        Scanner { ptr: 0, length: 0, buf: vec![0; 1024], small_cache: vec![0; 1024] }
    }

    fn load(&mut self) {
        use std::io::Read;
        let mut s = std::io::stdin();
        self.length = s.read(&mut self.buf).unwrap();
    }

    fn byte(&mut self) -> u8 {
        if self.ptr >= self.length {
            self.ptr = 0;
            self.load();
            if self.length == 0 {
                self.buf[0] = b'\n';
                self.length = 1;
            }
        }

        self.ptr += 1;
        return self.buf[self.ptr - 1];
    }

    fn is_space(b: u8) -> bool { b == b'\n' || b == b'\r' || b == b'\t' || b == b' ' }

    fn read<T>(&mut self) -> T where T: std::str::FromStr, T::Err: std::fmt::Debug, {
        let mut b = self.byte();
        while Scanner::is_space(b) {
            b = self.byte();
        }

        for pos in 0..self.small_cache.len() {
            self.small_cache[pos] = b;
            b = self.byte();
            if Scanner::is_space(b) {
                return String::from_utf8_lossy(&self.small_cache[0..(pos + 1)]).parse().unwrap();
            }
        }

        let mut v = self.small_cache.clone();
        while !Scanner::is_space(b) {
            v.push(b);
            b = self.byte();
        }
        return String::from_utf8_lossy(&v).parse().unwrap();
    }
}

SoundHound Inc. Programming Contest 2018 (春) D - 建物

問題

D - 建物

解法

(i, j) から (i, k) に移動して (i+1, k) に移動する経路 (i, j) => (i+1, k) を考える。このとき、 (i, j-1) => (i+1, k) よりも多くの報酬が得られることに留意する。次に (i+1, k+1) に移動する経路を考える。このとき (i, j-1) => (i+1, k+1) よりも (i, j) => (i+1, k+1) の方がより多くの報酬が得られる。このように、下の階から上の階へ上がる経路はしゃくとり法によって求まる。

コード

use std::cmp;
use std::i64::MIN;

fn main() {
    let (h, w) = {
        let v = read_values::<usize>();
        (v[0], v[1])
    };

    let p = {
        let mut p = vec![vec![0; w]; h + 1];
        for i in 0..h {
            let v = read_values::<i64>();
            for j in 0..w {
                p[i][j] = v[j];
            }
        }
        p
    };

    let f = {
        let mut f = vec![vec![0; w]; h + 1];
        for i in 0..h {
            let v = read_values::<i64>();
            for j in 0..w {
                f[i][j] = v[j];
            }
        }
        f
    };

    let mut gain = vec![vec![MIN; w]; h + 1];
    gain[0][0] = p[0][0];

    for i in 0..h {
        let mut left_turn = vec![0; w];
        for j in 1..w {
            left_turn[j] = cmp::max(left_turn[j - 1] + p[i][j - 1] - (f[i][j - 1] + f[i][j]), 0);
        }
        let mut right_turn = vec![0; w];
        for j in (0..(w - 1)).rev() {
            right_turn[j] = cmp::max(right_turn[j + 1] + p[i][j + 1] - (f[i][j + 1] + f[i][j]), 0);
        }

        let mut sum = vec![0; w + 1];
        for j in 0..w {
            sum[j + 1] = sum[j] + p[i][j] - f[i][j];
        }

        if i == 0 {
            for j in 0..w {
                gain[i + 1][j] = p[i + 1][j] - f[i + 1][j] + sum[j + 1] + right_turn[j];
            }
        } else {
            let mut left_max = 0;
            for j in 0..w {
                let segment_sum = sum[j + 1] - sum[left_max + 1];
                let enter_gain = p[i + 1][j] - f[i + 1][j];

                let old_gain = gain[i][left_max] + segment_sum + left_turn[left_max];
                let new_gain = gain[i][j] + left_turn[j];
                if old_gain < new_gain {
                    left_max = j;
                }

                gain[i + 1][j] = cmp::max(old_gain, new_gain) + right_turn[j] + enter_gain;
            }

            let mut right_max = w - 1;
            for j in (0..w).rev() {
                let segment_sum = sum[right_max] - sum[j];
                let enter_gain = p[i + 1][j] - f[i + 1][j];

                let old_gain = gain[i][right_max] + segment_sum + right_turn[right_max];
                let new_gain = gain[i][j] + right_turn[j];
                if old_gain < new_gain {
                    right_max = j;
                }

                gain[i + 1][j] = cmp::max(
                    cmp::max(old_gain, new_gain) + left_turn[j] + enter_gain,
                    gain[i + 1][j],
                );
            }
        }
    }
    for i in 0..w {
        println!("{}", gain[h][i]);
    }
}


fn read_line() -> String {
    let stdin = std::io::stdin();
    let mut buf = String::new();
    stdin.read_line(&mut buf).unwrap();
    buf
}

fn read_values<T>() -> Vec<T>
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Debug,
{
    read_line()
        .split(' ')
        .map(|a| a.trim().parse().unwrap())
        .collect()
}