CS Academy Round #14 Subarrays Xor Sum
解法
eha くんに解説をしてもらってようやく理解した。
各ビットごとに独立なので、ビットごとに分ける。
0と1のみの数列の長さ k 以下の数列の xor の合計値をとりたい。 i 番目の数を末尾とする長さ k 以下の数列のうち、 xor が z となるものをカウントする。
コード
import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; import java.util.NoSuchElementException; public class Main { private static final int MOD = (int) (1e9 + 7); /** * 長さ k 以下のものについて計算する */ private long count(int[] bit, int k) { if (k == 0) { return 0; } int N = bit.length; long res = 0; // count[i] := 長さ k 以下の数列で、 xor が i となるものの数 int[] count = new int[2]; int[] imos = new int[N + 1]; for (int i = 0; i < N; i++) { imos[i + 1] = imos[i] ^ bit[i]; } for (int i = 0; i < N; i++) { // 末尾が 1 ならば、 xor が 0 の数列は 1 に、 1 の数列は 0 になる if (bit[i] == 1) { int t = count[0]; count[0] = count[1]; count[1] = t; } // 長さ 1 の数列を追加する if (bit[i] == 1) { count[1]++; } else { count[0]++; } // 長さ k+1 になってしまった列を取り除く if (i - k >= 0) { int w = imos[i + 1] ^ imos[i - k]; if (w == 1) { count[1]--; } else { count[0]--; } } // i 番目を末尾とする、長さ k 以下の数列の個数のうち xor が 1 のものを足しておく res += count[1]; } return res; } private void solve(FastScanner in, PrintWriter out) { int N = in.nextInt(); int a = in.nextInt(); int b = in.nextInt(); int[] array = in.nextIntArray(N); long ans = 0; for (int bit = 0; bit < 31; bit++) { int[] tmp = new int[N]; for (int i = 0; i < N; i++) { tmp[i] = (array[i] >> bit) & 1; } long count = (count(tmp, b) - count(tmp, a - 1)); if (count < 0) { count += MOD; } count %= MOD; ans += count * (1 << bit); ans %= MOD; } out.println(ans); } public static void main(String[] args) { PrintWriter out = new PrintWriter(System.out); new Main().solve(new FastScanner(), out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) { return buffer[ptr++]; } else { return -1; } } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) { ptr++; } } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) { throw new NoSuchElementException(); } StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) { throw new NoSuchElementException(); } long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) { array[i] = nextInt(); } return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) { array[i] = nextLong(); } return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) { array[i] = next(); } return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) { array[i] = next().toCharArray(); } return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }