前言

$kuangbin$的$ac$自动机还是比较基础的,大多数题多是状态机模型的$dp$,其中有两道需要用到矩阵优化来记录路径数,然后是关于在tire图上找最短路问题,还有些要状压等等,用于字符串匹配的题很少,废话不多说,先大致分个类吧。


状态机$dp$

前言:

状态机$dp$​主要为建立AC自动机,开一维AC自动机的状态,在AC自动机上计数或求最大值等等。通常用记忆化能写,但记忆化不能压空间,也不能优化(比如:矩阵优化),这是只能递推写,递推时要注意设初始状态和不合法的状态。

一般在状态机求答案会有这样限制

不包含模板串的串,即在AC自动机上标记遇到这个状态不转移。

例如:G - Censored!J - DNA repairE - DNA Sequence

包含某些模式串,这时我们需要都开一维记录其选取情况,有时需要状压,或者我们可以正难则反,用总数-不包含模式串的数量。

例如:H - Wireless PasswordN - Walk Through SquaresF - 考研路茫茫——单词情结


G - Censored!(计数+大数)

题意:

n,m,p。n个字符,p个串,走m步,求不包含串的方案。

思路:

跑个ac自动机,将禁用串标记,然后$dp$​时跳过即可。

#include <iostream>
#include <cstring>
#include <queue>
#include <vector>
#include <map>
using namespace std;
typedef long long LL;
const int Base=1000;
class BigNum
{
    public:
        int num[100], len;
        BigNum():len(0) {}
        BigNum(int n):len(0)
        {
            for( ; n>0; n/=Base)
                num[len++]=n%Base;
        }
        BigNum Bigvalueof(LL n)
        {
            len=0;
            while(n)
            {
                num[len++]=n%Base;
                n/=Base;
            }
            return *this;
        }
        BigNum operator+(const BigNum& b)
        {
            BigNum c;
            int i, carry=0;
            for(i=0; i<this->len || i<b.len || carry>0; ++i)
            {
                if(i<this->len)
                    carry+=this->num[i];
                if(i<b.len)
                    carry+=b.num[i];
                c.num[i]=carry%Base;
                carry/=Base;
            }
            c.len=i;
            return c;
        }
        BigNum operator +=(const BigNum& b)
        {
            *this=*this+b;
            return *this;
        }
        void Print()
        {
            if(len==0)
            {
                puts("0");
                return ;
            }
            printf("%d", num[len-1]);
            for(int i=len-2; i>=0; --i)
            {
                for(int j=Base/10; j>0; j/=10)
                {
                    printf("%d", num[i]/j%10);
                }
            }
            puts("");
        }
};
int n,m,k,tot;
int fail[110],col[110],tr[110][60];
char buf[60];
char ta[60];
map <char,int> ma;
BigNum f[60][110];
void insert(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=ma[s[i]];
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];

    }
    col[p]=1;
}

void build(){
    queue <int> q;
    for(int i=1;i<=n;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=1;i<=n;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        col[p]|=col[fail[p]];
    }
}
int main(){
    scanf("%d %d %d",&n,&m,&k);
    scanf("%s",ta+1);
    for(int i=1;i<=n;i++) ma[ta[i]]=i;
    for(int i=1;i<=k;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf));
    }
    build();
    for(int i=0;i<=m;i++){
        for(int j=0;j<=tot;j++) f[i][j].Bigvalueof(0);
    }
    f[0][0].Bigvalueof(1);
    //f[0][0].push_back(1);
    for(int i=0;i<m;i++){
        for(int j=0;j<=tot;j++){
            for(int k=1;k<=n;k++){
                int z=tr[j][k];
                if(col[z]) continue;
                f[i+1][z]=f[i+1][z]+f[i][j];
            }
        }
    }
    BigNum res(0);
    for(int i=0;i<=tot;i++){
        if(col[i]) continue;
        res=res+f[m][i];
    }
    res.Print();

}

H - Wireless Password(计数+状压)

题意:

给你m个串,问有多少个长度为n个串包含恰好m个串里面k个串,可以重叠

思路:

