Educational Codeforces Round 121 (Rated for Div. 2) E. Black and White Tree

問題

https://codeforces.com/contest/1626/problem/E

木があります。いくつかの頂点が黒く塗られています。黒い頂点は必ず2つ以上あります。ある頂点に駒があるとき、黒い頂点を1つ選択することで、駒を選択した頂点の方向に1回動かすことができます。2回連続で同じ頂点を選択することができない時、頂点iに置かれた駒を適切に黒い頂点を選択することでいずれかの黒い頂点まで移動させることができるか求めてください。

解法

駒が置かれた頂点から見て同じ方向に2つ黒い頂点があるとき、これらを交互に選択することでその方向に連続で移動することができる。このことから、黒い頂点に移動できない頂点たちは、黒い頂点に囲まれた領域にあり、かつ、その領域の外に黒い頂点はない。

f:id:kenkoooo:20220121023933p:plain

コード

use std::collections::{BTreeSet, VecDeque};
use std::time::Instant;

fn main() {
    let (r, w) = (std::io::stdin(), std::io::stdout());
    let mut sc = IO::new(r.lock(), w.lock());

    let n: usize = sc.read();
    let color: Vec<i64> = sc.vec(n);
    let mut graph = vec![vec![]; n];
    for _ in 1..n {
        let u = sc.usize0();
        let v = sc.usize0();
        graph[u].push(v);
        graph[v].push(u);
    }

    let a = (0..n).find(|&i| color[i] == 1).unwrap();
    let dist_a = bfs(&graph, a);
    let b = (0..n)
        .filter(|&i| i != a && color[i] == 1)
        .min_by_key(|&i| dist_a[i])
        .unwrap();
    let dist_b = bfs(&graph, b);
    let ab = dist_b[a];

    let ans = if let Some(x) = (0..n).find(|&i| dist_b[i] + dist_a[i] == ab && color[i] == 0) {
        let mut q = VecDeque::new();
        q.push_back(x);
        let n = graph.len();
        let mut dist = vec![n; n];
        dist[x] = 0;
        while let Some(v) = q.pop_front() {
            if color[v] == 1 {
                continue;
            }
            for &next in graph[v].iter() {
                if dist[next] > dist[v] + 1 {
                    dist[next] = dist[v] + 1;
                    q.push_back(next);
                }
            }
        }

        let reached = (0..n).filter(|&i| color[i] == 1 && dist[i] < n).count();
        let total = (0..n).filter(|&i| color[i] == 1).count();
        if reached != total {
            vec![1; n]
        } else {
            let mut ans = vec![1; n];
            for i in 0..n {
                if dist[i] < n {
                    ans[i] = 0;
                }
            }

            for i in 0..n {
                if graph[i].iter().any(|&i| color[i] == 1) {
                    ans[i] = 1;
                }
                if color[i] == 1 {
                    ans[i] = 1;
                }
            }

            let mut s = vec![BTreeSet::new(); n];
            for i in 0..n {
                for &j in graph[i].iter() {
                    if dist[i] < n && dist[j] < n {
                        s[i].insert(j);
                    }
                }
            }

            let mut q = VecDeque::new();
            for i in 0..n {
                if s[i].len() == 1 && color[i] == 0 {
                    q.push_back(i);
                }
            }
            while let Some(i) = q.pop_front() {
                let next = s[i].clone();
                for next in next {
                    assert!(s[next].remove(&i));
                    if s[next].len() == 1 && color[next] == 0 {
                        q.push_back(next);
                    }
                }
                s[i].clear();
            }

            for v in 0..n {
                if graph[v].iter().all(|&i| color[i] == 0) {
                    continue;
                }

                for &next in graph[v].iter() {
                    if s[next].is_empty() {
                        dfs(next, v, &graph, &mut ans);
                    }
                }
            }

            if (0..n)
                .filter(|&i| s[i].len() >= 3)
                .any(|v| graph[v].iter().any(|&next| color[next] == 1))
            {
                vec![1; n]
            } else {
                ans
            }
        }
    } else {
        assert_eq!(ab, 1);
        vec![1; n]
    };

    for (i, ans) in ans.into_iter().enumerate() {
        if i > 0 {
            sc.write(' ');
        }
        sc.write(ans);
    }
    sc.write('\n');
}
fn dfs(v: usize, p: usize, graph: &Vec<Vec<usize>>, ans: &mut Vec<i64>) {
    ans[v] = 1;
    for &next in graph[v].iter() {
        if p == next {
            continue;
        }
        dfs(next, v, graph, ans);
    }
}

