CodeChef Snackdown 2016 : Online Elimination Round - Cube Towers

問題

Contest Page | CodeChef

N 個の空文字列と M 個の単語が与えられる。Q 個のクエリに応答せよ。クエリは以下の 3 種類からなる。

  • 文字列 X の先頭に文字 C を挿入する。
  • L〜R 番目の文字列の中で単語 P を含むものがいくつあるか出力する。
  • L〜R 番目の文字列の中で単語 P がいくつ含まれるか出力する。

解法

【重要】CodeChef は Java の実行時間制限がゆるいので Java で書く。

まず M 個の単語のハッシュを計算して map に保存しておく。

クエリを全て読んでおき、「単語 m がクエリ q の時に文字列 n に現れた」という情報に変換する。この時、クエリ q の時に操作される文字列は文字列 X[q] のみなので、n=X[q] であることがわかるとして、m を求める方法を考える。

クエリ q の時に文字列 X[q] に登場する単語は複数ありうるので、すべての単語について考える必要があるが、各クエリについて M 個の文字列を見ていると O(QM) となり TLE してしまう。そこで、「文字列 X[q] の末尾 L 文字と一致する単語」を探すことにする。文字列の長さ L は M^0.5 通りしかないので、ハッシュを計算しておけば、各クエリについて O(M^0.5 logN)で見つけることができるようになる。HashMap 神の力で O(M^0.5 ほげ) にすることで全体を O(QM^0.5 ほげ) にすることができる。

あとは各単語 m について、応答していく。この辺は BIT を使えば O((Q+M)logN) くらいになる(追加クエリが増えうるけどだいたいこんなもんだと思う)。


コード

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.util.*;

public class Main {
  private static class Task {

    void solve(FastScanner in, PrintWriter out) throws Exception {
      int N = in.nextInt(), M = in.nextInt(), Q = in.nextInt();

      // 単語を読み込んでリバースしておく
      String[] words = new String[M];
      for (int i = 0; i < M; i++) words[i] = new StringBuffer(in.next()).reverse().toString();

      //クエリは先に全て読んでおく
      int[] type = new int[Q];
      int[] X = new int[Q];
      char[] C = new char[Q];
      int[] L = new int[Q];
      int[] R = new int[Q];
      int[] P = new int[Q];
      for (int i = 0; i < Q; i++) {
        type[i] = in.nextInt();
        if (type[i] == 0) {
          X[i] = in.nextInt() - 1;
          C[i] = in.next().toCharArray()[0];
        } else {
          L[i] = in.nextInt() - 1;
          R[i] = in.nextInt() - 1;
          P[i] = in.nextInt() - 1;
        }
      }

      TreeSet<Integer> lengthSet = new TreeSet<>();
      HashMap<Long, Integer> wordHash = new HashMap<>();
      for (int m = 0; m < M; m++) {
        lengthSet.add(words[m].length());
        RollingHash64 hash64 = new RollingHash64(words[m]);
        wordHash.put(hash64.getHash(0, words[m].length()), m);
      }

      String[] towers = new String[N];
      Arrays.fill(towers, "");
      for (int q = 0; q < Q; q++) if (type[q] == 0) towers[X[q]] += C[q];

      RollingHash64[] towerHash = new RollingHash64[N];
      for (int i = 0; i < N; i++) towerHash[i] = new RollingHash64(towers[i]);

      // wordQuery.get(m) に単語 m に関するクエリを入れていく
      ArrayList<ArrayList<Integer>> wordQuery = new ArrayList<>();
      for (int i = 0; i < M; i++) wordQuery.add(new ArrayList<>());

      int[] curLength = new int[N];
      for (int q = 0; q < Q; q++) {
        // 追加クエリでなければ、貯めるだけ
        if (type[q] != 0) {
          wordQuery.get(P[q]).add(q);
          continue;
        }

        // 文字列の長さについてループすれば、O(M^0.5) で回せる
        for (int l : lengthSet) {
          if (curLength[X[q]] + 1 < l) break;
          long hash = towerHash[X[q]].getHash(curLength[X[q]] - l + 1, curLength[X[q]] + 1);
          if (wordHash.containsKey(hash)) wordQuery.get(wordHash.get(hash)).add(q);
        }
        curLength[X[q]]++;
      }

      long[] answer = new long[Q];
      for (int m = 0; m < M; m++) {
        FenwickTree bit2 = new FenwickTree(N);
        FenwickTree bit1 = new FenwickTree(N);
        boolean[] used = new boolean[N];
        for (int q : wordQuery.get(m)) {
          if (type[q] == 0) {
            bit2.add(X[q], 1);
            if (!used[X[q]]) {
              bit1.add(X[q], 1);
              used[X[q]] = true;
            }
          } else if (type[q] == 1) {
            answer[q] = bit1.sum(L[q], R[q] + 1);
          } else {
            answer[q] = bit2.sum(L[q], R[q] + 1);
          }
        }
      }

      for (int q = 0; q < Q; q++)
        if (type[q] != 0) {
          out.println(answer[q]);
        }
    }
  }