由于问你要包含m个串里面k个,那么我们要将其m串的选取情况状压再开一维到$dp$里面,然后就常规的跑ac自动机,将ac自动机每个位置选取情况标记,然后dp计数即可。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
const int mod=20090717;
int n,m,k,tot;
int col[110],tr[110][26],fail[110];
int f[30][105][1100];
char buf[15];
void init(){

    for(int i=0;i<=tot;i++){
        col[i]=fail[i]=0;
        memset(tr[i],0,sizeof tr[i]);
    }
    tot=0;
    memset(f,-1,sizeof f);
}
void insert(char s[],int sz,int x){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'a';
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    col[p]=col[p]|(1<<x);
}
void build(){
    queue <int> q;
    for(int i=0;i<26;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<26;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        col[p]|=col[fail[p]];
    }
}
int check(int x){
    int sum=0;
    while(x){
        sum=sum+(x&1);
        x>>=1;
    }
    return sum>=k;
}
int dfs(int p,int s,int st){

    if(p>n){
        if(check(st)) return 1;
        return 0;
    }
    if(f[p][s][st]!=-1) return f[p][s][st];
    int ans=0;
    for(int i=0;i<26;i++){
        ans=(ans+dfs(p+1,tr[s][i],st|col[tr[s][i]]))%mod;
    }
    return f[p][s][st]=ans;
}
void solve(){
    init();
    for(int i=1;i<=m;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf),i-1);
    }
    build();
    printf("%d\n",dfs(1,0,0));
}
int main(){
    while(scanf("%d %d %d",&n,&m,&k)&&(n||m||k)) solve();

}

I - Ring

题意:

给出m个串以及每个串的价值,现在让你找到一个长度不超过n的串,使得其价值最大。若有多个串满足,输出所有串中最短的,若依旧有多个串满足,输出字典序最小的串。

思路:

跑ac自动机,在自动机上标记每个点的价值,最后dp一边找个max,由于问的是长度不超过n的,需要从$[1,n]$都跑一边,或者再开一维,每次都返回是可以的。

代码:
#include <iostream>
#include <cstring>
#include <queue>
using namespace std;
const int N=55,M=1200;
int n,m,tot,T,w[110];
int f[N][M];
int tr[M][26],fail[M],col[M];
char buf[110][20];
void init(){
    for(int i=0;i<=tot;i++){
        memset(tr[i],0,sizeof tr[i]);
        fail[i]=col[i]=0;
    }
    memset(f,-1,sizeof f);
    tot=0;
}
void insert(char s[],int sz,int v){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'a';
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    col[p]=v;
}
void build(){
    queue <int> q;
    for(int i=0;i<26;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<26;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];

                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        col[p]=col[p]+col[fail[p]];
    }
}
int dfs(int p,int s){
    if(p>n){
        return 0;
    }
    if(f[p][s]!=-1) return f[p][s];
    int ans=0;
    for(int i=0;i<26;i++){
        int v=tr[s][i];
        ans=max(ans,dfs(p+1,v)+col[v]);
    }
    return f[p][s]=ans;
}
void print(int p,int s){
    if(p>n) return;
    for(int i=0;i<26;i++){
        int v=tr[s][i];
        if(f[p][s]==dfs(p+1,v)+col[v]){
            printf("%c",'a'+i);
            print(p+1,v);
            break;
        }
    }
}
void solve(){
    scanf("%d %d",&n,&m);
    init();
    for(int i=1;i<=m;i++){
        scanf("%s",buf[i]);
    }
    for(int i=1;i<=m;i++){
        scanf("%d",&w[i]);
        insert(buf[i],strlen(buf[i]),w[i]);
    }
    build();
    int ans=dfs(1,0);
    //cout<<"ans="<<ans<<"\n";
    if(ans==0){
        puts("");
        return;
    }
    int idx=0;
    for(int i=n;i>=1;i--){
        if(dfs(i,0)==ans){
            idx=i;
            break;
        }
    }

    print(idx,0);
    puts("");

}
int main(){
    scanf("%d",&T);
    while(T--) solve();
}

J - DNA repair

题意:

有n个模式串,一个目标串,问如果要修改目标串使目标串内没有模式串,问最少修改多少个字符。

思路:

跑ac自动机,标记,然后跑dp,被标记过跳过

