AtCoder Typical Contest 001 C: 高速フーリエ変換

コード

import java.util.Scanner;

public class Main {

	public void solve() {
		Scanner sc = new Scanner(System.in);
		int N = sc.nextInt();
		int[] a = new int[N];
		int[] b = new int[N];
		for (int i = 0; i < N; i++) {
			a[i] = sc.nextInt();
			b[i] = sc.nextInt();
		}
		sc.close();

		long[] c = convolute(a, b);
		for (int i = 0; i < N * 2; i++) {
			System.out.println(c[i]);
		}
	}

	private long[] convolute(int[] a, int[] b) {
		int N = a.length;
		int M = Integer.highestOneBit(Math.max(a.length, b.length) + 1) << 2;

		Complex[] ca = new Complex[M];
		Complex[] cb = new Complex[M];
		for (int i = 0; i < M; i++) {
			ca[i] = new Complex(0, 0);
			cb[i] = new Complex(0, 0);
		}

		for (int i = 0; i < N; i++) {
			ca[i + 1].real = a[i];
			cb[i + 1].real = b[i];
		}

		Complex[] cret = multipy(M, ca, cb);
		long[] ret = new long[M - 1];
		for (int i = 0; i < ret.length; i++) {
			ret[i] = Math.round(cret[i + 1].real);
		}
		return ret;
	}

	public Complex[] multipy(int n, Complex[] g, Complex[] h) {
		Complex[] gg = FFT(g, n, false);
		Complex[] hh = FFT(h, n, false);

		Complex[] ff = new Complex[n];
		for (int i = 0; i < n; i++) {
			ff[i] = gg[i].multiply(hh[i]);
		}

		ff = FFT(ff, n, true);
		for (int i = 0; i < n; i++) {
			ff[i] = ff[i].multiply(1.0 / n);
		}
		return ff;
	}

	private Complex[] FFT(Complex[] f, int n, boolean inv) {
		if (n == 1) {
			return new Complex[] { f[0] };
		}

		Complex[] f0 = new Complex[n / 2];
		Complex[] f1 = new Complex[n / 2];
		for (int i = 0; i < n / 2; i++) {
			f0[i] = f[2 * i + 0];
			f1[i] = f[2 * i + 1];
		}

		f0 = FFT(f0, n / 2, inv);
		f1 = FFT(f1, n / 2, inv);

		Complex zeta = new Complex(Math.cos(2 * Math.PI / n), Math.sin(2 * Math.PI / n));
		Complex pow_zeta = new Complex(1.0, 0);

		Complex[] ret = new Complex[n];
		for (int i = 0; i < n; i++) {
			if (inv) {
				ret[i] = f0[i % (n / 2)].add(f1[i % (n / 2)].divide(pow_zeta));
			} else {
				ret[i] = f0[i % (n / 2)].add(f1[i % (n / 2)].multiply(pow_zeta));
			}

			pow_zeta = pow_zeta.multiply(zeta);
		}

		return ret;
	}

	public static void main(String[] args) {
		new Main().solve();
	}

}

class Complex {
	double real, img;

	public Complex(double real, double img) {
		this.real = real;
		this.img = img;
	}

	public Complex add(Complex c) {
		return new Complex(this.real + c.real, this.img + c.img);
	}

	public Complex multiply(double d) {
		return new Complex(this.real * d, this.img * d);
	}

	public Complex multiply(Complex c) {
		double real = this.real * c.real - this.img * c.img;
		double img = this.real * c.img + this.img * c.real;
		return new Complex(real, img);
	}

	public Complex divide(Complex c) {
		double real = this.real * c.real + this.img * c.img;
		double img = -this.real * c.img + this.img * c.real;
		Complex divide = new Complex(real, img);
		return divide.multiply(1.0 / (c.real * c.real + c.img * c.img));
	}

	public String toString() {
		return this.real + " + " + this.img + "i";
	}
}