Up

BOJ 5051:: 피타고라스의 정리

$a$와 $b$가 정해지면 $c^2$을 $n$으로 나눈 나머지가 정해지니까, $0\le x<n$에 대해

\begin{equation*} f(x) = c^2\text{을 } n\text{으로 나눈 나머지가 } x\text{인 } c\text{의 개수(단, } 1\le c<n\text{)} \end{equation*}

로 정의하면 구해야 하는 값은 다음과 같이 쓸 수 있습니다.

\begin{equation*} \sum_{i=1}^{n-1} \sum_{j=i}^{n-1} f((i^2 + j^2) \% n) \end{equation*}

$j$가 $i$에서 시작하는 것이 불편하므로 식을 변형합니다(번사이드 보조정리).

\begin{equation*} \frac{1}{2} \left[ \sum_{i=1}^{n-1} \sum_{j=1}^{n-1} f((i^2+j^2)\%n) + \sum_{i=1}^{n-1} f(2i^2\%n) \right] \end{equation*}

여기서 좀만 더 생각해보면 아래 식으로 바꿀 수 있습니다.

\begin{align*} & \frac{1}{2} \left[ \sum_{i=1}^{n-1} \sum_{j=1}^{n-1} f(i)f(j)f((i+j)\%n) + \sum_{i=1}^{n-1} f(i)f(2i\%n) \right] \\ &\qquad = \frac{1}{2}\sum_{i=1}^{n-1} \left[ f(i) \sum_{j=1}^{n-1} f(j)f((i+j)\%n) + f(i)f(2i\%n) \right] \end{align*}

이제 $g(i)$를

\begin{equation*} g(i) = \sum_{j=1}^{n-1} f(j)f((i+j)\%n) \end{equation*}

으로 정의하면 왠지 합성곱 냄새가 납니다. 원래 FFT의 방식대로 $f(x)$를 주기 $n$짜리 수열로 생각하면 $f((i+j)\%n)=f(i+j)$로 간단하게 표현되고, $j=k-i$를 대입하면

\begin{equation*} g(i) = \sum_{k=i+1}^{n-i+1} f(k)f(k-i) = \sum_{k=1}^{n-1} f(k)f(k-i) \end{equation*}

구한 식이 합성곱이랑 약간 다른데, $f'(x)=f(-x)$라는 새 수열을 정의하면

\begin{equation*} g(i) = \sum_{k=1}^{n-1} f(k) f'(i-k) \end{equation*}

가 되어 $g$는 $f$와 $f'$의 합성곱이 됩니다. 이제 $g$를 구하고,

\begin{equation*} \frac{1}{2}\sum_{i=1}^{n-1} \left[ f(i)g(i) + f(i)f(2i\%n) \right] \end{equation*}

로 정답을 계산합니다. 전체 시간 복잡도는 $O(n\log n)$입니다.

#include<bits/stdc++.h>

using namespace std;
using ull = unsigned long long;
using cdbl = complex<double>;

const double PI = acos(-1.);

inline unsigned bitreverse(const unsigned n, const unsigned k) {
    unsigned r, i;
    for (r = 0, i = 0; i < k; ++i)
        r |= ((n >> i) & 1) << (k - i - 1);
    return r;
}

void fft(vector<cdbl> &a, bool is_reverse=false) {
    const unsigned n = a.size(), k = __builtin_ctz(n);
    unsigned s, i, j;
    for (i = 0; i < n; i++) {
        j = bitreverse(i, k);
        if (i < j)
            swap(a[i], a[j]);
    }
    for (s = 2; s <= n; s *= 2) {
        double t = 2*PI/s * (is_reverse? -1 : 1);
        cdbl ws(cos(t), sin(t));
        for (i = 0; i < n; i += s) {
            cdbl w(1);
            for (j = 0; j < s/2; j++) {
                cdbl tmp = a[i + j + s/2] * w;
                a[i + j + s/2] = a[i + j] - tmp;
                a[i + j] += tmp;
                w *= ws;
            }
        }
    }
    if (is_reverse)
        for (i = 0; i < n; i++)
            a[i] /= n;
}

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    unsigned n, np;
    cin >> n;
    for (np = 1; np < 2*n; np *= 2);

    vector<unsigned> f(np, 0);
    for (ull x = 1; x < n; x++)
        f[x * x % n]++;

    vector<cdbl> g(f.begin(), f.end());
    vector<cdbl> fp(np, 0);
    for (unsigned i = 0; i < n; i++)
        fp[i] = fp[np-n+i] = f[(n - i) % n];

    fft(g);
    fft(fp);
    for (unsigned i = 0; i < np; i++)
        g[i] *= fp[i];
    fft(g, true);

    ull res = 0;
    for (unsigned i = 0; i < n; i++)
        res += f[i] * ((ull)(g[i].real() + 0.5) + f[2*i % n]);

    cout << res / 2;

    return 0;
}