代码:
#include <iostream>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=55;
const int inf=0x3f3f3f3f;
int n,m,o,tot,ma[500];
int tr[1100][4],col[1100],fail[1100];
int f[5][1100];
char buf[N],a[1010];
void init(){
    ma['A']=0,ma['G']=1,ma['C']=2,ma['T']=3;
    for(int i=0;i<=tot;i++){
        memset(tr[i],0,sizeof tr[i]);
        col[i]=fail[i]=0;
    }
    memset(f,inf,sizeof f);
    tot=0;
}
void insert(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=ma[s[i]];
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    col[p]=1;
}
void build(){
    queue <int> q;
    for(int i=0;i<4;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<4;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
            col[p]|=col[fail[p]];
        }
    }

}
void solve(){
    init();
    for(int i=1;i<=n;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf));
    }
    scanf("%s",a+1);
    m=strlen(a+1);
    build();
    f[0&1][0]=0;
    for(int i=0;i<m;i++){
        for(int j=0;j<=tot;j++) f[i+1&1][j]=inf;
        for(int j=0;j<=tot;j++){
            for(int k=0;k<4;k++){
                int z=tr[j][k];
                if(col[z]) continue;
                f[i+1&1][z]=min(f[i+1&1][z],f[i&1][j]+(k!=ma[a[i+1]]));
            }
        }
    }
    int ans=inf;
    for(int i=0;i<=tot;i++){
        ans=min(ans,f[m&1][i]);
    }

    if(ans==inf) cout<<"Case "<<++o<<": "<<"-1\n";
    else cout<<"Case "<<++o<<": "<<ans<<"\n";

}
int main(){
    while(scanf("%d",&n)&&n) solve();

}

L - Lost's revenge

题意:

给出一个n个模式串,一个目标串,问把目标串重新排位最多能产生多少个模式串,可以重叠且所有串只包含A C G T。

思路:

首先跑ac自动机,然后在ac自动机上标记,由于题目所问目标串排列,统计目标串ACGT的数量,之后。

设:$dp[s][a][b][c][d]$​​​为在$ac$自动机s状态,$abcd$分别为$ACGT$​的选取情况,之后记忆化搜索。

$ps:$这里由于题目所给$abcd$的数量和为$40$,所以直接开数组会$MLE$,需要对$abcd$​状态类是哈希处理,具体处理看代码。

代码:
#include <cstdio>
#include <algorithm>
#include <queue>
#include <stack>
#include <string>
#include <string.h>
#include <map>
#include <iostream>
using namespace std;
const int maxn = 550;
const int mod = 20090717;
int INF = 1e9;
int nex[maxn][4], Exits[maxn], fail[maxn], last[maxn], cnt;
int num0, num1, num2, num3;
char t[maxn];
int n, m;
int getId(char ch){
    if(ch == 'A') return 0;
    else if(ch == 'T') return 1;
    else if(ch == 'C') return 2;
    else return 3;
}
void insert(char *s, int len){
    int p = 0;
    for(int i = 0; i < len; i++){
        int x = getId(s[i]);
        if(nex[p][x] == 0){
            memset(nex[cnt], 0, sizeof(nex[cnt]));
            Exits[cnt] = 0;
            last[cnt] = 0;
            fail[cnt] = 0;
            nex[p][x] = cnt++;
        }
        p = nex[p][x];
    }
    Exits[p]++;
}

queue<int> que;
void Build(){
    for(int i = 0; i < 4; i++){
        if(nex[0][i]) que.push(nex[0][i]);
    }
    while(que.size()){
        int p = que.front();
        que.pop();
        Exits[p] += Exits[fail[p]];
        for(int i = 0; i < 4; i++){
            int u = nex[p][i];
            if(u){
                fail[u] = nex[fail[p]][i];
                last[u] = Exits[fail[u]] ? fail[u] : last[fail[u]];
                que.push(u);
            } else {
                nex[p][i] = nex[fail[p]][i];
            }
        }
    }
}

int getState(int c0, int c1, int c2, int c3){
    return c0 * (num1 + 1) * (num2 + 1) * (num3 + 1) + c1 * (num2 + 1) * (num3 + 1) + c2 * (num3 + 1) + c3;
}
int dp[505][15005];
int dfs(int st, int c0, int c1, int c2, int c3){
    int state = getState(c0, c1, c2, c3);
    if(dp[st][state] != -1) return dp[st][state];
    int ans = 0;
    if(c0){
        int u = nex[st][0];
        ans = max(ans, dfs(u, c0 - 1, c1, c2, c3) + Exits[u]);
    }
    if(c1){
        int u = nex[st][1];
        ans = max(ans, dfs(u, c0, c1 - 1, c2, c3) + Exits[u]);
    }
    if(c2){
        int u = nex[st][2];
        ans = max(ans, dfs(u, c0, c1, c2 - 1, c3) + Exits[u]);
    }
    if(c3){
        int u = nex[st][3];
        ans = max(ans, dfs(u, c0, c1, c2, c3 - 1) + Exits[u]);
    }
    dp[st][state] = ans;
    return ans;
}
int main(int argc, char const *argv[])
{
    int ca = 0;
    while(1){
        scanf("%d", &n);
        if(n == 0) break;
        cnt = 1;
        memset(nex[0], 0, sizeof(nex[0]));
        num0 = num1 = num2 = num3 = 0;
        for(int i = 1; i <= n; i++){
            char s[25];
            scanf("%s", s);
            insert(s, strlen(s));
        }
        scanf("%s", t);
        int len = strlen(t);
        for(int i = 0; i < len; i++){
            if(getId(t[i]) == 0) num0++;
            else if(getId(t[i]) == 1) num1++;
            else if(getId(t[i]) == 2) num2++;
            else num3++;
        }
        Build();
        memset(dp, -1, sizeof(dp));
        int ans = dfs(0, num0, num1, num2, num3);
        printf("Case %d: %d\n", ++ca, ans);
    }
    return 0;
}

