【学习笔记】Segment Tree Beats/吉司机线段树

链接

区间最值操作

HDU-5306

支持对区间取 \(\min\),维护区间 \(\max\),查询区间和。

很容易想到一个暴力,我们每一次找出这个区间的最大值 \(mx\),如果 \(mx>x\),那么暴力修改这个位置的值,否则已经修改完毕,退出,时间复杂度为 \(O(n^2 \log n)\)

打一打补丁,对线段树上的每一个区间维护区间最大值 \(mx\),这个区间中最大值出现的次数 \(t\),区间次大值 \(se\),当然还要维护区间和 \(sum\)

现在考虑打上区间取 \(\min\) 标记

  • 如果 \(mx\le x\),那么对 \(sum\) 就没有修改。
  • 如果 \(se<x<mx\),那么 \(sum=sum-(mx-x)\times t\)
  • 如果 \(x\le se<mx\),此时无法直接更新节点信息,故向下左右子树递归。我们分别 DFS 这个节点的两个孩子,如果当前 DFS 的过程中遇到了前两种情况,就直接修改打上标记然后退出,否则就继续 DFS。

点击查看代码

#include<bits/stdc++.h>
using namespace std;

#define ls p<<1
#define rs p<<1|1
#define ll long long
inline char gc()
{
    static char buf[1 << 20/*这里很玄学,改成其他数字可能更快*/], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 20/*改成和上面一样的数字*/, stdin), p1 == p2) ? EOF : *p1 ++;
}

inline void read(int &n) // 用法 read(n);
{
    bool w = 0;
    char c = gc();
    for(; c < 48 || c > 57; c = gc())
        w = c == 45;
    for(n = 0; c >= 48 && c <= 57; c = gc())
        n = n * 10 + c - 48;
    n = w ? -n : n;
}
const int N=1e6+7;

int T,n,m,a[N],mx[N<<2],se[N<<2],cnt[N<<2],tag[N<<2];
ll sum[N<<2];

inline void pushup(int p){
    sum[p]=sum[ls]+sum[rs];
    if(mx[ls]==mx[rs]){
        mx[p]=mx[ls],se[p]=max(se[ls],se[rs]);
        cnt[p]=cnt[ls]+cnt[rs];
    }
    else if(mx[ls]>mx[rs]){
        mx[p]=mx[ls],se[p]=max(se[ls],mx[rs]);
        cnt[p]=cnt[ls];
    }
    else{
        mx[p]=mx[rs],se[p]=max(se[rs],mx[ls]);
        cnt[p]=cnt[rs];
    }
    return;
}

inline void build(int p,int l,int r){
    tag[p]=-1;
    if(l==r){
        sum[p]=mx[p]=a[l];
        cnt[p]=1,se[p]=-1;
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid),build(rs,mid+1,r);
    pushup(p);
    return;
}

inline void pushtag(int p,int tg){
    if(mx[p]<=tg)return;
    sum[p]+=(ll)(tg-mx[p])*cnt[p];
    mx[p]=tag[p]=tg;
    return;
}

inline void pushdown(int p){
    if(tag[p]==-1)return;
    pushtag(ls,tag[p]),pushtag(rs,tag[p]);
    tag[p]=-1;
    return;
}

inline void update(int p,int l,int r,int s,int t,int val){
    if(mx[p]<=val)return;
    if(s<=l&&r<=t&&se[p]<val){
        pushtag(p,val);
        return;
    }
    int mid=(l+r)>>1;
    pushdown(p);
    if(s<=mid)update(ls,l,mid,s,t,val);
    if(t>mid)update(rs,mid+1,r,s,t,val);
    pushup(p);
    return;
}

inline int querymax(int p,int l,int r,int s,int t){
    if(s<=l&&r<=t)return mx[p];
    int mid=(l+r)>>1,res=-1;
    pushdown(p);
    if(s<=mid)res=max(res,querymax(ls,l,mid,s,t));
    if(t>mid)res=max(res,querymax(rs,mid+1,r,s,t));
    return res;
}

inline ll querysum(int p,int l,int r,int s,int t){
    if(s<=l&&r<=t)return sum[p];
    int mid=(l+r)>>1;ll res=0;
    pushdown(p);
    if(s<=mid)res+=querysum(ls,l,mid,s,t);
    if(t>mid)res+=querysum(rs,mid+1,r,s,t);
    return res;
}

inline void solve(){
    read(n); read(m);
    for(int i=1;i<=n;i++)read(a[i]);
    build(1,1,n);
    for(int i=1;i<=m;i++){
        int op,l,r,val;
        read(op); read(l); read(r);
        if(!op){
            read(val);
            update(1,1,n,l,r,val);
        }
        else if(op==1)printf("%d\n",querymax(1,1,n,l,r));
        else printf("%lld\n",querysum(1,1,n,l,r));
    }
    return;
}

int main(){
    scanf("%d",&T);
    while(T--)solve();
    return 0;
}

请登录后发表评论

    没有回复内容