1 条题解

  • 0
    @ 2022-12-9 7:53:48

    C++ :

    #pragma GCC optimize ("O2")
    #include<bits/stdc++.h>
    #include<tr1/unordered_map>
    #define pb push_back
    #define pii pair<int,int>
    #define pvv pair<vector<int>,vector<int>>
    #define Fx first
    #define Gx second
    #define SZ(x) (int)x.size()
    #ifdef __linux__
    #define getchar getchar_unlocked
    #endif
    #define mem(a,b) memset(a,b,sizeof(a))
    #define  For(i,a,b) for(int i=a,i##E=b;i<=i##E;++i)
    #define rFor(i,a,b) for(int i=a,i##E=b;i>=i##E;--i)
    typedef long long LL;
    using namespace std;
    const int N=200010;
    const int M=110;
    const int inf=0x3f3f3f3f;
    const int mod=998244353;
    template<typename T>inline bool chkmin(T &a,const T &b){return a>b?a=b,1:0;}
    template<typename T>inline bool chkmax(T &a,const T &b){return a<b?a=b,1:0;}
    template<typename T>inline void read(T &x)
    {
    	x=0;int _f(0);char ch=getchar();
    	while(!isdigit(ch))_f|=(ch=='-'),ch=getchar();
    	while( isdigit(ch))x=x*10+ch-'0',ch=getchar();
    	x=_f?-x:x;
    }
    inline void file()
    {
    #ifdef ztzshiwo
    	freopen("in.txt","r",stdin);
    	freopen("out.txt","w",stdout);
    #endif
    }
    int n,m;
    int rev[N];
    int S[M][M],A[M],B[M],a[N];
    int fac[N],ifac[N];
    int LGx[N],LFx[N],RGx[N],RFx[N],w[N],iw[N],tot;
    inline void Add(int &x,const int&y){x=x+y<mod?x+y:x+y-mod;}
    inline int qpow(int a,int b)
    {
    	int ret=1;
    	for(;b;b>>=1,a=1ll*a*a%mod)
    		if(b&1)ret=1ll*ret*a%mod;
    	return ret;
    }
    #define Inv(x) qpow(x,mod-2);
    inline void init()
    {
    	fac[0]=ifac[0]=1;
    	For(i,1,n)fac[i]=1ll*fac[i-1]*i%mod;
    	ifac[n]=Inv(fac[n]);
    	rFor(i,n,2)ifac[i-1]=1ll*ifac[i]*i%mod;
    	S[0][0]=1;
    	For(i,1,m<<1)For(j,0,i)S[i][j]=(1ll*S[i-1][j]*j%mod+(j?S[i-1][j-1]:0))%mod;
    	For(i,0,m)A[i]=(1ll*S[m][i]*fac[i]%mod+1ll*S[m][i+1]*fac[i+1]%mod)*ifac[i]%mod;
    	int k=m<<1;
    	For(i,0,k)B[i]=(1ll*S[k][i]*fac[i]%mod+1ll*S[k][i+1]*fac[i+1]%mod)*ifac[i]%mod;
    	tot=1<<16;
    	LL ret=qpow(3,mod/tot);
    	w[0]=1;
    	For(i,1,tot)w[i]=1ll*w[i-1]*ret%mod;
    	For(i,0,tot)iw[i]=w[tot-i];
    }
    inline void Copy(int *x,const vector<int>&S,const int&Base)
    {
    	For(i,0,SZ(S)-1)x[i]=S[i];
    	//memset(x+SZ(S),0,(Base-SZ(S))*sizeof(int));
    	fill(x+SZ(S),x+Base,0);
    }
    inline int init_NTT(const int&n)
    {
    	int Base=1,len=0;
    	for(;Base<=n;Base<<=1)len++;
    	For(i,1,Base)rev[i]=(rev[i>>1]>>1)|(i&1)<<(len-1);
    	return Base;
    }
    inline void NTT(int *x,const int&Base,const int&flag)
    {
    	For(i,0,Base-1)if(i<rev[i])swap(x[i],x[rev[i]]);
    	for(int K=2;K<=Base;K<<=1)
    	{
    		int len=tot/K;
    		for(int M=0;M<Base;M+=K)
    		{
    			int *W=flag?iw:w,l,r,*a=x+M,*b=x+M+(K>>1);
    			For(i,0,(K>>1)-1)
    			{
    				l=*a,r=1ll**b**W%mod;
    				Add(*a=l,r),Add(*b=l,mod-r);
    				++a,++b,W+=len;
    			}
    		}
    	}
    	if(flag){int ret=Inv(Base);For(i,0,Base-1)x[i]=1ll*x[i]*ret%mod;}
    }
    inline pvv Solve(int l,int r)
    {
    	if(l==r)
    	{
    		pvv now;
    		now.Fx.resize(m+1),now.Gx.resize((m<<1)+1);
    		int ret=1;
    		For(i,0,m<<1)
    		{
    			if(i<=m)now.Fx[i]=1ll*A[i]*ret%mod;
    			now.Gx[i]=1ll*B[i]*ret%mod,ret=1ll*ret*a[l]%mod;
    		}
    		return now;
    	}
    	int mid=(l+r)>>1;
    	pvv L=Solve(l,mid),R=Solve(mid+1,r);
    	int Base=init_NTT(max(SZ(L.Fx)+SZ(R.Gx),SZ(L.Gx)+SZ(R.Fx)));
    	Copy(LFx,L.Fx,Base),Copy(LGx,L.Gx,Base),Copy(RFx,R.Fx,Base),Copy(RGx,R.Gx,Base);
    	NTT(LFx,Base,0),NTT(LGx,Base,0),NTT(RFx,Base,0),NTT(RGx,Base,0);
    	For(i,0,Base-1)
    	{
    		Add(LGx[i]=1ll*LGx[i]*RFx[i]%mod,1ll*LFx[i]*RGx[i]%mod);
    		LFx[i]=1ll*LFx[i]*RFx[i]%mod;
    	}
    	NTT(LFx,Base,1),NTT(LGx,Base,1);
    	return make_pair(vector<int>(LFx,LFx+min(n-1,SZ(L.Fx)+SZ(R.Fx))),
    					 vector<int>(LGx,LGx+min(n-1,max(SZ(L.Fx)+SZ(R.Gx),SZ(L.Gx)+SZ(R.Fx)))));
    }
    int main()
    {
    	file();
    	read(n),read(m);
    	For(i,1,n)read(a[i]);
    	init();
    	pvv Ans=Solve(1,n);
    	vector<int>G=Ans.Gx;
    	int ans=0,ret=1,S=0;
    	For(i,1,n)Add(S,a[i]);
    	For(i,0,n-2)
    	{
    		Add(ans,1ll*ret*ifac[i]%mod*G[n-2-i]%mod);
    		ret=1ll*ret*S%mod;
    	}
    	For(i,1,n)ans=1ll*ans*a[i]%mod;
    	printf("%lld\n",1ll*ans*fac[n-2]%mod);
    	//cerr<<(double)clock()/CLOCKS_PER_SEC<<endl;
    	return 0;
    }
    
    • 1

    #2320. 「清华集训 2017」生成树计数

    信息

    ID
    156
    时间
    7000ms
    内存
    1024MiB
    难度
    10
    标签
    递交数
    1
    已通过
    0
    上传者