N - Walk Through Squares

题意:

在一个矩阵内从左上角走到右下角,向右走得到一个R向下走得到一个D,问最后有几种走到右下角时得到的字符串包含题中给出的两个字符串

思路:

跑ac自动机,标记,之后开一维状压其选取情况。

代码:
#include <iostream>
#include <queue>
#include <cstring>
using namespace std;
typedef long long ll;
const int mod=1000000007;
int T,n,m,tot;
int tr[210][2],fail[210],col[210];
int f[105][105][205][5];
char buf[110];
void init(){
    for(int i=0;i<=tot;i++){
        memset(tr[i],0,sizeof tr[i]);
        fail[i]=col[i]=0;
    }
    tot=0;
}
int get(char ch){
    if(ch=='R') return 0;
    else return 1;
}
void insert(char s[],int sz,int id){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=get(s[i]);
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    col[p]=(1<<id);
}
void build(){
    queue <int> q;
    for(int i=0;i<2;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<2;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        col[p]|=col[fail[p]];
    }
}
int dfs(int a,int b,int s,int lim){
    if(a==0&&b==0){
        if(lim==(1<<2)-1) return 1;
        return 0;
    }
    if(f[a][b][s][lim]!=-1) return f[a][b][s][lim];
    ll ans=0;
    for(int i=0;i<2;i++){
        int j=tr[s][i];
        if(i==0&&a){
            ans=(ans+dfs(a-1,b,j,lim|col[j]))%mod;
        }else if(i==1&&b){
            ans=(ans+dfs(a,b-1,j,lim|col[j]))%mod;
        }
    }
    return f[a][b][s][lim]=ans;
}
void solve(){
    scanf("%d %d",&n,&m);
    init();
    for(int i=0;i<2;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf),i);
    }
    build();
    for(int i=0;i<=n+1;i++){
        for(int j=0;j<=m+1;j++){
            for(int k=0;k<=tot;k++){
                f[i][j][k][0]=f[i][j][k][1]=f[i][j][k][2]=-1;
                f[i][j][k][3]=-1;
            }
        }
    }
    printf("%d\n",dfs(n,m,0,0));
}
int main(){
    scanf("%d",&T);
    while(T--) solve();
}

E - DNA Sequence

题意:

有m种DNA序列是致病的,问长为n且不包含致病序列的DNA有多少种

思路:

建ac自动机,将致病标记。

设状态$f[i][j]$​​为长度为$i$​的字符串在状态机$j$​号点不包含致病序列的DNA种类。

转移:$f[i][k]+=f[i][j]$,且$k=tr[j][c]$,其中包含致病串不转移。​

我们看到n是非常大的,这题问我们包含致病序列,我们可以用矩阵来递推优化,类似计数路径数。

我们需要根据AC自动机构造出矩阵(类是构造邻接矩阵),最后矩阵快速幂计数答案。

#include <iostream>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll mod=100000;
const int N=101;
int m,n,tot,ma[500],col[N],tr[N][4],fail[N];
ll A[N][N];
char buf[N];
void qu(ll &x){
    x=x-x/mod*mod;
}
void mul(ll c[],ll a[],ll b[][N]){
    ll t[N]={0};
    for(int i=0;i<=tot;i++){
        for(int j=0;j<=tot;j++){
            t[i]=t[i]+a[j]*b[j][i];
        }
        //qu(t[i]);
        t[i]=t[i]%mod;
    }
    memcpy(c,t,sizeof t);
}
void mul(ll c[][N],ll a[][N],ll b[][N]){
    ll t[N][N]={0};
    for(int i=0;i<=tot;i++){
        for(int j=0;j<=tot;j++){
            for(int k=0;k<=tot;k++){
                t[i][j]=t[i][j]+a[i][k]*b[k][j];
            }
            //qu(t[i][j]);
            t[i][j]=t[i][j]%mod;
        }
    }
    memcpy(c,t,sizeof t);
}

