P15093 [UOI 2025 II Stage] Odd Rows

思路

首先考虑求答案的 subtask。
我们考虑维护四个值:$jmx,jmi,omx,omi$ 分别表示奇数和偶数的最大最小可能取到的值,$cnt$ 表示本次要填入多少个 $1$,发现 $jmx+omi=n,jmi+omx=n$,所以更新时只需要更新两个最大值。
对于 $jmx$ 有三种情况:

  • $omi\ge cnt$:此时一定是将所有 $cnt$ 个 $1$ 都放在原先是偶数的位置最多,即更新为 $jmx+cnt$ 个。
  • $omx\le cnt$:此时无论怎么选,都一定会有一部分奇数被转为偶数,而所有偶数都被转为奇数,所以新的奇数最大值就是总数减去被转换的奇数数量,要使得这个东西最大,只需使得被转换的奇数最少,即偶数最多,取 $omx$,更新为 $omx+jmi-(cnt-omx)$。
  • $omi\le cnt\le omx$:这是最重要的一种情况,同时需要用到本题的关键结论:一种数一定可以取到上下界之间所有与上界(下界)奇偶性相同的所有数,也就是 $o\in [omi,omx],o\equiv omi(omx)\pmod 2$。证明下文给出,此时我们要想取到最大值,只需判断 $omi+cnt\equiv n\pmod 2$ 是否成立,若成立取 $n$,否则取 $n-1$。

证明

我们已经求出了 $omi,omx$ 以及它们的奇偶性,假设 $omi\ne omx$,考虑如何得到 $omi+2$。
易得到只需要将一个原本填在奇数行的 $1$ 挪到另一个奇数行,此时就会增加两个偶数,若不存在这么多奇数行,则也一定会超出上界。
换言之,可以通过将一些上次填入 $1$ 的奇数行含有的 $1$ 挪动到 同样数量没有填入 $1$ 的奇数行即可。

现在就可以得到一个函数,用来求出答案。

1
2
3
4
5
6
int work(int jmx,int jmi,int omx,int omi,int cnt){
if(cnt<=omi)return jmx+cnt;
if(cnt>=omx)return omx+jmi-(cnt-omx);
if(omx%2==cnt%2)return n;
else return n-1;
}

到这里,你就获得了 50pts。

由于 $n \cdot m \le 10^6$,我们可以使用 $O(nm)$ 的 DP 并利用差分数组进行优化。
设 $dp_{i,j}$ 表示前 $i$ 列能否构造出 $j$ 个奇数行。初始状态 $dp_{0,0} = 1$(第一列手动处理)。
对于第 $i$ 列,如果 $dp_{i-1, k} = 1$,那么它可以转移到的范围是一个区间 $[L, R]$,并且与 $L$ 的奇偶性相同。我们可以利用差分数组 $ch$ 来优化这个区间加操作。
为了构造矩阵,我们还需要记录 $pre_{i,j}$ 表示 $dp_{i,j}$ 是从哪个 $k$ 转移来的。为了节省空间,我们在 DP 时不记录 $pre$,而是在求出最终答案后,从最后一列倒推回去计算 $pre$ 并求出每一列的 $tar$ 值。

然后考虑构造的 subtask。
这个状态比较复杂,要得到一个既定答案的同时还要满足初始值,难以贪心地求出,所以考虑 DP。

至此,我们就得到了所有的 $tar_i$,可以通过 $tar_i-tar_{i-1}$ 求得增加或减少的奇数数量,若不等于 $cnt_i$ 则是因为一部分增加或减少的奇数行被另一部分减少或增加的偶数行抵消掉了,这两个值都是好求的。

1
2
3
4
5
6
7
rep(i,1,m){
int prv=(i==1)?0:tar[i-1];
int cur=tar[i];
int x=(cur-prv+a[i].cnt)/2;
int y=a[i].cnt-x;
update(i,x,y);
}

