DP优化——动态dp

适用场景

动态 dp 主要用来处理动态修改点权/边权,的树形dp题 或者 区间序列上的带修改的dp。
其核心都是把 dp 变成矩乘的形式,这样修改只需要更改某个矩阵,再用线段树等数据结构维护。


以板子题为例进行讲解。

【模板】”动态 DP”&动态树分治

这道题是简单版。

简单版的前置知识:树链剖分,广义矩阵乘法。

不带修改的话,最大权独立集很简单:
\(f[u][0] = \sum _{v\in son(u)}max(f[v][0],f[v][1])\),表示不选 u 的答案。
\(f[u][1] = w[u] + \sum _{v\in son(u)}f[v][0]\),表示选 u 的答案。

注意到我们更改一个点的点权其实只会修改他到根的链上的那些 \(dp\) 值,所以全部重算一点都不划算。
考虑只去更改这条链上的dp值。
但这样当树是链时还是有可能 TLE (虽然题解区似乎有人 \(n\) 方过百万),这个时候就会想到树链剖分。
因为根到一个点最多会有 log 个不同的重链,所以可以考虑重链之间暴力修改,重链上用线段树快速维护。

这样的话我们就需要更改一下 \(f\) 的转移,使得能与树剖的性质匹配(\(f\) 的定义不变)。
\(g[u][0/1]\) 表述 \(u\) 选/不选,只考虑 \(u\) 的那些轻儿子,的答案。
那么 (\(son[u]\) 表示 \(u\) 的重儿子):
\(f[u][0] = g[u][0] + max(f[son[u][0],f[son[u][1])\)
\(f[u][1] = g[u][1] + f[son[u]][0]\)
特别的,叶子结点的 \(g[u][0]=0,g[u][1]=w[u]\) (和他的 f 相同)。

因为没了讨厌的 \(∑\),这个转移我们尝试改写成矩阵(为了后面用线段树维护)。
\(f[u][0] = max(g[u][0]+f[son[u][0] , g[u][0]+f[son[u]][1])\)
\(f[u][1] = g[u][1] + f[son[u]][0]\)
根据这个可以得到他的矩阵形式,其中 max 是矩乘的+,+是矩乘的*

\[\begin{pmatrix} f[son[u]][0] & f[son[u]][1] \\ \end{pmatrix} \times \begin{pmatrix} g[u][0] & g[u][1] \\ g[u][0] & -inf \\ \end{pmatrix} = \begin{pmatrix} f[u][0] & f[u][1] \end{pmatrix} \]

会得到转移矩阵里只跟当前点的 \(g\) 有关。
注意到转移时我们只需要重儿子的信息,以及当前点的 \(g\) 值,所以我们在线段树上维护每个点的 \(g\) 所构成的转移矩阵以及矩阵的区间乘积。
又注意到一条重链的底部一定是叶子。
所以对于一条重链的顶端他的 \(f\) 值就是这条重链的每个转移矩阵的乘积再乘一个初始矩阵(就是叶子的 \(f\) 值),这个区间乘线段树是好维护的。
所以我们只维护 \(g\) (或者其实是转移矩阵)就可以了。

修改流程如下:

  1. 当修改一个点 \(u\) 的点权时,当前点的 \(g[u][1]\) 要变一下。
  2. 然后 \(u\)\(top[u](重链顶端)\) 的所有点的 \(g\) 都是不变的,因为 \(g\) 在计算时不包含重儿子。
  3. \(top[u]\) 跳到 \(fa[top[u]]\) 时,这时因为 \(top[u]\)\(fa[top[u]]\) 的轻儿子,所以要更改 \(fa[top[u]]\)\(g\) 值。
    \(fa[top[u]\)\(g\) 值要用到 \(top[u]\)\(f\) 值,所以这个时候需要在线段树上区间查询一下。

复杂度是 \(O(n\times log^2(n))\),因为修改时要跳 \(log\) 次,每跳一次都要在线段树上查询一次。

一些细节:
矩阵乘法不满足交换律!!!所以线段树上 pushup 要从后往前合并。

code

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
inline int read(){
    int w = 1, s = 0;
    char c = getchar();
    for (; c < '0' || c > '9'; w *= (c == '-') ? -1 : 1, c = getchar());
    for (; c >= '0' && c <= '9'; s = 10 * s + (c - '0'), c = getchar());
    return s * w;
}
int n,T,w[N];
int tot,head[N],to[N<<1],Next[N<<1];
void add(int u,int v){
	to[++tot]=v,Next[tot]=head[u],head[u]=tot;
}

int top[N],down[N],rev[N],dfn[N],fa[N],son[N],Size[N],num,g[N][2],f[N][2];
//down是重链底端 
void dfs1(int u){
	Size[u]=1;
	for(int i=head[u];i;i=Next[i]){
		int v=to[i];
		if(v==fa[u]) continue;
		fa[v]=u;
		dfs1(v);
		Size[u]+=Size[v];
		if(Size[v]>Size[son[u]]) son[u]=v;
	}
}
void dfs2(int u){
	dfn[u]=++num;
	rev[num]=u;
	if(son[fa[u]]==u) top[u]=top[fa[u]];
	else top[u]=u;
	if(son[u]) dfs2(son[u]),down[u]=down[son[u]];
	else down[u]=u;
	g[u][1]=w[u];
	for(int i=head[u];i;i=Next[i]){
		int v=to[i];
		if(v==fa[u]||v==son[u]) continue;
		dfs2(v);
		g[u][0]+=max(f[v][0],f[v][1]);
		g[u][1]+=f[v][0];
	}
	f[u][0]=g[u][0]+max(f[son[u]][0],f[son[u]][1]);
	f[u][1]=g[u][1]+f[son[u]][0];
} 

struct Matrix{
	int n,m,a[3][3];
	void Init(){memset(a,-0x3f,sizeof a);}
	void Init2(){ //单位矩阵 
		for(int i=1;i<=n;i++){
			for(int j=1;j<=m;j++){
				if(i==j) a[i][j]=0;
				else a[i][j]=-0x3f3f3f3f;
			}
		}
	}
}F;
Matrix operator *(Matrix A,Matrix B){
	Matrix C; C.Init();
	C.n=A.n,C.m=B.m;
	for(int i=1;i<=C.n;i++){
		for(int j=1;j<=C.m;j++){
			for(int k=1;k<=A.m;k++){
				C.a[i][j]=max(C.a[i][j],A.a[i][k]+B.a[k][j]);
			}
		}
	}
	return C;
}

struct node{
	int l,r;
	Matrix G;
};
struct SegmentTree{
	node t[N<<2];
	void pushup(int p){
		t[p].G=t[p<<1|1].G*t[p<<1].G;
	}
	void build(int p,int l,int r){
		t[p].l=l,t[p].r=r;
		if(l==r){
			t[p].G.n=2,t[p].G.m=2;
			int u=rev[l];
			t[p].G.a[1][1]=g[u][0],t[p].G.a[1][2]=g[u][1],t[p].G.a[2][1]=g[u][0],t[p].G.a[2][2]=-0x3f3f3f3f;
			return;
		}
		int mid=(l+r)>>1;
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		pushup(p);
	}
	void change(int p,int x){
		if(t[p].l==t[p].r){
			int u=rev[x];
			t[p].G.a[1][1]=g[u][0],t[p].G.a[1][2]=g[u][1],t[p].G.a[2][1]=g[u][0],t[p].G.a[2][2]=-0x3f3f3f3f;
			return;
		}
		int mid=(t[p].l+t[p].r)>>1;
		if(x<=mid) change(p<<1,x);
		else change(p<<1|1,x);
		pushup(p);
	}
	Matrix ask(int p,int l,int r){
		if(l<=t[p].l&&t[p].r<=r) return t[p].G;
		int mid=(t[p].l+t[p].r)>>1;
		Matrix Res; Res.n=2,Res.m=2,Res.Init2();
		if(r>mid) Res=Res*ask(p<<1|1,l,r);
		if(l<=mid) Res=Res*ask(p<<1,l,r);
		return Res;
	}
}Seg;
void Init(){ //预处理:树剖,g 数组,f 数组,初始化线段树 
	dfs1(1); 
	dfs2(1);  
	
	Seg.build(1,1,n);
	F.n=1,F.m=2;
	F.a[1][1]=0,F.a[1][2]=-0x3f3f3f3f;  //初始矩阵,F 乘以叶子的转移矩阵就是叶子的 f。 
}
void change(int x,int y){
	int tmp=x;
	x=top[x];
	while(x!=1){   //先算出涉及到的点原来的 f 
		Matrix Ans=F;
		Ans=Ans * Seg.ask(1,dfn[x],dfn[down[x]]);
		f[x][0]=Ans.a[1][1],f[x][1]=Ans.a[1][2];
		g[ fa[x] ][0] -= max(f[x][0],f[x][1]);
		g[ fa[x] ][1] -= f[x][0];
		x=top[fa[x]];
	}

	x=tmp;
	g[x][1]-=w[x] , w[x]=y , g[x][1]+=w[x];
	Seg.change(1,dfn[x]);
	x=top[x];
	while(x!=1){
		Matrix Ans=F;
		Ans=Ans * Seg.ask(1,dfn[x],dfn[down[x]]);
		f[x][0]=Ans.a[1][1],f[x][1]=Ans.a[1][2];
		g[ fa[x] ][0] += max(f[x][0],f[x][1]);
		g[ fa[x] ][1] += f[x][0];
		Seg.change(1,dfn[fa[x]]);
		x=top[fa[x]];
	} 
	
}
signed main(){
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	n=read(),T=read();
	for(int i=1;i<=n;i++) w[i]=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read();
		add(u,v),add(v,u);
	}
	
	Init();
	
	while(T--){
		int x=read(),y=read();
		change(x,y);
		Matrix Ans=F;
		Ans=Ans * Seg.ask(1,dfn[1],dfn[down[1]]);
		printf("%d\n",max(Ans.a[1][1],Ans.a[1][2]));
	}
	return 0;
}
请登录后发表评论

    没有回复内容