ll ksm(ll p){
    ll F0[N]={1};
    while(p){
        if(p&1) mul(F0,F0,A);
        mul(A,A,A);
        p>>=1;
    }
    ll ans=0;
    for(int i=0;i<=tot;i++){
        //if(!col[i])
        ans=ans+F0[i];
    }
    //qu(ans);
    ans=ans%mod;
    return ans;
}

void insert(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=ma[s[i]];
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    col[p]++;
}

void build(){
    queue <int> q;
    for(int i=0;i<4;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<4;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        col[p]+=col[fail[p]];
    }
}

int main(){
    scanf("%d %d",&m,&n);
    ma['A']=0,ma['C']=1,ma['T']=2,ma['G']=3;
    for(int i=1;i<=m;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf));
    }
    build();
    for(int i=0;i<=tot;i++){
        if(col[i]) continue;
        for(int j=0;j<4;j++){
            if(col[tr[i][j]]) continue;
            A[i][tr[i][j]]++;
        }
    }

    printf("%lld\n",ksm(n));

}

F - 考研路茫茫——单词情结

题意:

给定n个模板串,问有多少个长度不超过L的含至少一个模板串的字符串。

思路:

这题和上题一样n很大,但这题问至少含有一个模板串数量,这题如果从正面求数量我们转移会有点复杂,我们可以考虑从反面求,求构造总共数量-不包含模式串

如果求长度为L不包含模板串的数量和上题一样,

但是这里求的是长度为[1,L]不包含模板串的数量的和,我们可以设$S_i$​为和$[1,i]$的和,即$S_i=f[0]+f[1]+f[2]..+f[i]+s[i-1]$​,我们也能用矩阵快速幂求,即在最后列开一维全部为1

代码:
#include <iostream>
#include <cstring>
#include <queue>
using namespace std;
typedef unsigned long long ll;
const int N=50;
int n,m,tot;
int col[N],fail[N],tr[N][26];
char buf[10];
struct node{
    ll e[N][N];
    int n;
    node(int _){
        n=_;
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                e[i][j]=0;
            }
        }
    }
    void build(){
        for(int i=0;i<n;i++){
            e[i][i]=1;
        }
    }
};
node operator * (const node& x,const node& y){
    node res=node(x.n);
    for(int i=0;i<res.n;i++){
        for(int j=0;j<res.n;j++){
            res.e[i][j]=0;
            for(int k=0;k<res.n;k++){
                res.e[i][j]=res.e[i][j]+x.e[i][k]*y.e[k][j];
            }
        }
    }
    return res;
}
void init(){
    memset(col,0,sizeof col);
    memset(fail,0,sizeof fail);
    memset(tr,0,sizeof tr);
    tot=0;
}
void insert(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'a';
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    col[p]++;
}
void build(){
    queue <int> q;
    for(int i=0;i<26;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<26;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        col[p]+=col[fail[p]];
    }

}
ll ksm(node A,int p){
    node F=node(A.n);
    F.build();
    while(p){
        if(p&1) F=F*A;
        A=A*A;
        p>>=1;
    }

    ll res=0;
    for(int i=0;i<F.n;i++){
        res=res+F.e[0][i];
    }
    return res;
}
ll ks(node AB,int p){
    node F=node(2);
    F.e[0][0]=26,F.e[0][1]=1;
    while(p){
        if(p&1) F=F*AB;
        AB=AB*AB;
        p>>=1;
    }
    return F.e[0][0];
}

void solve(){
    init();
    for(int i=1;i<=n;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf));
    }
    build();
    node A=node(tot+2);
    for(int i=0;i<A.n-1;i++){
        if(col[i]) continue;
        for(int c=0;c<26;c++){
            int  k=tr[i][c];
            if(col[k]) continue;
            A.e[i][k]++;
        }
    }
    for(int i=0;i<A.n;i++){
        A.e[i][A.n-1]=1;
    }
    node AB=node(2);
    AB.e[0][0]=AB.e[1][0]=26;
    AB.e[1][1]=1;
    //cout<<ks(AB,m-1)<<" "<<(ksm(A,m)-1)<<" "<<m<<"\n";
    printf("%llu\n",ks(AB,m-1)-(ksm(A,m)-1));
}
int main(){
    while(scanf("%d %d",&n,&m)!=EOF) solve();
}

最短路

O - 小明系列故事——女友的考验(DAG+最短路)

题意:

