P3714 [BJOI2017]树的难题

首先吐槽一下自己刚开始理解错题意了,以为题目中的按顺序可以按任意顺序.

这道题是一道关于树上路径的问题,很明显可以想到点分治,考虑当前的分治中心为 xx.

那么答案可以分为下面四种情况

1.序列一端为 xx, 另一端在子树内.

2.序列两端在两个不同的子树内,且两个子树到分治中心的颜色相同

3.序列两端在两个不同的子树内,且两个子树到分治中心的颜色不同

4.序列两端在一个子树内

对于第4种情况,很好解决,直接递归就行,第1种情况也是,重点是第2种和第3种.

因为我们需要将所有可能的情况都存起来,所以考虑用线段树, xx 的所有儿子按边的颜色排序,然后再来枚举,

那么就只需要开两棵线段树了,一棵存了之前的有儿子,一棵只存了当前颜色的儿子,在 dfsdfs 的时候分别查询一下就是了,

只需要在查询存的是当前颜色的儿子时需要减去一个 valval.

代码

#include<iostream>
#include<cstdio>
#include<queue>
#define INF (0x3f3f3f3f)
using namespace std;
const int N = 2e5 + 5;
typedef pair<int, int> PII;

int n, m, ql, qr;
int idx, rt1, rt2, tot, val[N], ans = -INF;
int head[2 * N], net[2 * N], edge[2 * N], ver[2 * N];
struct tree
{
int l, r, v, maxn;
void init()
{
l = r = 0;
v = maxn = -INF;
}
} tr[80 * N];
bool st[N];

void pushup(int p)
{
tr[p].maxn = max(tr[tr[p].l].maxn, tr[tr[p].r].maxn);
}

int merge(int p, int q, int l, int r)
{
if (!p || !q)
return p + q;
if (l == r)
{
tr[p].maxn = tr[p].v = max(tr[p].v, tr[q].v);
tr[q].init();
return p;
}
int mid = (l + r) >> 1;
tr[p].l = merge(tr[p].l, tr[q].l, l, mid);
tr[p].r = merge(tr[p].r, tr[q].r, mid + 1, r);
tr[q].init();
pushup(p);
return p;
}

void insert(int p, int l, int r, int d, int x)
{
if (l == r)
{
tr[p].maxn = tr[p].v = max(tr[p].v, x);
return;
}
int mid = (l + r) >> 1;
if (d <= mid)
{
if (!tr[p].l)
tr[p].l = ++tot;
insert(tr[p].l, l, mid, d, x);
}
else
{
if (!tr[p].r)
tr[p].r = ++tot;
insert(tr[p].r, mid + 1, r, d, x);
}
pushup(p);
}

int query(int p, int L, int R, int l, int r)
{
if (L >= l && R <= r)
return tr[p].maxn;
int mid = (L + R) >> 1, res = -INF;
if (l <= mid && tr[p].l)
res = max(res, query(tr[p].l, L, mid, l, r));
if (r > mid && tr[p].r)
res = max(res, query(tr[p].r, mid + 1, R, l, r));
return res;
}

void add(int a, int b, int c)
{
net[++idx] = head[a], ver[idx] = b;
edge[idx] = c, head[a] = idx;
}

int get_siz(int u, int fa)
{
int res = 1;
for (int i = head[u]; i; i = net[i])
if (ver[i] != fa && !st[ver[i]])
res += get_siz(ver[i], u);
return res;
}

int get_wc(int u, int fa, int tot, int &wc)
{
int sum = 1, ms = 0;
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (st[v] || v == fa)
continue;
int t = get_wc(v, u, tot, wc);
sum += t;
ms = max(ms, t);
}
ms = max(ms, tot - sum);
if (ms <= tot / 2)
wc = u;
return sum;
}

void get_dist(int rt, int u, int fa, int vals, int len, int last, int ned)
{
if (len > qr)
return;
ans = max(ans, query(rt1, 1, qr, max(1, ql - len), max(1, qr - len)) + vals - ned);
ans = max(ans, query(rt2, 1, qr, max(1, ql - len), max(1, qr - len)) + vals);
if (len >= ql)
ans = max(ans, vals);
insert(rt, 1, qr, len, vals);
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (v == fa || st[v])
continue;
if (edge[i] == last)
get_dist(rt, v, u, vals, len + 1, last, ned);
else
get_dist(rt, v, u, vals + val[edge[i]], len + 1, edge[i], ned);
}
}

void clear(int x)
{
if (tr[x].l)
clear(tr[x].l);
if (tr[x].r)
clear(tr[x].r);
tr[x].init();
}

void calc(int u)
{
get_wc(u, -1, get_siz(u, -1), u);
priority_queue<PII> q;
st[u] = true;
for (int i = head[u]; i; i = net[i])
{
int v = ver[i];
if (st[v])
continue;
q.push(make_pair(edge[i], v));
}
int last = 0;
rt2 = ++tot;
while (!q.empty())
{
int v = q.top().second, tp = q.top().first;
q.pop();
if (tp != last)
{
merge(rt2, rt1, 1, qr);
rt1 = ++tot;
}
int rt = ++tot;
get_dist(rt, v, u, val[tp], 1, tp, val[tp]);
merge(rt1, rt, 1, qr);
last = tp;
}
tot = 0, clear(rt1), clear(rt2);
for (int i = head[u]; i; i = net[i])
if (!st[ver[i]])
calc(ver[i]);
}

int main()
{
for (int i = 0; i < 10 * N; i++)
tr[i].v = tr[i].maxn = -INF;
scanf("%d%d%d%d", &n, &m, &ql, &qr);
for (int i = 1; i <= m; i++)
scanf("%d", &val[i]);
for (int i = 1; i < n; i++)
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
calc(1);
printf("%d", ans);
return 0;
}