0w1

Template RBST by Value

A set that allows kth_element. ( Insert and remove does not allow duplicate )

struct rbst{
    int val, size;
    ll sum;
    rbst *lc, *rc;
    rbst( int _v ) : size( 1 ), val( _v ), sum( _v ), lc( NULL ), rc( NULL ){}
    void pull(){
        sum = val;
        sum += lc ? lc->sum : 0;
        sum += rc ? rc->sum : 0;
        size = 1;
        size += lc ? lc->size : 0;
        size += rc ? rc->size : 0;
    }
};

int get_size( rbst *t ){
    return t ? t->size : 0;
}

void split( rbst *t, int k, rbst *&a, rbst *&b ){
    if( not t ) return ( void )( a = b = NULL );
    if( k <= get_size( t->lc ) )
        b = t,
        split( t->lc, k, a, b->lc ),
        b->pull();
    else
        a = t,
        split( t->rc, k - ( get_size( t->lc ) + 1 ), a->rc, b ),
        a->pull();
}

rbst* merge( rbst *a, rbst *b ){
    if( not a or not b ) return a ? a : b;
    if( rand() % ( a->size + b->size ) < a->size ){
        a->rc = merge( a->rc, b );
        a->pull();
        return a;
    }
    else{
        b->lc = merge( a, b->lc );
        b->pull();
        return b;
    }
}

int lower_bound( rbst *t, int v ){
    if( not t ) return 0;
    if( t->val >= v ) return lower_bound( t->lc, v );
    return get_size( t->lc ) + 1 + lower_bound( t->rc, v );
}

int kth( rbst *&t, int k ){ // 0 idx
    rbst *tl, *tr;
    split( t, k, tl, t );
    split( t, 1, t, tr );
    int res = t->val;
    t = merge( merge( tl, t ), tr );
    return res;
}

bool insert( rbst *&t, int v ){ // if already exists, does nothing
    int idx = lower_bound( t, v );
    if( idx < get_size( t ) and kth( t, idx ) == v ) return false; // already exists
    rbst *tt;
    split( t, idx, tt, t );
    t = merge( merge( tt, new rbst( v ) ), t );
    return true;
}

bool remove( rbst *&t, int v ){ // only removes one element that matches
    int idx = lower_bound( t, v );
    if( idx == get_size( t ) or kth( t, idx ) != v ) return false; // not found
    rbst *tl, *tr;
    split( t, idx, tl, t );
    split( t, 1, t, tr );
    t = merge( tl, tr );
    return true;
}

bool count( rbst *t, int v ){
    int lidx = lower_bound( t, v );
    int ridx = lower_bound( t, v + 1 );
    return ridx - lidx;
}

// returns sum of values in range [ ql, qr ]
ll get_sum( rbst *&t, int ql, int qr ){
    int lidx = lower_bound( t, ql );
    int ridx = lower_bound( t, qr + 1 ); // [ lidx, ridx ) is the range we want to know
    rbst *tl, *tr;
    split( t, lidx, tl, t );
    split( t, ridx - lidx, t, tr );
    ll res = t ? t->sum : 0;
    t = merge( tl, merge( t, tr ) );
    return res;
}

void print_all( rbst *t ){
    if( not t ) return;
    print_all( t->lc );
    cout << t->val << " ";
    print_all( t->rc );
}

const int MOD = 1e9 + 7;

void solve(){
    rbst *root = NULL;
    int N; cin >> N;
    for( int i = 0; i < N; ++i ){
        string op; cin >> op;
        if( op == "INSERT" ){
            int v; cin >> v;
            v = ( 1LL * v ) % MOD;
            insert( root, v );
        } else if( op == "REMOVE" ){
            int v; cin >> v;
            v = ( 1LL * v ) % MOD;
            remove( root, v );
        } else if( op == "FIND" ){
            int v; cin >> v;
            v = ( 1LL * v ) % MOD;
            if( count( root, v ) )
                cout << "Found" << endl;
            else
                cout << "Not found" << endl;
        } else if( op == "GET_SUM" ){
            int ql, qr; cin >> ql >> qr;
            ql = ( 1LL * ql ) % MOD;
            qr = ( 1LL * qr ) % MOD;
            cout << get_sum( root, ql, qr ) << endl;
        }
        /*cout << "After Op #" << i << " : ";
        print_all( root );
        cout << endl;*/
    }
}