Codeforces Round #302 Div2 C: Writing Code(動的計画法)

問題

codeforces.com

解法

動的計画法で dp[i][j][k]:=i人目まででj行書いてバグがk個ある組み合わせ を求める。

ある行fまで終わっている時にfからf+1、f+2、...への遷移を考えると以下のような4重ループになる。

// dp[i][j][k]:=i人目までで、j行書いて、バグがk個あるパターン
int[][][] dp = new int[N + 1][MAX_LINES + 1][MAX_BUGS + 1];
dp[0][0][0] = 1;
for (int person = 0; person < N; person++) {
	for (int finished = 0; finished <= MAX_LINES; finished++) {
		for (int lines = 0; lines + finished <= MAX_LINES; lines++) {
			if (a[person] * lines > MAX_BUGS) {
				break;
			}
			for (int bugs = 0; bugs <= MAX_BUGS; bugs++) {
				if (bugs + a[person] * lines > MAX_BUGS) {
					break;
				}
				dp[person + 1][finished + lines][bugs + a[person] * lines] += dp[person][finished][bugs];
				dp[person + 1][finished + lines][bugs + a[person] * lines] %= MOD;
			}
		}
	}
}

実はf->f+1->f+2となるので、f+2への遷移はf+1の時に考えればよく、ここでf->f+1以外の遷移は考える必要はない。そこでDPを書きなおして3重ループにする。

int[][][] dp = new int[N + 1][M + 1][B + 1];
dp[0][0][0] = 1;
for (int n = 1; n <= N; n++) {
	for (int m = 0; m <= M; m++) {
		for (int b = 0; b <= B; b++) {
			dp[n][m][b] += dp[n - 1][m][b];
			dp[n][m][b] %= MOD;
			if (b >= a[n] && m > 0) {
				dp[n][m][b] += dp[n][m - 1][b - a[n]];
				dp[n][m][b] %= MOD;
			}
		}
	}
}

これで時間内に計算が終わるが、このままだとMLEする。ここでdp[n][m][b] <- dp[n - 1][m][b]の遷移は配列内にただ残しておけば良いだけなので、nについて別個の配列を作る必要はない。

最終的に以下のように書き換えられる。

int[][] dp = new int[M + 1][B + 1];
dp[0][0] = 1;
for (int n = 1; n <= N; n++) {
	for (int m = 1; m <= M; m++) {
		for (int b = a[n]; b <= B; b++) {
			dp[m][b] += dp[m - 1][b - a[n]];
			dp[m][b] %= MOD;
		}
	}
}

コード

import java.util.Scanner;

public class Main {

	public void solve() {
		Scanner sc = new Scanner(System.in);
		int N = sc.nextInt();
		int M = sc.nextInt();
		int B = sc.nextInt();
		int MOD = sc.nextInt();
		int[] a = new int[N + 1];
		for (int i = 0; i < N; i++) {
			a[i + 1] = sc.nextInt();
		}
		sc.close();

		// dp[m][b]:= m行目まで終わってバグがb個になっている組み合わせ
		int[][] dp = new int[M + 1][B + 1];
		dp[0][0] = 1;
		for (int n = 1; n <= N; n++) {
			for (int m = 1; m <= M; m++) {
				for (int b = a[n]; b <= B; b++) {
					dp[m][b] += dp[m - 1][b - a[n]];
					dp[m][b] %= MOD;
				}
			}
		}

		long ans = 0;
		for (int bugs = 0; bugs <= B; bugs++) {
			ans += dp[M][bugs];
			ans %= MOD;
		}
		System.out.println(ans);
	}

	public static void main(String[] args) {
		new Main().solve();
	}

}