AtCoder Beginner Contest 003 D: AtCoder社の冬 (包除原理・bit管理・動的計画法)

解法

ステップとしては、

  • R*C内のX*Yの取りうるパターンを調べる。
  • X*Y内のD+Lの取りうるパターンを調べる。
  • D+L内のDとLのパターンを調べる。

の3つをこなす必要がある。

R*C内のX*Yの取りうるパターンは(R-X+1)*(C-Y+1)であり、D+L内のDとLのパターンは _{D+L} C_{D}である。

X*Y内のD+Lの取りうるパターンは、dp[i][j][bit]:= i番目の区画まで見た時にjこ埋まっているパターンで、一番上の行、一番下の行、一番左の列、一番右の列を埋めたかどうかbit管理する。

ただしJavaでdp[X*Y+1][D+L+1][(1<<<4)]を作ろうとするとMLEするので、適当に工夫する。

コード

import java.io.IOException;
import java.math.BigInteger;

public class Main {
	private static final int MOD = 1000000007;

	public static void main(String[] args) throws Exception {
		int R = nextInt();
		int C = nextInt();
		int X = nextInt();
		int Y = nextInt();
		int D = nextInt();
		int L = nextInt();

		long ans = (R - X + 1) * (C - Y + 1);
		ans %= MOD;
		int combi = combination(D + L, Math.min(D, L));
		System.gc();

		ans *= combi;
		ans %= MOD;

		int[][] dp = new int[D + L + 2][(1 << 4)];
		int[][] dp2 = new int[D + L + 2][(1 << 4)];
		dp[0][0] = 1;
		for (int i = 0; i < X * Y; i++) {
			for (int j = 0; j < (D + L + 1); j++) {
				for (int k = 0; k < (1 << 4); k++) {
					int bit = k;

					if (i < Y) {
						// 一番上の行にいる
						bit |= 1;
					}
					if (i % Y == 0) {
						// 一番左の列にいる
						bit |= 1 << 1;
					}
					if (X * Y - Y <= i) {
						// 一番下の列にいる
						bit |= 1 << 2;
					}
					if (i % Y == Y - 1) {
						bit |= 1 << 3;
					}
					dp2[j][k] += dp[j][k];
					dp2[j][k] %= MOD;

					dp2[j + 1][bit] += dp[j][k];
					dp2[j + 1][bit] %= MOD;

				}
			}
			for (int j = 0; j < dp2.length; j++) {
				for (int j2 = 0; j2 < dp2[j].length; j2++) {
					dp[j][j2] = dp2[j][j2];
					dp2[j][j2] = 0;
				}
			}
		}

		ans *= dp[D + L][(1 << 4) - 1];
		ans %= MOD;

		System.out.println(ans);

	}

	static int nextInt() {
		int c;
		try {
			c = System.in.read();
			while (c != '-' && (c < '0' || c > '9'))
				c = System.in.read();
			if (c == '-')
				return -nextInt();
			int res = 0;
			while (c >= '0' && c <= '9') {
				res = res * 10 + c - '0';
				c = System.in.read();
			}
			return res;
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		return -1;
	}

	static int combination(long m, long n) {
		if (m < n) {
			return 0;
		}
		// mCnをもとめる
		long ans = 1;
		long min = Math.min(m - n, n);
		if (min == 0) {
			return 1;
		}

		for (int i = 1; i <= min; i++) {
			ans *= m - i + 1;
			ans %= MOD;
			BigInteger modInv = new BigInteger(String.valueOf(i));
			modInv = modInv.modInverse(new BigInteger(String.valueOf(MOD)));
			ans *= modInv.longValue();
			ans %= MOD;
		}

		return (int) ans;

	}

}