Codeforces Round #311 Div2 C: Arthur and Table

問題

codeforces.com

解法

テーブル内で最大になる高さを総当たりする。最大の長さLの脚がM本存在する時、Lより長い脚は全て切り、Lより短い脚については、合計M-1本以下になるようにエネルギーの小さい方から切る。

コード

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Scanner;

public class Main {

	public void solve() {
		Scanner scanner = new Scanner(System.in);
		int N = scanner.nextInt();
		int[] L = new int[N];
		for (int i = 0; i < N; i++) {
			L[i] = scanner.nextInt();
		}

		int[] D = new int[N];
		for (int i = 0; i < N; i++) {
			D[i] = scanner.nextInt();
		}
		scanner.close();

		HashSet<Integer> lengthSet = new HashSet<>();
		HashSet<Integer> energySet = new HashSet<>();
		HashMap<Integer, ArrayList<Integer>> legMap = new HashMap<>();
		HashMap<Integer, Integer> energyNum = new HashMap<>();
		for (int i = 0; i < N; i++) {
			if (!legMap.containsKey(L[i])) {
				legMap.put(L[i], new ArrayList<Integer>());
			}
			legMap.get(L[i]).add(D[i]);
			if (energyNum.containsKey(D[i])) {
				energyNum.put(D[i], energyNum.get(D[i]) + 1);
			} else {
				energyNum.put(D[i], 1);
			}
			lengthSet.add(L[i]);
			energySet.add(D[i]);
		}

		Integer[] length = lengthSet.toArray(new Integer[0]);
		Arrays.sort(length);
		Integer[] energies = energySet.toArray(new Integer[0]);
		Arrays.sort(energies);

		long min = Integer.MAX_VALUE;
		long removedEnergy = 0;
		for (int i = length.length - 1; i >= 0; i--) {
			long upEnergy = 0;
			for (Integer energy : legMap.get(length[i])) {
				upEnergy += energy;
				energyNum.put(energy, energyNum.get(energy) - 1);
			}
			N -= legMap.get(length[i]).size();

			long downEnergy = 0;
			int remove = Math.max(0, N - legMap.get(length[i]).size() + 1);// 取り除きたい数
			int cnt = 0;
			for (int d = 0; d < energies.length; d++) {
				int energy = energies[d];
				int rm = Math.min(remove - cnt, energyNum.get(energy));
				cnt += rm;
				downEnergy += rm * energy;
			}

			min = Math.min(min, downEnergy + removedEnergy);
			removedEnergy += upEnergy;
		}
		System.out.println(min);

	}

	public static void main(String[] args) {
		new Main().solve();
	}
}