让你求从1走到n的最短路,但是有些路径是不能走的,且走到每次走只能走比当前点大的点

思路:

首先将不能跑的点建$ac$​​自动机标记,由于是$DAG$​​,我们可以记忆化求最短路,之后跑一边跑图,一边跑$ac$自动机,遇到$ac$自动机上标记的点跳过。

代码:
#include <iostream>
#include <cstring>
#include <cmath>
#include <queue>
#include <algorithm>
#define bug cout<<"....\n"
#define fi first
#define se second
using namespace std;
typedef pair <int,int> pll;
typedef pair <double,pll> pdd;
const int N=55,M=510;
int n,m,tot;
int buf[N],tr[M][N],fail[M],col[M];
int vis[N],state[N];
double x[N],y[N],dis[N],dp[N][M];
void init(){
    for(int i=0;i<=tot;i++){
        memset(tr[i],0,sizeof tr[i]);
        fail[i]=col[i]=0;
    }
    for(int i=0;i<N;i++){
        for(int j=0;j<M;j++){
            dp[i][j]=-1;
        }
    }
    tot=0;
}
double gao(int i,int j){
    return sqrt((x[i]-x[j])*(x[i]-x[j])+(y[i]-y[j])*(y[i]-y[j]));
}
void insert(int s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i];
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    col[p]=1;
}
int build(){
    queue <int> q;
    for(int i=1;i<=50;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=1;i<=50;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        col[p]|=col[fail[p]];
    }
}
double dfs(int p,int s){
    if(p==n){
        return 0;
    }
    if(dp[p][s]!=-1) return dp[p][s];
    double ans=1e18;
    for(int i=p+1;i<=n;i++){
        double w=gao(p,i);
        if(col[tr[s][i]]) continue;
        ans=min(ans,dfs(i,tr[s][i])+w);
    }
    return dp[p][s]=ans;
}
void solve(){
    init();
    for(int i=1;i<=n;i++){
        scanf("%lf %lf",&x[i],&y[i]);
    }
    for(int i=1;i<=m;i++){
        int sz;
        scanf("%d",&sz);
        for(int j=0;j<sz;j++){
            scanf("%d",&buf[j]);
        }
        insert(buf,sz);
    }
    build();
    if(col[tr[0][1]]){
        puts("Can not be reached!");
        return;
    }
    double ans=dfs(1,tr[0][1]);
    if(ans==1e18){
        puts("Can not be reached!");
    }else{
        printf("%.2f\n",ans);
    }
}
int main(){
    while(scanf("%d %d",&n,&m)&&(n||m)) solve();
}

M - Resource Archiver

题意:

给你$n$个资源串和$m$个病毒串然后让你求最少需要多少个$01$字符才能做出不含病毒串并且含有所有的资源串的字符串。

思路:

注意到资源串的数量$(n\le10)$​比较少,我们可以先建AC自动机,在建立的AC自动机上对每个资源串跑遍bfs,病毒串跳过,处理出资源串两两之间的距离。

最后状压$dp$,算出跑出经过$n$个串的最短路(类是哈密顿图)。

代码:
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
using namespace std;
const int N=15,M=60005;
int n,m,tot;
int tr[M][2],ban[M],num[N],fail[M];
int dis[M],vis[M],c[N][N];
int f[15][1<<12];
char buf[50005];
void init(){
    for(int i=0;i<=tot;i++){
        memset(tr[i],0,sizeof tr[i]);
        ban[i]=num[i]=fail[i]=0;
    }
    tot=0;
    memset(c,0,sizeof c);
    memset(f,-1,sizeof f);
}
void insert(char s[],int sz,int id){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'0';
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    num[id]=p;
}
void add(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'0';
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
    }
    ban[p]=1;
}
void build(){
    queue <int> q;
    for(int i=0;i<2;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<2;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
        ban[p]|=ban[fail[p]];
    }
}
void bfs(int k){
    queue <int> q;
    for(int i=0;i<=tot;i++) vis[i]=dis[i]=0;
    q.push(num[k]),vis[num[k]]=1;
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<2;i++){
            int v=tr[p][i];
            if(!vis[v]&&!ban[v]){
                vis[v]=1;
                dis[v]=dis[p]+1;
                q.push(tr[p][i]);
            }
        }
    }
    for(int i=0;i<=n;i++) c[k][i]=dis[num[i]];
}

