$a$와 $b$가 정해지면 $c^2$을 $n$으로 나눈 나머지가 정해지니까, $0\le x<n$에 대해
\begin{equation*} f(x) = c^2\text{을 } n\text{으로 나눈 나머지가 } x\text{인 } c\text{의 개수(단, } 1\le c<n\text{)} \end{equation*}
로 정의하면 구해야 하는 값은 다음과 같이 쓸 수 있습니다.
\begin{equation*} \sum_{i=1}^{n-1} \sum_{j=i}^{n-1} f((i^2 + j^2) \% n) \end{equation*}
$j$가 $i$에서 시작하는 것이 불편하므로 식을 변형합니다(번사이드 보조정리).
\begin{equation*} \frac{1}{2} \left[ \sum_{i=1}^{n-1} \sum_{j=1}^{n-1} f((i^2+j^2)\%n) + \sum_{i=1}^{n-1} f(2i^2\%n) \right] \end{equation*}
여기서 좀만 더 생각해보면 아래 식으로 바꿀 수 있습니다.
\begin{align*}
& \frac{1}{2} \left[ \sum_{i=1}^{n-1} \sum_{j=1}^{n-1} f(i)f(j)f((i+j)\%n) + \sum_{i=1}^{n-1} f(i)f(2i\%n) \right] \\
&\qquad = \frac{1}{2}\sum_{i=1}^{n-1} \left[ f(i) \sum_{j=1}^{n-1} f(j)f((i+j)\%n) + f(i)f(2i\%n) \right]
\end{align*}
이제 $g(i)$를
\begin{equation*} g(i) = \sum_{j=1}^{n-1} f(j)f((i+j)\%n) \end{equation*}
으로 정의하면 왠지 합성곱 냄새가 납니다. 원래 FFT의 방식대로 $f(x)$를 주기 $n$짜리 수열로 생각하면 $f((i+j)\%n)=f(i+j)$로 간단하게 표현되고, $j=k-i$를 대입하면
\begin{equation*} g(i) = \sum_{k=i+1}^{n-i+1} f(k)f(k-i) = \sum_{k=1}^{n-1} f(k)f(k-i) \end{equation*}
구한 식이 합성곱이랑 약간 다른데, $f'(x)=f(-x)$라는 새 수열을 정의하면
\begin{equation*} g(i) = \sum_{k=1}^{n-1} f(k) f'(i-k) \end{equation*}
가 되어 $g$는 $f$와 $f'$의 합성곱이 됩니다. 이제 $g$를 구하고,
\begin{equation*} \frac{1}{2}\sum_{i=1}^{n-1} \left[ f(i)g(i) + f(i)f(2i\%n) \right] \end{equation*}
로 정답을 계산합니다. 전체 시간 복잡도는 $O(n\log n)$입니다.
#include<bits/stdc++.h>
using namespace std;
using ull = unsigned long long;
using cdbl = complex<double>;
const double PI = acos(-1.);
inline unsigned bitreverse(const unsigned n, const unsigned k) {
unsigned r, i;
for (r = 0, i = 0; i < k; ++i)
r |= ((n >> i) & 1) << (k - i - 1);
return r;
}
void fft(vector<cdbl> &a, bool is_reverse=false) {
const unsigned n = a.size(), k = __builtin_ctz(n);
unsigned s, i, j;
for (i = 0; i < n; i++) {
j = bitreverse(i, k);
if (i < j)
swap(a[i], a[j]);
}
for (s = 2; s <= n; s *= 2) {
double t = 2*PI/s * (is_reverse? -1 : 1);
cdbl ws(cos(t), sin(t));
for (i = 0; i < n; i += s) {
cdbl w(1);
for (j = 0; j < s/2; j++) {
cdbl tmp = a[i + j + s/2] * w;
a[i + j + s/2] = a[i + j] - tmp;
a[i + j] += tmp;
w *= ws;
}
}
}
if (is_reverse)
for (i = 0; i < n; i++)
a[i] /= n;
}
int main(void) {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
unsigned n, np;
cin >> n;
for (np = 1; np < 2*n; np *= 2);
vector<unsigned> f(np, 0);
for (ull x = 1; x < n; x++)
f[x * x % n]++;
vector<cdbl> g(f.begin(), f.end());
vector<cdbl> fp(np, 0);
for (unsigned i = 0; i < n; i++)
fp[i] = fp[np-n+i] = f[(n - i) % n];
fft(g);
fft(fp);
for (unsigned i = 0; i < np; i++)
g[i] *= fp[i];
fft(g, true);
ull res = 0;
for (unsigned i = 0; i < n; i++)
res += f[i] * ((ull)(g[i].real() + 0.5) + f[2*i % n]);
cout << res / 2;
return 0;
}