给定 \(n,k,c\),以及长度为 \(n\) 的序列 \(a\)(保证元素互不相同)。
操作 \(k\) 次,每次随机选择一个 \(a_i\),然后将其 \(−1\)。
对于 \(x=0,1\dots2^c−1\) 输出最后序列的异或和为 \(x\) 的概率。
答案对 \(998244353\) 取模。
\(k,c\leqslant 16\),\(a_i\in[k,2^c)\),\(n\leqslant 2^c−k\)。
首先有一个结论,对 \(x\in[k,2^c)\),\(k\) 元组 \((x\oplus (x-1),x\oplus(x-2),\dots,x\oplus(x-k))\) 只有 \(O(kc)\) 种本质不同的情况,当 \(k=c=16\) 时,有 \(192\) 种。
我们考虑证明。我们发现,\(x\oplus(x-1)\) 等于把 \(x\) 在二进制下最低位的 \(1\) 变为 \(0\),比它更低的位变为 \(1\)。我们设 \(t=\log \text{lowbit}(x)\)。
- 当 \(t>\log k\) 时,我们发现 \(x\oplus(x-1),x\oplus(x-2),\dots,x\oplus(x-k)\) 这些数比第 \(t\) 位高的地方都是 \(0\),且比第 \(t\) 位低的地方与 \(x\) 无关(因为都是从全 \(1\) 开始减的),这一部分有 \(O(c)\) 种。
- 当 \(t\leqslant \log k\) 时,我们设 \(r\) 为 \(x\) 比 \(\lfloor\log k\rfloor\) 高的第一个为 \(1\) 的位,则 \(r\) 以上的位都为 \(0\),而 \(r\) 以下的位的情况只与后 \(\log k\) 位的情况有关,所以这一部分一共有 \(O(c2^{\log k})=O(ck)\) 种。
设 \(d_{i,j}=a_i\oplus(a_i-j)\),我们考虑枚举答案分别关于结果和操作次数的二元 EGF,则有 \[
F(x,y)=\prod\limits_{i=1}^n(\sum\limits_{j=0}^k \frac{x^{d_{i,j}}y^j}{j!})
\] 其中 \(x\) 维是异或卷积,\(y\) 维为加法卷积。则 \(q![x^p][y^q]F(x,y)\) 就是 \(q\) 次操作后结果为 \(p\oplus a_1\oplus a_2\oplus\dots\oplus a_n\) 的方案数。
我们现在考虑计算出这个多项式,我们先计算每种本质不同的 \(k\) 元组的出现次数,设 \(G_i(x,y)\) 为第 \(i\) 种本质不同的 \(k\) 元组的 EGF,\(r_i\) 为其出现次数,则 \(F(x,y)=\prod G_i^{r_i}(x,y)\),如果我们对每一种 \(k\) 元组都做 \(O(k^2)\) 的暴力 \(\ln,\exp\) 来多项式快速幂(牛顿迭代由于常数大可能会更慢),这样的总复杂度是 \(O(2^cck^3)\) 的。
我们发现,如果我们对某个 \(k\) 元组固定 \(y\),再对 \(x\) 做 FWT,此时做 FWT 的序列一定形如只有一项为 \(\dfrac{y^j}{j!}\),其余项都为 \(0\)。
我们考虑 FWT 的定义式:\(\hat{a}_i=\sum\limits_j(-1)^{\text{popcount}(i\&j)}a_j\)。如果我们认为当前序列只有一个元素 \(a_j\) 是 \(1\),则 \(\hat{a}_i=(-1)^{\text{popcount}(i\&j)}\)。所以,对于 \(y^j\) 来说,FWT 后任何一个 \(x\) 都只可能是 \(\pm\dfrac{y^j}{j!}\)。
而对于一个 \(k\) 元组做完所有 FWT 后,其关于 \(x\) 的某一位可以表示为一个长为 \(k+1\) 的 \(\pm 1\) 序列,序列的每一位表示 \(\dfrac{y^j}{j!}\) 的正负。我们注意到,如果两个 \(x\) 相同的多项式对应的 \(\pm1\) 序列也相同,则我们只需要把它们的指数加到一起然后一起做快速幂,这样能够显著减少 \(\ln,\exp\) 次数。
代码(卡着时限过的):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
| #include<bits/stdc++.h> #define pb push_back #define count __builtin_popcount using namespace std; int const p=998244353; typedef vector<int> vec; int pw(int x,int y) { int res=1; while(y) { if(y&1)res=1ll*res*x%p; x=1ll*x*x%p; y>>=1; } return res; } int a[65537],iv[65537],b[65537],sum[195],inv[65537],cnt[131075], st[195],top,tmp[20],tmp2[20],tmp3[20],f[65537]; long long tmp4[20]; vec v[65537]; map<vec,int>mp; void getexp(int *f,int *g,int n) { g[0]=1; for(int i=1;i<n;i++) { g[i]=0;f[i]=1ll*f[i]*i%p; for(int j=0;j<i;j++) g[i]=(g[i]+1ll*f[j+1]*g[i-j-1])%p; g[i]=1ll*g[i]*iv[i]%p; } } void getln(int *f,int *g,int n) { g[0]=0; for(int i=1;i<n;i++) { g[i]=0; for(int j=1;j<i;j++) g[i]=(g[i]+1ll*f[j]*g[i-j])%p; g[i]=(1ll*i*f[i]+p-g[i])%p; } for(int i=1;i<n;i++)g[i]=1ll*g[i]*iv[i]%p; } int mod(int x){return x>=p?x-p:x;} int main() { iv[1]=inv[0]=inv[1]=1; for(int i=2;i<65536;i++)iv[i]=1ll*(p-p/i)*iv[p%i]%p,inv[i]=1ll*inv[i-1]*iv[i]%p; int n,K,c,xorsum=0; scanf("%d%d%d",&n,&K,&c); for(int i=1;i<=n;i++)scanf("%d",&a[i]),xorsum^=a[i]; for(int i=1;i<=n;i++) { for(int j=0;j<=K;j++) v[i].pb(a[i]^(a[i]-j)); mp[v[i]]++; } int t=0,fac=1; for(auto r:mp)t++,v[t]=r.first,sum[t]=r.second; for(int i=2;i<=K;i++)fac=1ll*fac*i%p; for(int s=0;s<(1<<c);s++) { for(int i=1;i<=t;i++) { int d=0; for(int j=0;j<=K;j++) d|=((count(s&v[i][j])&1)<<j); if(!cnt[d])st[++top]=d; cnt[d]+=sum[i]; } memset(tmp,0,sizeof(tmp)); tmp[0]=1; for(int i=1;i<=top;i++) { for(int j=0;j<=K;j++) if(st[i]&(1<<j))tmp2[j]=p-inv[j]; else tmp2[j]=inv[j]; getln(tmp2,tmp3,K+1); for(int j=0;j<=K;j++)tmp3[j]=1ll*tmp3[j]*cnt[st[i]]%p; getexp(tmp3,tmp2,K+1); memset(tmp4,0,sizeof(tmp4)); for(int j=0;j<=K;j++) for(int k=0;k<=K-j;k++) tmp4[j+k]+=1ll*tmp2[j]*tmp[k]; for(int j=0;j<=K;j++)tmp[j]=tmp4[j]%p; cnt[st[i]]=0; } top=0; f[s]=1ll*tmp[K]*fac%p; } for(int len=2;len<=(1<<c);len<<=1) for(int i=0;i<(1<<c);i+=len) for(int j=i;j<i+(len>>1);j++) { int t=f[j]; f[j]=mod(f[j]+f[j+(len>>1)]); f[j+(len>>1)]=mod(t-f[j+(len>>1)]+p); } int d=1ll*pw(1<<c,p-2)*pw(n,p-1-K)%p; for(int i=0;i<(1<<c);i++)printf("%lld ",1ll*f[i^xorsum]*d%p); return 0; }
|