0w1

Aho-Corasick, Multiple Pattern Matching

Counts the frequency of each pattern in text. O( SIGMA( | P | ) )

string T;
vs pattern;
map< char, int > mci;

struct node{
    int val;
    int nxt[ 6 ];
};

vector< node > tree;
int cnt = 1, num = 0, n;
vi f, lst, occ_cnt;
queue< int > que;
map< string, int > m;

int insert( const string &s ){
    int t = 0;
    for( int i = 0; i < s.size(); ++i ){
        int x = mci[ s[ i ] ];
        t = not tree[ t ].nxt[ x ] ? tree[ t ].nxt[ x ] = cnt++ : tree[ t ].nxt[ x ];
    }
    tree[ t ].val = ++num;
    return t;
}

void get_fail(){
    f[ 0 ] = 0;
    for( int i = 0; i < 5; ++i ){
        int u = tree[ 0 ].nxt[ i ];
        if( u )
            f[ u ] = 0,
            que.push( u ),
            lst[ u ] = 0;
    }
    while( not que.empty() ){
        int t = que.front(); que.pop();
        for( int i = 0; i < 5; ++i ){
            int &u = tree[ t ].nxt[ i ];
            if( not u ){
                u = tree[ f[ t ] ].nxt[ i ];
                continue;
            }
            que.push( u );
            int v = f[ t ];
            while( v and not tree[ v ].nxt[ i ] )
                v = f[ v ];
            f[ u ] = tree[ v ].nxt[ i ];
            lst[ u ] = tree[ f[ u ] ].val ? f[ u ] : lst[ f[ u ] ];
        }
    }
}

void add( int j ){
    while( j )
        ++occ_cnt[ tree[ j ].val ],
        j = lst[ j ];
}

void match( const string &s ){
    int j = 0;
    for( int i = 0; i < s.size(); ++i ){
        int t = mci[ s[ i ] ];
        while( j and not tree[ j ].nxt[ t ] )
            j = f[ j ];
        j = tree[ j ].nxt[ t ];
        if( tree[ j ].val )
            add( j );
        else if( lst[ j ] )
            add( lst[ j ] );
    }
}

void solve(){
    mci[ 'A' ] = 1, mci[ 'C' ] = 2, mci[ 'T' ] = 3, mci[ 'G' ] = 4;
    cin >> T;
    int N; cin >> N;
    pattern = vs( N );
    tree = vector< node >( 10000001 );
    f = vi( 10000001 );
    lst = vi( 10000001 );
    occ_cnt = vi( 10000001 );
    for( int i = 0; i < N; ++i )
        cin >> pattern[ i ],
        insert( pattern[ i ] ),
        m[ pattern[ i ] ] = i + 1;
    get_fail();
    match( T );
    for( int i = 0; i < N; ++i )
        cout << occ_cnt[ m[ pattern[ i ] ] ] << " \n"[ i + 1 == N ];
}