dwjshift's Blog

bzoj3159 决战

| Comments

这题的做法挺好玩的,所以就来写写题解。
弄个LCT,只维护形态,方便提取路径。对于LCT的每条链,维护一棵平衡树(存权值),使得这棵平衡树按照中序遍历对应的是这条链上每个节点的权值。
接下来要解决的就是怎么维护这个平衡树了。首先是access的时候,需要把链断开或者把链接上,那么对应的就要把平衡树前面的若干个节点split出来或者merge两棵平衡树。LCT换根的时候,除了LCT那里要打个翻转tag,对应的平衡树那也要打个翻转tag。接着执行操作的时候直接提取出链来,然后在对应的平衡树上执行操作就好了。
平衡树要支持mergesplit,比较好的选择应该是splay或者treap。不过用treap就是,用splay大概只有一个log吧(其实我一直都不会势能分析T_T
葱娘的官方题解写的是维护两个LCT,个人觉得这样的说法不太确切,因为维护权值的那一堆平衡树相互之间是独立的。另外题解里还说到只能翻转深度单调递增的链,大概是搞错了吧。

code

#include <cstdio>
#include <algorithm>
#define repu(i,x,y) for (int i=x; i<=y; ++i)
#define repd(i,x,y) for (int i=x; i>=y; --i)
using namespace std;

typedef long long LL;
int n,m,r,u,v,w;
char opt[10];
struct node2
{
    int w,min,max,s,tag;
    LL sum;
    bool rev;
    node2 *lc,*rc,*fa;
} pool2[50100],*tp2=pool2,*nul2=tp2++;
struct node1
{
    int s;
    bool rev;
    node1 *lc,*rc,*fa;
    node2 *node;
} pool1[50100],*tp1=pool1,*nul1=tp1++,*idx[50100];

inline void update(node2 *x)
{   
    x->sum=x->lc->sum+x->rc->sum+x->w;
    x->max=max(max(x->lc->max,x->rc->max),x->w);
    x->min=min(min(x->lc->min,x->rc->min),x->w);
    x->s=x->lc->s+x->rc->s+1;
}

inline void add(node2 *x,int w)
{
    if (x==nul2)
        return;
    x->w+=w,x->tag+=w,x->min+=w,x->max+=w,x->sum+=LL(w)*x->s;
}

inline void push_down(node2 *x)
{
    if (x->tag)
    {   
        add(x->lc,x->tag),add(x->rc,x->tag);
        x->tag=0;
    }
    if (!x->rev)
        return;
    swap(x->lc,x->rc);
    x->lc->rev^=1,x->rc->rev^=1,x->rev=0;
}

void clear(node2 *x)
{
    if (x->fa)
        clear(x->fa);
    push_down(x);
}

inline void rotate(node2 *y)
{
    node2 *x=y->fa;
    if (x->fa)
        (x->fa->lc==x?x->fa->lc:x->fa->rc)=y;
    y->fa=x->fa;
    if (y==x->lc)
        (x->lc=y->rc)->fa=x,(y->rc=x)->fa=y;
    else
        (x->rc=y->lc)->fa=x,(y->lc=x)->fa=y;
    update(x),update(y);
}

void splay(node2 *x)
{
    node2 *y;
    for (clear(x); x->fa; rotate(x))
        if ((y=x->fa)->fa)
            rotate(x==y->lc ^ y==y->fa->lc?x:y);
}

node2 *search(node2 *now,int k)
{
    push_down(now);
    if (now->lc->s+1==k)
        return now;
    if (k<=now->lc->s)
        return search(now->lc,k);
    return search(now->rc,k-now->lc->s-1);
}

inline bool isroot(node1 *x)
{
    return !x->fa || x->fa->lc!=x && x->fa->rc!=x;
}

inline void update(node1 *x)
{
    x->s=x->lc->s+x->rc->s+1;
}

inline void rotate(node1 *y)
{
    node1 *x=y->fa;
    if (!isroot(x))
        (x->fa->lc==x?x->fa->lc:x->fa->rc)=y;
    else
        y->node=x->node;
    y->fa=x->fa;
    if (y==x->lc)
        (x->lc=y->rc)->fa=x,(y->rc=x)->fa=y;
    else
        (x->rc=y->lc)->fa=x,(y->lc=x)->fa=y;
    y->s=x->s,update(x);
}

inline void push_down(node1 *x)
{
    if (!x->rev)
        return;
    swap(x->lc,x->rc);
    x->lc->rev^=1,x->rc->rev^=1,x->rev=0;
}

void clear(node1 *x)
{
    if (!isroot(x))
        clear(x->fa);
    push_down(x);
}

void splay(node1 *x)
{
    node1 *y;
    for (clear(x); !isroot(x); rotate(x))
        if (!isroot(y=x->fa))
            rotate(x==y->lc ^ y==y->fa->lc?x:y);
}

void cut(node1 *x)
{
    splay(x);
    if (x->rc==nul1)
        return;
    splay(x->node=search(x->node,x->lc->s+1));
    x->rc->node=x->node->rc,x->rc=nul1,update(x);
    x->node->rc->fa=NULL,x->node->rc=nul2,update(x->node);
}

void access(node1 *x)
{
    for (cut(x); x->fa; x=x->fa)
    {
        cut(x->fa);
        splay(x->fa->node=search(x->fa->node,x->fa->s));
        x->fa->node->rc=x->node,x->node->fa=x->fa->node,update(x->fa->node);
        x->fa->rc=x,update(x->fa);
    }
}

void modify_root(node1 *x)
{
    access(x),splay(x),x->rev^=1,x->node->rev^=1;
}
int main()
{
    nul2->max=-(nul2->min=1<<30);
    scanf("%d%d%d",&n,&m,&r);
    repu(i,1,n)
    {
        tp2->lc=tp2->rc=nul2,tp2->s=1,tp1->node=tp2++;
        tp1->lc=tp1->rc=nul1,tp1->s=1,idx[i]=tp1++;
    }
    repu(i,1,n-1)
    {
        scanf("%d%d",&u,&v);
        modify_root(idx[u]),idx[u]->fa=idx[v];
    }
    while (m--)
    {
        scanf("%s%d%d",opt,&u,&v);
        modify_root(idx[u]),access(idx[v]),splay(idx[v]);
        if (opt[2]=='c')
            scanf("%d",&w),add(idx[v]->node,w);
        if (opt[2]=='m')
            printf("%lld\n",idx[v]->node->sum);
        if (opt[2]=='j')
            printf("%d\n",idx[v]->node->max);
        if (opt[2]=='n')
            printf("%d\n",idx[v]->node->min);
        if (opt[2]=='v')
            idx[v]->node->rev^=1;
    }
    return 0;
}

Comments

comments powered by Disqus