splaytree.cpp (3)
Bài toán
Cài splay tree.
Độ phức tạp
Code này của Đỗ Ngọc Khánh
#include<iostream>
#include<cassert>
using namespace std;
struct Node {
Node * child[2], * parent;
bool reverse;
int value, size;
long long sum;
};
Node * nil, * root;
void initTree() {
nil = new Node();
nil->child[0] = nil->child[1] = nil->parent = nil;
nil->value = nil->size = nil->sum = 0;
nil->reverse = false;
root = nil;
}
void update(Node * x) {
x->size = x->child[0]->size + x->child[1]->size + 1;
x->sum = x->child[0]->sum + x->child[1]->sum + x->value;
}
void pushDown(Node * x) {
if(x == nil) return;
if(x->reverse) {
swap(x->child[0], x->child[1]);
x->child[0]->reverse = !x->child[0]->reverse;
x->child[1]->reverse = !x->child[1]->reverse;
x->reverse = false;
}
}
void setLink(Node * x, Node * y, int d) {
x->child[d] = y;
y->parent = x;
}
int getDir(Node * x, Node * y) {
return x->child[0] == y ? 0 : 1;
}
void rotate(Node * x, int d) {
Node * y = x->child[d], * z = x->parent;
setLink(x, y->child[d ^ 1], d);
setLink(y, x, d ^ 1);
setLink(z, y, getDir(z, x));
update(x); update(y);
}
void splay(Node * x) {
while(x->parent != nil) {
Node * y = x->parent, * z = y->parent;
int dy = getDir(y, x), dz = getDir(z, y);
if(z == nil) rotate(y, dy);
else if(dy == dz) rotate(z, dz), rotate(y, dy);
else rotate(y, dy), rotate(z, dz);
}
}
Node * nodeAt(Node * x, int pos) {
while(pushDown(x), x->child[0]->size != pos)
if(pos < x->child[0]->size) x = x->child[0];
else pos -= x->child[0]->size + 1, x = x->child[1];
return splay(x), x;
}
void split(Node * x, int left, Node * &t1, Node * &t2) {
if(left == 0) t1 = nil, t2 = x;
else {
t1 = nodeAt(x, left - 1);
t2 = t1->child[1];
t1->child[1] = t2->parent = nil;
update(t1);
}
}
Node * join(Node * x, Node * y) {
if(x == nil) return y;
x = nodeAt(x, x->size - 1);
setLink(x, y, 1);
update(x);
return x;
}
void queryAssign(int pos, int value) {
root = nodeAt(root, pos);
root->value = value;
update(root);
}
void queryReverse(int x, int y) {
Node * t1, * t2, * t3;
split(root, y, t1, t3);
split(t1, x, t1, t2);
t2->reverse = !t2->reverse;
root = join(join(t1, t2), t3);
}
long long querySum(int x, int y) {
Node * t1, * t2, * t3;
split(root, y, t1, t3);
split(t1, x, t1, t2);
long long res = t2->sum;
root = join(join(t1, t2), t3);
return res;
}
const int N = 1e5;
int a[N];
Node * buildTree(int l, int r) {
if(l == r) return nil;
int mid = (l + r) >> 1;
Node * x = new Node();
x->value = a[mid];
x->parent = nil;
x->reverse = false;
setLink(x, buildTree(l, mid), 0);
setLink(x, buildTree(mid + 1, r), 1);
update(x);
return x;
}
int main() {
ios::sync_with_stdio(false);
int n; cin >> n;
for(int i = 0; i < n; ++i)
cin >> a[i];
initTree();
root = buildTree(0, n);
int q; cin >> q;
for(int i = 0; i < q; ++i) {
char type; int x, y;
cin >> type >> x >> y; --x;
if(type == 'S') queryAssign(x, y);
else if(type == 'R') queryReverse(x, y);
else if(type == 'Q') cout << querySum(x, y) << '\n';
else assert(false);
}
return 0;
}
Nhận xét
Cảm ơn bạn Đỗ Ngọc Khánh đã giúp tôi hoàn thành bài viết này.