0%

CF1408I Bitwise Magic

给定 \(n,k,c\),以及长度为 \(n\) 的序列 \(a\)(保证元素互不相同)。

操作 \(k\) 次,每次随机选择一个 \(a_i\),然后将其 \(−1\)

对于 \(x=0,1\dots2^c−1\) 输出最后序列的异或和为 \(x\) 的概率。

答案对 \(998244353\) 取模。

\(k,c\leqslant 16\)\(a_i\in[k,2^c)\)\(n\leqslant 2^c−k\)

首先有一个结论,对 \(x\in[k,2^c)\)\(k\) 元组 \((x\oplus (x-1),x\oplus(x-2),\dots,x\oplus(x-k))\) 只有 \(O(kc)\) 种本质不同的情况,当 \(k=c=16\) 时,有 \(192\) 种。

我们考虑证明。我们发现,\(x\oplus(x-1)\) 等于把 \(x\) 在二进制下最低位的 \(1\) 变为 \(0\),比它更低的位变为 \(1\)。我们设 \(t=\log \text{lowbit}(x)\)

  • \(t>\log k\) 时,我们发现 \(x\oplus(x-1),x\oplus(x-2),\dots,x\oplus(x-k)\) 这些数比第 \(t\) 位高的地方都是 \(0\),且比第 \(t\) 位低的地方与 \(x\) 无关(因为都是从全 \(1\) 开始减的),这一部分有 \(O(c)\) 种。
  • \(t\leqslant \log k\) 时,我们设 \(r\)\(x\)\(\lfloor\log k\rfloor\) 高的第一个为 \(1\) 的位,则 \(r\) 以上的位都为 \(0\),而 \(r\) 以下的位的情况只与后 \(\log k\) 位的情况有关,所以这一部分一共有 \(O(c2^{\log k})=O(ck)\) 种。

\(d_{i,j}=a_i\oplus(a_i-j)\),我们考虑枚举答案分别关于结果和操作次数的二元 EGF,则有 \[ F(x,y)=\prod\limits_{i=1}^n(\sum\limits_{j=0}^k \frac{x^{d_{i,j}}y^j}{j!}) \] 其中 \(x\) 维是异或卷积,\(y\) 维为加法卷积。则 \(q![x^p][y^q]F(x,y)\) 就是 \(q\) 次操作后结果为 \(p\oplus a_1\oplus a_2\oplus\dots\oplus a_n\) 的方案数。

我们现在考虑计算出这个多项式,我们先计算每种本质不同的 \(k\) 元组的出现次数,设 \(G_i(x,y)\) 为第 \(i\) 种本质不同的 \(k\) 元组的 EGF,\(r_i\) 为其出现次数,则 \(F(x,y)=\prod G_i^{r_i}(x,y)\),如果我们对每一种 \(k\) 元组都做 \(O(k^2)\) 的暴力 \(\ln,\exp\) 来多项式快速幂(牛顿迭代由于常数大可能会更慢),这样的总复杂度是 \(O(2^cck^3)\) 的。

我们发现,如果我们对某个 \(k\) 元组固定 \(y\),再对 \(x\) 做 FWT,此时做 FWT 的序列一定形如只有一项为 \(\dfrac{y^j}{j!}\),其余项都为 \(0\)

我们考虑 FWT 的定义式:\(\hat{a}_i=\sum\limits_j(-1)^{\text{popcount}(i\&j)}a_j\)。如果我们认为当前序列只有一个元素 \(a_j\)\(1\),则 \(\hat{a}_i=(-1)^{\text{popcount}(i\&j)}\)。所以,对于 \(y^j\) 来说,FWT 后任何一个 \(x\) 都只可能是 \(\pm\dfrac{y^j}{j!}\)

而对于一个 \(k\) 元组做完所有 FWT 后,其关于 \(x\) 的某一位可以表示为一个长为 \(k+1\)\(\pm 1\) 序列,序列的每一位表示 \(\dfrac{y^j}{j!}\) 的正负。我们注意到,如果两个 \(x\) 相同的多项式对应的 \(\pm1\) 序列也相同,则我们只需要把它们的指数加到一起然后一起做快速幂,这样能够显著减少 \(\ln,\exp\) 次数。

