AOJ 2829 Room Assignment
解法
JAG の wiki に分かりやすい解説があります。
2017/Practice/模擬国内予選/講評 - ACM-ICPC Japanese Alumni Group
場合分けとしては以下の 3 つです
- 長さ 3 以上の閉路が存在する場合は 0
- 閉路を潰した時に、全てのツリーの葉の数が 2 以下でなければ 0
- 長さ 3 以上の弱連結成分が存在しない時
- それ以外
コード
import java.util.Scanner import scala.collection.mutable.ArrayBuffer object Main { private val Mod = (1e9 + 7).toInt // 階乗を前計算しておく private val Fact = { var cur = 1L for (i <- 0 to 200000) yield { if (i == 0) { 1 } else { cur = (cur * i) % Mod cur } } } // 2 の累乗を前計算しておく private val Pow = { val pow = new Array[Int](200000) pow(0) = 1 for (i <- 1 until pow.length) pow(i) = pow(i - 1) * 2 % Mod pow } def main(args: Array[String]): Unit = { val in = new Scanner(System.in) while (true) { val n = in.nextInt() if (n == 0) { return } val a = for (_ <- 0 until n) yield { in.nextInt() - 1 } println(solve(a.toArray)) } } // 長さ 3 以上の閉路が存在しないことを確認する def cycleCheck(graph: IndexedSeq[Seq[Int]]): Boolean = { val cmp = StronglyConnectedComponents.decompose(graph) var map = Map[Int, Int]() cmp.foreach { c => map += (c -> (map.getOrElse(c, 0) + 1)) } map.values.forall(count => count <= 2) } // 全ての弱連結成分がパスが閉路を潰すとパスになることを確認する def pathCheck(graph: IndexedSeq[Seq[Int]]): Boolean = { val uf = new UnionFind(graph.size) for { i <- graph.indices j <- graph(i) } uf.unite(i, j) var map = Map[Int, Int]() for { i <- graph.indices if graph(i).isEmpty } { val count = map.getOrElse(uf.find(i), 0) map += (uf.find(i) -> (count + 1)) } map.values.forall(count => count <= 2) } // 要素数 2 の弱連結成分の数と、要素数 3 以上の弱連結成分の数を算出する def getSize(array: Array[Int]): (Int, Int) = { val N = array.length val uf = new UnionFind(N) for (i <- array.indices) uf.unite(i, array(i)) val sizes = for { i <- 0 until N if i == 0 || !uf.isSame(i, 0) } yield { if (i == 0) { uf.partialSizeOf(i) } else { val s = uf.partialSizeOf(i) uf.unite(i, 0) s } } var three = 0 var two = 0 sizes.foreach(size => if (size == 2) { two += 1 } else { three += 1 }) (two, three) } def solve(parent: Array[Int]): Int = { val n = parent.length val combination = new Combination(n + 1, Mod) val graph = parent.indices.map(i => new ArrayBuffer[Int]()) parent.indices.foreach(i => graph(parent(i)).append(i)) if (cycleCheck(graph) && pathCheck(graph)) { val (two, three) = getSize(parent) if (three == 0) { // 長さ 3 以上の弱連結成分が存在しないとき var cur: Long = ((two + 2) / 2) % Mod cur = cur * Pow(two) % Mod cur = cur * Fact(two) % Mod cur.toInt } else { var ans = 0L for (i <- 0 to math.min(two + three - 2, two)) { var cur = combination.get(two + three - i - 2, two - i).toLong cur = cur * Fact(two) % Mod cur = cur * Fact(three) % Mod cur = cur * Pow(two + three) % Mod cur = cur * (two + three - i - 1) % Mod cur = cur * ((i + 2) / 2) % Mod ans = (ans + cur) % Mod } for (i <- 0 to math.min(two + three - 1, two)) { var cur = combination.get(two + three - i - 1, two - i).toLong cur = cur * Fact(two) % Mod cur = cur * Fact(three) % Mod cur = cur * Pow(two + three) % Mod cur = cur * ((i + 2) / 2) % Mod ans = (ans + cur * 2) % Mod } ans.toInt } } else { 0 } } class UnionFind(n: Int) { private val parent = (0 until n).toArray private val sizes = Array.fill[Int](n)(1) private var _size = n def find(x: Int): Int = { if (x == parent(x)) { x } else { parent(x) = find(parent(x)) parent(x) } } def unite(a: Int, b: Int): Boolean = { val fa = find(a) val fb = find(b) if (fa == fb) { false } else { val (x, y) = if (sizes(fa) >= sizes(fb)) { (fa, fb) } else { (fb, fa) } parent(y) = x sizes(x) += sizes(y) sizes(y) = 0 _size -= 1 true } } def isSame(x: Int, y: Int): Boolean = find(x) == find(y) def partialSizeOf(x: Int): Int = sizes(find(x)) def size(): Int = _size } class Combination(max: Int, mod: Int) { private val inv = new Array[Long](max + 1) private val fact = new Array[Long](max + 1) private val invFact = new Array[Long](max + 1) inv(1) = 1 for (i <- 2 to max) inv(i) = inv(mod % i) * (mod - mod / i) % mod fact(0) = 1 invFact(0) = 1 for (i <- 1 to max) fact(i) = (fact(i - 1) * i) % mod for (i <- 1 to max) invFact(i) = (invFact(i - 1) * inv(i)) % mod /** * get nCm */ def get(n: Int, m: Int): Int = { if (n < m) { 0 } else fact(n) * invFact(m) % mod * invFact(n - m) % mod }.toInt } object StronglyConnectedComponents { def decompose(graph: IndexedSeq[Seq[Int]]): Array[Int] = { val vs = new ArrayBuffer[Int]() val V = graph.size val cmp = new Array[Int](V) val rg = graph.indices.map(_ => new ArrayBuffer[Int]()) for { from <- graph.indices to <- graph(from) } rg(to).append(from) var used = Array.fill[Boolean](V)(false) var stack = List[Int]() val added = Array.fill[Boolean](V)(false) for { i <- used.indices if !used(i) } { stack = i :: stack while (stack.nonEmpty) { val v = stack.head used(v) = true var pushed = false for { u <- graph(v).reverse if !used(u) } { stack = u :: stack pushed = true } if (!pushed) { stack = stack.tail if (!added(v)) { vs.append(v) added(v) = true } } } } used = Array.fill[Boolean](V)(false) var k = 0 for { i <- vs.reverse if !used(i) } { stack = i :: stack used(i) = true cmp(i) = k while (stack.nonEmpty) { val v = stack.head var pushed = false for { u <- rg(v) if !used(u) } { used(u) = true cmp(u) = k stack = u :: stack pushed = true } if (!pushed) { stack = stack.tail } } k += 1 } cmp } } }