BZOJ3509

这是蒟蒻第一次独立用多项式算法解决问题qwq

题意很简单,给你一个序列,问你其中形成等差序列的3元子序列有多少。

一个很直接的想法是固定中间一个元素aja_j之后看看左右有多少两边有多少对满足和等于2aj2a_j。这样可以想到从左到右枚举j,同时维护j左边的桶和j右边的桶,
每次把这两个桶卷起来。可这样是O(n2logn)O(n^2logn)的,无法通过。

考虑分块,先把至少有两个元素在同一个块内的答案处理完后,就可以只优化刚才那个过程,从左到右枚举块,只用做块个数次FFT。同一个块内的答案很好计算,所以就做完啦
需要调调块大小才能过,块个数要少点

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int N=2e5+50,L=30001;
const long double Pi=acos(-1);
int n,sq,bel[N],c[N],p[N],lim=1,len=-1,num1[N],num2[N];LL ans;
struct node{
    long double x,y;
    node operator +(const node &b){return (node){x+b.x,y+b.y};}
    node operator -(const node &b){return (node){x-b.x,y-b.y};}
    node operator *(const node &b){return (node){x*b.x-y*b.y,x*b.y+y*b.x};}
}a[N],b[N];
int read(){
    int x=0,c;
    while(!isdigit(c=getchar()));
    while(isdigit(c))x=x*10+(c^48),c=getchar();
    return x;
}
void FFT(node *a,int opt){
    for(int i=0;i<lim;i++)if(i<p[i])swap(a[i],a[p[i]]);
    for(int i=1;i<lim;i<<=1){
        node Wn=(node){cos(Pi/i),opt*sin(Pi/i)};
        for(int t=i<<1,j=0;j<lim;j+=t){
            node w=(node){1,0};
            for(int k=0;k<i;k++,w=w*Wn){
                node x=a[j+k],y=a[i+j+k]*w;
                a[j+k]=x+y;a[i+j+k]=x-y;
            }
        }
    }
    if(opt<0)for(int i=0;i<lim;i++)a[i].x/=lim;
}
int main(){
    while(lim<L*2)lim<<=1,len++;
    for(int i=1;i<lim;i++)p[i]=p[i>>1]>>1^(i&1)<<len;
    n=read();sq=sqrt(n)*3;
    for(int i=1;i<=n;i++)num1[c[i]=read()]++;
    for(int i=1;i<=n;i++)bel[i]=(i-1)/sq+1;
    for(int i=1,j;i<=n;i++){
        j=i;
        while(j<n&&bel[j+1]==bel[i])j++;
        for(int k=i;k<=j;k++){
            num1[c[k]]--;
            for(int l=i;l<k;l++)
                if(2*c[k]>=c[l])ans+=num1[2*c[k]-c[l]];
            for(int l=j;l>k;l--)
                if(2*c[k]>=c[l])ans+=num2[2*c[k]-c[l]];
        }
        for(int k=i;k<=j;k++)num2[c[k]]++;
        i=j;
    }
    for(int i=1,j;i<=n;i++){
        j=i;num2[c[j]]--;
        while(j<n&&bel[j+1]==bel[j])num2[c[++j]]--;
        for(int k=0;k<lim;k++)a[k]=(node){num1[k],0},b[k]=(node){num2[k],0};
        FFT(a,1);FFT(b,1);
        for(int k=0;k<lim;k++)a[k]=a[k]*b[k];
        FFT(a,-1);
        for(int k=i;k<=j;k++)ans+=(LL)(a[c[k]*2].x+0.5),num1[c[k]]++;
        i=j;
    }
    printf("%lld\n",ans);
}