CS Academy Round #12 Prefix Suffix Counting
問題
https://csacademy.com/contest/round-12/#task/prefix-suffix-counting/
巨大な整数 N と M が与えられる。M は K 桁である。この時、1 以上 N 以下の範囲で、上 K 桁と下 K 桁がともに M と一致する整数は何通りあるか。
解法
整数 N の桁数を L とする。
場合分けすると、以下の3通りが考えられる。
- 2K 桁未満のとき
- 2K 桁以上、L桁未満のとき
- L 桁のとき
2K 桁未満のとき、M が重なっているので、条件を満たす整数が作れるか考えれば良い。例えば M=232 のとき、 5 桁の整数 23232は条件を満たすが、条件を満たす 4 桁の整数は作れない。
2K 桁以上、L桁未満のとき、上下のK桁ずつを除いた桁はなんの数字でも良いので、l 桁の整数を作るとき、 通りの整数が作れる。
L 桁の整数を作るとき、N の上K桁がMより大きければ、「2K 桁以上、L桁未満のとき」と同じになる。上K桁がMより小さければ、L桁の整数は作れない。上K桁がMに等しいときは桁dp をする。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.NoSuchElementException; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class Main { private static class Task { final long MOD = (long) 1e9 + 7; long modPow(long x, long e) { long ret = 1; long cur = x; while (e > 0) { if ((e & 1) != 0) ret = (ret * cur) % MOD; cur = (cur * cur) % MOD; e /= 2; } return ret; } int digitDP(String A, int p) { int N = A.length(); long[][] dp = new long[N + 1][2]; dp[0][0] = 1; for (int i = 0; i < N; i++) { for (int j = 0; j < 2; j++) { int lim = j > 0 ? 9 : A.charAt(i) - '0'; for (int d = 0; d <= lim; d++) { int k = j > 0 || d < lim ? 1 : 0; dp[i + 1][k] += dp[i][j]; if (dp[i + 1][k] > MOD) dp[i + 1][k] -= MOD; } } } long ans = 0; for (int i = 0; i < 2; i++) { if (p < 0 && i == 0) continue; ans += dp[N][i]; } return (int) (ans % MOD); } void solve(FastScanner in, PrintWriter out) { String N = in.next(); String M = in.next(); int L = N.length(); int K = M.length(); if (L < K) { out.println(0); return; } long ans = 0; // 2*K 未満 RollingHash64 hash64 = new RollingHash64(M); for (int l = K; l < 2 * K; l++) { if (l > L) break; int lcp = hash64.lcp(0, l - K); if (K * 2 - lcp == l) { if (l < L) ans++; else { String check = M.substring(0, l - K) + M; if (N.equals(check)) ans++; } } } if (L < 2 * K) { out.println(ans); return; } if (2 * K == L && N.compareTo(M + M) >= 0) { out.println(ans + 1); return; } // L 未満 2*K 以上 for (int l = 2 * K; l < L; l++) { ans += modPow(10, l - 2 * K); if (ans > MOD) ans -= MOD; } String head = N.substring(0, K); int c = head.compareTo(M); if (c < 0) { out.println(ans); return; } if (c > 0) { ans += modPow(10, L - 2 * K); if (ans > MOD) ans -= MOD; out.println(ans); return; } String tail = N.substring(L - K); int p = tail.compareTo(M); ans += digitDP(N.substring(K, L - K), p); ans %= MOD; out.println(ans); } class RollingHash64 { private final long base = 1000000007; private long[] hash, pow; private int n; RollingHash64(String S) { n = S.length(); hash = new long[n + 1]; pow = new long[n + 1]; hash[0] = 0; pow[0] = 1; for (int j = 0; j < n; j++) { pow[j + 1] = pow[j] * base; hash[j + 1] = (hash[j] * base + S.charAt(j)); } } long getHash(int l, int r) { return (hash[r] - hash[l] * pow[r - l]); } boolean match(int l1, int r1, int l2, int r2) { return getHash(l1, r1) == getHash(l2, r2); } boolean match(int l1, int l2, int k) { return match(l1, l1 + k, l2, l2 + k); } int lcp(int i, int j) { int l = 0, r = Math.min(n - i, n - j) + 1; while (l + 1 < r) { int m = (l + r) / 2; if (match(i, j, m)) l = m; else r = m; } return l; } } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }