CodeChef Snackdown 2016 : Online Elimination Round - Cube Towers
問題
N 個の空文字列と M 個の単語が与えられる。Q 個のクエリに応答せよ。クエリは以下の 3 種類からなる。
- 文字列 X の先頭に文字 C を挿入する。
- L〜R 番目の文字列の中で単語 P を含むものがいくつあるか出力する。
- L〜R 番目の文字列の中で単語 P がいくつ含まれるか出力する。
解法
【重要】CodeChef は Java の実行時間制限がゆるいので Java で書く。
まず M 個の単語のハッシュを計算して map に保存しておく。
クエリを全て読んでおき、「単語 m がクエリ q の時に文字列 n に現れた」という情報に変換する。この時、クエリ q の時に操作される文字列は文字列 X[q] のみなので、n=X[q] であることがわかるとして、m を求める方法を考える。
クエリ q の時に文字列 X[q] に登場する単語は複数ありうるので、すべての単語について考える必要があるが、各クエリについて M 個の文字列を見ていると O(QM) となり TLE してしまう。そこで、「文字列 X[q] の末尾 L 文字と一致する単語」を探すことにする。文字列の長さ L は M^0.5 通りしかないので、ハッシュを計算しておけば、各クエリについて O(M^0.5 logN)で見つけることができるようになる。HashMap 神の力で O(M^0.5 ほげ) にすることで全体を O(QM^0.5 ほげ) にすることができる。
あとは各単語 m について、応答していく。この辺は BIT を使えば O((Q+M)logN) くらいになる(追加クエリが増えうるけどだいたいこんなもんだと思う)。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.*; public class Main { private static class Task { void solve(FastScanner in, PrintWriter out) throws Exception { int N = in.nextInt(), M = in.nextInt(), Q = in.nextInt(); // 単語を読み込んでリバースしておく String[] words = new String[M]; for (int i = 0; i < M; i++) words[i] = new StringBuffer(in.next()).reverse().toString(); //クエリは先に全て読んでおく int[] type = new int[Q]; int[] X = new int[Q]; char[] C = new char[Q]; int[] L = new int[Q]; int[] R = new int[Q]; int[] P = new int[Q]; for (int i = 0; i < Q; i++) { type[i] = in.nextInt(); if (type[i] == 0) { X[i] = in.nextInt() - 1; C[i] = in.next().toCharArray()[0]; } else { L[i] = in.nextInt() - 1; R[i] = in.nextInt() - 1; P[i] = in.nextInt() - 1; } } TreeSet<Integer> lengthSet = new TreeSet<>(); HashMap<Long, Integer> wordHash = new HashMap<>(); for (int m = 0; m < M; m++) { lengthSet.add(words[m].length()); RollingHash64 hash64 = new RollingHash64(words[m]); wordHash.put(hash64.getHash(0, words[m].length()), m); } String[] towers = new String[N]; Arrays.fill(towers, ""); for (int q = 0; q < Q; q++) if (type[q] == 0) towers[X[q]] += C[q]; RollingHash64[] towerHash = new RollingHash64[N]; for (int i = 0; i < N; i++) towerHash[i] = new RollingHash64(towers[i]); // wordQuery.get(m) に単語 m に関するクエリを入れていく ArrayList<ArrayList<Integer>> wordQuery = new ArrayList<>(); for (int i = 0; i < M; i++) wordQuery.add(new ArrayList<>()); int[] curLength = new int[N]; for (int q = 0; q < Q; q++) { // 追加クエリでなければ、貯めるだけ if (type[q] != 0) { wordQuery.get(P[q]).add(q); continue; } // 文字列の長さについてループすれば、O(M^0.5) で回せる for (int l : lengthSet) { if (curLength[X[q]] + 1 < l) break; long hash = towerHash[X[q]].getHash(curLength[X[q]] - l + 1, curLength[X[q]] + 1); if (wordHash.containsKey(hash)) wordQuery.get(wordHash.get(hash)).add(q); } curLength[X[q]]++; } long[] answer = new long[Q]; for (int m = 0; m < M; m++) { FenwickTree bit2 = new FenwickTree(N); FenwickTree bit1 = new FenwickTree(N); boolean[] used = new boolean[N]; for (int q : wordQuery.get(m)) { if (type[q] == 0) { bit2.add(X[q], 1); if (!used[X[q]]) { bit1.add(X[q], 1); used[X[q]] = true; } } else if (type[q] == 1) { answer[q] = bit1.sum(L[q], R[q] + 1); } else { answer[q] = bit2.sum(L[q], R[q] + 1); } } } for (int q = 0; q < Q; q++) if (type[q] != 0) { out.println(answer[q]); } } } static class FenwickTree { int N; long[] data; FenwickTree(int N) { this.N = N + 1; data = new long[N + 1]; } void add(int k, long val) { for (int x = k; x < N; x |= x + 1) { data[x] += val; } } // [0, k) long sum(int k) { if (k >= N) k = N - 1; int ret = 0; for (int x = k - 1; x >= 0; x = (x & (x + 1)) - 1) { ret += data[x]; } return ret; } // [l, r) long sum(int l, int r) { return sum(r) - sum(l); } long get(int k) { assert (0 <= k && k < N); return sum(k + 1) - sum(k); } int getAsSetOf(int w) { w++; if (w <= 0) return -1; int x = 0; int k = 1; while (k * 2 <= N) k *= 2; for (; k > 0; k /= 2) { if (x + k <= N && data[x + k - 1] < w) { w -= data[x + k - 1]; x += k; } } return x; } } static class RollingHash64 { private final long base = 1000000007; private long[] hash, pow; private int n; RollingHash64(String S) { n = S.length(); hash = new long[n + 1]; pow = new long[n + 1]; hash[0] = 0; pow[0] = 1; for (int j = 0; j < n; j++) { pow[j + 1] = pow[j] * base; hash[j + 1] = (hash[j] * base + S.charAt(j)); } } long getHash(int l, int r) { return (hash[r] - hash[l] * pow[r - l]); } boolean match(int l1, int r1, int l2, int r2) { return getHash(l1, r1) == getHash(l2, r2); } boolean match(int l1, int l2, int k) { return match(l1, l1 + k, l2, l2 + k); } int lcp(int i, int j) { int l = 0, r = Math.min(n - i, n - j) + 1; while (l + 1 < r) { int m = (l + r) / 2; if (match(i, j, m)) l = m; else r = m; } return l; } } // Template public static void main(String[] args) throws Exception { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) { array[i] = nextInt(); } return array; } } }