AtCoder ABC #136 F - Enclosed Points

F - Enclosed Points

問題

TODO

BIT

解説の前に今回のBITの使いどころについて軽く整理。

いまx座標で昇順に並べた点(例:  (0,1), (1,4), (2,2), (3,3), (4,0)) があるとする。

BITを利用すると各点を原点として特定の象限に含まれる点の数を順に求めることができる。

例えば第3象限の場合は以下をx座標昇順に繰り返せば良い。

  1. BIT.sum(対象点のy座標)が点の第3象限に含まれる点の数となる
  2. BIT.add(対象点のy座標, 1)

第2象限のばあいは例の場合であれば1.が対象のx座標 - BIT.sum(対象点のy座標)となる。 (x座標と対象点の左側にある点の数が一致するため)

第1,4象限の場合は最初にBITの葉を1で初期化しておきBIT.add(対象点のy座標, -1)とすることで同様に求めることができる。

実装

BITの実装は以下の通り。スニペットにしてある

#[derive(Debug)]
struct FenwickTree {
    n: usize,
    bit: Vec<isize>,
}

impl FenwickTree {
    fn new(n: usize) -> Self {
        FenwickTree {
            n: n,
            bit: vec![0; n],
        }
    }

    // [0, i)
    fn sum(&self, i: usize) -> isize {
        let mut s = 0;
        let mut i = i as isize;
        while i > 0 {
            i -= 1;
            s += self.bit[i as usize];
            i &= i + 1;
        }
        s
    }

    fn add(&mut self, i: usize, x: isize) {
        let mut i = i;
        while i < self.n {
            self.bit[i] += x;
            i |= i + 1;
        }
    }
}

解説

コードは以下の通り

  • p2に予め2の累乗のMODを記録しておく
  • Vec::binary_search()を用いて各点の座標を0〜N-1に圧縮
  • 2つのBIT(左象限用、右象限用)を作成し右象限用は1で初期化
  • r1r4は対象点を原点とした各象限の点数
  • pr1pr4は各象限の点が取りうる場合の数(の剰余)

あとは公式解説に従って24通りについて場合を求めて足し合わせている。

これでACにはなるのだが、長い、長すぎる。他の回答はこれほどは長くはないので、何かしらもう少し効率よく記述できると思われるがとりあえずTODOとしておく。

static MOD: isize = 998244353;

fn solve(xys: Vec<(isize, isize)>) -> isize {
    let n = xys.len() as isize;
    let mut p2 = vec![0 as isize; 200002];
    p2[0] = 1;
    for x in 1..p2.len() {
        p2[x] = (p2[x - 1] * 2) % MOD;
    }

    let mut sortedx = xys.iter().map(|&(x, _)| x).collect::<Vec<_>>();
    sortedx.sort();

    let mut sortedy = xys.iter().map(|&(_, y)| y).collect::<Vec<_>>();
    sortedy.sort();

    let mut compressed = vec![(0, 0); xys.len()];
    for (i, &(x, y)) in xys.iter().enumerate() {
        let ix = sortedx.binary_search(&x).unwrap() as isize;
        let iy = sortedy.binary_search(&y).unwrap() as isize;
        compressed[i] = (ix, iy);
    }

    compressed.sort_by_key(|&(x, _)| x);

    let mut lft = FenwickTree::new(xys.len());
    let mut rft = FenwickTree::new(xys.len());
    for i in 0..xys.len() {
        rft.add(i, 1);
    }

    let mut ans = 0;

    for (x, y) in compressed {
        let r3 = lft.sum(y as usize) as usize;
        let r2 = (x - r3 as isize) as usize;
        let r4 = rft.sum(y as usize) as usize;
        let r1 = (n - x - 1 - r4 as isize) as usize;

        let pr1 = p2[r1] - 1;
        let pr2 = p2[r2] - 1;
        let pr3 = p2[r3] - 1;
        let pr4 = p2[r4] - 1;

        // 各象限と自身のみ
        ans += pr1 + pr2 + pr3 + pr4;
        ans %= MOD;

        // 隣り合う象限と自身
        ans += if r1 > 0 && r2 > 0 {
            (pr1 * pr2) % MOD
        } else {
            0
        } + if r2 > 0 && r3 > 0 {
            (pr2 * pr3) % MOD
        } else {
            0
        } + if r3 > 0 && r4 > 0 {
            (pr3 * pr4) % MOD
        } else {
            0
        } + if r4 > 0 && r1 > 0 {
            (pr4 * pr1) % MOD
        } else {
            0
        };
        ans %= MOD;

        // r1+r3, r2+r4 (自身を含む場合と含まない場合で2倍)
        ans += if r1 > 0 && r3 > 0 {
            (pr1 * pr3) * 2 % MOD
        } else {
            0
        } + if r2 > 0 && r4 > 0 {
            (pr2 * pr4) * 2 % MOD
        } else {
            0
        };
        ans %= MOD;

        // 3象限を含む場合
        ans += if r1 > 0 && r2 > 0 && r3 > 0 {
            (((pr1 * pr2) % MOD) * pr3 * 2) % MOD
        } else {
            0
        } + if r2 > 0 && r3 > 0 && r4 > 0 {
            (((pr2 * pr3) % MOD) * pr4 * 2) % MOD
        } else {
            0
        } + if r3 > 0 && r4 > 0 && r1 > 0 {
            (((pr3 * pr4) % MOD) * pr1 * 2) % MOD
        } else {
            0
        } + if r4 > 0 && r1 > 0 && r2 > 0 {
            (((pr4 * pr1) % MOD) * pr2 * 2) % MOD
        } else {
            0
        };
        ans %= MOD;

        // 全て含む場合(自身を含む場合と含まない場合で2倍)
        ans += if r1 > 0 && r2 > 0 && r3 > 0 && r4 > 0 {
            (((((pr1 * pr2) % MOD) * pr3) % MOD) * pr4 * 2) % MOD
        } else {
            0
        };
        ans %= MOD;

        lft.add(y as usize, 1);
        rft.add(y as usize, -1);
    }

    ans += n;
    ans %= MOD;

    ans
}