代码(卡着时限过的):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include<bits/stdc++.h>
#define pb push_back
#define count __builtin_popcount
using namespace std;
int const p=998244353;
typedef vector<int> vec;
int pw(int x,int y)
{
int res=1;
while(y)
{
if(y&1)res=1ll*res*x%p;
x=1ll*x*x%p;
y>>=1;
}
return res;
}
int a[65537],iv[65537],b[65537],sum[195],inv[65537],cnt[131075],
st[195],top,tmp[20],tmp2[20],tmp3[20],f[65537];
long long tmp4[20];
vec v[65537];
map<vec,int>mp;
void getexp(int *f,int *g,int n)//x^0-x^{n-1}
{
g[0]=1;
for(int i=1;i<n;i++)
{
g[i]=0;f[i]=1ll*f[i]*i%p;
for(int j=0;j<i;j++)
g[i]=(g[i]+1ll*f[j+1]*g[i-j-1])%p;
g[i]=1ll*g[i]*iv[i]%p;
}
}
void getln(int *f,int *g,int n)
{
g[0]=0;
for(int i=1;i<n;i++)
{
g[i]=0;
for(int j=1;j<i;j++)
g[i]=(g[i]+1ll*f[j]*g[i-j])%p;
g[i]=(1ll*i*f[i]+p-g[i])%p;
}
for(int i=1;i<n;i++)g[i]=1ll*g[i]*iv[i]%p;
}
int mod(int x){return x>=p?x-p:x;}
int main()
{
iv[1]=inv[0]=inv[1]=1;
for(int i=2;i<65536;i++)iv[i]=1ll*(p-p/i)*iv[p%i]%p,inv[i]=1ll*inv[i-1]*iv[i]%p;
int n,K,c,xorsum=0;
scanf("%d%d%d",&n,&K,&c);
for(int i=1;i<=n;i++)scanf("%d",&a[i]),xorsum^=a[i];
for(int i=1;i<=n;i++)
{
for(int j=0;j<=K;j++)
v[i].pb(a[i]^(a[i]-j));
mp[v[i]]++;
}
int t=0,fac=1;
for(auto r:mp)t++,v[t]=r.first,sum[t]=r.second;
for(int i=2;i<=K;i++)fac=1ll*fac*i%p;
for(int s=0;s<(1<<c);s++)
{
for(int i=1;i<=t;i++)
{
int d=0;
for(int j=0;j<=K;j++)
d|=((count(s&v[i][j])&1)<<j);
if(!cnt[d])st[++top]=d;
cnt[d]+=sum[i];
}
memset(tmp,0,sizeof(tmp));
tmp[0]=1;
for(int i=1;i<=top;i++)
{
for(int j=0;j<=K;j++)
if(st[i]&(1<<j))tmp2[j]=p-inv[j];
else tmp2[j]=inv[j];
getln(tmp2,tmp3,K+1);
for(int j=0;j<=K;j++)tmp3[j]=1ll*tmp3[j]*cnt[st[i]]%p;
getexp(tmp3,tmp2,K+1);
memset(tmp4,0,sizeof(tmp4));
for(int j=0;j<=K;j++)
for(int k=0;k<=K-j;k++)
tmp4[j+k]+=1ll*tmp2[j]*tmp[k];
for(int j=0;j<=K;j++)tmp[j]=tmp4[j]%p;
cnt[st[i]]=0;
}
top=0;
f[s]=1ll*tmp[K]*fac%p;
}
for(int len=2;len<=(1<<c);len<<=1)
for(int i=0;i<(1<<c);i+=len)
for(int j=i;j<i+(len>>1);j++)
{
int t=f[j];
f[j]=mod(f[j]+f[j+(len>>1)]);
f[j+(len>>1)]=mod(t-f[j+(len>>1)]+p);
}
int d=1ll*pw(1<<c,p-2)*pw(n,p-1-K)%p;
for(int i=0;i<(1<<c);i++)printf("%lld ",1ll*f[i^xorsum]*d%p);
return 0;
}