2016 多校第 2 场

1004 Differencia

题目链接

题目描述

有两个序列:$\displaystyle \big\lbrace a_1, a_2, \cdots, a_n \big\rbrace$,$\displaystyle \big\lbrace b_1, b_2, \cdots, b_n \big\rbrace$
有两种操作:

  • $+lr~x$ 将所有的 $a_i(l \leqslant i\leqslant r)$ 置为 $x$
  • $?lr$ 询问 $l\leqslant i\leqslant r$ 中有多少个 $i$ 满足 $a_i \geqslant b_i$

数据范围:$1\leqslant n\leqslant 10^5$,$3\times 10^6$ 次询问,强制在线。

题目简析

将 $B$ 数组建成归并树(用线段树实现即可),并预处理出初始的 $a_i \geqslant b_i$ 的前缀和,用该线段树维护。
那么,对于每次查询就是简单的线段树区间求和问题,复杂度为 $O(n\log n)$。
由于修改操作是将一个区间内所有的 $a_i$ 置为 $x$,区间修改将影响线段树中 $O(\log n)$ 个节点;对于每个节点,直接二分即可知道有多少这个区间内有多少个点满足 $b_i \leqslant x$ 了。
这么做的时间复杂度是 $O(n\log n + q\log^2 n)$ 的,遗憾的是,出题人只给 $O(n\log n)$ 以下的复杂度过。
如果预处理初每个节点所维护的区间中每个节点在左右子节点中的 $rank$,这个可以线扫,总复杂度为 $O(n\log n)$;不难发现,每次在节点所维护的区间内查找有多少个点小于等于 $x$ 操作仅需在根节点处二分一次,之后 $O(1)$ 转移,复杂度降为 $O((n+q)\log n)$。

程序实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include <cstdio>/*{{{*/
#include <cstring>
#include <iostream>
#include <algorithm>

inline int read() {
bool positive = true;
char c = getchar();
int s = 0;
for(; c < '0' || c > '9'; c=getchar())
if( c == '-' ) positive = false;
for(; c >= '0' && c <= '9'; c=getchar())
s = s*10 + c-'0';
return positive? s: -s;
}

namespace solve {
const int MAXN = 100000+10;
int A[MAXN], B[20][MAXN], L[20][MAXN], R[20][MAXN];
int sumv[MAXN<<2], setv[MAXN<<2], posv[MAXN<<2];

inline void build(int o, int lft, int rht, int cur=0) {
setv[o] = 0;
if( lft == rht ) {
B[cur][lft] = read();
sumv[o] = A[lft] >= B[cur][lft]? 1: 0;
} else {
int mid = lft+rht >> 1;
build(o<<1, lft, mid, cur+1);
build(o<<1|1, mid+1, rht, cur+1);

int tot = lft, i = lft, j = mid+1;
for(; i <= mid && j <= rht;)
if( B[cur+1][i] <= B[cur+1][j] ) B[cur][tot++] = B[cur+1][i++];
else B[cur][tot++] = B[cur+1][j++];
for(; i <= mid;) B[cur][tot++] = B[cur+1][i++];
for(; j <= rht;) B[cur][tot++] = B[cur+1][j++];
sumv[o] = sumv[o<<1] + sumv[o<<1|1];

// 计算区间 [lft,rht] 的每个节点在左右子节点中的 rank
L[cur][lft] = lft;
R[cur][lft] = mid+1;
for(int& l=L[cur][lft]; l <= mid && B[cur+1][l] <= B[cur][lft]; ++l);
for(int& r=R[cur][lft]; r <= rht && B[cur+1][r] <= B[cur][lft]; ++r);
--L[cur][lft];
--R[cur][lft];
for(int i=lft+1; i <= rht; ++i) {
L[cur][i] = L[cur][i-1]+1;
R[cur][i] = R[cur][i-1]+1;
for(int& l=L[cur][i]; l <= mid && B[cur+1][l] <= B[cur][i]; ++l);
for(int& r=R[cur][i]; r <= rht && B[cur+1][r] <= B[cur][i]; ++r);
--L[cur][i];
--R[cur][i];
}
}
}

inline void pushdown(int o, int lft, int rht, int cur) {
int lc = o<<1, rc = o<<1|1, mid = lft+rht>>1;
setv[lc] = setv[o]; posv[lc] = posv[o] >= lft? L[cur][posv[o]]: lft-1;
setv[rc] = setv[o]; posv[rc] = posv[o] >= lft? R[cur][posv[o]]: mid;
sumv[lc] = posv[lc]-lft+1;
sumv[rc] = posv[rc]-mid;
setv[o] = 0;
}

int ul, ur, uv;
inline void update(int o, int lft, int rht, int pos, int cur=0) {
if( lft == rht ) {
sumv[o] = uv >= B[cur][lft]? 1: 0;
return ;
}
if( ul <= lft && rht <= ur ) {
int mid = lft+rht >> 1;
setv[o] = uv;
posv[o] = pos;
sumv[o] = pos-lft+1;
} else {
if( setv[o] ) pushdown(o, lft, rht, cur);
int mid = lft+rht >> 1;
if( ul <= mid ) update(o<<1, lft, mid, pos >= lft? L[cur][pos]: lft-1, cur+1);
if( mid < ur ) update(o<<1|1, mid+1, rht, pos >= lft? R[cur][pos]: mid, cur+1);
sumv[o] = sumv[o<<1] + sumv[o<<1|1];
}
}

int ql, qr;
inline int query(int o, int lft, int rht, int cur=0) {
if( ql <= lft && rht <= qr ) return sumv[o];
if( setv[o] ) pushdown(o, lft, rht, cur);
int mid = lft+rht >> 1;
int ans = 0;
if( ql <= mid ) ans += query(o<<1, lft, mid, cur+1);
if( mid < qr ) ans += query(o<<1|1, mid+1, rht, cur+1);
return ans;
}
};

typedef long long LL;
const int MOD = 1000000000+7;
const int C = ~(1<<31);
const int M = (1<<16)-1;

int n, m, A, B, a, b, last;

inline int rnd(int last) {
a = (36969 + (last >> 3)) * (a & M) + (a >> 16);
b = (18000 + (last >> 3)) * (b & M) + (b >> 16);
return (C & ((a << 16) + b)) % 1000000000;
}

int main()
{
int T_T = read();
for(int kase=1; kase <= T_T; ++kase) {
scanf("%d%d%d%d", &n, &m, &A, &B);
for(int i=1; i <= n; ++i) solve:: A[i] = read();
solve:: build(1, 1, n);

LL ans = 0LL;
a = A, b = B, last = 0;
for(int i=1; i <= m; ++i) {
int l = rnd(last) % n + 1;
int r = rnd(last) % n + 1;
int x = rnd(last) + 1;
if( l > r ) std:: swap(l, r);
if( (l + r + x) & 1 ) {
solve:: ul = l;
solve:: ur = r;
solve:: uv = x;
int pos = std:: upper_bound(solve:: B[0]+1, solve:: B[0]+n+1, x)-solve:: B[0]-1;
solve:: update(1, 1, n, pos);
} else {
solve:: ql = l;
solve:: qr = r;
last = solve:: query(1, 1, n);
ans = (ans + (LL) i * last) % MOD;
}
}
printf("%d\n", ans);
}

return 0;
}/*}}}*/