コード
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.NoSuchElementException;
public class Main {
private static final int MAX = 2500;
private long intervalCount(ArrayList<Integer> lines, FenwickTree bit) {
long ans = 0;
for (int i = -1; i < lines.size(); i++) {
int x1 = i >= 0 ? lines.get(i) + 1 : 0;
int x2 = i + 1 < lines.size() ? lines.get(i + 1) : MAX;
if (x2 - x1 >= 2) {
long k = bit.sum(x1, x2);
ans += k * (k - 1) / 2;
}
}
return ans;
}
private void solve(FastScanner in, PrintWriter out) {
int N = in.nextInt();
ArrayList<Integer>[] xs = new ArrayList[MAX];
for (int i = 0; i < MAX; i++) {
xs[i] = new ArrayList<>();
}
for (int i = 0; i < N; i++) {
int y = in.nextInt() - 1;
int x = in.nextInt() - 1;
xs[y].add(x);
}
for (ArrayList<Integer> list : xs) {
Collections.sort(list);
}
long ans = 0;
for (int left = 0; left < MAX; left++) {
if (xs[left].isEmpty()) {
continue;
}
boolean[] added = new boolean[MAX];
FenwickTree bit = new FenwickTree(MAX);
int addedCount = 0;
for (int right = left; right < MAX; right++) {
if (xs[right].isEmpty()) {
continue;
}
for (int x : xs[right]) {
if (!added[x]) {
added[x] = true;
bit.add(x, 1);
addedCount++;
}
}
if (right == left) {
continue;
}
ArrayList<Integer> merged = new ArrayList<>(xs[left].size() + xs[right].size());
for (int i = 0, j = 0; i < xs[left].size() || j < xs[right].size(); ) {
int leftTop = i < xs[left].size() ? xs[left].get(i) : MAX;
int rightTop = j < xs[right].size() ? xs[right].get(j) : MAX;
if (leftTop == rightTop) {
merged.add(leftTop);
i++;
j++;
} else if (leftTop < rightTop) {
merged.add(leftTop);
i++;
} else {
merged.add(rightTop);
j++;
}
}
ans += addedCount * (addedCount - 1) / 2;
ans -= intervalCount(xs[left], bit);
ans -= intervalCount(xs[right], bit);
ans += intervalCount(merged, bit);
}
}
out.println(ans);
}
class FenwickTree {
int N;
long[] data;
FenwickTree(int N) {
this.N = N + 1;
data = new long[N + 1];
}
void add(int k, long val) {
for (int x = k; x < N; x |= x + 1) {
data[x] += val;
}
}
long sum(int k) {
if (k >= N) {
k = N - 1;
}
long ret = 0;
for (int x = k - 1; x >= 0; x = (x & (x + 1)) - 1) {
ret += data[x];
}
return ret;
}
long sum(int l, int r) {
return sum(r) - sum(l);
}
}
public static void main(String[] args) {
PrintWriter out = new PrintWriter(System.out);
new Main().solve(new FastScanner(), out);
out.close();
}
private static class FastScanner {
private final InputStream in = System.in;
private final byte[] buffer = new byte[1024];
private int ptr = 0;
private int bufferLength = 0;
private boolean hasNextByte() {
if (ptr < bufferLength) {
return true;
} else {
ptr = 0;
try {
bufferLength = in.read(buffer);
} catch (IOException e) {
e.printStackTrace();
}
if (bufferLength <= 0) {
return false;
}
}
return true;
}
private int readByte() {
if (hasNextByte()) {
return buffer[ptr++];
} else {
return -1;
}
}
private static boolean isPrintableChar(int c) {
return 33 <= c && c <= 126;
}
private void skipUnprintable() {
while (hasNextByte() && !isPrintableChar(buffer[ptr])) {
ptr++;
}
}
boolean hasNext() {
skipUnprintable();
return hasNextByte();
}
public String next() {
if (!hasNext()) {
throw new NoSuchElementException();
}
StringBuilder sb = new StringBuilder();
int b = readByte();
while (isPrintableChar(b)) {
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
long nextLong() {
if (!hasNext()) {
throw new NoSuchElementException();
}
long n = 0;
boolean minus = false;
int b = readByte();
if (b == '-') {
minus = true;
b = readByte();
}
if (b < '0' || '9' < b) {
throw new NumberFormatException();
}
while (true) {
if ('0' <= b && b <= '9') {
n *= 10;
n += b - '0';
} else if (b == -1 || !isPrintableChar(b)) {
return minus ? -n : n;
} else {
throw new NumberFormatException();
}
b = readByte();
}
}
double nextDouble() {
return Double.parseDouble(next());
}
double[] nextDoubleArray(int n) {
double[] array = new double[n];
for (int i = 0; i < n; i++) {
array[i] = nextDouble();
}
return array;
}
double[][] nextDoubleMap(int n, int m) {
double[][] map = new double[n][];
for (int i = 0; i < n; i++) {
map[i] = nextDoubleArray(m);
}
return map;
}
public int nextInt() {
return (int) nextLong();
}
public int[] nextIntArray(int n) {
int[] array = new int[n];
for (int i = 0; i < n; i++) {
array[i] = nextInt();
}
return array;
}
public long[] nextLongArray(int n) {
long[] array = new long[n];
for (int i = 0; i < n; i++) {
array[i] = nextLong();
}
return array;
}
public String[] nextStringArray(int n) {
String[] array = new String[n];
for (int i = 0; i < n; i++) {
array[i] = next();
}
return array;
}
public char[][] nextCharMap(int n) {
char[][] array = new char[n][];
for (int i = 0; i < n; i++) {
array[i] = next().toCharArray();
}
return array;
}
public int[][] nextIntMap(int n, int m) {
int[][] map = new int[n][];
for (int i = 0; i < n; i++) {
map[i] = nextIntArray(m);
}
return map;
}
}
}