  static class FenwickTree {
    int N;
    long[] data;

    FenwickTree(int N) {
      this.N = N + 1;
      data = new long[N + 1];
    }

    void add(int k, long val) {
      for (int x = k; x < N; x |= x + 1) {
        data[x] += val;
      }
    }

    // [0, k)
    long sum(int k) {
      if (k >= N) k = N - 1;
      int ret = 0;
      for (int x = k - 1; x >= 0; x = (x & (x + 1)) - 1) {
        ret += data[x];
      }
      return ret;
    }

    // [l, r)
    long sum(int l, int r) {
      return sum(r) - sum(l);
    }

    long get(int k) {
      assert (0 <= k && k < N);
      return sum(k + 1) - sum(k);
    }

    int getAsSetOf(int w) {
      w++;
      if (w <= 0) return -1;
      int x = 0;
      int k = 1;
      while (k * 2 <= N) k *= 2;
      for (; k > 0; k /= 2) {
        if (x + k <= N && data[x + k - 1] < w) {
          w -= data[x + k - 1];
          x += k;
        }
      }
      return x;
    }
  }

  static class RollingHash64 {
    private final long base = 1000000007;
    private long[] hash, pow;
    private int n;

    RollingHash64(String S) {
      n = S.length();
      hash = new long[n + 1];
      pow = new long[n + 1];
      hash[0] = 0;
      pow[0] = 1;
      for (int j = 0; j < n; j++) {
        pow[j + 1] = pow[j] * base;
        hash[j + 1] = (hash[j] * base + S.charAt(j));
      }
    }

    long getHash(int l, int r) {
      return (hash[r] - hash[l] * pow[r - l]);
    }

    boolean match(int l1, int r1, int l2, int r2) {
      return getHash(l1, r1) == getHash(l2, r2);
    }

    boolean match(int l1, int l2, int k) {
      return match(l1, l1 + k, l2, l2 + k);
    }

    int lcp(int i, int j) {
      int l = 0, r = Math.min(n - i, n - j) + 1;
      while (l + 1 < r) {
        int m = (l + r) / 2;
        if (match(i, j, m))
          l = m;
        else
          r = m;
      }
      return l;
    }
  }

  // Template
  public static void main(String[] args) throws Exception {
    OutputStream outputStream = System.out;
    FastScanner in = new FastScanner();
    PrintWriter out = new PrintWriter(outputStream);
    Task solver = new Task();
    solver.solve(in, out);
    out.close();
  }
  private static class FastScanner {
    private final InputStream in = System.in;
    private final byte[] buffer = new byte[1024];
    private int ptr = 0;
    private int bufferLength = 0;

    private boolean hasNextByte() {
      if (ptr < bufferLength) {
        return true;
      } else {
        ptr = 0;
        try {
          bufferLength = in.read(buffer);
        } catch (IOException e) {
          e.printStackTrace();
        }
        if (bufferLength <= 0) {
          return false;
        }
      }
      return true;
    }

    private int readByte() {
      if (hasNextByte()) return buffer[ptr++];
      else return -1;
    }

    private static boolean isPrintableChar(int c) {
      return 33 <= c && c <= 126;
    }

    private void skipUnprintable() {
      while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++;
    }

    boolean hasNext() {
      skipUnprintable();
      return hasNextByte();
    }

    public String next() {
      if (!hasNext()) throw new NoSuchElementException();
      StringBuilder sb = new StringBuilder();
      int b = readByte();
      while (isPrintableChar(b)) {
        sb.appendCodePoint(b);
        b = readByte();
      }
      return sb.toString();
    }

    long nextLong() {
      if (!hasNext()) throw new NoSuchElementException();
      long n = 0;
      boolean minus = false;
      int b = readByte();
      if (b == '-') {
        minus = true;
        b = readByte();
      }
      if (b < '0' || '9' < b) {
        throw new NumberFormatException();
      }
      while (true) {
        if ('0' <= b && b <= '9') {
          n *= 10;
          n += b - '0';
        } else if (b == -1 || !isPrintableChar(b)) {
          return minus ? -n : n;
        } else {
          throw new NumberFormatException();
        }
        b = readByte();
      }
    }

    double nextDouble() {
      return Double.parseDouble(next());
    }

    double[] nextDoubleArray(int n) {
      double[] array = new double[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextDouble();
      }
      return array;
    }

    double[][] nextDoubleMap(int n, int m) {
      double[][] map = new double[n][];
      for (int i = 0; i < n; i++) {
        map[i] = nextDoubleArray(m);
      }
      return map;
    }

    public int nextInt() {
      return (int) nextLong();
    }

    public int[] nextIntArray(int n) {
      int[] array = new int[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextInt();
      }
      return array;
    }
  }
}