动态 DP

NOIP居然会考这种东西,所以不得不来学一下

结合着上面这道题,可以看出,动态DP就是一个动态规划问题加上了修改操作,

如果每一次修改我们都去跑一遍动态规划,时间复杂度直接起飞,所以这时候就要想办法优化。

首先看一下这道题如果不带修改操作该怎么做

fi,0/1f_{i,0/1} 表示以 ii 为子树不选/选 ii 的最大点权独立集

那么状态转移方程就为 {fu,0=max(fv,0,fv,1)fu,1=au+fv,0\begin{cases}f_{u,0}=\sum\max(f_{v,0},f_{v,1})\\f_{u,1}=a_u+\sum f_{v,0}\end{cases}

一个简单的树上DP,可以看出,每个结点只会对他的父亲结点造成影响.

这时候如果我们修改了其中一个结点,那么他只会对在他这条链上的祖先结点造成影响,

如果这条树的高度比较平均,那么就只需要 log(n)log(n) 次,可惜如果这棵树退化成链,那么一次修改就需要 nn次,

显然是不行的.既然是跟树有关,可以想到一个数据结构,树链剖分.因为我们的DP是从叶子结点往上转移,

而树链剖分中每条重链的链尾都是叶子结点,这就可以让我们很好的进行DP的转移,

同时树链剖分可以让我们快速的进行修改操作,那么我们该怎么把DP与树链剖分结合?

将DP的转移式稍微变一下形,令 gi,0g_{i,0} 表示所有 ii 的所有轻儿子可取可不取的最大值, gi,1g_{i,1} 表示 ii 的所有轻儿子都不去并取 ii 的最大值

{fu,0=gu,0+max(fv,1,fv,0)fu,1=gu,1+fv,0\begin{cases}f_{u,0}=g_{u,0}+max(f_{v,1},f_{v,0})\\f_{u,1}=g_{u,1}+f_{v,0}\end{cases}

这个东西并不好直接快速在树上转移,这个形式可以考虑用矩阵加速.

重定义矩阵乘法 ci,j=max(ai,k+bk,j)c_{i,j}=max(a_{i,k}+b_{k,j})

这个为什么可以套用在上面的转移式?

再将上面的转移式改一下,{fu,0=max(fv,1+gu,0,fv,0+gu,0)fu,1=max(gu,1+fv,0,inf)\begin{cases}f_{u,0}=max(f_{v,1}+g_{u,0},f_{v,0}+g_{u,0})\\f_{u,1}=max(g_{u,1}+f_{v,0},-\inf)\end{cases}

这样就可套用矩阵乘法了.可以直接构造出一个矩阵

fv,0fv,1×gu,0gu,1gu,0inf=fu,0fv,0\begin{vmatrix}f_{v,0}&f_{v,1}\end{vmatrix}\times\begin{vmatrix}g_{u,0}&g_{u,1}\\g_{u,0}&-\inf\end{vmatrix}=\begin{vmatrix}f_{u,0}&f_{v,0}\end{vmatrix}

所以我们只需要在线段树中维护一个转移矩阵,最后求一下所有转移矩阵的积就行.

注意线段树做乘法时的顺序是从父亲结点到叶子节点,所以我们要交换一下矩阵的顺序

gu,0gu,1gu,1inf×fv,0fv,1=fu,0fu,1\begin{vmatrix}g_{u,0}&g_{u,1}\\g_{u,1}&-\inf\end{vmatrix}\times\begin{vmatrix}f_{v,0}\\f_{v,1}\end{vmatrix}=\begin{vmatrix}f_{u,0}\\f_{u,1}\end{vmatrix}

这样就可以快速进行转移了

代码

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int N = 1e5 + 5;

struct matrix
{
int data[2][2];
matrix() { memset(data, -0x3f, sizeof(data)); };
matrix operator * (const matrix a) const
{
matrix c;
for (int i = 0; i < 2; i++)
for (int j = 0; j < 2; j++)
for (int k = 0; k < 2; k++)
c.data[i][j] = max(c.data[i][j], data[i][k] + a.data[k][j]);
return c;
}
} it[N];
struct tree
{
int l, r;
matrix mx;
} tr[4 * N];
int n, m, a[N], fa[N], siz[N], son[N];
int head[N], ver[2 * N], net[2 * N], idx, ed[N];
int top[N], tot, id[N], f[N][2], dfsn[N];

void add(int a, int b)
{
net[++idx] = head[a], ver[idx] = b, head[a] = idx;
}

void dfs1(int u, int f)
{
siz[u] = 1, fa[u] = f;
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (v == f)
continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
}
}

void dfs2(int u, int t)
{
dfsn[u] = ++tot, id[tot] = u, top[u] = t;
f[u][0] = 0, f[u][1] = a[u], ed[t] = max(ed[t], tot);
it[u].data[0][0] = it[u].data[0][1] = 0;
it[u].data[1][0] = a[u];
if (!son[u])
return;
dfs2(son[u], t);
f[u][0] += max(f[son[u]][0], f[son[u]][1]);
f[u][1] += f[son[u]][0];
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (v == fa[u] || v == son[u])
continue;
dfs2(v, v);
f[u][0] += max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
it[u].data[0][0] += max(f[v][0], f[v][1]);
it[u].data[0][1] = it[u].data[0][0];
it[u].data[1][0] += f[v][0];
}
}

void pushup(int p)
{
tr[p].mx = tr[p << 1].mx * tr[p << 1 | 1].mx;
}

void build(int l, int r, int p)
{
tr[p].l = l, tr[p].r = r;

if (l == r)
{
tr[p].mx = it[id[l]];
return;
}
int mid = (l + r) >> 1;
build(l, mid, p << 1);
build(mid + 1, r, p << 1 | 1);
pushup(p);
}

void update_tree(int x, int p)
{
if (tr[p].l == tr[p].r)
{
tr[p].mx = it[id[x]];
return;
}
int mid = (tr[p].l + tr[p].r) >> 1;
if (x <= mid)
update_tree(x, p << 1);
else
update_tree(x, p << 1 | 1);
pushup(p);
}

matrix query(int l, int r, int p)
{
if (tr[p].l >= l && tr[p].r <= r)
return tr[p].mx;
int mid = (tr[p].l + tr[p].r) >> 1;
matrix res;
if (r <= mid)
return query(l, r, p << 1);
else if (l > mid)
return query(l, r, p << 1 | 1);
else
return query(l, r, p << 1) * query(l, r, p << 1 | 1);
return res;
}

void update_path(int u, int w)
{
it[u].data[1][0] += w - a[u], a[u] = w;
matrix ta, tb;
while (u)
{
ta = query(dfsn[top[u]], ed[top[u]], 1);
update_tree(dfsn[u], 1);
tb = query(dfsn[top[u]], ed[top[u]], 1);
u = fa[top[u]];
it[u].data[0][0] += max(tb.data[0][0], tb.data[1][0]) - max(ta.data[0][0], ta.data[1][0]);
it[u].data[0][1] = it[u].data[0][0];
it[u].data[1][0] += tb.data[0][0] - ta.data[0][0];
}
}

int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
dfs1(1, 0);
dfs2(1, 1);
build(1, n, 1);
while (m--)
{
int x, y;
scanf("%d%d", &x, &y);
update_path(x, y);
matrix ans = query(dfsn[1], ed[1], 1);
printf("%d\n", max(ans.data[1][0], ans.data[0][0]));
}
return 0;
}