CS Academy Round #12 Prefix Suffix Counting

問題

https://csacademy.com/contest/round-12/#task/prefix-suffix-counting/

巨大な整数 N と M が与えられる。M は K 桁である。この時、1 以上 N 以下の範囲で、上 K 桁と下 K 桁がともに M と一致する整数は何通りあるか。

解法

整数 N の桁数を L とする。

場合分けすると、以下の3通りが考えられる。

  1. 2K 桁未満のとき
  2. 2K 桁以上、L桁未満のとき
  3. L 桁のとき

2K 桁未満のとき、M が重なっているので、条件を満たす整数が作れるか考えれば良い。例えば M=232 のとき、 5 桁の整数 23232は条件を満たすが、条件を満たす 4 桁の整数は作れない。

2K 桁以上、L桁未満のとき、上下のK桁ずつを除いた桁はなんの数字でも良いので、l 桁の整数を作るとき、 10^{l-2K} 通りの整数が作れる。

L 桁の整数を作るとき、N の上K桁がMより大きければ、「2K 桁以上、L桁未満のとき」と同じになる。上K桁がMより小さければ、L桁の整数は作れない。上K桁がMに等しいときは桁dp をする。

コード

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.util.NoSuchElementException;

/*
                   _ooOoo_
                  o8888888o
                  88" . "88
                  (| -_- |)
                  O\  =  /O
               ____/`---'\____
             .'  \\|     |//  `.
            /  \\|||  :  |||//  \
           /  _||||| -:- |||||-  \
           |   | \\\  -  /// |   |
           | \_|  ''\---/''  |   |
           \  .-\__  `-`  ___/-. /
         ___`. .'  /--.--\  `. . __
      ."" '<  `.___\_<|>_/___.'  >'"".
     | | :  `- \`.;`\ _ /`;.`/ - ` : | |
     \  \ `-.   \_ __\ /__ _/   .-` /  /
======`-.____`-.___\_____/___.-`____.-'======
                   `=---='
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            pass System Test!
*/

public class Main {
  private static class Task {
    final long MOD = (long) 1e9 + 7;

    long modPow(long x, long e) {
      long ret = 1;
      long cur = x;
      while (e > 0) {
        if ((e & 1) != 0) ret = (ret * cur) % MOD;
        cur = (cur * cur) % MOD;
        e /= 2;
      }
      return ret;
    }

    int digitDP(String A, int p) {
      int N = A.length();
      long[][] dp = new long[N + 1][2];
      dp[0][0] = 1;
      for (int i = 0; i < N; i++) {
        for (int j = 0; j < 2; j++) {
          int lim = j > 0 ? 9 : A.charAt(i) - '0';
          for (int d = 0; d <= lim; d++) {
            int k = j > 0 || d < lim ? 1 : 0;
            dp[i + 1][k] += dp[i][j];
            if (dp[i + 1][k] > MOD) dp[i + 1][k] -= MOD;
          }
        }
      }

      long ans = 0;
      for (int i = 0; i < 2; i++) {
        if (p < 0 && i == 0) continue;
        ans += dp[N][i];
      }
      return (int) (ans % MOD);
    }

    void solve(FastScanner in, PrintWriter out) {
      String N = in.next();
      String M = in.next();
      int L = N.length();
      int K = M.length();

      if (L < K) {
        out.println(0);
        return;
      }

      long ans = 0;

      // 2*K 未満
      RollingHash64 hash64 = new RollingHash64(M);
      for (int l = K; l < 2 * K; l++) {
        if (l > L) break;
        int lcp = hash64.lcp(0, l - K);
        if (K * 2 - lcp == l) {
          if (l < L) ans++;
          else {
            String check = M.substring(0, l - K) + M;
            if (N.equals(check)) ans++;
          }
        }
      }

      if (L < 2 * K) {
        out.println(ans);
        return;
      }

      if (2 * K == L && N.compareTo(M + M) >= 0) {
        out.println(ans + 1);
        return;
      }

      // L 未満 2*K 以上
      for (int l = 2 * K; l < L; l++) {
        ans += modPow(10, l - 2 * K);
        if (ans > MOD) ans -= MOD;
      }

      String head = N.substring(0, K);
      int c = head.compareTo(M);
      if (c < 0) {
        out.println(ans);
        return;
      }
      if (c > 0) {
        ans += modPow(10, L - 2 * K);
        if (ans > MOD) ans -= MOD;
        out.println(ans);
        return;
      }

      String tail = N.substring(L - K);
      int p = tail.compareTo(M);
      ans += digitDP(N.substring(K, L - K), p);
      ans %= MOD;
      out.println(ans);
    }
    class RollingHash64 {
      private final long base = 1000000007;
      private long[] hash, pow;
      private int n;

