Color a Tree
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others) Total Submission(s): 854 Accepted Submission(s): 281思路:寻找最大权值,合并这个节点和他的父亲节点,记下这两个节点的拓扑序列,同时新节点的权值为这些节点的算术平均值,直到只有一个节点。因为这个节点必定是访问该节点的父节点之后第一个访问的节点。
证明:
对于每一次新的访问,我们要计算所有未被访问的点权之和。显然这个计算式很繁琐且不易处理。
ò 思考:根据访问代价的计算规则,对于根节点,它的权值需要计算1次;对于第2个访问的点,它的权值需要计算2次;对于第3个点……以此类推。
ò 发现:每个点权的计算次数只与遍历顺序有关!即访问这个点的时间戳。
ò 给定一棵N个有权值的节点的有根树(默认根节点编号为1)。每个节点的权值为Ci。
ò 现在需要遍历这棵树。每个点的访问代价为这个点的权值与访问它的时间戳的乘积。
ò 遍历只能按拓扑序,即访问i时i的父亲必须已经被访问。
ò 求最小遍历代价。遍历代价为每个点的访问代价之和。
ò 转化的好处?使代价函数的计算式更加确定且易于计算。
ò 定义访问序列为一个排列P={i1,i2,i3...in-1,in},表示节点的访问顺序。
ò 定义点权序列为T={Ci1,Ci2,Ci3…Cin-1,Cin}。
ò 由简化问题的贪心解法,我们考虑尽量访问当前未访问的点权最大节点i。然而由于拓扑序的限制,我们想要访问它,必须先访问它的父亲。
ò 猜想:当前点权最大的节点i必定是在访问它的父亲j后立刻访问,否则得不到最优答案。
ò 猜想是否正确?
令当前最大节点为i,它的父亲为j。在访问序列P中i的位置为xi,j的位置为xj,假设xi+1<xj。
ò 关键:对于所有的k满足xj<xk<xi,这些点必定不是j的祖先或i的后代。
ò 我们在序列P中交换i和k,P仍然是一个合法的访问序列。
ò 显然地,由于Ci是访问j时所有未访问点的点权最大值,那么交换后的访问序列P对应的遍历代价变小了。
ò 进一步,将i交换到xj+1的位置是最优的。
ò 于是猜想是正确的。
ò 结论1:当前最大节点Ci必定是在访问它的父亲j后立刻访问。
ò
ò 由结论1知,在访问序列中,i和j应该是相邻的,j在前,i在后。
ò 那么i和j可以合并为一个结点k,k的父节点与j的父结点相同,k的子结点是所有j的子结点和i的子结点。然后用k代替树中的j和i,这样形成一棵n-1个结点的树。
ò 合并的好处?
ò 问题的规模缩小了。
ò 新的子问题:k在新树的访问序列中的位置?
ò 思考:访问序列P中k必然在k的后代之前,那么我们需要讨论的是k与非k后代的节点的相对位置。现在有两个选择:一是先访问k,然后访问非k的后代的m个结点(i1,i2,...,im) ;二是先访问非k的后代的m个结点(i1,i2,...,im),然后访问k。
ò 我们需要知道怎样的相对位置使最终代价最小。
ò 当前决策k完成时第二种选择相对于第一种选择费用之差为:
ò F2-F1=(Ci+Cj)×m-{sigma(Cik)|k=1..m}×2
ò 也就是说,第二种方案先访问m个节点,这m个节点相比第一种方案提前了2个时间点,那么减少的费用是2×{sigma(Cik)|k=1..m};后访问i和j,i,j相比第一种方案延后了m个时间点,增加的费用是(Ci+Cj)×m。
ò 1.标记根节点为第一个访问的节点。
ò 2.求出当前未访问节点中的最大点权节点i。
ò 3.将i和它的父亲j合并为一个节点,节点权值为两者权值的算术平均数。在序列P中将j的后继置为i。同时更新树的信息。
ò 4.若当前树中节点数大于1,则转第2步。
ò 5.树的大小为1时算法结束。
ò 6.扫描求得的P序列得到答案。
ò 时间复杂度:O(N^2)。
ò 注意到我们每次操作需要得到当前最大权对应的节点i并将i的儿子的父亲改为i的父亲。
ò O(N)的扫描成为算法复杂度的瓶颈。
ò 如何高效求最大值?
ò 推荐数据结构:最大堆(O(LogN))。
ò 如何高效地将i的儿子的父亲改为i的父亲?
ò 推荐数据结构:并查集(O(α(N)))。
ò 总的时间复杂度:O(NLogN)。
#include#include #include #define FOR(i,a,b) for(int i=a;i<=b;++i)#define clr(f,z) memset(f,z,sizeof(f))using namespace std;const int mm=1009;int rt[mm],fa[mm],val[mm],to[mm],num[mm];bool vis[mm];float dw[mm];int root;int find(int x){ if(x^rt[x]) rt[x]=find(rt[x]); return rt[x];}int child(int x){ if(to[x]^x) return child(to[x]); return to[x];}int main(){ int n,a,b; while(~scanf("%d%d",&n,&root)) { if(n==0&&root==0)break; FOR(i,1,n)rt[i]=i,fa[i]=i,to[i]=i,vis[i]=0,num[i]=1; FOR(i,1,n)scanf("%d",&val[i]),dw[i]=val[i]; FOR(i,2,n)scanf("%d%d",&a,&b),fa[b]=a; vis[root]=1; FOR(i,2,n) { int id=0; FOR(j,1,n) if(!vis[j]) { if(id==0||((dw[id]/num[id])<(dw[j]/num[j]))) id=j; } int u=fa[id]; u=child(u); ///时间戳相连 to[u]=id; u=find(fa[id]); ///子节点合到父节点 num[u]+=num[id];///节点数 dw[u]+=dw[id]; rt[id]=u; vis[id]=1; } int ans=0,z=root; /// puts("++++"); FOR(i,1,n) { ans+=val[z]*i; z=to[z]; } printf("%d\n",ans); }}