【CF860E】Arkady and a Nobody-men 长链剖分

【CF860E】Arkady and a Nobody-men

题意:给你一棵n个点的有根树。如果b是a的祖先,定义$r(a,b)$为b的子树中深度小于等于a的深度的点的个数(包括a)。定义$z(a)=\sum\limits r(a,b)$(b是a的祖先)。要你求出每个点的z值。

$n\le 5\times 10^5$

题解:一开始naive的思路:将所有点按深度排序,将深度相同的点统一处理,统计答案时相当于链加,链求和,用树剖+树状数组搞一搞,时间复杂度$O(n\log^2n)$。

后来看题解发现我这个想法简直菜爆了。我们先从树形DP的角度去想,先给出转移方程:

$ans(x)=ans(fa(x))+dep(x)+ans'(x)$,ans'(x)表示与a深度相同的点 对a的贡献。

现在问题变成了求ans',我们考虑在每个点对的lca处统计贡献。具体地,我们对于每个点x,维护若干个三元组(d,a,cnt)表示x的子树中有cnt个d级子孙,其中一个子孙为a。DP的过程就相当于在父亲节点处将所有儿子节点的三元组合并,在合并时顺便统计贡献。

具体地,合并方式如下:假如x有两个儿子,它们有三元组$(d,a,cnt_a)$和$(d,b,cnt_b)$,则:

1.ans'(a)+=dep(x)\times cnt_b
2.ans'(b)+=dep(x)\times cnt_a
3.得到新三元组(d,a,cnt_a+cnt_b)

但是后面的点 对b的贡献呢?我们发现后面的点 对a和b的贡献就是相同的了,所以我们建一个新图,在新图中从a到b连一条长度为$ans'(b)-ans'(a)$的边,最后在新图上DFS一下,最最后统计一下ans数组即可。

以上过程采用长链剖分优化,由于一开始的三元组个数为n,则每次合并都会减少一个三元组,所以时间复杂度O(n)。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
const int maxn=500010;
typedef long long ll;
ll ans[maxn];
struct node
{
	int v,x;
	node() {}
	node(int a,int b) {v=a,x=b;}
}mem[maxn<<1],*f[maxn],*now=mem;
int n,cnt,Cnt,rt;
int to[maxn],nxt[maxn],head[maxn],dep[maxn],md[maxn],son[maxn],fa[maxn];
bool vis[maxn];
int To[maxn],Nxt[maxn],Head[maxn];
ll Val[maxn];
inline void add(int a,int b)
{
	to[cnt]=b,nxt[cnt]=head[a],head[a]=cnt++;
}
inline void Add(int a,int b,ll c)
{
	To[Cnt]=b,Val[Cnt]=c,Nxt[Cnt]=Head[a],Head[a]=Cnt++;
}
void dfs1(int x)
{
	md[x]=0;
	for(int i=head[x];i!=-1;i=nxt[i])
	{
		dep[to[i]]=dep[x]+1,dfs1(to[i]);
		if(md[to[i]]+1>md[x])	md[x]=md[to[i]]+1,son[x]=to[i];
	}
}
void dfs2(int x)
{
	if(f[x]==NULL)	f[x]=now,now+=md[x]+2;
	if(son[x])	f[son[x]]=f[x]+1,dfs2(son[x]);
	f[x][0]=node(1,x);
	for(int i=head[x];i!=-1;i=nxt[i])	if(to[i]!=son[x])
	{
		dfs2(to[i]);
		for(int j=0;j<=md[to[i]];j++)
		{
			node a=f[x][j+1],b=f[to[i]][j];
			ans[b.x]+=1ll*dep[x]*a.v;
			ans[a.x]+=1ll*dep[x]*b.v;
			Add(a.x,b.x,ans[b.x]-ans[a.x]);
			f[x][j+1]=node(a.v+b.v,a.x);
		}
	}
}
void dfs3(int x)
{
	for(int i=Head[x];i!=-1;i=Nxt[i])	ans[To[i]]=ans[x]+Val[i],dfs3(To[i]);
}
void dfs4(int x)
{
	for(int i=head[x];i!=-1;i=nxt[i])	ans[to[i]]+=ans[x]+dep[x],dfs4(to[i]);
}
inline int rd()
{
	int ret=0,f=1;	char gc=getchar();
	while(gc<'0'||gc>'9')	{if(gc=='-')	f=-f;	gc=getchar();}
	while(gc>='0'&&gc<='9')	ret=ret*10+(gc^'0'),gc=getchar();
	return ret*f;
}
int main()
{
	n=rd();
	int i;
	memset(head,-1,sizeof(head)),memset(Head,-1,sizeof(Head));
	for(i=1;i<=n;i++)
	{
		fa[i]=rd();
		if(!fa[i])	rt=i;
		else	add(fa[i],i);
	}
	dep[rt]=1,dfs1(rt);
	dfs2(rt);
	for(i=0;i<=md[rt];i++)	dfs3(f[rt][i].x);
	dfs4(rt);
	for(i=1;i<=n;i++)	printf("%lld ",ans[i]);
	return 0;
}