fn bfs(graph: &Vec<Vec<usize>>, from: usize) -> Vec<usize> {
    let mut q = VecDeque::new();
    q.push_back(from);
    let n = graph.len();
    let mut dist = vec![n; n];
    dist[from] = 0;
    while let Some(v) = q.pop_front() {
        for &next in graph[v].iter() {
            if dist[next] > dist[v] + 1 {
                dist[next] = dist[v] + 1;
                q.push_back(next);
            }
        }
    }
    dist
}

pub struct ReRooting<T, Identity, Merge, AddRoot> {
    dp: Vec<Vec<T>>,
    ans: Vec<T>,
    graph: Vec<Vec<usize>>,
    identity: Identity,
    merge: Merge,
    add_root: AddRoot,
}

impl<T, Identity, Merge, AddRoot> ReRooting<T, Identity, Merge, AddRoot>
where
    T: Clone,
    Identity: Fn() -> T,
    Merge: Fn(T, T) -> T,
    AddRoot: Fn(T) -> T,
{
    pub fn new(n: usize, identity: Identity, merge: Merge, add_root: AddRoot) -> Self {
        Self {
            dp: vec![vec![]; n],
            ans: vec![identity(); n],
            graph: vec![vec![]; n],
            identity,
            merge,
            add_root,
        }
    }
    pub fn add_edge(&mut self, a: usize, b: usize) {
        self.graph[a].push(b);
    }
    pub fn build(&mut self) {
        self.dfs(0, 0);
        self.dfs2(0, 0, (self.identity)());
    }

    fn dfs(&mut self, v: usize, p: usize) -> T {
        let mut sum = (self.identity)();
        let deg = self.graph[v].len();
        self.dp[v] = vec![(self.identity)(); deg];
        let next = self.graph[v].clone();
        for (i, next) in next.into_iter().enumerate() {
            if next == p {
                continue;
            }
            let t = self.dfs(next, v);
            self.dp[v][i] = t.clone();
            sum = (self.merge)(sum, t);
        }
        (self.add_root)(sum)
    }
    fn dfs2(&mut self, v: usize, p: usize, dp_p: T) {
        for (i, &next) in self.graph[v].iter().enumerate() {
            if next == p {
                self.dp[v][i] = dp_p.clone();
            }
        }

        let deg = self.graph[v].len();
        let mut dp_l = vec![(self.identity)(); deg + 1];
        let mut dp_r = vec![(self.identity)(); deg + 1];
        for i in 0..deg {
            dp_l[i + 1] = (self.merge)(dp_l[i].clone(), self.dp[v][i].clone());
        }
        for i in (0..deg).rev() {
            dp_r[i] = (self.merge)(dp_r[i + 1].clone(), self.dp[v][i].clone());
        }

        self.ans[v] = (self.add_root)(dp_l[deg].clone());

        let next = self.graph[v].clone();
        for (i, next) in next.into_iter().enumerate() {
            if next == p {
                continue;
            }
            self.dfs2(
                next,
                v,
                (self.add_root)((self.merge)(dp_l[i].clone(), dp_r[i + 1].clone())),
            );
        }
    }
}
pub struct IO<R, W: std::io::Write>(R, std::io::BufWriter<W>);

impl<R: std::io::Read, W: std::io::Write> IO<R, W> {
    pub fn new(r: R, w: W) -> IO<R, W> {
        IO(r, std::io::BufWriter::new(w))
    }
    pub fn write<S: ToString>(&mut self, s: S) {
        use std::io::Write;
        self.1.write_all(s.to_string().as_bytes()).unwrap();
    }
    pub fn read<T: std::str::FromStr>(&mut self) -> T {
        use std::io::Read;
        let buf = self
            .0
            .by_ref()
            .bytes()
            .map(|b| b.unwrap())
            .skip_while(|&b| b == b' ' || b == b'\n' || b == b'\r' || b == b'\t')
            .take_while(|&b| b != b' ' && b != b'\n' && b != b'\r' && b != b'\t')
            .collect::<Vec<_>>();
        unsafe { std::str::from_utf8_unchecked(&buf) }
            .parse()
            .ok()
            .expect("Parse error.")
    }
    pub fn usize0(&mut self) -> usize {
        self.read::<usize>() - 1
    }
    pub fn vec<T: std::str::FromStr>(&mut self, n: usize) -> Vec<T> {
        (0..n).map(|_| self.read()).collect()
    }
    pub fn chars(&mut self) -> Vec<char> {
        self.read::<String>().chars().collect()
    }
}