0w1

CFR 337 D. Book of Evil ( Tree DP )

Problem - 337D - Codeforces

題意:
給一棵樹,和若干個指定節點,和 D。求有多少節點,滿足至最遠的指定節點的距離不大於 D。

數據範圍:
指定節點數 M,樹大小 N,1 ≤ M ≤ N ≤ 1e5
0 ≤ D < N

解法:
以 0 為有根樹進行 DP
dpd[ i ] : i 號節點與向下最遠的指定節點之距離
dpu[ i ] : i 號節點與向上最遠的指定節點之距離
dpd 的轉移不難。至於 dpu 需要考慮來自上方,和來自兄弟的。在節點 u 的立場,u 的孩子之間更新 dpu,兄弟的貢獻可以利用預處理記錄最大的兩個值 max1 和 max2 做更新。

時間 / 空間複雜度:
O( N )

需要注意:
dfs 需要分兩次,先 dpd 再 dpu。因為 dpu 在更新某個節點時,必須保證該節點的兄弟們的 dpd 也都正確。
想清楚什麼時候才要 dfs 下去。太早會出問題。例如第 52 行,如果再早一點點就會出事。能保證下面用來更新的東西都是對的才能執行 dfs。

int N, M, D;
vi P;
vvi G;

void init(){
    cin >> N >> M >> D;
    P = vi( M );
    for( int i = 0; i < M; ++i )
        cin >> P[ i ], --P[ i ];
    G = vvi( N );
    for( int i = 0; i < N - 1; ++i ){
        int u, v; cin >> u >> v; --u, --v;
        G[ u ].emplace_back( v );
        G[ v ].emplace_back( u );
    }
}

vi mark; // is affected
vi dpu; // distance to furthest affected, upwards
vi dpd; // downwards

void dfsd( int u, int fa ){
    if( mark[ u ] )
        upmax( dpd[ u ], 0 );
    for( int v : G[ u ] ){
        if( v == fa ) continue;
        dfsd( v, u );
        upmax( dpd[ u ], dpd[ v ] + 1 );
    }
}

void dfsu( int u, int fa ){
    if( mark[ u ] )
        upmax( dpu[ u ], 0 );
    int max1 = -INF, id1 = -1, max2 = -INF, id2 = -1;
    for( int v : G[ u ] ){
        if( v == fa ) continue;
        upmax( dpu[ v ], dpu[ u ] + 1 );
        if( max1 < dpd[ v ] )
            max2 = max1, id2 = id1,
            max1 = dpd[ v ], id1 = v;
        else if( max2 < dpd[ v ] )
            max2 = dpd[ v ], id2 = v;
    }
    for( int v : G[ u ] ){
        if( v == fa ) continue;
        if( id1 != v )
            upmax( dpu[ v ], max1 + 2 );
        else
            upmax( dpu[ v ], max2 + 2 );
    }
    for( int v : G[ u ] ){ // THIS IS THE LAST!!!
        if( v == fa ) continue;
        dfsu( v, u );
    }
}

void preprocess(){
    mark = vi( N );
    for( int i = 0; i < M; ++i )
        mark[ P[ i ] ] = 1;
    dpu = dpd = vi( N, -INF );
    dfsd( 0, -1 );
    dfsu( 0, -1 );
}

void solve(){
    int ans = 0;
    for( int i = 0; i < N; ++i )
        if( max( dpu[ i ], dpd[ i ] ) <= D )
            ++ans;
    cout << ans << endl;
}