リクルートコミュニケーションズ (RCO) におけるプログラミングコンテストの活用について
この記事は Recruit Engineers Advent Calendar 2016 の3日目の記事です。
RCO プロコン部
RCO アドテク部には、プロコン部、Kaggle 部、SET (Sushi is Everything) などのサークルがあります。プロコン部は気ままにプログラミングコンテスト(プロコン)に参加するサークル、Kaggle 部は気ままに機械学習コンペの Kaggle に参加するサークルで、SET は寿司を食べているようです。
そもそもプロコンとは
「プログラミングのコンテスト」というと範囲が広くなりますが、ここでは、「与えられたプログラミングの問題を、制限時間内に早く正確に解くコンテスト」のことを指します。2時間程度で4問ほど出題されることが多いです。様々なコンテストサービスがあり、以下のサービスを利用することが多いです。
どんな人が参加しているの?
プロコン部と名乗って入るものの、何か決まったメンバーのサークルがあるわけではなく、コンテストになると集まってくる感じです。
よくいるメンバーとしては uwi さん、shiratty8 さん、KenjiH さん、dpforest さん、iehn さん、kenkoooo などがいますが、それ以外の人たちもたまに来たりします。機械学習エンジニアよりもアプリケーション開発エンジニアの方がよく参加している印象です。
日々の活動
プロコン部の日々の活動としては以下のようなものが挙げられます。
- 社内チャットの「プロコン部ルーム」に溜まる
- コンテストに参加する
- 社内 wiki に自分が解いた問題の解説記事を投稿する
- アルゴリズムイントロダクションを輪読する
これらの活動について書いていきます。
社内チャットの「プロコン部ルーム」に溜まる
参加したコンテストの問題の感想や、好きなデータ構造について話したりします。分からない問題などについても聞けたりします。
コンテストに参加する
業務時間中にコンテストがある場合、勉強会としてみんなで集まってコンテストに参加します。TopCoder SRM や HackerRank HourRank などは1〜2時間程度で終わるので、その後1時間ほど問題について議論する時間を設けます。その日に競技プログラミングを始めた人から世界トップクラスの人まで様々なレベルの人が集まるので、解法を議論したり、時間内に通らなかったコードをレビューしたり、別解を検討したりします。
プロコン部ご飯 pic.twitter.com/Ag7azgiAz9
— 宇宙ツイッタラーX (@kenkoooo) October 13, 2016
3時間の長丁場になるので、会社から弁当が支給されます。予算は1人1500円までで、ネット注文しておくこともあれば、八重洲が近いので東京駅まで駅弁を買い出しに行くこともあります。
このように、弁当で体力をつなぎつつ、コンテストを題材にした内容の濃い勉強会を行っています。
社内 wiki に自分が解いた問題の解説記事を投稿する
自分が解いた問題についてチャットで感想を言うこともありますが、内容を残しておきたい問題などについては社内の wiki に投稿することもあります。
アルゴリズムイントロダクションを輪読する
アルゴリズムイントロダクションを読む会が行われています*2。この本は、アルゴリズムの正当性や計算量の上界の証明が丁寧に書かれている教科書で、「プログラミングコンテストチャレンジブック(蟻本)」などと違ってプロコンに直接役立つ本ではありませんが、計算量の議論などはプロコン慣れしている方がやはり分かりやすく、参加者はプロコン勢がほとんどです。
進め方としては、毎週1人担当者を決め、発表者が好きな章を選んで発表しています。最近は、重たい章だけが残ってきたので、何回かに分けて発表します。
※僕が最大フローの章を担当した時の資料
業務の役に立つのか
このように RCO ではプログラミングコンテストを仕事に取り入れていますが、業務では役に立つのでしょうか。
実装が速くなる
個人的な感覚ですが、競プロ勢は実際の業務での実装スピードがかなり速いように感じます。もちろん設計などはまた別の問題で、素早く実装したコードが必ずしも優れた実装であるとは限りませんが、少なくとも自分の知る範囲では、他の人が5営業日かかる実装をプロコン勢は1日や2日で終わらせてしまうような気がします。
これにはいくつか要因が考えられますが、以下のような感じでしょうか。
- プロコンで何度もバグを埋め込んだおかげで、自分がバグを埋めやすい箇所などを把握しているため、バグに振り回される時間が短い。
- プロコンで何度も調べたおかげで、標準ライブラリや言語仕様の知識が蓄積していて、調べる回数が少なくて済む。
- プロコンでできるだけ共通化して記述量を減らすコツが蓄積していて、記述量が少なくて済む。
- プロコン用に設定やスニペットをゴリゴリに積んだ IDE を使っているのでスピードが出る。*3
プロコン特有のものではなく、単に四六時中プログラミングしていると身につくようなことばかりですね。ただ、プロコンでは特にまっさらな状態から実装するため、できるだけ速く少なく正確に実装する訓練が詰めるのかもしれません。
より強くなってくると、実装を始める前に頭のなかでコーディングを終えてしまうらしいので、さらに速くなるのかもしれません。早くその境地に達したいですね。
計算量の大まかな感覚がつかめる
プロコンで出題される問題の多くは「データのサイズが小さければ簡単に解ける問題」であることが多いです。
例えば、次のスライドに出てくる問題を見てみましょう。
スライドの中で、この問題は N=10 程度まであれば簡単に解けるが、N=1000などは発想を変えなければならないという話をしています。
実際の業務の中で、例えば最大フローや動的計画法などのアルゴリズムを駆使して計算を高速化する場面は多くはありませんが、計算量の感覚というのは重要だと思います。特にRCOアドテク部ではリアルタイム処理を行うことがあり、そういった分野では計算量の感覚は不可欠です。
コンピュータがざっくり1秒間に10億回くらい計算できると考えて、サービスに対して秒間10万のクエリが来るとき、1クエリあたり1万回くらいの計算ならできそうということが分かります。3億件のユーザーデータを抱えているので、ユーザー数を N とするとユーザー検索は O(log N) で終わらせる必要がありますね。
この「コンピュータが1秒間に計算できる回数」、「各操作にかかる計算回数」の感覚は非常に重要で、これがチーム内で共有されていることで「この関数は毎秒10万回呼ばれるから中の操作はこのくらいにしないと」「この関数は毎秒1000回しか呼ばれないから、計算量は少し悪くても可読性を優先しよう」などを考えることができます。
おわりに
プロコンは単純にゲームとして非常に面白いと思いますが、同時にプログラマの実装力を高めるかもしれません。色んな会社にプロコンを業務として取り入れる動きが広まってほしいですね。
いつの日か、ACM-ICPC のように会社対抗プログラミングコンテストが開催されるのを楽しみにしています。
C++ (gcc) で 128 ビット整数を使う
128 ビット整数を使いたいけど boost が入ってないので多倍長整数が使いづらいという環境で __int128 なるものの存在を知ったのでメモ。
自前でパースとか出力とかを用意するのでやや面倒だが、多倍長整数ライブラリを持っていなくても安心。
#include <bits/stdc++.h> using namespace std; std::ostream &operator<<(std::ostream &dest, __int128_t value) { std::ostream::sentry s(dest); if (s) { __uint128_t tmp = value < 0 ? -value : value; char buffer[128]; char *d = std::end(buffer); do { --d; *d = "0123456789"[tmp % 10]; tmp /= 10; } while (tmp != 0); if (value < 0) { --d; *d = '-'; } int len = std::end(buffer) - d; if (dest.rdbuf()->sputn(d, len) != len) { dest.setstate(std::ios_base::badbit); } } return dest; } __int128 parse(string &s) { __int128 ret = 0; for (int i = 0; i < s.length(); i++) if ('0' <= s[i] && s[i] <= '9') ret = 10 * ret + s[i] - '0'; return ret; } int main() { string s = "187821878218782187821878218782"; __int128 x = parse(s); x *= 2; cout << x << endl; }
Codeforces Round #381 (Div. 2) D. Alyona and a tree
問題
http://codeforces.com/contest/740/problem/D
ツリーがあり、頂点1が根です。各頂点 u には数字が書き込まれています。各頂点 v について、v の部分木に含まれ、かつ を満たす頂点uの数を出力してください。
解法
DFS で辺の重みを BIT に持ちながら潜っていき、各頂点 u についてどのくらい上の頂点まで条件を満たせるか計算します。後はその数を持ちながら登っていき、いもす法の要領で無効になった瞬間に引くようにすれば良いです。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.*; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class D { private static class Task { class Edge { int to; long cost; Edge(int to, long cost) { this.to = to; this.cost = cost; } } ArrayList<Edge>[] graph; long[] a; FenwickTree bit; long[] ans; long[] imos; int dfs(int v, int d) { int x = 0; for (Edge edge : graph[v]) { int next = d + 1; bit.set(d, edge.cost); int ok = next; int ng = -1; while (ok - ng > 1) { int m = (ok + ng) / 2; long sum = bit.sum(m, next);//[m, next] if (sum <= a[edge.to]) { ok = m; } else { ng = m; } } int n = next - ok; int target = next - n; if (target <= next) imos[target]++; int y = dfs(edge.to, next); y -= imos[next]; x += y; imos[next] = 0; } ans[v] = x; return x + 1; } void solve(FastScanner in, PrintWriter out) throws Exception { int N = in.nextInt(); a = new long[N]; for (int i = 0; i < N; i++) { a[i] = in.nextInt(); } graph = new ArrayList[N]; for (int i = 0; i < N; i++) { graph[i] = new ArrayList<>(); } for (int i = 0; i < N - 1; i++) { int p = in.nextInt() - 1; int to = i + 1; long w = in.nextInt(); graph[p].add(new Edge(to, w)); } bit = new FenwickTree(N + 10); ans = new long[N]; imos = new long[N + 10]; dfs(0, 0); for (long a : ans) out.print(a + " "); out.println(); } class FenwickTree { int N; long[] data; FenwickTree(int N) { this.N = N + 1; data = new long[N + 1]; } void add(int k, long val) { for (int x = k; x < N; x |= x + 1) { data[x] += val; } } void set(int k, long val) { long now = get(k); long dv = val - now; add(k, dv); } // [0, k) long sum(int k) { if (k >= N) k = N - 1; long ret = 0; for (int x = k - 1; x >= 0; x = (x & (x + 1)) - 1) { ret += data[x]; } return ret; } // [l, r) long sum(int l, int r) { if (l >= r) return 0; return sum(r) - sum(l); } long get(int k) { assert (0 <= k && k < N); return sum(k + 1) - sum(k); } int getAsSetOf(int w) { w++; if (w <= 0) return -1; int x = 0; int k = 1; while (k * 2 <= N) k *= 2; for (; k > 0; k /= 2) { if (x + k <= N && data[x + k - 1] < w) { w -= data[x + k - 1]; x += k; } } return x; } } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) throws Exception { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }
Codeforces Round #381 (Div. 2) E. Alyona and towers
問題
http://codeforces.com/contest/740/problem/E
N 個の非負の数列{a}があります。クエリが M 個あり、各クエリでは、l 番目から r 番目までの数字にそれぞれ v を足し、数列に含まれる最大の hill の幅を答えてください。
ただし、hill とはを満たす部分列です。
解法
hill の左側、すなわち単調増加する部分列の始点と終点をmapで管理し、各クエリごとに、部分列を分割したり連結したりします。右側についても同様です。
hill が組み変わるごとにセグ木に保存しておき、最大値を出力します。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.*; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class E { private static class Task { void solve(FastScanner in, PrintWriter out) throws Exception { int N = in.nextInt(); long[] a = in.nextLongArray(N); TreeMap<Integer, Integer> left = new TreeMap<>(); int from = 0; for (int i = 0; i < N; i++) { if (i == N - 1 || a[i] >= a[i + 1]) { left.put(from, i); from = i + 1; } } TreeMap<Integer, Integer> right = new TreeMap<>(); from = N - 1; for (int i = N - 2; i >= -1; i--) { if (i < 0 || a[i] <= a[i + 1]) { right.put(i + 1, from); from = i; } } RMQ rmq = new RMQ(N); for (Map.Entry<Integer, Integer> entry : left.entrySet()) { from = entry.getKey(); int to = entry.getValue(); if (!right.containsKey(to)) continue; int to2 = right.get(to); int width = to2 - from + 1; rmq.update(to, width); } long[] d = new long[N - 1]; for (int i = 0; i < N - 1; i++) { d[i] = a[i + 1] - a[i]; } int M = in.nextInt(); TreeSet<Integer> tmp = new TreeSet<>(); for (int i = 0; i < M; i++) { int l = in.nextInt() - 1; int r = in.nextInt() - 1; long v = in.nextInt(); if (l > 0) { long prev = d[l - 1]; d[l - 1] += v; if (prev <= 0 && d[l - 1] > 0) { // left merge from = left.floorKey(l - 1); int to = left.get(l); left.remove(l); rmq.update(l, 0); left.put(from, to); tmp.add(from); } if (prev < 0 && d[l - 1] >= 0) { // right cut Map.Entry<Integer, Integer> entry = right.floorEntry(l - 1); from = entry.getKey(); int to = entry.getValue(); right.put(from, l - 1); right.put(l, to); rmq.update(from, 0); tmp.add(left.floorKey(from)); tmp.add(left.floorKey(l)); } } if (r < N - 1) { long prev = d[r]; d[r] -= v; if (prev > 0 && d[r] <= 0) { // left cut Map.Entry<Integer, Integer> entry = left.floorEntry(r); from = entry.getKey(); int to = entry.getValue(); left.put(from, r); left.put(r + 1, to); rmq.update(to, 0); tmp.add(from); tmp.add(r + 1); } if (prev >= 0 && d[r] < 0) { // right merge from = right.floorKey(r); int to = right.get(r + 1); right.remove(r + 1); right.put(from, to); rmq.update(from, 0); tmp.add(left.floorKey(from)); } } for (int f : tmp) { int k = left.get(f); Integer to = right.get(k); if (to==null) continue; rmq.update(k, to - f + 1); } tmp.clear(); out.println(rmq.query(0, N + 1)); } } class RMQ { private int N; private long[] seg; RMQ(int M) { N = Integer.highestOneBit(M) * 2; seg = new long[N * 2]; } public void update(int k, long value) { seg[k += N - 1] = value; while (k > 0) { k = (k - 1) / 2; seg[k] = Math.max(seg[k * 2 + 1], seg[k * 2 + 2]); } } //[a, b) long query(int a, int b) { return query(a, b, 0, 0, N); } long query(int a, int b, int k, int l, int r) { if (r <= a || b <= l) return 0; if (a <= l && r <= b) return seg[k]; long x = query(a, b, k * 2 + 1, l, (l + r) / 2); long y = query(a, b, k * 2 + 2, (l + r) / 2, r); return Math.max(x, y); } } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) throws Exception { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }
AtCoder Regular Contest 033 C - データ構造
解法
Treap を実装した。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.NoSuchElementException; import java.util.Random; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class Main { private static class Task { void solve(FastScanner in, PrintWriter out) throws Exception { Treap treap = new Treap(); int Q = in.nextInt(); for (int i = 0; i < Q; i++) { int t = in.nextInt(); int x = in.nextInt(); if (t == 1) { treap.insert(x); } else { long key = treap.rank(x - 1); out.println(key); treap.erase(key); } } } class Treap { Random random = new Random(); class Node { Node left, right; long key; int priority; int count; Node(long key) { this.key = key; priority = random.nextInt(); left = null; right = null; count = 1; } } Node root = null; int count(Node n) { return n == null ? 0 : n.count; } void update(Node c) { c.count = 1 + count(c.left) + count(c.right); } Node leftRotate(Node c) { Node r = c.right; c.right = r.left; r.left = c; update(c); return r; } Node rightRotate(Node c) { Node l = c.left; c.left = l.right; l.right = c; update(c); return l; } Node insert(Node c, long key) { if (c == null) return new Node(key); if (c.key < key) { c.right = insert(c.right, key); if (c.right.priority < c.priority) c = leftRotate(c); } else { c.left = insert(c.left, key); if (c.left.priority < c.priority) c = rightRotate(c); } update(c); return c; } Node getMinNode(Node c) { while (c.left != null) c = c.left; return c; } Node erase(Node c, long key) { if (key == c.key) { if (c.left == null) return c.right; if (c.right == null) return c.left; Node minNode = getMinNode(c.right); c.key = minNode.key; c.right = erase(c.right, minNode.key); } else { if (c.key < key) c.right = erase(c.right, key); else c.left = erase(c.left, key); } update(c); return c; } void insert(long key) { if (contains(key)) return; root = insert(root, key); } void erase(long key) { root = erase(root, key); } int size() { return count(root); } boolean contains(long key) { return find(root, key) >= 0; } int find(long key) { return find(root, key); } int find(Node c, long key) { if (c == null) return -1; if (c.key == key) return count(c.left); if (key < c.key) return find(c.left, key); int pos = find(c.right, key); if (pos < 0) return pos; return count(c.left) + 1 + pos; } Node rank(Node c, int rank) { while (c != null) { int leftCount = count(c.left); if (leftCount == rank) return c; if (leftCount < rank) { rank -= leftCount + 1; c = c.right; } else { c = c.left; } } return c; } long rank(int rank) { return rank(root, rank).key; } } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) throws Exception { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }
RUPC 2016 Day2 L: String in String
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.*; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class Main { private static class Task { void solve(FastScanner in, PrintWriter out) throws Exception { String S = in.next(); int N = S.length(); SuffixArray sa = new SuffixArray(S); int Q = in.nextInt(); int[][] query = new int[Q][2]; ArrayList<int[]> queue = new ArrayList<>(); for (int q = 0; q < Q; q++) { int l = in.nextInt(); int r = in.nextInt(); String M = in.next(); r -= M.length() - 1; if (l > r) continue; int low = sa.lowerBound(M); int up = sa.upperBound(M); if (low == -1) continue; query[q][0] = low; query[q][1] = up; queue.add(new int[]{l, 0, q}); queue.add(new int[]{r, 2, q}); } for (int i = 0; i <= N; i++) { queue.add(new int[]{sa.sa[i], 1, i}); } Collections.sort(queue, new Comparator<int[]>() { @Override public int compare(int[] o1, int[] o2) { if (o1[0] != o2[0]) return Integer.compare(o1[0], o2[0]); if (o1[1] != o2[1]) return Integer.compare(o1[1], o2[1]); return Integer.compare(o1[2], o2[2]); } }); int[] ans = new int[Q]; FenwickTree bit = new FenwickTree(N + 1); for (int[] event : queue) { int kind = event[1]; if (kind == 1) { int saPos = event[2]; bit.add(saPos, 1); } else if (kind == 0) { int q = event[2]; ans[q] -= bit.sum(query[q][0], query[q][1]); } else { int q = event[2]; ans[q] += bit.sum(query[q][0], query[q][1]); } } for (int a : ans) out.println(a); } class FenwickTree { int N; long[] data; FenwickTree(int N) { this.N = N + 1; data = new long[N + 1]; } void add(int k, long val) { for (int x = k; x < N; x |= x + 1) { data[x] += val; } } // [0, k) long sum(int k) { if (k >= N) k = N - 1; long ret = 0; for (int x = k - 1; x >= 0; x = (x & (x + 1)) - 1) { ret += data[x]; } return ret; } // [l, r) long sum(int l, int r) { return sum(r) - sum(l); } long get(int k) { assert (0 <= k && k < N); return sum(k + 1) - sum(k); } int getAsSetOf(int w) { w++; if (w <= 0) return -1; int x = 0; int k = 1; while (k * 2 <= N) k *= 2; for (; k > 0; k /= 2) { if (x + k <= N && data[x + k - 1] < w) { w -= data[x + k - 1]; x += k; } } return x; } } class SuffixArray { String S; int N, K; Integer[] sa; int[] rank; public SuffixArray(String S) { this.S = S; build(); } private void build() { N = S.length(); rank = new int[N + 1]; sa = new Integer[N + 1]; for (int i = 0; i <= N; i++) { sa[i] = i; rank[i] = i < N ? S.charAt(i) : -1; } int[] tmp = new int[N + 1]; for (int _k = 1; _k <= N; _k *= 2) { final int k = _k; Arrays.sort(sa, new Comparator<Integer>() { @Override public int compare(Integer i, Integer j) { return compareNode(i, j, k); } }); tmp[sa[0]] = 0; for (int i = 1; i <= N; i++) { tmp[sa[i]] = tmp[sa[i - 1]] + ((compareNode(sa[i - 1], sa[i], k) < 0) ? 1 : 0); } for (int i = 0; i <= N; i++) { rank[i] = tmp[i]; } } } private int compareNode(int i, int j, int k) { if (rank[i] != rank[j]) { return rank[i] - rank[j]; } else { int ri = i + k <= N ? rank[i + k] : -1; int rj = j + k <= N ? rank[j + k] : -1; return ri - rj; } } public int lowerBound(String t) { int a = -1, b = S.length(); while (b - a > 1) { int c = (a + b) / 2; String sub = S.substring(sa[c], Math.min(t.length() + sa[c], S.length())); if (sub.compareTo(t) < 0) a = c; else b = c; } String sub = S.substring(sa[b], Math.min(t.length() + sa[b], S.length())); return sub.compareTo(t) == 0 ? b : -1; } public int upperBound(String t) { int a = -1, b = S.length() + 1; while (b - a > 1) { int c = (a + b) / 2; String sub = S.substring(sa[c], Math.min(t.length() + sa[c], S.length())); if (sub.compareTo(t) <= 0) a = c; else b = c; } return b; } } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) throws Exception { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }
AOJ 2644 Longest Match
解法
SuffixArray を作ると、Sに含まれる a の SuffixArray 上での lower_bound と upper_bound を求めることができます。その中で最も前のものを求めたいですが、これはRMQで持っておけば良いです。
b については上記の問題をすべて反転させた文字列で行えば同じ問題になります。
コード
import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.util.*; /* _ooOoo_ o8888888o 88" . "88 (| -_- |) O\ = /O ____/`---'\____ .' \\| |// `. / \\||| : |||// \ / _||||| -:- |||||- \ | | \\\ - /// | | | \_| ''\---/'' | | \ .-\__ `-` ___/-. / ___`. .' /--.--\ `. . __ ."" '< `.___\_<|>_/___.' >'"". | | : `- \`.;`\ _ /`;.`/ - ` : | | \ \ `-. \_ __\ /__ _/ .-` / / ======`-.____`-.___\_____/___.-`____.-'====== `=---=' ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ pass System Test! */ public class Main { private static class Task { void solve(FastScanner in, PrintWriter out) throws Exception { String S = in.next(); SuffixArray sa = new SuffixArray(S); SuffixArray reverseSA = new SuffixArray(new StringBuilder(S).reverse().toString()); int N = S.length(); RMQ rmq = new RMQ(N + 1); RMQ reverseRMQ = new RMQ(N); for (int i = 0; i <= N; i++) { rmq.update(i, sa.sa[i]); reverseRMQ.update(i, reverseSA.sa[i]); } int Q = in.nextInt(); for (int q = 0; q < Q; q++) { String x = in.next(); String y = new StringBuilder(in.next()).reverse().toString(); int low = sa.lowerBound(x); if (low == -1) { out.println(0); continue; } int up = sa.upperBound(x); int reverseLow = reverseSA.lowerBound(y); if (reverseLow == -1) { out.println(0); continue; } int reverseUp = reverseSA.upperBound(y); long s = rmq.query(low, up); long t = N - reverseRMQ.query(reverseLow, reverseUp); if (s + x.length() <= t && s <= t - y.length()) { out.println(t - s); } else { out.println(0); } } } class RMQ { private long INF = (long) 1e18; private int N; private long[] seg; RMQ(long[] array) { N = Integer.highestOneBit(array.length) * 2; seg = new long[N * 2]; Arrays.fill(seg, INF); for (int i = 0; i < array.length; i++) update(i, array[i]); } RMQ(int M) { N = Integer.highestOneBit(M) * 2; seg = new long[N * 2]; Arrays.fill(seg, INF); } void update(int k, long value) { seg[k += N - 1] = value; while (k > 0) { k = (k - 1) / 2; seg[k] = Math.min(seg[k * 2 + 1], seg[k * 2 + 2]); } } //[a, b) long query(int a, int b) { return query(a, b, 0, 0, N); } long query(int a, int b, int k, int l, int r) { if (r <= a || b <= l) return INF; if (a <= l && r <= b) return seg[k]; long x = query(a, b, k * 2 + 1, l, (l + r) / 2); long y = query(a, b, k * 2 + 2, (l + r) / 2, r); return Math.min(x, y); } } class SuffixArray { String S; int N, K; Integer[] sa; int[] rank; public SuffixArray(String S) { this.S = S; build(); } private void build() { N = S.length(); rank = new int[N + 1]; sa = new Integer[N + 1]; for (int i = 0; i <= N; i++) { sa[i] = i; rank[i] = i < N ? S.charAt(i) : -1; } int[] tmp = new int[N + 1]; for (int _k = 1; _k <= N; _k *= 2) { final int k = _k; Arrays.sort(sa, new Comparator<Integer>() { @Override public int compare(Integer i, Integer j) { return compareNode(i, j, k); } }); tmp[sa[0]] = 0; for (int i = 1; i <= N; i++) { tmp[sa[i]] = tmp[sa[i - 1]] + ((compareNode(sa[i - 1], sa[i], k) < 0) ? 1 : 0); } for (int i = 0; i <= N; i++) { rank[i] = tmp[i]; } } } private int compareNode(int i, int j, int k) { if (rank[i] != rank[j]) { return rank[i] - rank[j]; } else { int ri = i + k <= N ? rank[i + k] : -1; int rj = j + k <= N ? rank[j + k] : -1; return ri - rj; } } public int lowerBound(String t) { int a = -1, b = S.length(); while (b - a > 1) { int c = (a + b) / 2; String sub = S.substring(sa[c], Math.min(t.length() + sa[c], S.length())); if (sub.compareTo(t) < 0) a = c; else b = c; } String sub = S.substring(sa[b], Math.min(t.length() + sa[b], S.length())); return sub.compareTo(t) == 0 ? b : -1; } public int upperBound(String t) { int a = -1, b = S.length() + 1; while (b - a > 1) { int c = (a + b) / 2; String sub = S.substring(sa[c], Math.min(t.length() + sa[c], S.length())); if (sub.compareTo(t) <= 0) a = c; else b = c; } return b; } } } /** * ここから下はテンプレートです。 */ public static void main(String[] args) throws Exception { OutputStream outputStream = System.out; FastScanner in = new FastScanner(); PrintWriter out = new PrintWriter(outputStream); Task solver = new Task(); solver.solve(in, out); out.close(); } private static class FastScanner { private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int bufferLength = 0; private boolean hasNextByte() { if (ptr < bufferLength) { return true; } else { ptr = 0; try { bufferLength = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (bufferLength <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } private void skipUnprintable() { while (hasNextByte() && !isPrintableChar(buffer[ptr])) ptr++; } boolean hasNext() { skipUnprintable(); return hasNextByte(); } public String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while (isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while (true) { if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; } else if (b == -1 || !isPrintableChar(b)) { return minus ? -n : n; } else { throw new NumberFormatException(); } b = readByte(); } } double nextDouble() { return Double.parseDouble(next()); } double[] nextDoubleArray(int n) { double[] array = new double[n]; for (int i = 0; i < n; i++) { array[i] = nextDouble(); } return array; } double[][] nextDoubleMap(int n, int m) { double[][] map = new double[n][]; for (int i = 0; i < n; i++) { map[i] = nextDoubleArray(m); } return map; } public int nextInt() { return (int) nextLong(); } public int[] nextIntArray(int n) { int[] array = new int[n]; for (int i = 0; i < n; i++) array[i] = nextInt(); return array; } public long[] nextLongArray(int n) { long[] array = new long[n]; for (int i = 0; i < n; i++) array[i] = nextLong(); return array; } public String[] nextStringArray(int n) { String[] array = new String[n]; for (int i = 0; i < n; i++) array[i] = next(); return array; } public char[][] nextCharMap(int n) { char[][] array = new char[n][]; for (int i = 0; i < n; i++) array[i] = next().toCharArray(); return array; } public int[][] nextIntMap(int n, int m) { int[][] map = new int[n][]; for (int i = 0; i < n; i++) { map[i] = nextIntArray(m); } return map; } } }