int dfs(int p,int s){
    if(s==((1<<(n+1))-1)) return 0;
    if(f[p][s]!=-1) return f[p][s];
    int ans=1e9;
    for(int i=1;i<=n;i++){
        if(s&(1<<i)) continue;
        ans=min(ans,dfs(i,s|(1<<i))+c[p][i]);
    }
    return f[p][s]=ans;
}
void solve(){
    init();
    for(int i=1;i<=n;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf),i);
    }
    for(int i=1;i<=m;i++){
        scanf("%s",buf);
        add(buf,strlen(buf));
    }
    build();
    for(int i=0;i<=n;i++){
        bfs(i);
    }
    printf("%d\n",dfs(0,1));
}
int main(){
    while(scanf("%d %d",&n,&m)&&(n||m)) solve();
}

多模式串匹配:

题意:

给你一个长度为n的单词表,一个文本串,问你这个文本串中出现了单词表中多少个单词;

思路:

建AC自动机,将文本串放入匹配。记得将匹配过的点标记,只算一次答案。

代码:
#include <iostream>
#include <queue>
#include <cstring>
using namespace std;
const int N=10005,M=1000007;
int n,tot,T,col[M],fail[M],tire[M][26];
char a[N][100],b[N];
void init(){
    for(int i=0;i<=tot;i++){
        fail[i]=col[i]=0;
        memset(tire[i],0,sizeof tire[i]);
    }
    tot=0;
}
void insert(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'a';
        if(!tire[p][ch]) tire[p][ch]=++tot;
        p=tire[p][ch];
    }
    //?
    col[p]++;
}
void build(){
    queue <int> q;
    for(int i=0;i<26;i++){
        if(tire[0][i]){
            fail[tire[0][i]]=0;
            q.push(tire[0][i]);
        }
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<26;i++){
            if(tire[p][i]){
                fail[tire[p][i]]=tire[fail[p]][i];
                q.push(tire[p][i]);
            }else{
                tire[p][i]=tire[fail[p]][i];
            }
        }
    }
}
int ask(char s[],int sz){
    int p=0,ans=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'a';
        p=tire[p][ch];
        for(int j=p;j&&col[j]!=-1;j=fail[j]){
            ans+=col[j];
            col[j]=-1;
        }
    }
    return ans;
}


int main(){
    scanf("%d",&T);
    while(T--){

        scanf("%d",&n);
        init();
        for(int i=1;i<=n;i++){
            scanf("%s",a[i]);
            insert(a[i],strlen(a[i]));
        }
        scanf("%s",b);
        build();
        printf("%d\n",ask(b,strlen(b)));
    }

}

B - 病毒侵袭

题意:

给出N个病毒的字符串,再给出M个网站的字符串,求字符串中含有病毒的数量和一共含有病毒的网址,

思路:

将病毒串建AC自动机,然后将网站跑AC自动机进行匹配,每次记得将匹配过的点标记,只算一次答案。

代码:
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <vector>
#define fi first
#define se second
#define bug cout<<"哈?\n"
using namespace std;
typedef pair <int,int> pll;
int n,m,tot;
int fail[100007],col[100007],tire[100007][128];
char buf[10010];
vector <int> ans;
vector <pll> d;
void insert(char s[],int sz,int id){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i];
        if(!tire[p][ch]) tire[p][ch]=++tot;
        p=tire[p][ch];
    }
    col[p]=id;
}
void build(){
    queue <int> q;
    for(int i=0;i<128;i++){
        if(tire[0][i]){
            fail[tire[0][i]]=0;
            q.push(tire[0][i]);
        }
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<128;i++){
            if(tire[p][i]){
                fail[tire[p][i]]=tire[fail[p]][i];
                q.push(tire[p][i]);
            }else{
                tire[p][i]=tire[fail[p]][i];
            }
        }
    }
}
void query(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i];
        p=tire[p][ch];
        for(int j=p;j&&col[j]!=-1;j=fail[j]){
            if(col[j]) ans.push_back(col[j]);
            d.push_back(pll(j,col[j]));
            col[j]=-1;
        }
    }
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%s",buf);
        insert(buf,strlen(buf),i);
    }
    build();
    scanf("%d",&m);
    int res=0;
    for(int i=1;i<=m;i++){
        scanf("%s",buf);
        query(buf,strlen(buf));
        if(ans.size()){
            printf("web %d:",i);
            res++;
            sort(ans.begin(),ans.end());
            for(auto p:ans){
                printf(" %d",p);
            }
            puts("");
            for(auto it:d){
                col[it.fi]=it.se;
            }
            d.clear();
            ans.clear();
        }
    }
    printf("total: %d\n",res);
}

C - 病毒侵袭持续中

题意:

给你n个字符串,再给你个长串,问你n个字符串每个字符串出现的次数

