11221: 【原1221】bst
题目
题目描述
author: DS TA 原OJ链接:https://acm.sjtu.edu.cn/OnlineJudge-old/problem/1221
Description
实现二叉查找树,支持插入,删除,查找,删除小于 x 的所有元素,删除大于 x 的所有元素,删除大于 a 且小于 b 的所有元素,查找第 i 小的元素。
Input Format
第 1 行: 一个数, n, 表示总操作个数。
接下来 n 行: 每行首先一个单词, 表示操作的名称, 这一行接下来的格式每种操作不同:
"insert": 插入, 接下来一个整数, x, 表示被插入的元素
"delete": 删除, 接下来一个整数, x, 表示被删除的元素(若树中有重复删除任意一个)
"delete_less_than": 删除小于 x 的所有元素, 接下来一个整数, x
"delete_greater_than": 删除大于 x 的所有元素, 接下来一个整数, x
"delete_interval": 删除大于 a 且小于 b 的所有元素, 接下来两个整数, a, b
"find": 查找, 接下来一个整数, x, 表示被查找的元素
"find_ith": 查找第 i 小的元素, 接下来一个整数, i
Output Format
对于每个 "find" 和 "find_ith" 操作,输出一行。
其中对于 "find" 操作,输出 Y/N,表示是否查找到询问的元素。
对于 "find_ith" 操作,输出第 i 小的元素值,若不存在输出 N。
注意: 对于数列 1, 2, 2, 3, 我们认为第 1 小的元素是 1, 第 2 小的元素是 2, 第 3 小的元素还是 2, 第 4 小的元素是 3
Sample Input
22
insert 42
insert 42
insert 43
find 42
find 44
find_ith 2
find_ith 4
delete 42
delete_greater_than 42
find_ith 1
insert 1
insert 2
insert 3
insert 4
insert 5
delete_less_than 2
delete_interval 3 5
find 1
find 2
find 3
find 4
find 5
Sample Output
Y
N
42
N
42
N
Y
Y
N
Y
Limits
保证任意时刻树中元素不超过 5000 个,操作数小于300000
BugenZhao's solution
//
// Created by BugenZhao on 2019/4/25.
//
template<typename Key, typename Val>
class BPair {
public:
Key key;
Val val;
BPair(Key key, Val val) : key(key), val(val) {}
BPair() = delete;
};
template<typename Key, typename Val>
class BBinarySearchTree {
class Node {
public:
BPair<Key, Val> data;
Node *left;
Node *right;
explicit Node(const BPair<Key, Val> &data, Node *left = nullptr, Node *right = nullptr) :
data(data), left(left), right(right) {}
};
Node *root;
int size;
int find_ith_curCount;
Node *find_ith_ans;
private:
BPair<Key, Val> *get(Node *node, const Key &key) {
if (node == nullptr || node->data.key == key)
return &node->data;
if (key < node->data.key)
return get(node->left, key);
else
return get(node->right, key);
}
void put(Node *&node, const Key &key, const Val &val) {
if (node == nullptr)
node = new Node({key, val});
else if (key < node->data.key)
put(node->left, key, val);
else if (key > node->data.key)
put(node->right, key, val);
else
node->data.val = val;
}
void removeMin(Node *&node) {
if (node == nullptr) return;
if (node->left == nullptr) {
auto oldNode = node;
node = node->right;
delete node;
} else {
removeMin(node->left);
}
}
void remove(Node *&node, const Key &key) {
if (node == nullptr)
return;
if (key < node->data.key)
remove(node->left, key);
else if (key > node->data.key)
remove(node->right, key);
else if (node->right == nullptr) {
auto oldNode = node;
node = node->left;
delete oldNode;
} else if (node->left == nullptr) {
auto oldNode = node;
node = node->right;
delete oldNode;
} else {
auto p = node->right;
while (p->left != nullptr) p = p->left;
node->data = p->data;
remove(node->right, node->data.key);
}
}
void clear(Node *&node) {
if (node == nullptr) return;
clear(node->left);
clear(node->right);
delete node;
--size;
}
void delete_less_than(Node *&node, const Key &key) {
if (node == nullptr) return;
delete_less_than(node->left, key);
delete_less_than(node->right, key);
if (node->data.key < key) remove(node, node->data.key);
}
void delete_greater_than(Node *&node, const Key &key) {
if (node == nullptr) return;
delete_greater_than(node->left, key);
delete_greater_than(node->right, key);
if (node->data.key > key) remove(node, node->data.key);
}
void delete_interval(Node *&node, const Key &a, const Key &b) {
if (node == nullptr) return;
delete_interval(node->left, a, b);
delete_interval(node->right, a, b);
if (node->data.key > a && node->data.key < b) remove(node, node->data.key);
}
void find_ith(Node *node, int i) {
if (find_ith_ans) return;
if (node == nullptr) return;
find_ith(node->left, i);
if (find_ith_ans) return; // awesome
if ((find_ith_curCount -= node->data.val) <= 0) {
find_ith_ans = node;
return;
}
find_ith(node->right, i);
}
public:
BBinarySearchTree() : root(nullptr), size(0) {}
BPair<Key, Val> *get(const Key &key) {
return get(root, key);
}
void put(const Key &key, const Val &val) {
put(root, key, val);
}
void remove(const Key &key) {
remove(root, key);
}
void clear() {
clear(root);
}
virtual ~BBinarySearchTree() {
clear();
}
void delete_less_than(const Key &key) {
delete_less_than(root, key);
}
void delete_greater_than(const Key &key) {
delete_greater_than(root, key);
}
void delete_interval(const Key &a, const Key &b) {
delete_interval(root, a, b);
}
BPair<Key, Val> *find_ith(int i) {
if (i <= 0) return nullptr;
find_ith_curCount = i;
find_ith_ans = nullptr;
find_ith(root, i);
if (find_ith_ans) return &find_ith_ans->data;
else return nullptr;
}
};
#include <iostream>
#include <string>
using std::ios, std::cin, std::cout, std::endl, std::string;
using ll = long long;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
BBinarySearchTree<int, int> bst;
int n;
string cmd;
int a, b;
cin >> n;
while (n--) {
cin >> cmd;
if (cmd == "insert") {
cin >> a;
auto p = bst.get(a);
if (p) ++(p->val);
else bst.put(a, 1);
} else if (cmd == "delete") {
cin >> a;
auto p = bst.get(a);
if (p) --(p->val);
} else if (cmd == "find") {
cin >> a;
auto p = bst.get(a);
if (p && p->val) cout << "Y\n";
else cout << "N\n";
} else if (cmd == "delete_less_than") {
cin >> a;
bst.delete_less_than(a);
} else if (cmd == "delete_greater_than") {
cin >> a;
bst.delete_greater_than(a);
} else if (cmd == "delete_interval") {
cin >> a >> b;
bst.delete_interval(a, b);
} else if (cmd == "find_ith") {
cin >> a;
auto p = bst.find_ith(a);
if (p) cout << p->key << '\n';
else cout << "N\n";
}
}
}
ligongzzz's solution
#include "iostream"
#include "cstring"
using namespace std;
class bst {
public:
class node {
public:
int val = 0;
node* lchild = nullptr, * rchild = nullptr;
node* parent = nullptr;
};
class iterator {
public:
node* cur_pos = nullptr;
//重载++i
iterator& operator++() {
if (cur_pos == nullptr)
return *this;
//假如有右孩子
if (cur_pos->rchild != nullptr) {
for (cur_pos = cur_pos->rchild; cur_pos->lchild != nullptr;)
cur_pos = cur_pos->lchild;
}
//假如是落单的根节点
else if (cur_pos->parent == nullptr) {
cur_pos = nullptr;
}
//假如是左叶子
else if (cur_pos->parent->lchild == cur_pos)
cur_pos = cur_pos->parent;
//假如是右叶子
else {
auto last = cur_pos;
for (cur_pos = cur_pos->parent;
cur_pos != nullptr && cur_pos->rchild == last;
cur_pos = cur_pos->parent)
last = cur_pos;
}
return *this;
}
//重载i++
iterator operator++(int) {
auto temp = *this;
if (cur_pos == nullptr)
return temp;
//假如有右孩子
if (cur_pos->rchild != nullptr) {
for (cur_pos = cur_pos->rchild; cur_pos->lchild != nullptr;)
cur_pos = cur_pos->lchild;
}
//假如是落单的根节点
else if (cur_pos->parent == nullptr) {
cur_pos = nullptr;
}
//假如是左叶子
else if (cur_pos->parent->lchild == cur_pos)
cur_pos = cur_pos->parent;
//假如是右叶子
else {
auto last = cur_pos;
for (cur_pos = cur_pos->parent;
cur_pos != nullptr && cur_pos->rchild == last;
cur_pos = cur_pos->parent)
last = cur_pos;
}
return temp;
}
bool operator==(const iterator& other) {
return cur_pos == other.cur_pos;
}
bool operator!=(const iterator& other) {
return cur_pos != other.cur_pos;
}
int& operator*() {
return cur_pos->val;
}
};
node* root = nullptr;
iterator begin() {
iterator temp;
auto p = root;
for (; p->lchild; p = p->lchild);
temp.cur_pos = p;
return temp;
}
iterator end() {
iterator temp;
temp.cur_pos = nullptr;
return temp;
}
iterator find(const int& val) {
iterator ans;
int length = 0;
for (auto p = root; p;) {
if (p->val == val) {
ans.cur_pos = p;
return ans;
}
else if (val < p->val) {
p = p->lchild;
}
else {
p = p->rchild;
}
}
ans.cur_pos = nullptr;
return ans;
}
void insert(const int& val) {
if (!root) {
root = new node;
root->val = val;
return;
}
//寻找
auto p = root;
for (; p;) {
if (p->val == val) {
//统一插入右子树
if (p->rchild) {
p->rchild->parent = new node;
p->rchild->parent->val = val;
p->rchild->parent->rchild = p->rchild;
p->rchild = p->rchild->parent;
p->rchild->parent = p;
}
else {
p->rchild = new node;
p->rchild->val = val;
p->rchild->parent = p;
}
return;
}
if (val < p->val) {
if (p->lchild)
p = p->lchild;
else {
p->lchild = new node;
p->lchild->parent = p;
p = p->lchild;
break;
}
}
else {
if (p->rchild)
p = p->rchild;
else {
p->rchild = new node;
p->rchild->parent = p;
p = p->rchild;
break;
}
}
}
//增加
p->val = val;
}
iterator erase(const iterator& pos) {
iterator ans = pos;
auto p = pos.cur_pos;
//如果是叶子结点则直接删除
if (!p->lchild && !p->rchild) {
++ans;
if (p == root)
root = nullptr;
else if (p->parent->lchild == p)
p->parent->lchild = nullptr;
else
p->parent->rchild = nullptr;
delete p;
}
//如果只有左孩子
else if (p->lchild && !p->rchild) {
++ans;
if (p == root) {
root = p->lchild;
p->lchild->parent = nullptr;
}
else if (p->parent->lchild == p) {
p->parent->lchild = p->lchild;
p->lchild->parent = p->parent;
}
else {
p->parent->rchild = p->lchild;
p->lchild->parent = p->parent;
}
delete p;
}
//如果有右孩子
else if (p->rchild) {
auto q = p->rchild;
for (; q->lchild; q = q->lchild);
p->val = q->val;
if (q->parent->lchild == q)
q->parent->lchild = q->rchild;
else
q->parent->rchild = q->rchild;
if (q->rchild)
q->rchild->parent = q->parent;
delete q;
}
return ans;
}
void erase(const int& val) {
auto p = root;
for (; p;) {
if (p->val == val) {
//如果是叶子结点则直接删除
if (!p->lchild && !p->rchild) {
if (p == root)
root = nullptr;
else if (p->parent->lchild == p)
p->parent->lchild = nullptr;
else
p->parent->rchild = nullptr;
delete p;
}
//如果只有左孩子
else if (p->lchild && !p->rchild) {
if (p == root) {
root = p->lchild;
p->lchild->parent = nullptr;
}
else if (p->parent->lchild == p) {
p->parent->lchild = p->lchild;
p->lchild->parent = p->parent;
}
else {
p->parent->rchild = p->lchild;
p->lchild->parent = p->parent;
}
delete p;
}
//如果有右孩子
else if (p->rchild) {
auto q = p->rchild;
for (; q->lchild; q = q->lchild);
p->val = q->val;
if (q->parent->lchild == q)
q->parent->lchild = q->rchild;
else
q->parent->rchild = q->rchild;
if (q->rchild)
q->rchild->parent = q->parent;
delete q;
}
return;
}
else if (val < p->val)
p = p->lchild;
else
p = p->rchild;
}
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
bst setData;
int num;
cin >> num;
for (; num > 0; num--) {
char op[100];
cin >> op;
if (strcmp(op, "insert") == 0) {
int temp;
cin >> temp;
setData.insert(temp);
}
else if (strcmp(op, "delete") == 0) {
int temp;
cin >> temp;
if(auto iter=setData.find(temp);iter!=setData.end())
setData.erase(iter);
}
else if (strcmp(op, "delete_less_than") == 0) {
int temp;
cin >> temp;
for (auto p = setData.begin(); p != setData.end(); ) {
if (*p < temp) {
p = setData.erase(p);
}
else break;
}
}
else if (strcmp(op, "delete_greater_than") == 0) {
int temp;
cin >> temp;
for (auto p = setData.begin();p!=setData.end(); ) {
if (*p > temp) {
p = setData.erase(p);
}
else ++p;
}
}
else if (strcmp(op, "delete_interval") == 0) {
int l, r;
cin >> l >> r;
for (auto p = setData.begin(); p != setData.end();) {
if (*p > l && (*p) < r)
p = setData.erase(p);
else if (*p >= r)
break;
else ++p;
}
}
else if (strcmp(op, "find") == 0) {
int temp;
cin >> temp;
if (setData.find(temp) != setData.end())
cout << "Y" << endl;
else cout << "N" << endl;
}
else if (strcmp(op, "find_ith") == 0) {
int temp;
bool flag = false;
cin >> temp;
int i = 1;
for (auto p = setData.begin(); p != setData.end(); ++p, ++i) {
if (i == temp) {
flag = true;
cout << *p << endl;
break;
}
}
if (!flag)
cout << "N" << endl;
}
}
return 0;
}
Neight99's solution
#include <cstring>
#include <iostream>
using namespace std;
int findN;
bool flag;
class BinarySearchTree {
struct Node {
int data;
Node *left;
Node *right;
Node(int x = 0, Node *l = 0, Node *r = 0)
: data(x), left(l), right(r) {}
Node &operator=(Node &other) {
if (&other != this) {
data = other.data;
left = other.left;
right = other.right;
}
return *this;
}
};
Node *root;
int sum;
void clear(Node *&rhs);
void insert(int x, Node *&rhs);
bool find(int x, Node *&rhs);
void find_ith(int x, Node *&rhs);
void deleteNode(int x, Node *&rhs);
void deleteLess(int x, Node *&rhs);
void deleteGreater(int x, Node *&rhs);
void deleteInterval(int low, int high, Node *&rhs);
public:
BinarySearchTree() : root(0), sum(0) {}
~BinarySearchTree() { clear(root); }
void insert(int x) { insert(x, root); }
bool find(int x) { return find(x, root); }
void find_ith(int pos) {
if (pos > sum) {
flag = 0;
return;
}
flag = 0;
findN = 0;
find_ith(pos, root);
}
void deleteEqual(int x) {
deleteNode(x, root);
if (sum == 0) {
root = NULL;
}
}
void deleteLess(int x) {
deleteLess(x, root);
if (sum == 0) {
root = NULL;
}
}
void deleteGreater(int x) {
deleteGreater(x, root);
if (sum == 0) {
root = NULL;
}
}
void deleteInterval(int low, int high) {
deleteInterval(low, high, root);
if (sum == 0) {
root = NULL;
}
}
};
void BinarySearchTree::clear(Node *&rhs) {
if (rhs == 0) {
return;
} else {
clear(rhs->left);
clear(rhs->right);
delete rhs;
sum--;
}
}
void BinarySearchTree::insert(int x, Node *&rhs) {
if (rhs == 0) {
sum++;
rhs = new Node(x);
} else if (x <= rhs->data) {
insert(x, rhs->left);
} else {
insert(x, rhs->right);
}
}
bool BinarySearchTree::find(int x, Node *&rhs) {
if (rhs == 0) {
return 0;
} else if (rhs->data == x) {
return true;
} else if (x < rhs->data) {
return find(x, rhs->left);
} else {
return find(x, rhs->right);
}
}
void BinarySearchTree::find_ith(int x, Node *&rhs) {
if (findN > x) {
return;
}
if (rhs->left != 0) {
find_ith(x, rhs->left);
}
if (x == ++findN) {
cout << rhs->data << '\n';
flag = 1;
return;
}
if (rhs->right != 0) {
find_ith(x, rhs->right);
}
}
void BinarySearchTree::deleteNode(int x, Node *&rhs) {
if (rhs == 0) {
return;
}
if (x < rhs->data) {
deleteNode(x, rhs->left);
} else if (x > rhs->data) {
deleteNode(x, rhs->right);
} else if (rhs->left != 0 && rhs->right != 0) {
Node *p = rhs->right;
while (p->left != 0) {
p = p->left;
}
rhs->data = p->data;
deleteNode(rhs->data, rhs->right);
} else {
Node *clean = rhs;
rhs = (rhs->left != 0) ? rhs->left : rhs->right;
delete clean;
sum--;
}
}
void BinarySearchTree::deleteLess(int x, Node *&rhs) {
if (rhs == 0) {
return;
}
while (rhs != 0 && x > rhs->data) {
clear(rhs->left);
Node *temp = rhs->right;
delete rhs;
rhs = temp;
sum--;
}
if (rhs != 0 && x <= rhs->data) {
deleteLess(x, rhs->left);
}
}
void BinarySearchTree::deleteGreater(int x, Node *&rhs) {
if (rhs == 0) {
return;
}
while (rhs != 0 && x < rhs->data) {
clear(rhs->right);
Node *temp = rhs->left;
delete rhs;
rhs = temp;
sum--;
}
if (rhs != 0 && x >= rhs->data) {
deleteGreater(x, rhs->right);
}
}
void BinarySearchTree::deleteInterval(int low, int high, Node *&rhs) {
if (low >= high || rhs == 0) {
return;
}
while (rhs != 0 && rhs->data < high && rhs->data > low) {
deleteNode(rhs->data, rhs);
}
if (rhs != 0 && rhs->data >= high) {
deleteInterval(low, high, rhs->left);
}
if (rhs != 0 && rhs->data <= low) {
deleteInterval(low, high, rhs->right);
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
BinarySearchTree bst;
char order[100] = {0};
int n1, n2;
int times = 0;
cin >> times;
for (int i = 0; i < times; i++) {
cin >> order;
if (!strcmp(order, "insert")) {
cin >> n1;
bst.insert(n1);
} else if (!strcmp(order, "delete")) {
cin >> n1;
bst.deleteEqual(n1);
} else if (!strcmp(order, "delete_less_than")) {
cin >> n1;
bst.deleteLess(n1);
} else if (!strcmp(order, "delete_greater_than")) {
cin >> n1;
bst.deleteGreater(n1);
} else if (!strcmp(order, "delete_interval")) {
cin >> n1 >> n2;
bst.deleteInterval(n1, n2);
} else if (!strcmp(order, "find")) {
cin >> n1;
bool flag = bst.find(n1);
if (flag == 1) {
cout << "Y" << '\n';
} else {
cout << "N" << '\n';
}
} else if (!strcmp(order, "find_ith")) {
cin >> n1;
bst.find_ith(n1);
if (!flag) {
cout << "N\n";
}
}
}
return 0;
}
skyzh's solution
#include <iostream>
#include <climits>
#include <cstring>
using namespace std;
template<typename T>
struct BST {
struct Node {
T x;
Node *l, *r;
Node(Node *l = nullptr, Node *r = nullptr) : l(l), r(r) {}
Node(const T &x, Node *l = nullptr, Node *r = nullptr) : x(x), l(l), r(r) {}
void debug(int depth = 0) {
return;
for (int i = 0; i < depth; i++) cout << " ";
cout << x << endl;
for (int i = 0; i < depth; i++) cout << " ";
cout << "L" << endl;
if (l) l->debug(depth + 1);
for (int i = 0; i < depth; i++) cout << " ";
cout << "R" << endl;
if (r) r->debug(depth + 1);
}
} *root;
BST() : root(nullptr) {}
void clear(Node *ptr) {
if (!ptr) return;
clear(ptr->l);
clear(ptr->r);
delete ptr;
}
bool find(Node *ptr, const T &x) {
if (!ptr) return false;
if (ptr->x == x) return true;
return find(ptr->l, x) || find(ptr->r, x);
}
bool find(const T &x) {
return find(root, x);
}
Node *insert(Node *ptr, const T &x) {
if (!ptr) return new Node(x);
if (x <= ptr->x) ptr->l = insert(ptr->l, x);
if (x > ptr->x) ptr->r = insert(ptr->r, x);
return ptr;
}
void insert(const T &x) {
root = insert(root, x);
}
Node *find_ith(Node *ptr, int &i) {
if (!ptr) return nullptr;
Node *l = find_ith(ptr->l, i);
if (l) return l;
if (i == 1) return ptr;
--i;
Node *r = find_ith(ptr->r, i);
if (r) return r;
return nullptr;
}
Node *find_ith(int i) {
return find_ith(root, i);
}
void delete_less_than(const T &x) {
delete_interval(INT_MIN, x);
}
void delete_greater_than(const T &x) {
delete_interval(x, INT_MAX);
}
Node *delete_node_at(Node *ptr) {
if (!ptr->l) {
return ptr->r;
} else {
Node *prev = nullptr, *c = ptr->l;
while (c->r) {
prev = c;
c = c->r;
}
if (!prev) {
c->r = ptr->r;
return c;
}
prev->r = delete_node_at(c);
c->l = ptr->l;
c->r = ptr->r;
return c;
}
}
Node *delete_node(Node *ptr, const T &x) {
if (!ptr) return nullptr;
if (x < ptr->x) ptr->l = delete_node(ptr->l, x);
if (x == ptr->x) {
Node *result = delete_node_at(ptr);
delete ptr;
return result;
}
if (x > ptr->x) ptr->r = delete_node(ptr->r, x);
return ptr;
}
void delete_node(const T &x) {
root = delete_node(root, x);
}
Node *delete_interval(Node *ptr, const T &x1, const T &x2, const T &t1, const T &t2) {
if (!ptr) return nullptr;
if (t1 <= x1 && x2 <= t2) {
clear(ptr);
return nullptr;
}
ptr->l = delete_interval(ptr->l, x1, ptr->x, t1, t2);
ptr->r = delete_interval(ptr->r, ptr->x + 1, x2, t1, t2);
if (t1 <= ptr->x && ptr->x <= t2) {
Node *tmp = delete_node_at(ptr);
delete ptr;
return tmp;
}
return ptr;
}
Node *delete_interval(const T &t1, const T &t2) {
root = delete_interval(root, INT_MIN, INT_MAX, t1, t2);
}
};
int main() {
/*
"insert": 插入, 接下来一个整数, x, 表示被插入的元素
"delete": 删除, 接下来一个整数, x, 表示被删除的元素(若树中有重复删除任意一个)
"delete_less_than": 删除小于 x 的所有元素, 接下来一个整数, x
"delete_greater_than": 删除大于 x 的所有元素, 接下来一个整数, x
"delete_interval": 删除大于 a 且小于 b 的所有元素, 接下来两个整数, a, b
"find": 查找, 接下来一个整数, x, 表示被查找的元素
"find_ith": 查找第 i 小的元素, 接下来一个整数, i
*/
char cmd[100];
int N;
int op1, op2;
BST<int> tree;
cin >> N;
for (int i = 0; i < N; i++) {
cin >> cmd;
if (strcmp(cmd, "insert") == 0) {
cin >> op1;
tree.insert(op1);
tree.root->debug();
} else if (strcmp(cmd, "delete") == 0) {
cin >> op1;
tree.delete_node(op1);
tree.root->debug();
} else if (strcmp(cmd, "delete_less_than") == 0) {
cin >> op1;
tree.delete_less_than(op1 - 1);
tree.root->debug();
} else if (strcmp(cmd, "delete_greater_than") == 0) {
cin >> op1;
tree.delete_greater_than(op1 + 1);
tree.root->debug();
} else if (strcmp(cmd, "delete_interval") == 0) {
cin >> op1 >> op2;
tree.delete_interval(op1 + 1, op2 - 1);
tree.root->debug();
} else if (strcmp(cmd, "find") == 0) {
cin >> op1;
if (tree.find(op1)) cout << "Y" << endl; else cout << "N" << endl;
} else if (strcmp(cmd, "find_ith") == 0) {
cin >> op1;
BST<int>::Node *ith = tree.find_ith(op1);
if (ith) cout << ith->x << endl; else cout << "N" << endl;
}
}
return 0;
}
yyong119's solution
#include <cstdio>
#include <cstring>
class SearchTree {
public:
struct Node {
Node(int data = 0x7fffffff) : data_(data), number_(1), l_son_(NULL), r_son_(NULL) {}
void DeleteLSon() {
if (l_son_) {
l_son_->DeleteSon();
delete l_son_;
l_son_ = NULL;
}
}
void DeleteRSon() {
if (r_son_) {
r_son_->DeleteSon();
delete r_son_;
r_son_ = NULL;
}
}
void DeleteSon() {
DeleteLSon();
DeleteRSon();
}
int data_, number_;
Node *l_son_, *r_son_;
};
SearchTree() : head_(new Node()) {}
~SearchTree() {
head_->DeleteSon();
delete head_;
}
void Insert(int data) {
Node *p = head_, *q = head_->l_son_;
while(q) {
if (q->data_ == data) {
++q->number_;
return;
}
p = q;
if (q->data_ < data) q = q->r_son_;
else q = q->l_son_;
}
if (p->data_ >= data) p->l_son_ = new Node(data);
else p->r_son_ = new Node(data);
}
void Remove(int data) {
Node *p = head_, *q = head_->l_son_;
while(q && q->data_ != data) {
p = q;
if (q->data_ < data) q = q->r_son_;
else q = q->l_son_;
}
Remove(p, q);
}
void Remove(Node *p, Node *q, bool total = 0) {
if (q == NULL) return;
if (!total && q->number_ > 1) {
--q->number_;
return;
}
if (q->l_son_ && q->r_son_) {
Node *p2 = q, *q2 = q->l_son_;
for (; q2->r_son_; p2 = q2, q2 = q2->r_son_);
q->data_ = q2->data_;
q->number_ = q2->number_;
if (p2->l_son_ == q2) p2->l_son_ = q2->l_son_;
else p2->r_son_ = q2->l_son_;
delete q2;
}
else {
Node *new_node = NULL;
if (q->l_son_)
new_node = q->l_son_;
else if (q->r_son_)
new_node = q->r_son_;
if (p->l_son_ == q)
p->l_son_ = new_node;
else
p->r_son_ = new_node;
delete q;
}
}
void RemoveLessThan(int data) {
RemoveLessThan(head_, head_->l_son_, data);
}
void RemoveLessThan(Node *p, Node *q, int data) {
Node *tmp;
while(q) {
if (q->data_ <= data) {
q->DeleteLSon();
if (q->data_ == data) return;
tmp = q;
q = q->r_son_;
if (p->l_son_ == tmp) p->l_son_ = q;
else p->r_son_ = q;
delete tmp;
}
else {
p = q;
q = q->l_son_;
}
}
}
void RemoveGreaterThan(int data) {
RemoveGreaterThan(head_, head_->l_son_, data);
}
void RemoveGreaterThan(Node *p, Node *q, int data) {
Node *tmp;
for (; q;) {
if (q->data_ >= data) {
q->DeleteRSon();
if (q->data_ == data)
return;
tmp = q;
q = q->l_son_;
if (p->l_son_ == tmp) p->l_son_ = q;
else p->r_son_ = q;
delete tmp;
}
else {
p = q;
q = q->r_son_;
}
}
}
void RemoveInvterval(int lower, int upper) {
Node *p = head_, *q = head_->l_son_;
while(q) {
if (q->data_ < lower) {
p = q;
q = q->r_son_;
}
else if (q->data_ > upper) {
p = q;
q = q->l_son_;
}
else {
RemoveGreaterThan(q, q->l_son_, lower);
RemoveLessThan(q, q->r_son_, upper);
if (q->data_ == lower || q->data_ == upper) return;
Remove(p, q, 1);
return;
}
}
}
bool Find(int data) {
for (Node *p = head_->l_son_; p;) {
if (p->data_ == data) return true;
else if (p->data_ < data)
p = p->r_son_;
else
p = p->l_son_;
}
return false;
}
Node *FindIth(int i) {
return FindIth(head_->l_son_, i);
}
Node *FindIth(Node *node, int &i) {
if (!node) return NULL;
Node *tmp = FindIth(node->l_son_, i);
if (tmp) return tmp;
if (node->number_ >= i) return node;
i -= node->number_;
return FindIth(node->r_son_, i);
}
void Output() {
if (head_->l_son_) Output(head_->l_son_);
printf("\n");
}
void Output(Node *p) {
if (p->l_son_) Output(p->l_son_);
printf("%d,%d,%d ", p, p->data_, p->number_);
if (p->r_son_) Output(p->r_son_);
}
Node *head_;
};
SearchTree tree;
int total_command_number, tmp1, tmp2;
char command[30];
SearchTree::Node *tmp;
int main() {
scanf("%d", &total_command_number);
for (int i = 0; i < total_command_number; ++i) {
scanf("%s%d", command, &tmp1);
if (!strcmp(command, "insert"))
tree.Insert(tmp1);
else if (!strcmp(command, "delete"))
tree.Remove(tmp1);
else if (!strcmp(command, "delete_less_than"))
tree.RemoveLessThan(tmp1);
else if (!strcmp(command, "delete_greater_than"))
tree.RemoveGreaterThan(tmp1);
else if (!strcmp(command, "delete_interval")) {
scanf("%d", &tmp2);
tree.RemoveInvterval(tmp1, tmp2);
}
else if (!strcmp(command, "find"))
printf("%c\n", tree.Find(tmp1) ? 'Y' : 'N');
else if (!strcmp(command, "find_ith")) {
tmp = tree.FindIth(tmp1);
if (tmp) printf("%d\n", tmp->data_);
else printf("N\n");
}
}
}
zqy2018's solution
/*
See the editorial at https://github.com/zqy1018/tutorials/blob/master/bst_tutorial/bst.pdf
*/
#include <bits/stdc++.h>
#define INF 2000000000
using namespace std;
typedef long long ll;
int read(){
int f = 1, x = 0;
char c = getchar();
while(c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
while(c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
return f * x;
}
struct Tr {
int siz, v, prio, lch, rch;
};
Tr tr[400005];
int S = 0, root = 0;
void maintain(int x){
tr[x].siz = 1 + tr[tr[x].lch].siz + tr[tr[x].rch].siz;
}
int tree_new(int k){
++S;
tr[S].siz = 1, tr[S].v = k,
tr[S].prio = rand(),
tr[S].lch = tr[S].rch = 0;
return S;
}
struct pair_of_int{
int x, y;
pair_of_int(int _x, int _y): x(_x), y(_y){}
};
pair_of_int Split(int now, int k){
if (!now) return pair_of_int(0, 0);
else {
int x, y;
if (tr[now].v <= k){
x = now;
pair_of_int res = Split(tr[now].rch, k);
tr[now].rch = res.x;
y = res.y;
}else {
y = now;
pair_of_int res = Split(tr[now].lch, k);
x = res.x;
tr[now].lch = res.y;
}
maintain(now);
return pair_of_int(x, y);
}
}
pair_of_int Split_K(int now, int k){
if (!now) return pair_of_int(0, 0);
else {
int x, y;
if (k > tr[tr[now].lch].siz){
x = now;
pair_of_int res = Split_K(tr[now].rch, k - tr[tr[now].lch].siz - 1);
tr[now].rch = res.x;
y = res.y;
}else {
y = now;
pair_of_int res = Split_K(tr[now].lch, k);
x = res.x;
tr[now].lch = res.y;
}
maintain(now);
return pair_of_int(x, y);
}
}
int Merge(int x, int y){
if (!x || !y) return x + y;
if (tr[x].prio < tr[y].prio){
tr[x].rch = Merge(tr[x].rch, y);
maintain(x);
return x;
}else{
tr[y].lch = Merge(x, tr[y].lch);
maintain(y);
return y;
}
}
void Insert(int k){
int z = tree_new(k);
pair_of_int res = Split(root, k);
root = Merge(Merge(res.x, z), res.y);
}
void Del(int k){
pair_of_int res1 = Split(root, k - 1);
pair_of_int res2 = Split_K(res1.y, 1);
root = Merge(res1.x, res2.y);
}
bool Lookup(int k){
int t = root;
while (t){
if (tr[t].v < k) t = tr[t].rch;
else if (tr[t].v > k) t = tr[t].lch;
else return true;
}
return false;
}
void Del_less(int k){
pair_of_int res = Split(root, k - 1);
root = res.y;
}
void Del_greater(int k){
pair_of_int res = Split(root, k);
root = res.x;
}
void Del_interval(int l, int r){
int x, y, w, z;
pair_of_int res1 = Split(root, l);
pair_of_int res2 = Split(res1.y, r - 1);
root = Merge(res1.x, res2.y);
}
bool Kth(int k, int &res){
if(k <= 0 || k > tr[root].siz) return false;
int t = root;
for(; ; ){
int evid = tr[tr[t].lch].siz;
if (k <= evid) t = tr[t].lch;
else if (k == evid + 1) {
res = tr[t].v;
break;
}else k -= evid + 1, t = tr[t].rch;
}
return true;
}
int Q;
void init(){
Q = read();
srand(time(NULL));
}
void solve(){
char o[30];
while (Q--){
scanf("%s", o);
if (o[0] == 'i'){
int k = read();
Insert(k);
}
if (o[0] == 'd'){
if (o[6] == '\0'){
int k = read();
if (Lookup(k))
Del(k);
}else {
if (o[7] == 'l'){
int k = read();
Del_less(k);
}
if (o[7] == 'g'){
int k = read();
Del_greater(k);
}
if (o[7] == 'i'){
int l = read(), r = read();
Del_interval(l, r);
}
}
}
if (o[0] == 'f'){
if (o[4] == '\0'){
int k = read();
printf("%s\n", (Lookup(k) ? "Y": "N"));
}else {
int k = read(), res;
if (Kth(k, res)) printf("%d\n", res);
else printf("N\n");
}
}
}
}
int main(){
init();
solve();
return 0;
}
Zsi-r's solution
#include <iostream>
#include <cstring>
using namespace std;
struct bstnode
{
int data;
bstnode *left, *right;
bstnode(int d, bstnode *l = NULL, bstnode *r = NULL) : data(d), left(l), right(r){};
bstnode(){};
};
struct StNode{
bstnode *node;
int TimesPop;
StNode(bstnode *N = NULL) : node(N), TimesPop(0){};
};
struct stacknode
{
StNode data;
stacknode *next;
stacknode(const StNode d, stacknode *n = NULL):data(d),next(n){};
stacknode():next(NULL){};
~stacknode(){};
};
class stack
{
private:
stacknode *top_p;
public:
stack() { top_p = NULL; }
~stack()
{
stacknode *temp = top_p;
while(top_p!=NULL){
top_p = top_p->next;
delete temp;
temp = top_p;
}
}
bool isempty() { return top_p == NULL; }
void push(const StNode&x)
{
top_p = new stacknode(x, top_p);
}
StNode pop()
{
stacknode *temp = top_p;
StNode value = temp->data;
top_p = top_p->next;
delete temp;
return value;
}
};
class bst
{
public:
bstnode *root;
void insert(int d,bstnode* &n)
{
if (n == NULL)
n = new bstnode(d);
else if (n->data<d)
insert(d, n->right);
else if (d<=n->data)
insert(d, n->left);
}
void remove(int d,bstnode* &n)
{
if (n==NULL)
return;
else if (n->data>d)
remove(d, n->left);
else if (n->data<d)
remove(d, n->right);
else if (n->left!=NULL&&n->right!=NULL)
{
bstnode *temp = n->right;
while (temp->left!=NULL)
temp = temp->left;
n->data = temp->data;
remove(n->data,n->right);
}
else
{
bstnode *temp = n;
n = (n->left != NULL) ? n->left : n->right;
delete temp;
}
}
void find (int x,bstnode *n) const
{
if (n==NULL)
{
cout << 'N' << endl;
return;
}
if(n->data==x)
{
cout << 'Y' << endl;
return;
}
else if (n->data<x)
return find(x, n->right);
else if (n->data>x)
return find(x, n->left);
}
public:
bst() { root = NULL; }
~bst(){};
void find(int x)
{
find(x, root);
}
void find_ith(int i)
{
int count = 0;
bool flag = false;
stack s;
StNode current(root);
s.push(current);
while(!s.isempty())
{
current = s.pop();
if (++current.TimesPop==2){
count++;
if (count == i)
{
flag = true;
cout << current.node->data << endl;
return;
}
if (current.node->right!=NULL)
s.push(StNode(current.node->right));
}
else
{
s.push(current);
if (current.node->left!=NULL)
s.push(StNode(current.node->left));
}
}
if (!flag)
cout << 'N' << endl;
}
void delete_greater_than(int x,bstnode *&n)
{
if (n==NULL)
return;
if (n->data<=x)
delete_greater_than(x, n->right);
else if (n->data>x)
{
delete_greater_than(x, n->right);
delete_greater_than(x, n->left);
remove(n->data, n);
}
}
void delete_less_than(int x,bstnode *&n)
{
if (n==NULL)
return;
delete_less_than(x, n->left);
delete_less_than(x, n->right);
if (n->data<x)
remove(n->data, n);
}
void delete_interval(int a ,int b,bstnode *&n)
{
if (n==NULL)
return;
delete_interval(a, b, n->left);
delete_interval(a, b, n->right);
if (n->data>a && n->data<b)
remove(n->data, n);
}
};
int main()
{
int n,num,a,b;
char s[20];
bst tree;
cin >> n;
for (int i = 0; i < n;i++)
{
cin >> s;
if(strcmp(s,"insert")==0){
cin >> num;
tree.insert(num, tree.root);
}
else if (strcmp(s,"delete")==0)
{
cin >> num;
tree.remove(num, tree.root);
}
else if (strcmp(s,"find")==0)
{
cin >> num;
tree.find(num);
}
else if (strcmp(s,"find_ith")==0)
{
cin >> num;
tree.find_ith(num);
}
else if (strcmp(s,"delete_greater_than")==0)
{
cin >> num;
tree.delete_greater_than(num, tree.root);
}
else if (strcmp(s,"delete_less_than")==0)
{
cin >> num;
tree.delete_less_than(num, tree.root);
}
else if (strcmp(s,"delete_interval")==0)
{
cin >> a >> b;
tree.delete_interval(a, b, tree.root);
}
}
return 0;
}