BOJ 5051:: 피타고라스의 정리 – FFT

https://www.acmicpc.net/problem/5051

$O(n^3)$

무식하게 모든 $a$, $b$, $c$를 다 시도해봅니다.

#include<bits/stdc++.h>

using namespace std;
using ull = unsigned long long;

int main(void) {
    ull n;
    cin >> n;

    ull res = 0;
    for (ull i = 1; i < n; i++)
        for (ull j = i; j < n; j++)
            for (ull k = 1; k < n; k++)
                if ((i*i + j*j) % n == k*k % n)
                    res++;
    
    cout << res;

    return 0;
}

$O(n^2)$

$a$와 $b$가 정해지면 $c^2$을 $n$으로 나눈 나머지가 정해지니까, $0\le x < n$에 대해 $f(x)$를

$f(x) := c^2$을 $n$으로 나눈 나머지가 $x$인 $c$의 개수(단, $1 \le c < n$)

로 정의하면 구해야 하는 값은

\[ \sum_{i=1}^{n-1} \sum_{j=i}^{n-1} f((i^2 + j^2) \% n) \]

으로 쓸 수 있습니다. 그런데 $a \le b$라는 조건 때문에 두 번째 $\sum$가 $j=i$부터라 계산하기 귀찮으니, 일단 $a \le b$란 조건을 없애봅시다. 아무 조건 없이 계산한 값에서 $a=b$인 경우를 빼면 $a \neq b$인 경우를 셀 수 있고, 그걸 반으로 나누면 $a < b$인 경우를 셀 수 있습니다. 여기에 다시 $a=b$인 경우를 더해주면 되죠. (또는 번사이드 레마를 써도 됩니다.)

\[ \begin{align}
& \frac{1}{2} \left[ \sum_{i=1}^{n-1} \sum_{j=1}^{n-1} f((i^2+j^2)\%n) – \sum_{i=1}^{n-1} f(2i^2\%n) \right] + \sum_{i=1}^{n-1} f(2i^2\%n) \\
& = \frac{1}{2} \left[ \sum_{i=1}^{n-1} \sum_{j=1}^{n-1} f((i^2+j^2)\%n) + \sum_{i=1}^{n-1} f(2i^2\%n) \right]
\end{align} \]

여기서 좀만 더 생각해보면 아래 식으로 바꿀 수 있습니다.

\[ \begin{align}
& \frac{1}{2} \left[ \sum_{i=1}^{n-1} \sum_{j=1}^{n-1} f(i)f(j)f((i+j)\%n) + \sum_{i=1}^{n-1} f(i)f(2i\%n) \right] \\
& = \frac{1}{2}\sum_{i=1}^{n-1} \left[ f(i) \sum_{j=1}^{n-1} f(j)f((i+j)\%n) + f(i)f(2i\%n) \right]
\end{align} \]

$O(n\log n)$

$g(i)$를

\[ g(i) = \sum_{j=1}^{n-1} f(j)f((i+j)\%n) \]

으로 정의하면 왠지 합성곱 냄새가 납니다. 원래 FFT의 방식대로 $f(x)$를 주기 $n$짜리 수열로 생각하면 $f((i+j)\%n)=f(i+j)$로 간단하게 표현되고, $j=k-i$를 대입하면

\[ g(i) = \sum_{k=i+1}^{n-i+1} f(k)f(k-i) = \sum_{k=1}^{n-1} f(k)f(k-i) \]

구한 식이 합성곱이랑 약간 다른데, $f'(x) = f(-x)$라는 새 수열을 정의하면

\[ g(i) = \sum_{k=1}^{n-1} f(k) f'(i-k) \]

가 되어 $g$는 $f$와 $f’$의 합성곱이 됩니다. 이제 $g$를 구하고,

\[ \frac{1}{2}\sum_{i=1}^{n-1} \left[ f(i)g(i) + f(i)f(2i\%n) \right] \]

로 정답을 계산합니다.

#include<bits/stdc++.h>

using namespace std;
using ull = unsigned long long;
using cdbl = complex<double>;

const double PI = acos(-1.);

inline unsigned bitreverse(const unsigned n, const unsigned k) {
    unsigned r, i;
    for (r = 0, i = 0; i < k; ++i)
        r |= ((n >> i) & 1) << (k - i - 1);
    return r;
}

void fft(vector<cdbl> &a, bool is_reverse=false) {
    const unsigned n = a.size(), k = __builtin_ctz(n);
    unsigned s, i, j;
    for (i = 0; i < n; i++) {
        j = bitreverse(i, k);
        if (i < j)
            swap(a[i], a[j]);
    }
    for (s = 2; s <= n; s *= 2) {
        double t = 2*PI/s * (is_reverse? -1 : 1);
        cdbl ws(cos(t), sin(t));
        for (i = 0; i < n; i += s) {
            cdbl w(1);
            for (j = 0; j < s/2; j++) {
                cdbl tmp = a[i + j + s/2] * w;
                a[i + j + s/2] = a[i + j] - tmp;
                a[i + j] += tmp;
                w *= ws;
            }
        }
    }
    if (is_reverse)
        for (i = 0; i < n; i++)
            a[i] /= n;
}

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    unsigned n, np;
    cin >> n;
    for (np = 1; np < 2*n; np *= 2);

    vector<unsigned> f(np, 0);
    for (ull x = 1; x < n; x++)
        f[x * x % n]++;

    vector<cdbl> g(f.begin(), f.end());
    vector<cdbl> fp(np, 0);
    for (unsigned i = 0; i < n; i++)
        fp[i] = fp[np-n+i] = f[(n - i) % n];

    fft(g);
    fft(fp);
    for (unsigned i = 0; i < np; i++)
        g[i] *= fp[i];
    fft(g, true);

    ull res = 0;
    for (unsigned i = 0; i < n; i++)
        res += f[i] * ((ull)(g[i].real() + 0.5) + f[2*i % n]);

    cout << res / 2;

    return 0;
}