update 函数该如何写呢?只需定义两个双端队列,分别存此时是奇数和偶数的值,每次从队头弹出一些值到另一个的队尾,因为 $cnt_i\le n$ 所以也不会出现一个数被两个队列互相弹的情况,将每次弹的值的位置置为 $1$ 即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include<bits/stdc++.h>
#define fir first
#define sec second
#define int long long
#define pii pair<int,int>
#define fep(i,s,e) for(int i=s;i<e;i++)
#define pef(i,s,e) for(int i=s;i>e;i--)
#define rep(i,s,e) for(int i=s;i<=e;i++)
#define per(i,s,e) for(int i=s;i>=e;i--)
namespace FastIO{
template<typename T>inline void read(T &x){
x=0;int f=1;char c=getchar();
for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48);x*=f;
}
template<typename T,typename...Args>
inline void read(T &x,Args&...args){
read(x);
read(args...);
}
template<typename T>void print(T x){
if(x<0)x=-x,putchar('-');
if(x>9)print(x/10);
putchar((x%10)^48);
}
}
using namespace std;
using namespace FastIO;
const int N=5e6+5;
struct node{
int cnt,id;
vector<int>v;
}a[N];
int n,m,tar[N],mx[N],mi[N];
int work(int jmx,int jmi,int omx,int omi,int cnt){
if(cnt<=omi)return jmx+cnt;
if(cnt>=omx)return omx+jmi-(cnt-omx);
if(omx%2==cnt%2)return n;
else return n-1;
}
deque<int>J,O;
void update(int M,int ou,int ji){
rep(i,1,ou){
J.push_back(O.front());
a[M].v[O.front()]=1;
O.pop_front();
}
rep(i,1,ji){
O.push_back(J.front());
a[M].v[J.front()]=1;
J.pop_front();
}
}
void solve(){
read(n,m);
rep(i,1,m)read(a[i].cnt),a[i].id=i;
int jmx=a[1].cnt,jmi=a[1].cnt;
int omx=n-a[1].cnt,omi=n-a[1].cnt;
rep(i,1,m)a[i].v.resize(n+1,0);
mx[1]=mi[1]=a[1].cnt;
rep(i,2,m){
int j1,o1;
j1=work(jmx,jmi,omx,omi,a[i].cnt);
o1=work(omx,omi,jmx,jmi,a[i].cnt);
jmx=j1;omx=o1;jmi=n-omx;omi=n-jmx;
mx[i]=jmx;mi[i]=jmi;
}
print(jmx);puts("");
vector<vector<bool>>dp(m+1,vector<bool>(n+1,false));
vector<vector<int>>pre(m+1,vector<int>(n+1,-1));
vector<int>ch(n+2,0);
vector<vector<pii>>add(m+1),del(m+1);
dp[0][0]=true;
rep(i,1,m){
fill(ch.begin(),ch.end(),0);
rep(k,0,n){
if(!dp[i-1][k])continue;
int c=a[i].cnt;
int L=max({0ll,(k+c-n+1)/2,c-n+k});
int R=min(k,c);
if(L<=R){
int left=k+c-2*R;
int right=k+c-2*L;
if(left<=right){
ch[left]++;
if(right+2<=n)ch[right+2]--;
}
}
}
int sum=0;
for(int j=0;j<=n;j+=2){
sum+=ch[j];
if(sum>0)dp[i][j]=true;
}
sum=0;
for(int j=1;j<=n;j+=2){
sum+=ch[j];
if(sum>0)dp[i][j]=true;
}
}
int w=0;
per(j,n,0){
if(dp[m][j]){
w=j;
break;
}
}
int nw=w;
per(i,m,1){
rep(k,0,n){
if(!dp[i-1][k])continue;
int c=a[i].cnt;
int L=max({0LL,(k+c-n+1)/2,c-n+k});
int R=min(k,c);
if(L<=R){
int left=k+c-2*R;
int right=k+c-2*L;
if(left<=nw&&nw<=right&&((nw-left)%2==0)){
pre[i][nw]=k;
nw=k;
break;
}
}
}
}
tar[m]=w;
per(i,m-1,1)tar[i]=pre[i+1][tar[i+1]];
tar[1]=a[1].cnt;
rep(i,1,n)O.push_back(i);
rep(i,1,m){
int prv=(i==1)?0:tar[i-1];
int cur=tar[i];
int x=(cur-prv+a[i].cnt)/2;
int y=a[i].cnt-x;
update(i,x,y);
}
rep(i,1,n){
rep(j,1,m){
print(a[j].v[i]);
putchar(' ');
}puts("");
}
}
int T;
signed main(){
solve();
}