Template RBST by Value
A set that allows kth_element. ( Insert and remove does not allow duplicate )
struct rbst{ int val, size; ll sum; rbst *lc, *rc; rbst( int _v ) : size( 1 ), val( _v ), sum( _v ), lc( NULL ), rc( NULL ){} void pull(){ sum = val; sum += lc ? lc->sum : 0; sum += rc ? rc->sum : 0; size = 1; size += lc ? lc->size : 0; size += rc ? rc->size : 0; } }; int get_size( rbst *t ){ return t ? t->size : 0; } void split( rbst *t, int k, rbst *&a, rbst *&b ){ if( not t ) return ( void )( a = b = NULL ); if( k <= get_size( t->lc ) ) b = t, split( t->lc, k, a, b->lc ), b->pull(); else a = t, split( t->rc, k - ( get_size( t->lc ) + 1 ), a->rc, b ), a->pull(); } rbst* merge( rbst *a, rbst *b ){ if( not a or not b ) return a ? a : b; if( rand() % ( a->size + b->size ) < a->size ){ a->rc = merge( a->rc, b ); a->pull(); return a; } else{ b->lc = merge( a, b->lc ); b->pull(); return b; } } int lower_bound( rbst *t, int v ){ if( not t ) return 0; if( t->val >= v ) return lower_bound( t->lc, v ); return get_size( t->lc ) + 1 + lower_bound( t->rc, v ); } int kth( rbst *&t, int k ){ // 0 idx rbst *tl, *tr; split( t, k, tl, t ); split( t, 1, t, tr ); int res = t->val; t = merge( merge( tl, t ), tr ); return res; } bool insert( rbst *&t, int v ){ // if already exists, does nothing int idx = lower_bound( t, v ); if( idx < get_size( t ) and kth( t, idx ) == v ) return false; // already exists rbst *tt; split( t, idx, tt, t ); t = merge( merge( tt, new rbst( v ) ), t ); return true; } bool remove( rbst *&t, int v ){ // only removes one element that matches int idx = lower_bound( t, v ); if( idx == get_size( t ) or kth( t, idx ) != v ) return false; // not found rbst *tl, *tr; split( t, idx, tl, t ); split( t, 1, t, tr ); t = merge( tl, tr ); return true; } bool count( rbst *t, int v ){ int lidx = lower_bound( t, v ); int ridx = lower_bound( t, v + 1 ); return ridx - lidx; } // returns sum of values in range [ ql, qr ] ll get_sum( rbst *&t, int ql, int qr ){ int lidx = lower_bound( t, ql ); int ridx = lower_bound( t, qr + 1 ); // [ lidx, ridx ) is the range we want to know rbst *tl, *tr; split( t, lidx, tl, t ); split( t, ridx - lidx, t, tr ); ll res = t ? t->sum : 0; t = merge( tl, merge( t, tr ) ); return res; } void print_all( rbst *t ){ if( not t ) return; print_all( t->lc ); cout << t->val << " "; print_all( t->rc ); } const int MOD = 1e9 + 7; void solve(){ rbst *root = NULL; int N; cin >> N; for( int i = 0; i < N; ++i ){ string op; cin >> op; if( op == "INSERT" ){ int v; cin >> v; v = ( 1LL * v ) % MOD; insert( root, v ); } else if( op == "REMOVE" ){ int v; cin >> v; v = ( 1LL * v ) % MOD; remove( root, v ); } else if( op == "FIND" ){ int v; cin >> v; v = ( 1LL * v ) % MOD; if( count( root, v ) ) cout << "Found" << endl; else cout << "Not found" << endl; } else if( op == "GET_SUM" ){ int ql, qr; cin >> ql >> qr; ql = ( 1LL * ql ) % MOD; qr = ( 1LL * qr ) % MOD; cout << get_sum( root, ql, qr ) << endl; } /*cout << "After Op #" << i << " : "; print_all( root ); cout << endl;*/ } }