TopCoder SRM 501 Div1 Medium: FoxAverageSequence

解法

topcoder.g.hatena.ne.jp

dp[pos][sum][last][flag]の配列を作ると時間もメモリも間に合わないのが、posはpos->pos+1の遷移しか起きないので、dp[sum][last][flag]とnext[sum][last][flag]だけを用意すれば良い。

コード

public class FoxAverageSequence {

	private final int MOD = 1000000007;

	public int theCount(int[] seq) {
		int N = seq.length;

		int[][][] dp = new int[1601][41][2];
		if (seq[0] == -1) {
			for (int num = 0; num <= 40; num++) {
				dp[num][num][0] = 1;
			}
		} else {
			dp[seq[0]][seq[0]][0] = 1;
		}

		for (int pos = 0; pos < N - 1; pos++) {
			int[][][] next = new int[1601][41][2];
			for (int sum = 0; sum <= 1600; sum++) {
				if ((pos + 1) * 40 < sum) {
					break;
				}
				for (int prev = 0; prev <= 40; prev++) {
					if (seq[pos + 1] != -1) {
						int last = seq[pos + 1];
						if (sum + last <= 1600 && last * (pos + 1) <= sum) {
							if (prev > last) {
								next[sum + last][last][1] += dp[sum][prev][0];
								next[sum + last][last][1] %= MOD;
							} else {
								next[sum + last][last][0] += dp[sum][prev][0];
								next[sum + last][last][0] %= MOD;
								next[sum + last][last][0] += dp[sum][prev][1];
								next[sum + last][last][0] %= MOD;
							}
						}
						continue;
					}

					for (int last = 0; last <= 40; last++) {
						if (dp[sum][prev][0] == 0) {
							break;
						}
						if (sum + last > 1600 || last * (pos + 1) > sum) {
							break;
						}
						if (prev > last) {
							next[sum + last][last][1] += dp[sum][prev][0];
							next[sum + last][last][1] %= MOD;
						} else {
							next[sum + last][last][0] += dp[sum][prev][0];
							next[sum + last][last][0] %= MOD;
						}
					}

					for (int last = 0; last <= 40; last++) {
						if (dp[sum][prev][1] == 0) {
							break;
						}
						if (sum + last > 1600 || last * (pos + 1) > sum) {
							break;
						}
						if (prev <= last) {
							next[sum + last][last][0] += dp[sum][prev][1];
							next[sum + last][last][0] %= MOD;
						}
					}
				}
			}
			dp = next;
		}

		long ans = 0;
		for (int sum = 0; sum <= 1600; sum++) {
			for (int last = 0; last <= 40; last++) {
				ans += dp[sum][last][0];
				ans %= MOD;
				ans += dp[sum][last][1];
				ans %= MOD;
			}
		}
		return (int) ans;
	}
}