Submission #14654208
Source Code Expand
Copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
"""
numba-able RBST(Randomized Binary Search Tree)
testdata: AtCoder: ABC170E
"""
from collections import defaultdict
from heapq import heappush, heappop
import sys
import numpy as np
def main(N, Q, data):
# --- RBST implementation
INF = 10 ** 9 + 1
SUM_UNITY = 0
random_state = np.array([123456789, 362436069, 521288629, 88675123])
MAX_NODES = 5 * 10 ** 5
values = np.repeat(SUM_UNITY, MAX_NODES)
sizes = np.zeros(MAX_NODES, dtype=np.int32)
sums = np.repeat(SUM_UNITY, MAX_NODES)
lefts = np.zeros(MAX_NODES, dtype=np.int32)
rights = np.zeros(MAX_NODES, dtype=np.int32)
node_id = 1
ret_left = 0
ret_right = 0
root = 0
def randInt():
tx, ty, tz, tw = random_state
tt = tx ^ (tx << 11)
random_state[0] = ty
random_state[1] = tz
random_state[2] = tw
random_state[3] = tw = (tw ^ (tw >> 19)) ^ (tt ^ (tt >> 8))
return tw
def create_node(v):
nonlocal node_id
i = node_id
values[i] = v
sizes[i] = 1
sums[i] = v
lefts[i] = 0
rights[i] = 0
node_id += 1
return i
def update(node):
sizes[node] = sizes[lefts[node]] + sizes[rights[node]] + 1
sums[node] = sums[lefts[node]] + sums[rights[node]] + values[node]
# add extra code here
return node
def push(node):
if not node:
return
# add extra code here
def lower_bound(node, val):
ret = 0
while True:
push(node)
if not node:
return ret
if val <= values[node]:
node = lefts[node]
else:
ret += sizes[lefts[node]] + 1
node = rights[node]
def upper_bound(node, val):
ret = 0
while True:
push(node)
if not node:
return ret
if val >= values[node]:
ret += sizes[lefts[node]] + 1
node = rights[node]
else:
node = lefts[node]
# def get(node, k):
# "k: 0-origin"
# push(node)
# if not node:
# return -1
# if k == sizes[lefts[node]]:
# return values[node]
# if k < sizes[lefts[node]]:
# return get(lefts[node], k)
# return get(rights[node], k - sizes[lefts[node]] - 1)
def merge(left, right):
is_left = []
left_snapshot = []
right_snapshot = []
ret = 0
while True:
push(left)
push(right)
if not left or not right:
if left:
ret = left
else:
ret = right
break
if randInt() % (sizes[left] + sizes[right]) < sizes[left]:
is_left.append(True)
left_snapshot.append(left)
right_snapshot.append(right)
left = rights[left]
else:
is_left.append(False)
left_snapshot.append(left)
right_snapshot.append(right)
right = lefts[right]
for i in range(len(is_left) - 1, -1, -1):
x = is_left[i]
left = left_snapshot[i]
right = right_snapshot[i]
if x:
rights[left] = ret
ret = update(left)
else:
lefts[right] = ret
ret = update(right)
return ret
def split(node, k):
"split tree into [0, k) and [k, n)"
nonlocal ret_left, ret_right
is_left = []
node_snapshot = []
while True:
push(node)
if not node:
ret_left = 0
ret_right = 0
break
if k <= sizes[lefts[node]]:
is_left.append(True)
node_snapshot.append(node)
node = lefts[node]
continue
else:
is_left.append(False)
node_snapshot.append(node)
k -= sizes[lefts[node]] + 1
node = rights[node]
continue
for i in range(len(is_left) - 1, -1, -1):
x = is_left[i]
node = node_snapshot[i]
if x:
lefts[node] = ret_right
ret_right = update(node)
else:
rights[node] = ret_left
ret_left = update(node)
def count(val):
return upper_bound(root, val) - lower_bound(root, val)
def insert(val):
nonlocal root, ret_left, ret_right
split(root, lower_bound(root, val))
r = merge(ret_left, create_node(val))
r = merge(r, ret_right)
root = r
def erase(val):
nonlocal root, ret_left, ret_right
if count(val) == 0:
return # erasing absent item
split(root, lower_bound(root, val))
lhs = ret_left
split(ret_right, 1)
rhs = ret_right
root = merge(lhs, rhs)
# --- end RBST implementation
# --- ABC170E implementation
# k: kindergarden, p: person
p_to_rate = [0] * (N + 1) # 1-origin
p_to_k = [0] * (N + 1) # 1-origin
# dsc. order heapq for each k
MAX_K = 200000
# k_to_ps = defaultdict(list)
# k_to_ps = [[] for i in range(MAX_K + 1)]
k_to_ps = [[(-INF, 0)]]
for i in range(MAX_K):
x = [(-INF, 0)]
k_to_ps.append(x)
x.pop()
k_to_ps[0].pop()
AB = data[:2 * N]
CD = data[2 * N:]
AB = AB.reshape(-1, 2)
CD = CD.reshape(-1, 2)
for i in range(N):
A, B = AB[i]
I = i + 1
p_to_rate[I] = A
p_to_k[I] = B
heappush(k_to_ps[B], (-A, I))
# construct RBST
for i in range(MAX_K + 1):
k = k_to_ps[i]
if k:
neg_rate, max_p = k[0]
insert(-neg_rate)
answers = [0] * Q
for t in range(Q):
C, D = CD[t]
src = p_to_k[C]
dst = D
rateC = p_to_rate[C]
p_to_k[C] = dst
# remove from `src`
# print(f"move {C} from {src} to {dst}")
neg_rate, max_p = k_to_ps[src][0]
if max_p == C:
# print("max person leaving")
erase(-neg_rate)
heappop(k_to_ps[src])
if not k_to_ps[src]:
# now it is empty
pass
else:
# find next person
while True:
if not k_to_ps[src]:
break
neg_rate, max_p = k_to_ps[src][0]
if p_to_k[max_p] != src:
heappop(k_to_ps[src])
continue
insert(-neg_rate)
break
else:
# not max person leaving, no update on max_ps
pass
# move to `dst`
if not k_to_ps[dst]:
# destination is empty
heappush(k_to_ps[dst], (-rateC, C))
insert(rateC)
else:
# compare to existing max person
neg_rate, max_p = k_to_ps[dst][0]
if -neg_rate < rateC:
# max person changed
erase(-neg_rate)
insert(rateC)
else:
# no update on max_ps
pass
heappush(k_to_ps[dst], (-rateC, C))
cur = root
while cur:
minvalue = values[cur]
cur = lefts[cur]
answers[t] = minvalue
return np.array(answers)
USE_NUMBA = True
if USE_NUMBA and sys.argv[-1] == 'ONLINE_JUDGE' or sys.argv[-1] == '-c':
print("compiling")
from numba.pycc import CC
cc = CC('my_module')
cc.export('main', 'i8[:](i8,i8,i8[::1])')(main)
# b1: bool, i4: int32, i8: int64, double: f8, [:], [:, :]
cc.compile()
exit()
else:
input = sys.stdin.buffer.readline
read = sys.stdin.buffer.read
if USE_NUMBA and sys.argv[-1] != '-p':
# -p: pure python mode
# if not -p, import compiled module
from my_module import main # pylint: disable=all
elif sys.argv[-1] == "-t":
_test()
exit()
elif sys.argv[-1] != '-p' and len(sys.argv) == 2:
# input given as file
input_as_file = open(sys.argv[1])
input = input_as_file.buffer.readline
read = input_as_file.buffer.read
# read parameter
N, Q = [int(x) for x in input().split()]
data = np.int64(read().split())
print(*main(N, Q, data), sep="\n")
Submission Info
Submission Time |
|
Task |
E - Smart Infants |
User |
nishiohirokazu |
Language |
Python (3.8.2) |
Score |
500 |
Code Size |
8934 Byte |
Status |
AC |
Exec Time |
2983 ms |
Memory |
99736 KB |
Judge Result
Set Name |
Sample |
All |
Score / Max Score |
0 / 0 |
500 / 500 |
Status |
AC
|
|
Set Name |
Test Cases |
Sample |
|
All |
handmade02, handmade03, handmade04, handmade05, handmade06, handmade07, handmade08, handmade09, random10, sample00, sample01 |
Case Name |
Status |
Exec Time |
Memory |
handmade02 |
AC |
164 ms |
62520 KB |
handmade03 |
AC |
160 ms |
62584 KB |
handmade04 |
AC |
167 ms |
62476 KB |
handmade05 |
AC |
432 ms |
69216 KB |
handmade06 |
AC |
444 ms |
87932 KB |
handmade07 |
AC |
459 ms |
83184 KB |
handmade08 |
AC |
2737 ms |
96220 KB |
handmade09 |
AC |
2748 ms |
96436 KB |
random10 |
AC |
2983 ms |
99736 KB |
sample00 |
AC |
158 ms |
62424 KB |
sample01 |
AC |
162 ms |
62492 KB |