#include<iostream> #include<cstdio> using namespace std; const int N = 2e5 + 5; typedef long long ll;
int n, top, Y; int head[N], ver[N], net[N], idx, Sum; int c[N], cnt[N], tim[N], siz[N], cut[N]; bool st[N], use[N]; ll sum[N];
void add(int a, int b) { net[++idx] = head[a], ver[idx] = b, head[a] = idx; }
int get_wc(int u, int fa, int tot, int &wc) { int alls = 1, ms = 0; for (int i = head[u]; i; i = net[i]) { int v = ver[i]; if (v == fa || st[v]) continue; int t = get_wc(v, u, tot, wc); alls += t; ms = max(ms, t); } ms = max(ms, tot - alls); if (ms <= tot / 2) wc = u; return alls; }//求重心,不用多讲了吧
void get_siz(int u, int fa) { siz[u] = 1; for (int i = head[u]; i; i = net[i]) if (ver[i] != fa && !st[ver[i]]) get_siz(ver[i], u), siz[u] += siz[ver[i]]; }
void up_col(int u, int fa, int tp) { if (!tim[c[u]]) { cnt[c[u]] += tp * siz[u];//当前颜色的贡献 Sum += tp * siz[u];//总贡献 } if (!use[c[u]] && tp == 1) cut[++top] = c[u], use[c[u]] = true;//记录一下一共用了哪些颜色 tim[c[u]]++;//次数加一 for (int i = head[u]; i; i = net[i]) if (ver[i] != fa && !st[ver[i]]) up_col(ver[i], u, tp); tim[c[u]]--;//还原 }
void update(int u, int fa, int num, ll tot)//更新子树内每个点的答案 { if (!tim[c[u]]) num++, tot += cnt[c[u]]; tim[c[u]]++; sum[u] += Sum - tot + num * Y; for (int i = head[u]; i; i = net[i]) if (ver[i] != fa && !st[ver[i]]) update(ver[i], u, num, tot); tim[c[u]]--; }
void calc(int u) { get_wc(u, -1, siz[u], u);//找出重心 st[u] = true, Sum = top = 0; get_siz(u, -1), up_col(u, -1, 1);//求出子树大小与 $cnt$ for (int i = head[u]; i; i = net[i]) { int v = ver[i]; if (!st[v]) { tim[c[u]]++, up_col(v, u, -1); cnt[c[u]] -= siz[v], Sum -= siz[v];//先减去当前子树对 $cnt,Sum$ 的贡献 Y = siz[u] - siz[v], update(v, u, 0, 0); up_col(v, u, 1), tim[c[u]]--; cnt[c[u]] += siz[v], Sum += siz[v];//加回来 } } sum[u] += Sum - cnt[c[u]] + siz[u]; for (int i = 1; i <= top; i++) use[cut[i]] = false, cnt[cut[i]] = 0;//清空数组 for (int i = head[u]; i; i = net[i]) if (!st[ver[i]]) calc(ver[i]);//进行下一层点分 }
int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &c[i]); for (int i = 1; i < n; i++) { int a, b; scanf("%d%d", &a, &b); add(a, b), add(b, a); } get_siz(1, -1);//求出每个子树的大小 calc(1);//点分治 for (int i = 1; i <= n; i++) printf("%lld\n", sum[i]); return 0; }
|