Codeforces Round #373 (Div. 1) C. Sasha and Array
問題
http://codeforces.com/contest/718/problem/C
下記の二種類のクエリを処理せよ。
- 数列 のlからrまでにvずつ足せ。
- 数列 のlからrまでの各 について を求め、その和を出力しろ。
解法
StarrySkyTree の各ノードに行列をもたせる。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.NoSuchElementException; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class E { private static class Task { final int MOD = (int) (1e9 + 7); Matrix base = new Matrix(new long[]{1, 1, 1, 0}); Matrix one = new Matrix(new long[]{1, 0, 0, 1}); Matrix zero = new Matrix(new long[4]); void solve(FastScanner in, PrintWriter out) { int N = in.nextInt(); int Q = in.nextInt(); int[] a = in.nextIntArray(N); Matrix[] ms = new Matrix[N]; for (int i = 0; i < N; i++) { ms[i] = powMatrix(base, a[i] - 1); } StarrySkyTreeMatrix seg = new StarrySkyTreeMatrix(ms); for (int i = 0; i < Q; i++) { int t = in.nextInt(); if (t == 1) { int l = in.nextInt() - 1; int r = in.nextInt() - 1; int v = in.nextInt(); seg.add(l, r + 1, powMatrix(base, v)); } else { int l = in.nextInt() - 1; int r = in.nextInt() - 1; out.println(seg.getSum(l, r + 1).mat[0]); } } } class StarrySkyTreeMatrix { final int N = 1 << 17; Matrix[] seg, segAdd; StarrySkyTreeMatrix(Matrix[] ms) { seg = new Matrix[N * 2]; segAdd = new Matrix[N * 2]; for (int i = 0; i < N * 2; i++) { seg[i] = zero; segAdd[i] = one; } for (int i = 0; i < ms.length; i++) { seg[i + N] = ms[i]; } for (int i = N - 1; i > 0; i--) { seg[i] = addMatrix(seg[i * 2], seg[i * 2 + 1]); } } void add(int a, int b, Matrix v) { add(a, b, v, 0, N, 1); } void add(int a, int b, Matrix v, int l, int r, int k) { if (r <= a || b <= l) return; if (a <= l && r <= b) { segAdd[k] = productMatrix(segAdd[k], v); return; } add(a, b, v, l, (l + r) / 2, k * 2); add(a, b, v, (l + r) / 2, r, k * 2 + 1); seg[k] = addMatrix( productMatrix(seg[k * 2], segAdd[k * 2]), productMatrix(seg[k * 2 + 1], segAdd[k * 2 + 1]) ); } Matrix getSum(int a, int b) { return getSum(a, b, 0, N, 1); } Matrix getSum(int a, int b, int l, int r, int k) { if (b <= l || r <= a) return zero; if (a <= l && r <= b) return productMatrix(seg[k], segAdd[k]); Matrix x = getSum(a, b, l, (l + r) / 2, k * 2); Matrix y = getSum(a, b, (l + r) / 2, r, k * 2 + 1); return productMatrix(segAdd[k], addMatrix(x, y)); } } class Matrix { long[] mat; Matrix(long[] mat) { this.mat = mat; } } Matrix addMatrix(Matrix a, Matrix b) { Matrix c = new Matrix(new long[4]); for (int i = 0; i < 4; i++) { c.mat[i] = a.mat[i] + b.mat[i]; if (c.mat[i] > MOD) c.mat[i] -= MOD; } return c; } Matrix productMatrix(Matrix a, Matrix b) { Matrix c = new Matrix(new long[4]); c.mat[0] = (a.mat[0] * b.mat[0] + a.mat[1] * b.mat[2]) % MOD; c.mat[1] = (a.mat[0] * b.mat[1] + a.mat[1] * b.mat[3]) % MOD; c.mat[2] = (a.mat[2] * b.mat[0] + a.mat[3] * b.mat[2]) % MOD; c.mat[3] = (a.mat[2] * b.mat[1] + a.mat[3] * b.mat[3]) % MOD; return c; } Matrix powMatrix(Matrix a, long p) { Matrix ret = new Matrix(new long[]{1, 0, 0, 1}); while (p > 0) { if (p % 2 == 1) ret = productMatrix(ret, a); a = productMatrix(a, a); p /= 2; } return ret; } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }