0w1

BZOJ 1036 树的统计 ( 樹平方分割 )

Problem 1036. -- [ZJOI2008]树的统计Count
このテクすごいね、猛練習しよう。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 3e4 + 4;
const int MAXQ = 2e5 + 5;
const int INF = 0x3f3f3f3f;

int n;
vector< int > es[ MAXN ], bes[ MAXN ];
int par[ MAXN ], dpt[ MAXN ], sz[ MAXN ], head[ MAXN ];
int w[ MAXN ], sum[ MAXN ], mxw[ MAXN ];
int BS;

void build(int u, int fa, int d){
    par[ u ] = fa;
    dpt[ u ] = d;
    if( !head[ u ] ) sz[ head[ u ] = u ] = 1;
    for(int i = 0; i < es[ u ].size(); ++i){
        int v = es[ u ][ i ];
        if( v == fa ) continue;
        if( sz[ head[ u ] ] + 1 < BS )
            bes[ u ].push_back( v ),
            sz[ head[ v ] = head[ u ] ]++;
        build( v, u, d + 1 );
    }
}

void dfs(int u, int s, int m){
    sum[ u ] = s + w[ u ];
    mxw[ u ] = max( m, w[ u ] );
    for(int i = 0; i < bes[ u ].size(); ++i){
        int v = bes[ u ][ i ];
        dfs( v, sum[ u ], mxw[ u ] );
    }
}

void update(int u, int v){
    w[ u ] = v;
    if( head[ u ] == u ) dfs( u, 0, -INF );
    else dfs( u, sum[ par[ u ] ], mxw[ par[ u ] ] );
}

int qmax(int u, int v){
    int res = -INF;
    while( u != v ){
        if( head[ u ] == head[ v ] ){
            if( dpt[ u ] < dpt[ v ] ) swap( u, v );
            res = max( res, w[ u ] );
            u = par[ u ];
        } else{
            if( dpt[ head[ u ] ] < dpt[ head[ v ] ] ) swap( u, v );
            res = max( res, mxw[ u ] );
            u = par[ head[ u ] ];
        }
    }
    return max( res, w[ u ] );
}

int qsum(int u, int v){
    int res = 0;
    while( u != v ){
        if( head[ u ] == head[ v ] ){
            if( dpt[ u ] < dpt[ v ] ) swap( u, v );
            res += w[ u ];
            u = par[ u ];
        } else{
            if( dpt[ head[ u ] ] < dpt[ head[ v ] ] ) swap( u, v );
            res += sum[ u ];
            u = par[ head[ u ] ];
        }
    }
    return res + w[ u ];
}

int main(){
    scanf("%d", &n);
    BS = sqrt( n ) + 1;
    for(int i = 0; i < n - 1; ++i){
        int u, v; scanf("%d%d", &u, &v);
        es[ u ].push_back( v );
        es[ v ].push_back( u );
    }
    build( 1, -1, 0 );
    for(int i = 1; i <= n; ++i){
        int w; scanf("%d", &w);
        update( i, w );
    }
    int q; scanf("%d", &q);
    while( q-- ){
        char op[ 10 ]; scanf("%s", op);
        int u, v; scanf("%d%d", &u, &v);
        if( !strcmp( op, "CHANGE" ) ){
            update( u, v );
        } else if( !strcmp( op, "QMAX" ) ){
            printf("%d\n", qmax( u, v ));
        } else if( !strcmp( op, "QSUM" ) ){
            printf("%d\n", qsum( u, v ));
        }
    }
    return 0;
}