思路:

拿n个字符串建AC自动机,然后匹配长串,匹配时记录每个节点出现次数。

代码:
#include <iostream>
#include <cstring>
#include <queue>
using namespace std;
typedef long long ll;
const int M=50007,N=2000007;
int n,tot;
int col[M],tire[M][128],fail[M],Ans[M],ma[M];
char a[N],buf[1007][100];
void init(){
    for(int i=0;i<=tot;i++){
        fail[i]=col[i]=Ans[i]=0;
        memset(tire[i],0,sizeof tire[i]);
    }
    for(int i=1;i<=n;i++) ma[i]=0;
   tot=0;
}
void insert(char s[],int sz,int id){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i];
        if(!tire[p][ch]) tire[p][ch]=++tot;
        p=tire[p][ch];
    }
    ma[id]=p;
    col[p]++;
}
void build(){
    queue <int> q;
    for(int i=0;i<128;i++){
        if(tire[0][i]){
            fail[tire[0][i]]=0;
            q.push(tire[0][i]);
        }
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<128;i++){
            if(tire[p][i]){
                fail[tire[p][i]]=tire[fail[p]][i];
                q.push(tire[p][i]);
            }else{
                tire[p][i]=tire[fail[p]][i];
            }
        }

    }
}
void query(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i];
        p=tire[p][ch];
        for(int j=p;j;j=fail[j]) Ans[j]+=col[j];
    }
}
int main(){
    while(scanf("%d",&n)!=EOF){
        init();
        for(int i=1;i<=n;i++){
            scanf("%s",buf[i]);
            insert(buf[i],strlen(buf[i]),i);
        }
        build();
        scanf("%s",a);
        query(a,strlen(a));
        for(int i=1;i<=n;i++){
            if(Ans[ma[i]]) printf("%s: %d\n",buf[i],Ans[ma[i]]);
        }
    }
}

ZOJ - 3228

题意:

给一个字符串s和n次询问,每次询问一个字符串在s中出现的次数,ord为0表示允许重叠,ord为1表示不允许重叠。

思路:

第一个询问和上题一样,第二个问不允许重叠,我们可以记录个上次匹配在s串的位置ed,和字符串的长度len,在每次匹配的时候根据ed和len判断与上次匹配是否重叠。

代码:
#include <iostream>
#include <cstring>
#include <queue>
using namespace std;
const int N=1e5+10;
int n,m,o,tot;
int len[6*N],ed[6*N],tr[6*N][26],fail[6*N];
int ans[2][6*N],tp[N],num[N];
char a[N],buf[10];
void init(){
    for(int i=0;i<=tot;i++){
        memset(tr[i],0,sizeof tr[i]);
        fail[i]=ed[i]=len[i]=0;
        ans[0][i]=ans[1][i]=0;
    }
    tot=0;
}
void insert(char s[],int sz,int id){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'a';
        if(!tr[p][ch]) tr[p][ch]=++tot;
        p=tr[p][ch];
        ed[p]=-1;
    }
    len[p]=sz;num[id]=p;
}
void build(){
    queue <int> q;
    for(int i=0;i<26;i++){
        if(tr[0][i]) q.push(tr[0][i]);
    }
    while(!q.empty()){
        int p=q.front();q.pop();
        for(int i=0;i<26;i++){
            if(tr[p][i]){
                fail[tr[p][i]]=tr[fail[p]][i];
                q.push(tr[p][i]);
            }else{
                tr[p][i]=tr[fail[p]][i];
            }
        }
    }
}
int query(char s[],int sz){
    int p=0;
    for(int i=0;i<sz;i++){
        int ch=s[i]-'a';
        p=tr[p][ch];
        for(int j=p;j;j=fail[j]){
            if(len[j]) ans[0][j]++;
            if(len[j]&&(i-ed[j]>=len[j])){
                ans[1][j]++;ed[j]=i;
            }
        }
    }
}
int main(){
    while(scanf("%s",a)!=EOF){
        scanf("%d",&m);
        init();
        for(int i=1;i<=m;i++){
            scanf("%d %s",&tp[i],buf);
            insert(buf,strlen(buf),i);
        }
        build();
        query(a,strlen(a));
        printf("Case %d\n",++o);
        for(int i=1;i<=m;i++){
            if(tp[i]==0) printf("%d\n",ans[0][num[i]]);
            else printf("%d\n",ans[1][num[i]]);
        }
        puts("");
    }
}

最后修改:2023 年 01 月 10 日
如果觉得我的文章对你有用,请随意赞赏