myy的论文中的毒瘤优化。。多项式乘法可以做到2次或1.5次FFT。。
若要做 AABB 卷积
构造
P=A+BiP=A+Bi
Q=ABiQ=A-Bi
推推式子发现 Q(ωk)=P(ωlenk)Q(\omega^k)=P(\omega^{len-k})的共轭复数。于是只要把 PP 点值求出,QQ 点值就可以 O(n)O(n) 求出,从而解出 AABB 的点值。于是做到了两次FFT。一次半那个。。咕了,没用啊,而且这个两次的肯定也没用啊。。

#include<bits/stdc++.h>
using namespace std;
const int N=3e6+50;
const double pi=acos(-1);
int n,m,a[N],b[N],lim=1,l=-1,r[N];
struct cp{
    double x,y;
    cp operator +(const cp &b){return (cp){x+b.x,y+b.y};}
    cp operator -(const cp &b){return (cp){x-b.x,y-b.y};}
    cp operator *(const cp &b){return (cp){x*b.x-y*b.y,x*b.y+y*b.x};}
    cp operator ~(){return (cp){x,-y};}
}P[N],Q[N],A[N],B[N],rt[N],irt[N];
int read(){
    int x=0,c;
    while(!isdigit(c=getchar()));
    while(isdigit(c))x=x*10+c-48,c=getchar();
    return x;
}
void init(int n){
    while(lim<n)lim<<=1,l++;
    for(int i=0;i<lim;i++)r[i]=r[i>>1]>>1^(i&1)<<l;
    rt[lim>>1]=(cp){1,0};cp w=(cp){cos(2*pi/lim),sin(2*pi/lim)};
    for(int i=(lim>>1)+1;i<lim;i++)rt[i]=rt[i-1]*w;
    for(int i=(lim>>1)-1;i;i--)rt[i]=rt[i<<1];
    for(int i=1;i<lim;i++)irt[i]=~rt[i];
}
void FFT(cp *a,cp *w){
    for(int i=0;i<lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(int i=1;i<lim;i<<=1)
        for(int t=i<<1,j=0;j<lim;j+=t)
            for(int k=0;k<i;k++){
                cp x=a[j+k],y=a[i+j+k]*w[i|k];
                a[j+k]=x+y;a[i+j+k]=x-y;
            }
}
int main(){
    scanf("%d%d",&n,&m);init(n+m+1);
    for(int i=0;i<=n;i++)P[i].x=read();
    for(int i=0;i<=m;i++)P[i].y=read();
    FFT(P,rt);Q[0]=~P[0];
    for(int i=1;i<lim;i++)Q[i]=~P[lim-i];
    for(int i=0;i<lim;i++)A[i]=(P[i]+Q[i])*(cp){1.0/2,0},B[i]=(P[i]-Q[i])*(cp){0,-1.0/2};
    for(int i=0;i<lim;i++)A[i]=A[i]*B[i];
    FFT(A,irt);
    for(int i=0;i<=n+m;i++)printf("%d ",(int)(A[i].x/lim+0.5));
    return 0;
}

任意模数NTT

三模数NTT没有写,直接上拆系数FFT吧。
我们知道卷积出的值大小是 102310^{23} 级别,直接FFT精度不够。那么考虑把每个数表示成 aM+ba*M+b,其中 M=215M=2^{15}
这样我们做4个卷积,每个卷积的精度都够了。如果加上上面提到的FFT次数优化,可以做到只用4次FFT。

然后有一个困扰我比较久的问题,看起来在计算过程中的数可能达到 101910^19 级别,爆精度了,但是因为计算结果为 101410^14 级别,所以爆掉的精度都在小数上,没有关系啦。

下面这份代码是封装得很好看的多项式求逆。看来不怕写任意模数了呢,其实特别好写。

#include<bits/stdc++.h>
#define LL long long
#define double long double
using namespace std;
const int N=4e5+50,M=32768,mod=1e9+7;
const double Pi=acos(-1);
int n,r[N],lo[N];
struct node{
	double x,y;
	node operator +(const node &b){return (node){x+b.x,y+b.y};}
	node operator -(const node &b){return (node){x-b.x,y-b.y};}
	node operator *(const node &b){return (node){x*b.x-y*b.y,x*b.y+y*b.x};}
	node operator ~(){return (node){x,-y};}
}a[N],b[N],a1[N],b1[N],rt[N],irt[N],I,O,c[N],d[N];
int read(){
	int x=0,c;
	while(!isdigit(c=getchar()));
	while(isdigit(c))x=x*10+c-48,c=getchar();
	return x;
}
int Glim(int n){int lim=1;while(lim<n)lim<<=1;return lim;}
void getr(int n){for(int i=1;i<n;i++)r[i]=r[i>>1]>>1^(i&1)<<lo[n];}
int power(int x,int y){
	int z=1;
	for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*z*x%mod;
	return z;
}
void init(int n){
	int lim=Glim(n);
	for(int i=2,j=0;i<=lim;i<<=1,j++)lo[i]=j;
	for(int i=1;i<lim;i<<=1)for(int k=0;k<i;k++)
		rt[i+k]=(node){cos(Pi*k/i),sin(Pi*k/i)};
	for(int i=1;i<lim;i++)irt[i]=~rt[i];
}
void FFT(node *a,int n,int op){
	for(int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
	for(int i=1;i<n;i<<=1)
		for(int t=i<<1,j=0;j<n;j+=t)
			for(int k=0;k<i;k++){
				node x=a[j+k],y=a[i+j+k]*(op>0?rt[i+k]:irt[i+k]);
				a[j+k]=x+y;a[i+j+k]=x-y;
			}
	if(op<0)for(int i=0;i<n;i++)a[i].x/=n,a[i].y/=n;
}
LL get(double x){return x>0?(LL)(x+0.5):(LL)(x-0.5);}
void MTT(node *a,node *b,int n){
	for(int i=0;i<n;i++)
		a[i]=(node){get(a[i].x)/M,get(a[i].x)%M},
		b[i]=(node){get(b[i].x)/M,get(b[i].x)%M};
	FFT(a,n,1);FFT(b,n,1);
	for(int i=0;i<n;i++)a1[i]=~a[i?n-i:i],b1[i]=~b[i?n-i:i];
	for(int i=0;i<n;i++){
		node p=a[i],q=a1[i],s=b[i],t=b1[i];
		a[i]=(p+q)*(node){0.5,0};
		b[i]=(s+t)*(node){0.5,0};
		a1[i]=(q-p)*(node){0,0.5};
		b1[i]=(t-s)*(node){0,0.5};
	}
	for(int i=0;i<n;i++)a[i]=a[i]*b[i]+(a[i]*b1[i]+b[i]*a1[i])*I,b[i]=a1[i]*b1[i];
	FFT(a,n,-1);FFT(b,n,-1);
	for(int i=0;i<n;i++)a[i].x=(get(a[i].x)%mod*M%mod*M+get(a[i].y)%mod*M+get(b[i].x))%mod;
}
void Inv(node *a,node *b,int n){
	if(n==1){b[0].x=power(get(a[0].x),mod-2);return;}
	int m=(n+1)>>1,lim=Glim(n+m);Inv(a,b,m);
	for(int i=0;i<lim;i++)c[i]=i<n?a[i]:O,d[i]=i<m?b[i]:O;
	getr(lim);MTT(c,d,lim);
	for(int i=0;i<lim;i++)d[i]=i<m?b[i]:O;
	MTT(c,d,lim);
	for(int i=m;i<n;i++)b[i].x=-c[i].x;
}
int main(){
	scanf("%d",&n);init(2*n);I=(node){0,1};O=(node){0,0};
	for(int i=0;i<n;i++)a[i].x=read();
	Inv(a,b,n);
	for(int i=0;i<n;i++)printf("%d ",(get(b[i].x)+mod)%mod);
	return 0;
}