天下一プログラマーコンテスト2016本戦 C - たんごたくさん
解法
文字列 S の位置 pos を見ている時、 [pos..pos+l) と一致する長さ l の単語を列挙する。l の最大値は 200 なので、列挙される単語もたかだか 200 個しかない。
Trie 木を作って単語の列挙を効率的に行い、遷移が最大 200 通りの dp をすれば良い。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.ArrayList; import java.util.NoSuchElementException; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class Main { private static class Task { class Trie { class Node { int[] to = new int[26]; int weight = 0; } ArrayList<Node> nodes = new ArrayList<>(); Trie() { nodes.add(new Node()); } void add(char[] s, int weight) { recursiveAdd(0, 0, s, weight); } void recursiveAdd(int n, int pos, char[] s, int weight) { Node node = nodes.get(n); if (pos == s.length) { node.weight = weight; return; } int c = s[pos] - 'a'; if (node.to[c] == 0) { Node next = new Node(); node.to[c] = nodes.size(); nodes.add(next); } recursiveAdd(node.to[c], pos + 1, s, weight); } ArrayList<int[]> getDestinations(int pos, char[] S) { ArrayList<int[]> list = new ArrayList<>(); recursiveGet(0, pos, 0, S, list); return list; } void recursiveGet(int n, int pos, int length, char[] S, ArrayList<int[]> list) { Node node = nodes.get(n); if (node.weight > 0) list.add(new int[]{length, node.weight}); if (pos == S.length) return; int c = S[pos] - 'a'; if (node.to[c] == 0) return; recursiveGet(node.to[c], pos + 1, length + 1, S, list); } } void solve(FastScanner in, PrintWriter out) { char[] S = in.next().toCharArray(); int N = S.length; int M = in.nextInt(); char[][] P = new char[M][]; for (int i = 0; i < M; i++) P[i] = in.next().toCharArray(); int[] W = new int[M]; for (int i = 0; i < M; i++) W[i] = in.nextInt(); Trie trie = new Trie(); for (int i = 0; i < M; i++) { trie.add(P[i], W[i]); } long[] dp = new long[N + 1]; for (int pos = 0; pos < N; pos++) { for (int[] dw : trie.getDestinations(pos, S)) { dp[pos + dw[0]] = Math.max(dp[pos + dw[0]], dp[pos] + dw[1]); } dp[pos + 1] = Math.max(dp[pos + 1], dp[pos]); } out.println(dp[N]); } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }