0w1

CFR Educational 2 E. Lomsat gelral ( Smaller to Larger or Time stamp + Mo's Algorithm )

Problem - E - Codeforces
The official solution is to simply merge subtrees from the leaves to the root. Using the smaller to larger strategy, where we always merge the smaller subtree towards the larger, the overall time complexity is O( ( n lg n ) * lg n ), since there are at most n lg n nodes moved from one set to another. I mistook the index with its color when merging and spent lots of time debugging. It is also important to make clear which values should be swapped when preparing for the smaller to larger strategy.

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
const int MAXN = 1e5 + 5;
const int MAXC = MAXN;

int n;
int col[ MAXN ];
vector<int> es[ MAXN ];

ll ans[ MAXN ];

int Zmax_occ[ MAXN ];

map<int, int> dfs(int u, int fa){ // map[ i ] = j: j of color i is in subtree u
    map<int, int> col_occ; // color occurrence
    int max_occ = col_occ[ col[ u ] ] = 1;
    ll cur_sum = col[ u ]; // current dominating color sum
    for(int v: es[ u ]){
        if( v == fa ) continue;
        map<int, int> son_occ = dfs( v, u );
        if( col_occ.size() < son_occ.size() ){
            swap( col_occ, son_occ );
            cur_sum = ans[ v ];
            max_occ = Zmax_occ[ v ];
        }
        for(auto it = son_occ.begin(); it != son_occ.end(); ++it){
            if( col_occ[ it->first ] + it->second > max_occ ){
                max_occ = ( col_occ[ it->first ] += it->second );
                cur_sum = it->first;
            } else if( col_occ[ it->first ] + it->second == max_occ ){
                col_occ[ it->first ] += it->second;
                cur_sum += it->first;
            } else{
                col_occ[ it->first ] += it->second;
            }
        }
    }
    ans[ u ] = cur_sum;
    Zmax_occ[ u ] = max_occ;
    return col_occ;
}

void solve(){
    dfs( 1, -1 );
    for(int i = 1; i <= n; ++i)
        printf("%lld%c", ans[ i ], i == n ? '\n' : ' ');
}

int main(){
    scanf("%d", &n);
    for(int i = 1; i <= n; ++i) // 1 based idx
        scanf("%d", &col[ i ]);
    for(int i = 0; i < n - 1; ++i){
        int x, y; scanf("%d%d", &x, &y);
        es[ x ].push_back( y );
        es[ y ].push_back( x );
    }
    solve();
    return 0;
}

Another way to solve this problem is to make use of Mo's algorithm. We can make time stamps on each node, the pre-visit stamp represents its left bound in segment, post-visit stamp represents its right bound. Note that the col[ ] array from the input should be re-mapped to the segment in order to make sense. Transitions for both shrinking and expanding for one single element can be done in O( 1 ), maintaining frequency of each color, count of each frequency and the current answer for each frequency.
Therefore with Mo's algorithm, the overall time complexity is O( n ^ 1.5 ).

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 5;
const int MAXC = MAXN;

int n;
int _col[ MAXN ]; // input col
int col[ MAXN ]; // col mapped from tree to segment
vector<int> es[ MAXN ];

int dfs_clock;
int in[ MAXN ], out[ MAXN ]; // TIMESTAMP

void buildTimeStamp(int u, int fa){
    in[ u ] = ++dfs_clock; // 1 based idx
    col[ in[ u ] ] = _col[ u ];
    for(int v: es[ u ]){
        if( v == fa ) continue;
        buildTimeStamp( v, u );
    }
    out[ u ] = dfs_clock; // u -> [ in[ u ], out[ u ] ]
}

int blk[ MAXN ];
struct Query{
    int id, ql, qr;
    Query(int _i = -1, int _l = -1, int _r = -1): id(_i), ql(_l), qr(_r){}
    bool operator < (const Query &oth) const{
        if( blk[ ql ] != blk[ oth.ql ] ) return blk[ ql ] < blk[ oth.ql ];
        return qr < oth.qr;
    }
} qry[ MAXN ];

ll ans[ MAXN ];

void mosAlgorithm(){
    int lb = 2, rb = 1; // [ lb, rb ]
    int max_occ = 0;
    vector<int> occ( MAXN ), occ_cnt( MAXN );
    vector<ll> sum_occ_cnt( MAXN ); // sum_occ_cnt[ i ]: sum of a particular occ_cnt
    for(int i = 1; i <= n; ++i){
        while( rb < qry[ i ].qr ){
            ++rb;
            sum_occ_cnt[ occ[ col[ rb ] ] ] -= col[ rb ];
            --occ_cnt[ occ[ col[ rb ] ] ];

            max_occ = max<int>( max_occ, ++occ[ col[ rb ] ] );

            ++occ_cnt[ occ[ col[ rb ] ] ];
            sum_occ_cnt[ occ[ col[ rb ] ] ] += col[ rb ];
        }
        while( rb > qry[ i ].qr ){
            sum_occ_cnt[ occ[ col[ rb ] ] ] -= col[ rb ];
            --occ_cnt[ occ[ col[ rb ] ] ];

            if( occ_cnt[ max_occ ] == 0 ) --max_occ;
            --occ[ col[ rb ] ];

            ++occ_cnt[ occ[ col[ rb ] ] ];
            sum_occ_cnt[ occ[ col[ rb ] ] ] += col[ rb ];
            --rb;
        }
        while( lb > qry[ i ].ql ){
            --lb;
            sum_occ_cnt[ occ[ col[ lb ] ] ] -= col[ lb ];
            --occ_cnt[ occ[ col[ lb ] ] ];
            
            max_occ = max<int>( max_occ, ++occ[ col[ lb ] ] );

            ++occ_cnt[ occ[ col[ lb ] ] ];
            sum_occ_cnt[ occ[ col[ lb ] ] ] += col[ lb ];
        }
        while( lb < qry[ i ].ql ){
            sum_occ_cnt[ occ[ col[ lb ] ] ] -= col[ lb ];
            --occ_cnt[ occ[ col[ lb ] ] ];

            if( occ_cnt[ max_occ ] == 0 ) --max_occ;
            --occ[ col[ lb ] ];

            ++occ_cnt[ occ[ col[ lb ] ] ];
            sum_occ_cnt[ occ[ col[ lb ] ] ] += col[ lb ];
            ++lb;
        }
        assert( occ_cnt[ max_occ ] > 0 );
        assert( sum_occ_cnt[ max_occ ] >  0 );
        ans[ qry[ i ].id ] = sum_occ_cnt[ max_occ ];
    }
}

void solve(){
    buildTimeStamp( 1, -1 );
    assert( dfs_clock == n ); // range[ 1 ]: [ 1, n ]
    for(int i = 1; i <= n; ++i)
        qry[ i ] = Query( i, in[ i ], out[ i ] );
    int BLK_SIZE = sqrt( n ) + 1;
    for(int i = 1; i <= n; ++i)
        blk[ i ] = ( i - 1 ) / BLK_SIZE;
    // sort( qry, qry + n ); OOPS
    sort( qry + 1, qry + 1 + n );
    mosAlgorithm();
    for(int i = 1; i <= n; ++i)
        cout << ans[ i ] << ( i == n ? '\n' : ' ' );
}

int main(){
    ios::sync_with_stdio( false );
    cin >> n;
    for(int i = 1; i <= n; ++i)
        cin >> _col[ i ];
    for(int i = 0; i < n - 1; ++i){
        int x, y; cin >> x >> y;
        es[ x ].push_back( y );
        es[ y ].push_back( x );
    }
    solve();
    return 0;
}