myy的论文中的毒瘤优化。。多项式乘法可以做到2次或1.5次FFT。。
若要做 与 卷积
构造
推推式子发现 的共轭复数。于是只要把 点值求出, 点值就可以 求出,从而解出 和 的点值。于是做到了两次FFT。一次半那个。。咕了,没用啊,而且这个两次的肯定也没用啊。。
#include<bits/stdc++.h>
using namespace std;
const int N=3e6+50;
const double pi=acos(-1);
int n,m,a[N],b[N],lim=1,l=-1,r[N];
struct cp{
double x,y;
cp operator +(const cp &b){return (cp){x+b.x,y+b.y};}
cp operator -(const cp &b){return (cp){x-b.x,y-b.y};}
cp operator *(const cp &b){return (cp){x*b.x-y*b.y,x*b.y+y*b.x};}
cp operator ~(){return (cp){x,-y};}
}P[N],Q[N],A[N],B[N],rt[N],irt[N];
int read(){
int x=0,c;
while(!isdigit(c=getchar()));
while(isdigit(c))x=x*10+c-48,c=getchar();
return x;
}
void init(int n){
while(lim<n)lim<<=1,l++;
for(int i=0;i<lim;i++)r[i]=r[i>>1]>>1^(i&1)<<l;
rt[lim>>1]=(cp){1,0};cp w=(cp){cos(2*pi/lim),sin(2*pi/lim)};
for(int i=(lim>>1)+1;i<lim;i++)rt[i]=rt[i-1]*w;
for(int i=(lim>>1)-1;i;i--)rt[i]=rt[i<<1];
for(int i=1;i<lim;i++)irt[i]=~rt[i];
}
void FFT(cp *a,cp *w){
for(int i=0;i<lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int i=1;i<lim;i<<=1)
for(int t=i<<1,j=0;j<lim;j+=t)
for(int k=0;k<i;k++){
cp x=a[j+k],y=a[i+j+k]*w[i|k];
a[j+k]=x+y;a[i+j+k]=x-y;
}
}
int main(){
scanf("%d%d",&n,&m);init(n+m+1);
for(int i=0;i<=n;i++)P[i].x=read();
for(int i=0;i<=m;i++)P[i].y=read();
FFT(P,rt);Q[0]=~P[0];
for(int i=1;i<lim;i++)Q[i]=~P[lim-i];
for(int i=0;i<lim;i++)A[i]=(P[i]+Q[i])*(cp){1.0/2,0},B[i]=(P[i]-Q[i])*(cp){0,-1.0/2};
for(int i=0;i<lim;i++)A[i]=A[i]*B[i];
FFT(A,irt);
for(int i=0;i<=n+m;i++)printf("%d ",(int)(A[i].x/lim+0.5));
return 0;
}
任意模数NTT
三模数NTT没有写,直接上拆系数FFT吧。
我们知道卷积出的值大小是 级别,直接FFT精度不够。那么考虑把每个数表示成 ,其中 。
这样我们做4个卷积,每个卷积的精度都够了。如果加上上面提到的FFT次数优化,可以做到只用4次FFT。
然后有一个困扰我比较久的问题,看起来在计算过程中的数可能达到 级别,爆精度了,但是因为计算结果为 级别,所以爆掉的精度都在小数上,没有关系啦。
下面这份代码是封装得很好看的多项式求逆。看来不怕写任意模数了呢,其实特别好写。
#include<bits/stdc++.h>
#define LL long long
#define double long double
using namespace std;
const int N=4e5+50,M=32768,mod=1e9+7;
const double Pi=acos(-1);
int n,r[N],lo[N];
struct node{
double x,y;
node operator +(const node &b){return (node){x+b.x,y+b.y};}
node operator -(const node &b){return (node){x-b.x,y-b.y};}
node operator *(const node &b){return (node){x*b.x-y*b.y,x*b.y+y*b.x};}
node operator ~(){return (node){x,-y};}
}a[N],b[N],a1[N],b1[N],rt[N],irt[N],I,O,c[N],d[N];
int read(){
int x=0,c;
while(!isdigit(c=getchar()));
while(isdigit(c))x=x*10+c-48,c=getchar();
return x;
}
int Glim(int n){int lim=1;while(lim<n)lim<<=1;return lim;}
void getr(int n){for(int i=1;i<n;i++)r[i]=r[i>>1]>>1^(i&1)<<lo[n];}
int power(int x,int y){
int z=1;
for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)z=1ll*z*x%mod;
return z;
}
void init(int n){
int lim=Glim(n);
for(int i=2,j=0;i<=lim;i<<=1,j++)lo[i]=j;
for(int i=1;i<lim;i<<=1)for(int k=0;k<i;k++)
rt[i+k]=(node){cos(Pi*k/i),sin(Pi*k/i)};
for(int i=1;i<lim;i++)irt[i]=~rt[i];
}
void FFT(node *a,int n,int op){
for(int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int i=1;i<n;i<<=1)
for(int t=i<<1,j=0;j<n;j+=t)
for(int k=0;k<i;k++){
node x=a[j+k],y=a[i+j+k]*(op>0?rt[i+k]:irt[i+k]);
a[j+k]=x+y;a[i+j+k]=x-y;
}
if(op<0)for(int i=0;i<n;i++)a[i].x/=n,a[i].y/=n;
}
LL get(double x){return x>0?(LL)(x+0.5):(LL)(x-0.5);}
void MTT(node *a,node *b,int n){
for(int i=0;i<n;i++)
a[i]=(node){get(a[i].x)/M,get(a[i].x)%M},
b[i]=(node){get(b[i].x)/M,get(b[i].x)%M};
FFT(a,n,1);FFT(b,n,1);
for(int i=0;i<n;i++)a1[i]=~a[i?n-i:i],b1[i]=~b[i?n-i:i];
for(int i=0;i<n;i++){
node p=a[i],q=a1[i],s=b[i],t=b1[i];
a[i]=(p+q)*(node){0.5,0};
b[i]=(s+t)*(node){0.5,0};
a1[i]=(q-p)*(node){0,0.5};
b1[i]=(t-s)*(node){0,0.5};
}
for(int i=0;i<n;i++)a[i]=a[i]*b[i]+(a[i]*b1[i]+b[i]*a1[i])*I,b[i]=a1[i]*b1[i];
FFT(a,n,-1);FFT(b,n,-1);
for(int i=0;i<n;i++)a[i].x=(get(a[i].x)%mod*M%mod*M+get(a[i].y)%mod*M+get(b[i].x))%mod;
}
void Inv(node *a,node *b,int n){
if(n==1){b[0].x=power(get(a[0].x),mod-2);return;}
int m=(n+1)>>1,lim=Glim(n+m);Inv(a,b,m);
for(int i=0;i<lim;i++)c[i]=i<n?a[i]:O,d[i]=i<m?b[i]:O;
getr(lim);MTT(c,d,lim);
for(int i=0;i<lim;i++)d[i]=i<m?b[i]:O;
MTT(c,d,lim);
for(int i=m;i<n;i++)b[i].x=-c[i].x;
}
int main(){
scanf("%d",&n);init(2*n);I=(node){0,1};O=(node){0,0};
for(int i=0;i<n;i++)a[i].x=read();
Inv(a,b,n);
for(int i=0;i<n;i++)printf("%d ",(get(b[i].x)+mod)%mod);
return 0;
}