【XSY1301】原题的价值 第二类斯特林数 NTT
题目描述
给你\(n,m\),求所有\(n\)个点的简单无向图中每个点度数的\(m\)次方的和。
\(n\leq {10}^9,m\leq {10}^5\)
题解
\(g_n\)为\(n\)个点的无向图个数,\(f_n\)为\(n\)个点的答案。
\[\begin{align}g_n&=2^{\binom{n}{2}}\\f_n&=ng_{n-1}\sum_{i=0}^{n-1}\binom{n-1}{i}i^m\\&=ng_{n-1}\sum_{i=0}^{n-1}\binom{n-1}{i}\sum_{j=0}^{i}\binom{i}{j}S(m,j)j!\\&=ng_{n-1}\sum_{i=0}^{n-1}\sum_{j=0}^i\binom{n-1}{i}\binom{i}{j}S(m,j)j!\\&=ng_{n-1}\sum_{i=0}^{n-1}\sum_{j=0}^i\binom{n-j}{j}\binom{n-1-i}{i-j}S(m,j)j!\\&=ng_{n-1}\sum_{j=0}^m\binom{n-1}{j}S(m,j)j!\sum_{i=j}^{n-1}\binom{n-1-j}{i-j}\\&=ng_{n-1}\sum_{j=0}^m{(n-1)}^\underline{j}S(m,j)2^{n-1-j}\\\end{align}\]
用ntt算斯特林数
时间复杂度:\(O(m\log m)\)
代码
#include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<ctime> #include<utility> using namespace std; typedef long long ll; typedef pair<int,int> pii; ll p=998244353; ll fp(ll a,ll b) { ll s=1; while(b) { if(b&1) s=s*a%p; a=a*a%p; b>>=1; } return s; } ll fc[300010]; ll ifc[300010]; ll a[300010]; ll b[300010]; int rev[300010]; void ntt(ll *a,int n,int t) { ll u,v,w,wn; int i,j,k; rev[0]=0; for(i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0); for(i=0;i<n;i++) if(rev[i]<i) swap(a[rev[i]],a[i]); for(i=2;i<=n;i<<=1) { if(t==1) wn=fp(3,(p-1)/i); else wn=fp(fp(3,(p-1)/i),p-2); for(j=0;j<n;j+=i) { w=1; for(k=j;k<j+i/2;k++) { u=a[k]; v=a[k+i/2]*w%p; a[k]=(u+v)%p; a[k+i/2]=(u-v)%p; w=w*wn%p; } } } if(t==-1) { ll inv=fp(n,p-2); for(i=0;i<n;i++) a[i]=a[i]*inv%p; } } ll c[300010]; int main() { // freopen("b.in","r",stdin); // freopen("b.out","w",stdout); int n,m; scanf("%d%d",&n,&m); fc[0]=fc[1]=ifc[0]=ifc[1]=1; int i; int t=min(n-1,m); for(i=2;i<=t;i++) { fc[i]=fc[i-1]*i%p; ifc[i]=ifc[i-1]*fp(i,p-2)%p; } for(i=0;i<=t;i++) { a[i]=(i&1?-1:1)*ifc[i]; b[i]=fp(i,m)*ifc[i]%p; } int k=1; while(k<=2*t) k<<=1; ntt(a,k,1); ntt(b,k,1); for(i=0;i<k;i++) a[i]=a[i]*b[i]%p; ntt(a,k,-1); for(i=0;i<k;i++) a[i]=(a[i]%p+p)%p; ll ans=0; c[0]=1; for(i=1;i<=t;i++) c[i]=c[i-1]*(n-i)%p; for(i=0;i<=t;i++) ans=(ans+c[i]%p*a[i]%p*fp(2,n-1-i)%p)%p; ans=ans*n%p*fp(2,ll(n-1)*(n-2)/2%(p-1))%p; printf("%lld\n",ans); return 0; }