2021/11/17

Codeforces Round #751 (Div. 1) C. Optimal Insertion

codeforces.com

とりあえずbはソートしても良いことが分かって、aに挿入された後のb全体で見てもソートされていることが分かる。すると、bを小さい順にaに挿入していくと、既に挿入されたbの値と後から挿入されるbの値はお互いに影響しないことが分かる。よって、各biの挿入は独立に考えてよく、biは挿入した際に増える反点数が最小であるような位置に挿入すれば良い。複数ある場合は一番前にすると無難。

use crate::fenwick_tree::FenwickTree;
use crate::lazy_segment_tree::LazySegmentTree;

fn main() {
    let (r, w) = (std::io::stdin(), std::io::stdout());
    let mut sc = IO::new(r.lock(), w.lock());

    let t: usize = sc.read();
    for _ in 0..t {
        let n: usize = sc.read();
        let m: usize = sc.read();

        let mut a = Vec::with_capacity(n);
        for i in 0..n {
            let x: i64 = sc.read();
            a.push((x, i));
        }

        let mut b: Vec<i64> = sc.vec(m);
        a.sort();
        b.sort();
        let ans = solve(a, b);
        sc.write(ans);
        sc.write('\n');
    }
}
const INF: i64 = 1 << 60;

fn solve(a: Vec<(i64, usize)>, b: Vec<i64>) -> i64 {
    let n = a.len() + 1;
    let mut min_seg = LazySegmentTree::new(
        n,
        || (INF, n),
        |&s: &(i64, usize), &t: &(i64, usize)| s.min(t),
        |&f, &(x, i)| (f + x, i),
        |&f, &g| f + g,
        || 0,
    );

    for i in 0..n {
        min_seg.set(i, (0, i));
    }

    for i in 0..(n - 1) {
        min_seg.apply_range((i + 1)..n, 1);
    }

    let mut cost = 0;
    let mut a_cur = 0;
    let mut b_cur = 0;
    while a_cur < a.len() {
        let mut tt = 0;
        while b_cur < b.len() && b[b_cur] < a[a_cur].0 {
            b_cur += 1;
            tt += 1;
        }
        if tt > 0 {
            let (x, _) = min_seg.prod(0..n);
            cost += tt * x;
        }

        let mut next = a_cur;
        while next < a.len() && a[a_cur].0 == a[next].0 {
            next += 1;
        }

        for i in a_cur..next {
            let (_, i) = a[i];
            min_seg.apply_range((i + 1)..n, -1);
        }

        if b_cur < b.len() && b[b_cur] == a[a_cur].0 {
            let mut b_next = b_cur;
            while b_next < b.len() && b[b_cur] == b[b_next] {
                b_next += 1;
            }

            let count = b_next - b_cur;
            let (x, _) = min_seg.prod(0..n);
            cost += x * count as i64;
            b_cur = b_next;
        }

        for i in a_cur..next {
            let (_, i) = a[i];
            min_seg.apply_range(0..i, 1);
        }

        a_cur = next;
    }

    let mut tt = b.len() - b_cur;
    if tt > 0 {
        let (x, _) = min_seg.prod(0..n);
        cost += tt as i64 * x;
    }

    let mut bit = FenwickTree::new(n, || 0);
    let mut s = 0;

    for (_, i) in a {
        s += bit.sum(i, n);
        bit.add(i, 1);
    }
    cost + s
}
pub mod fenwick_tree {
    /// `FenwickTree` is a data structure that can efficiently update elements
    /// and calculate prefix sums in a table of numbers.
    /// [https://en.wikipedia.org/wiki/Fenwick_tree](https://en.wikipedia.org/wiki/Fenwick_tree)
    pub struct FenwickTree<T, F> {
        n: usize,
        data: Vec<T>,
        initialize: F,
    }

    impl<T, F> FenwickTree<T, F>
    where
        T: Copy + std::ops::AddAssign + std::ops::Sub<Output = T>,
        F: Fn() -> T,
    {
        /// Constructs a new `FenwickTree`. The size of `FenwickTree` should be specified by `size`.
        pub fn new(size: usize, initialize: F) -> FenwickTree<T, F> {
            FenwickTree {
                n: size + 1,
                data: vec![initialize(); size + 1],
                initialize,
            }
        }

        pub fn add(&mut self, k: usize, value: T) {
            let mut x = k;
            while x < self.n {
                self.data[x] += value;
                x |= x + 1;
            }
        }

        /// Returns a sum of range `[l, r)`
        pub fn sum(&self, l: usize, r: usize) -> T {
            self.sum_one(r) - self.sum_one(l)
        }

        /// Returns a sum of range `[0, k)`
        pub fn sum_one(&self, k: usize) -> T {
            assert!(k < self.n, "Cannot calculate for range [{}, {})", k, self.n);
            let mut result = (self.initialize)();
            let mut x = k as i32 - 1;
            while x >= 0 {
                result += self.data[x as usize];
                x = (x & (x + 1)) - 1;
            }

            result
        }
    }
}

pub mod lazy_segment_tree {
    type Range = std::ops::Range<usize>;

    pub struct LazySegmentTree<S, Op, E, F, Mapping, Composition, Id> {
        n: usize,
        size: usize,
        log: usize,
        data: Vec<S>,
        lazy: Vec<F>,
        op: Op,
        e: E,
        mapping: Mapping,
        composition: Composition,
        id: Id,
    }

