yukicoder No. 174: カードゲーム(Hard)

解法

#17295 No.174 カードゲーム(Hard) - yukicoder

dp[k]:= あるターンturnにおける、j番目に小さいカードより小さいカードがk枚ある確率、をメモする。

turnが遷移する時、

dp[0] *= (1 - P[player]);
dp[card - 1] += dp[card] * (P[player] + (card - 1) * Prest);
dp[card] *= (totalCardNum - card) * Prest;

というふうに遷移するので、これらを元にturnでplayerがxを出す確率を計算すれば良い。

コード

import java.util.Arrays;
import java.util.Scanner;

public class Main {

	public static void main(String[] args) {
		Scanner sc = new Scanner(System.in);
		int N = sc.nextInt();

		double[] P = new double[2];
		for (int i = 0; i < 2; i++) {
			P[i] = sc.nextDouble();
		}

		int[] a = new int[N];
		for (int i = 0; i < a.length; i++) {
			a[i] = sc.nextInt();
		}
		int[] b = new int[N];
		for (int i = 0; i < b.length; i++) {
			b[i] = sc.nextInt();
		}
		sc.close();

		Arrays.sort(a);
		Arrays.sort(b);

		// probability[player][turn][x]:=playerがturnにxのカードを出す確率
		double[][][] probability = new double[2][N][N];
		for (int player = 0; player < 2; player++) {
			for (int j = 0; j < N; j++) {
				// jが出される確率を求める

				// dp[k]:= turnにjより小さいカードがk枚残り、かつjが残っている確率
				double[] dp = new double[N];
				dp[j] = 1;
				for (int turn = 0; turn < N; turn++) {
					if (turn == N - 1) {
						// 最終ターン
						probability[player][N - 1][j] = dp[0];
						break;
					}

					// dp[0]:= jが残っていて、かつ、jが最小の確率
					probability[player][turn][j] += P[player] * dp[0];
					dp[0] *= (1 - P[player]);

					// turnでの残りカード枚数
					int totalCardNum = N - turn - 1;

					// 最小以外のカードが出る1枚あたりの確率
					double Prest = (1 - P[player]) / totalCardNum;

					for (int card = 1; card < N; card++) {
						dp[card - 1] += dp[card] * (P[player] + (card - 1) * Prest);

						probability[player][turn][j] += dp[card] * Prest;
						if (totalCardNum - card >= 0) {
							dp[card] *= (totalCardNum - card) * Prest;
						}
					}

				}
			}
		}

		double ans = 0;
		for (int turn = 0; turn < N; turn++) {
			for (int i = 0; i < N; i++) {
				for (int j = 0; j < N; j++) {
					if (a[i] > b[j]) {
						ans += probability[0][turn][i] * probability[1][turn][j] * (a[i] + b[j]);
					}
				}
			}

		}
		System.out.println(ans);

	}
}