CFR Educational 2 E. Lomsat gelral ( Smaller to Larger or Time stamp + Mo's Algorithm )
Problem - E - Codeforces
The official solution is to simply merge subtrees from the leaves to the root. Using the smaller to larger strategy, where we always merge the smaller subtree towards the larger, the overall time complexity is O( ( n lg n ) * lg n ), since there are at most n lg n nodes moved from one set to another. I mistook the index with its color when merging and spent lots of time debugging. It is also important to make clear which values should be swapped when preparing for the smaller to larger strategy.
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> pii; const int MAXN = 1e5 + 5; const int MAXC = MAXN; int n; int col[ MAXN ]; vector<int> es[ MAXN ]; ll ans[ MAXN ]; int Zmax_occ[ MAXN ]; map<int, int> dfs(int u, int fa){ // map[ i ] = j: j of color i is in subtree u map<int, int> col_occ; // color occurrence int max_occ = col_occ[ col[ u ] ] = 1; ll cur_sum = col[ u ]; // current dominating color sum for(int v: es[ u ]){ if( v == fa ) continue; map<int, int> son_occ = dfs( v, u ); if( col_occ.size() < son_occ.size() ){ swap( col_occ, son_occ ); cur_sum = ans[ v ]; max_occ = Zmax_occ[ v ]; } for(auto it = son_occ.begin(); it != son_occ.end(); ++it){ if( col_occ[ it->first ] + it->second > max_occ ){ max_occ = ( col_occ[ it->first ] += it->second ); cur_sum = it->first; } else if( col_occ[ it->first ] + it->second == max_occ ){ col_occ[ it->first ] += it->second; cur_sum += it->first; } else{ col_occ[ it->first ] += it->second; } } } ans[ u ] = cur_sum; Zmax_occ[ u ] = max_occ; return col_occ; } void solve(){ dfs( 1, -1 ); for(int i = 1; i <= n; ++i) printf("%lld%c", ans[ i ], i == n ? '\n' : ' '); } int main(){ scanf("%d", &n); for(int i = 1; i <= n; ++i) // 1 based idx scanf("%d", &col[ i ]); for(int i = 0; i < n - 1; ++i){ int x, y; scanf("%d%d", &x, &y); es[ x ].push_back( y ); es[ y ].push_back( x ); } solve(); return 0; }
Another way to solve this problem is to make use of Mo's algorithm. We can make time stamps on each node, the pre-visit stamp represents its left bound in segment, post-visit stamp represents its right bound. Note that the col[ ] array from the input should be re-mapped to the segment in order to make sense. Transitions for both shrinking and expanding for one single element can be done in O( 1 ), maintaining frequency of each color, count of each frequency and the current answer for each frequency.
Therefore with Mo's algorithm, the overall time complexity is O( n ^ 1.5 ).
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int MAXN = 1e5 + 5; const int MAXC = MAXN; int n; int _col[ MAXN ]; // input col int col[ MAXN ]; // col mapped from tree to segment vector<int> es[ MAXN ]; int dfs_clock; int in[ MAXN ], out[ MAXN ]; // TIMESTAMP void buildTimeStamp(int u, int fa){ in[ u ] = ++dfs_clock; // 1 based idx col[ in[ u ] ] = _col[ u ]; for(int v: es[ u ]){ if( v == fa ) continue; buildTimeStamp( v, u ); } out[ u ] = dfs_clock; // u -> [ in[ u ], out[ u ] ] } int blk[ MAXN ]; struct Query{ int id, ql, qr; Query(int _i = -1, int _l = -1, int _r = -1): id(_i), ql(_l), qr(_r){} bool operator < (const Query &oth) const{ if( blk[ ql ] != blk[ oth.ql ] ) return blk[ ql ] < blk[ oth.ql ]; return qr < oth.qr; } } qry[ MAXN ]; ll ans[ MAXN ]; void mosAlgorithm(){ int lb = 2, rb = 1; // [ lb, rb ] int max_occ = 0; vector<int> occ( MAXN ), occ_cnt( MAXN ); vector<ll> sum_occ_cnt( MAXN ); // sum_occ_cnt[ i ]: sum of a particular occ_cnt for(int i = 1; i <= n; ++i){ while( rb < qry[ i ].qr ){ ++rb; sum_occ_cnt[ occ[ col[ rb ] ] ] -= col[ rb ]; --occ_cnt[ occ[ col[ rb ] ] ]; max_occ = max<int>( max_occ, ++occ[ col[ rb ] ] ); ++occ_cnt[ occ[ col[ rb ] ] ]; sum_occ_cnt[ occ[ col[ rb ] ] ] += col[ rb ]; } while( rb > qry[ i ].qr ){ sum_occ_cnt[ occ[ col[ rb ] ] ] -= col[ rb ]; --occ_cnt[ occ[ col[ rb ] ] ]; if( occ_cnt[ max_occ ] == 0 ) --max_occ; --occ[ col[ rb ] ]; ++occ_cnt[ occ[ col[ rb ] ] ]; sum_occ_cnt[ occ[ col[ rb ] ] ] += col[ rb ]; --rb; } while( lb > qry[ i ].ql ){ --lb; sum_occ_cnt[ occ[ col[ lb ] ] ] -= col[ lb ]; --occ_cnt[ occ[ col[ lb ] ] ]; max_occ = max<int>( max_occ, ++occ[ col[ lb ] ] ); ++occ_cnt[ occ[ col[ lb ] ] ]; sum_occ_cnt[ occ[ col[ lb ] ] ] += col[ lb ]; } while( lb < qry[ i ].ql ){ sum_occ_cnt[ occ[ col[ lb ] ] ] -= col[ lb ]; --occ_cnt[ occ[ col[ lb ] ] ]; if( occ_cnt[ max_occ ] == 0 ) --max_occ; --occ[ col[ lb ] ]; ++occ_cnt[ occ[ col[ lb ] ] ]; sum_occ_cnt[ occ[ col[ lb ] ] ] += col[ lb ]; ++lb; } assert( occ_cnt[ max_occ ] > 0 ); assert( sum_occ_cnt[ max_occ ] > 0 ); ans[ qry[ i ].id ] = sum_occ_cnt[ max_occ ]; } } void solve(){ buildTimeStamp( 1, -1 ); assert( dfs_clock == n ); // range[ 1 ]: [ 1, n ] for(int i = 1; i <= n; ++i) qry[ i ] = Query( i, in[ i ], out[ i ] ); int BLK_SIZE = sqrt( n ) + 1; for(int i = 1; i <= n; ++i) blk[ i ] = ( i - 1 ) / BLK_SIZE; // sort( qry, qry + n ); OOPS sort( qry + 1, qry + 1 + n ); mosAlgorithm(); for(int i = 1; i <= n; ++i) cout << ans[ i ] << ( i == n ? '\n' : ' ' ); } int main(){ ios::sync_with_stdio( false ); cin >> n; for(int i = 1; i <= n; ++i) cin >> _col[ i ]; for(int i = 0; i < n - 1; ++i){ int x, y; cin >> x >> y; es[ x ].push_back( y ); es[ y ].push_back( x ); } solve(); return 0; }