Codeforces Round #307 Div2 D: GukiZ and Binary Operations

問題

codeforces.com

解法

kをビットに直して考える。各aのi番目のビットの組み合わせは 2^n通りあるが、kのi番目のビットが0になるためには、そのうち1が連続して存在するものを取り除かなければならない。
例えば、n=3のとき、以下の5通りの組み合わせが0になる。

{0,0,0} ならば (0&0)|(0&0)=0
{0,0,1} ならば (0&0)|(0&1)=0
{0,1,0} ならば (0&1)|(1&0)=0
{1,0,0} ならば (1&0)|(0&0)=0
{1,0,1} ならば (1&0)|(0&1)=0

それ以外の組み合わせは1になる。

{1,1,0} ならば (1&1)|(1&0)=1
{0,1,1} ならば (0&1)|(1&1)=1
{1,1,1} ならば (1&1)|(1&1)=1

この時0になる組み合わせの数はフィボナッチ数列のn項目になるので、1になる組み合わせの数は 2^n-fibonacci(n)通りとなる。

あとはkのl個のビットについて、0になるか1になるかで組み合わせをかけていけば良い。

MOD=1の時や、 2^l>k以外の時は、必ず0になるということにも注意する。

フィボナッチ数を効率よく求めるために行列を用いる。

 
  A = \left(
    \begin{array}{cc}
      1 & 1 \\
      1 & 0
    \end{array}
  \right)
とした時、k番目のフィボナッチ数 F_k
 
\left(
    \begin{array}{c}
      F_{k+1} \\
      F_{k}
    \end{array}
  \right)  
 = A^k \left(
    \begin{array}{c}
      1\\
      0
    \end{array}
  \right)
によって高速に計算できる。

行列の累乗を効率的に求める方法は以下の問題で出てきた。

kenkoooo.hatenablog.com

コード

import java.util.Scanner;

public class Main {

	private final int MAX_L = 64;
	private int MOD;

	public void solve() {
		Scanner sc = new Scanner(System.in);
		long N = sc.nextLong();
		long k = sc.nextLong();
		int l = sc.nextInt();
		MOD = sc.nextInt();
		sc.close();

		if (Math.pow(2, l) <= k) {
			System.out.println(0);
			return;
		}

		long fib = fibonacci(N);
		long twoex2 = powMod(2, N);
		long rest = (twoex2 - fib + MOD) % MOD;

		long ans = 1;
		for (int i = 0; i < l; i++) {
			if (((k >> i) & 1) == 0) {
				ans *= fib;
				ans %= MOD;
			} else {
				ans *= rest;
				ans %= MOD;
			}
		}
		System.out.println(ans % MOD);

	}

	private long fibonacci(long N) {
		long[][] fibA = new long[MAX_L + 1][];
		fibA[0] = new long[] { 1, 1, 1, 0 };
		for (int i = 1; i < fibA.length; i++) {
			fibA[i] = multiplyMatrix(fibA[i - 1], fibA[i - 1]);
		}

		long[] A = { 1, 0, 0, 1 };
		for (int i = 0; ((N + 1) >> i) > 0; i++) {
			if ((((N + 1) >> i) & 1) != 0) {
				A = multiplyMatrix(fibA[i], A);
			}
		}
		return A[0];
	}

	private long powMod(long num, long n) {
		if (n == 0) {
			return 1;
		}
		int M = Long.toBinaryString(n).length() + 1;
		long[] expo = new long[M];
		expo[0] = num;
		for (int i = 1; i < expo.length; i++) {
			expo[i] = (expo[i - 1] * expo[i - 1]) % MOD;
		}

		long ret = 1;
		for (int i = 0; i < expo.length; i++) {
			if (((n >> i) & 1) != 0) {
				ret *= expo[i];
				ret %= MOD;
			}
		}
		return ret;
	}

	private long[] multiplyMatrix(long[] a, long[] b) {
		long[] ret = new long[4];

		ret[0] = (a[0] * b[0]) % MOD + (a[1] * b[2]) % MOD;
		ret[1] = (a[0] * b[1]) % MOD + (a[1] * b[3]) % MOD;
		ret[2] = (b[0] * a[2]) % MOD + (b[2] * a[3]) % MOD;
		ret[3] = (b[1] * a[2]) % MOD + (b[3] * a[3]) % MOD;
		for (int i = 0; i < 4; i++) {
			ret[i] %= MOD;
		}
		return ret;
	}

	public static void main(String[] args) {
		new Main().solve();
	}

}