    impl<S, Op, E, F, Mapping, Composition, Id> LazySegmentTree<S, Op, E, F, Mapping, Composition, Id>
    where
        S: Clone,
        E: Fn() -> S,
        F: Clone,
        Op: Fn(&S, &S) -> S,
        Mapping: Fn(&F, &S) -> S,
        Composition: Fn(&F, &F) -> F,
        Id: Fn() -> F,
    {
        pub fn new(
            n: usize,
            e: E,
            op: Op,
            mapping: Mapping,
            composition: Composition,
            id: Id,
        ) -> Self {
            let size = n.next_power_of_two() as usize;
            LazySegmentTree {
                n,
                size,
                log: size.trailing_zeros() as usize,
                data: vec![e(); 2 * size],
                lazy: vec![id(); size],
                e,
                op,
                mapping,
                composition,
                id,
            }
        }
        pub fn set(&mut self, mut index: usize, value: S) {
            assert!(index < self.n);
            index += self.size;
            for i in (1..=self.log).rev() {
                self.push(index >> i);
            }
            self.data[index] = value;
            for i in 1..=self.log {
                self.update(index >> i);
            }
        }

        pub fn get(&mut self, mut index: usize) -> S {
            assert!(index < self.n);
            index += self.size;
            for i in (1..=self.log).rev() {
                self.push(index >> i);
            }
            self.data[index].clone()
        }

        pub fn prod(&mut self, range: Range) -> S {
            let mut l = range.start;
            let mut r = range.end;
            assert!(l < r && r <= self.n);

            l += self.size;
            r += self.size;

            for i in (1..=self.log).rev() {
                if ((l >> i) << i) != l {
                    self.push(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.push(r >> i);
                }
            }

            let mut sum_l = (self.e)();
            let mut sum_r = (self.e)();
            while l < r {
                if l & 1 != 0 {
                    sum_l = (self.op)(&sum_l, &self.data[l]);
                    l += 1;
                }
                if r & 1 != 0 {
                    r -= 1;
                    sum_r = (self.op)(&self.data[r], &sum_r);
                }
                l >>= 1;
                r >>= 1;
            }

            (self.op)(&sum_l, &sum_r)
        }

        pub fn all_prod(&self) -> S {
            self.data[1].clone()
        }

        pub fn apply(&mut self, mut index: usize, f: F) {
            assert!(index < self.n);
            index += self.size;
            for i in (1..=self.log).rev() {
                self.push(index >> i);
            }
            self.data[index] = (self.mapping)(&f, &self.data[index]);
            for i in 1..=self.log {
                self.update(index >> i);
            }
        }
        pub fn apply_range(&mut self, range: Range, f: F) {
            let mut l = range.start;
            let mut r = range.end;
            assert!(l <= r && r <= self.n);
            if l == r {
                return;
            }

            l += self.size;
            r += self.size;

            for i in (1..=self.log).rev() {
                if ((l >> i) << i) != l {
                    self.push(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.push((r - 1) >> i);
                }
            }

            {
                let mut l = l;
                let mut r = r;
                while l < r {
                    if l & 1 != 0 {
                        self.all_apply(l, f.clone());
                        l += 1;
                    }
                    if r & 1 != 0 {
                        r -= 1;
                        self.all_apply(r, f.clone());
                    }
                    l >>= 1;
                    r >>= 1;
                }
            }

            for i in 1..=self.log {
                if ((l >> i) << i) != l {
                    self.update(l >> i);
                }
                if ((r >> i) << i) != r {
                    self.update((r - 1) >> i);
                }
            }
        }

        fn update(&mut self, k: usize) {
            self.data[k] = (self.op)(&self.data[2 * k], &self.data[2 * k + 1]);
        }
        fn all_apply(&mut self, k: usize, f: F) {
            self.data[k] = (self.mapping)(&f, &self.data[k]);
            if k < self.size {
                self.lazy[k] = (self.composition)(&f, &self.lazy[k]);
            }
        }
        fn push(&mut self, k: usize) {
            self.all_apply(2 * k, self.lazy[k].clone());
            self.all_apply(2 * k + 1, self.lazy[k].clone());
            self.lazy[k] = (self.id)();
        }
    }
}
pub struct IO<R, W: std::io::Write>(R, std::io::BufWriter<W>);

impl<R: std::io::Read, W: std::io::Write> IO<R, W> {
    pub fn new(r: R, w: W) -> IO<R, W> {
        IO(r, std::io::BufWriter::new(w))
    }
    pub fn write<S: ToString>(&mut self, s: S) {
        use std::io::Write;
        self.1.write_all(s.to_string().as_bytes()).unwrap();
    }
    pub fn read<T: std::str::FromStr>(&mut self) -> T {
        use std::io::Read;
        let buf = self
            .0
            .by_ref()
            .bytes()
            .map(|b| b.unwrap())
            .skip_while(|&b| b == b' ' || b == b'\n' || b == b'\r' || b == b'\t')
            .take_while(|&b| b != b' ' && b != b'\n' && b != b'\r' && b != b'\t')
            .collect::<Vec<_>>();
        unsafe { std::str::from_utf8_unchecked(&buf) }
            .parse()
            .ok()
            .expect("Parse error.")
    }
    pub fn usize0(&mut self) -> usize {
        self.read::<usize>() - 1
    }
    pub fn vec<T: std::str::FromStr>(&mut self, n: usize) -> Vec<T> {
        (0..n).map(|_| self.read()).collect()
    }
    pub fn chars(&mut self) -> Vec<char> {
        self.read::<String>().chars().collect()
    }
}