      RollingHash64(String S) {
        n = S.length();
        hash = new long[n + 1];
        pow = new long[n + 1];
        hash[0] = 0;
        pow[0] = 1;
        for (int j = 0; j < n; j++) {
          pow[j + 1] = pow[j] * base;
          hash[j + 1] = (hash[j] * base + S.charAt(j));
        }
      }

      long getHash(int l, int r) {
        return (hash[r] - hash[l] * pow[r - l]);
      }

      boolean match(int l1, int r1, int l2, int r2) {
        return getHash(l1, r1) == getHash(l2, r2);
      }

      boolean match(int l1, int l2, int k) {
        return match(l1, l1 + k, l2, l2 + k);
      }

      int lcp(int i, int j) {
        int l = 0, r = Math.min(n - i, n - j) + 1;
        while (l + 1 < r) {
          int m = (l + r) / 2;
          if (match(i, j, m))
            l = m;
          else
            r = m;
        }
        return l;
      }
    }
  }

  /**
   * ここから下はテンプレートです。
   */
  public static void main(String[] args) {
    OutputStream outputStream = System.out;
    FastScanner in = new FastScanner();
    PrintWriter out = new PrintWriter(outputStream);
    Task solver = new Task();
    solver.solve(in, out);
    out.close();
  }
  private static class FastScanner {
    private final InputStream in = System.in;
    private final byte[] buffer = new byte[1024];
    private int ptr = 0;
    private int bufferLength = 0;

    private boolean hasNextByte() {
      if (ptr < bufferLength) {
        return true;
      } else {
        ptr = 0;
        try {
          bufferLength = in.read(buffer);
        } catch (IOException e) {
          e.printStackTrace();
        }
        if (bufferLength <= 0) {
          return false;
        }
      }
      return true;
    }

    private int readByte() {
      if (hasNextByte()) return buffer[ptr++];
      else return -1;
    }

    private static boolean isPrintableChar(int c) {
      return 33 <= c && c <= 126;
    }

    private void skipUnprintable() {
      while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++;
    }

    boolean hasNext() {
      skipUnprintable();
      return hasNextByte();
    }

    public String next() {
      if (!hasNext()) throw new NoSuchElementException();
      StringBuilder sb = new StringBuilder();
      int b = readByte();
      while (isPrintableChar(b)) {
        sb.appendCodePoint(b);
        b = readByte();
      }
      return sb.toString();
    }

    long nextLong() {
      if (!hasNext()) throw new NoSuchElementException();
      long n = 0;
      boolean minus = false;
      int b = readByte();
      if (b == '-') {
        minus = true;
        b = readByte();
      }
      if (b < '0' || '9' < b) {
        throw new NumberFormatException();
      }
      while (true) {
        if ('0' <= b && b <= '9') {
          n *= 10;
          n += b - '0';
        } else if (b == -1 || !isPrintableChar(b)) {
          return minus ? -n : n;
        } else {
          throw new NumberFormatException();
        }
        b = readByte();
      }
    }

    double nextDouble() {
      return Double.parseDouble(next());
    }

    double[] nextDoubleArray(int n) {
      double[] array = new double[n];
      for (int i = 0; i < n; i++) {
        array[i] = nextDouble();
      }
      return array;
    }

    double[][] nextDoubleMap(int n, int m) {
      double[][] map = new double[n][];
      for (int i = 0; i < n; i++) {
        map[i] = nextDoubleArray(m);
      }
      return map;
    }

    public int nextInt() {
      return (int) nextLong();
    }

    public int[] nextIntArray(int n) {
      int[] array = new int[n];
      for (int i = 0; i < n; i++) array[i] = nextInt();
      return array;
    }

    public long[] nextLongArray(int n) {
      long[] array = new long[n];
      for (int i = 0; i < n; i++) array[i] = nextLong();
      return array;
    }

    public String[] nextStringArray(int n) {
      String[] array = new String[n];
      for (int i = 0; i < n; i++) array[i] = next();
      return array;
    }

    public char[][] nextCharMap(int n) {
      char[][] array = new char[n][];
      for (int i = 0; i < n; i++) array[i] = next().toCharArray();
      return array;
    }

    public int[][] nextIntMap(int n, int m) {
      int[][] map = new int[n][];
      for (int i = 0; i < n; i++) {
        map[i] = nextIntArray(m);
      }
      return map;
    }
  }
}