【原題見
這里】
本題是Splay Tree處理序列問題(也就是當線段樹用)的一個典型例題。
Splay Tree之所以可以當線段樹用,是因為它可以支持一個序列,然后用“左端前趨伸展到根,右端后繼伸展到根的右子結點,取根的右子結點的左子結點”這種伸展方法,對一個序列中的一整段進行整體操作。由于要防止出現前趨或后繼不存在的情況,需要在這個序列的兩端加入兩個邊界結點,要求其值不能影響到結點各種記載信息的維護(多取0、∞或-∞)。這兩個邊界結點在樹中永遠存在,不會被刪除。
(1)結點的引用:
在當線段樹用的Splay Tree中,真正的關鍵字是下標而不是值,因此,“序列中第i個結點”實際上對應的是“樹中第(i+1)小的結點”(因為左邊還有一個邊界結點),這就說明在對結點引用時需要找第K小的操作。因此,下面的“結點x”指的是“樹中第(x+1)小的結點”。
(2)標記:
在線段樹中,如果對一個結點所表示的線段整體進行了某種操作,需要在這個結點上打上一個標記,在下一次再找到這個結點時,其標記就會下放到其兩個子結點上。在Splay Tree中也可以引入標記。比如要對[2, 6]這一段進行整體操作,就將結點1伸展到根的位置,將結點7伸展到根的右子樹的位置,然后結點7的左子樹就表示[2, 6]這一段,對這棵子樹的根結點打上標記并立即生效(必須是立即生效,而不是等下一次引用再生效),也就是立即改變該結點記錄的一些信息的值。如果下次再次引用到這個結點,就要將其標記下放到其兩個子結點處;
需要注意的一點是,如果要伸展某個結點x到r的子結點的位置,就必須保證從x原來的位置到r的這個子結點(x伸展后的位置)上的所有結點上均沒有標記,否則就會導致標記混亂。因此,必須首先找到這個結點x,在此過程中不斷下放標記。
(3)自底向上維護的信息:
標記可以看成一種自頂向下維護的信息。除了標記以外,作為“線段樹”,往往還要維護一些自底向上維護的信息。比如在sequence這題中,就有lmax(左段連續最大和)、rmax(右段連續最大和)、midmax(全段連續最大和)以及sum(全段總和)等信息要維護。對于這類東東其實也很好辦,因為子樹大小(sz域)就是一種自底向上維護的信息,因此對于這些信息只要按照維護sz域的辦法維護即可(統一寫在upd函數里)。唯一的不同點是打標記時它們的值可能要改變。
(4)對連續插入的結點建樹:
本題的插入不是一個一個插入,而是一下子插入一整段,因此需要先將它們建成一棵樹。一般建樹操作都是遞歸的,這里也一樣。設目前要對A[l..r]建樹(A為待插入序列),若l>r則退出,否則找到位于中間的元素mid = l + r >> 1,將A[mid]作根,再對A[l..mid-1]建左子樹,對A[mid+1..r]建右子樹即可。這樣可以保證一開始建的就是一棵平衡樹,減小常數因子。
(5)回收空間:
根據本題的數據范圍提示,插入的結點總數最多可能達到4000000,但在任何時刻樹中最多只有500002個結點(包括兩個邊界),此時為了節省空間,可以采用循環隊列回收空間的方法。即:一開始將所有的可用空間(可用下標,本題為1~500002)存在循環隊列Q里,同時設立頭尾指針front和rear,每次如果有新結點插入,就取出Q[front]并作為新結點的下標,如果有結點要刪除(本題是一次刪除整棵子樹,因此在刪除后需要分別回收它們的空間),則從rear開始,將每個刪除的結點的下標放回到Q里。當然,這種方法是要犧牲一定的時間的,因此在空間不是特別吃緊的情況下不要用。
【2012年1月16日更新】
今天重寫sequence的時候,禿然發現加入的邊界點可能會對lmax、rmax、midmax等的維護造成影響:當序列中所有的值都是負數時,若邊界點的值設為0,將使這3個值也為0,所以,邊界點的值應設為-INF(不會影響到sum,因為可以單獨調出[l, r]的sum,避開邊界)。這就說明并非所有這樣的題中都可以設置邊界點(比如HFTSC2011的那題就不行),如果邊界點會對維護的信息造成影響,就不能設置邊界點,在各個操作中,分4種情況判斷。(代碼已經修改)
下面上代碼了:
#include <iostream>
#include <stdio.h>
using namespace std;
#define re(i, n) for (int i=0; i<n; i++)
#define re1(i, n) for (int i=1; i<=n; i++)
const int MAXN = 500002, NOSM = -2000, INF = ~0U >> 2;
struct node {
int v, c[2], p, sz, sum, lmax, rmax, midmax, sm;
bool rev, d;
} T[MAXN + 1];
int root, Q[MAXN + 1], front, rear, a[MAXN], len, res;
int max(int SS0, int SS1)
{
return SS0 >= SS1 ? SS0 : SS1;
}
int max(int SS0, int SS1, int SS2)
{
int M0 = SS0 >= SS1 ? SS0 : SS1; return M0 >= SS2 ? M0 : SS2;
}
void newnode(int n, int _v)
{
T[n].v = T[n].sum = T[n].lmax = T[n].rmax = T[n].midmax = _v; T[n].c[0] = T[n].c[1] = 0; T[n].sz = 1; T[n].sm = NOSM; T[n].rev = 0;
}
void sc(int _p, int _c, bool _d)
{
T[_p].c[_d] = _c; T[_c].p = _p; T[_c].d = _d;
}
void sm_opr(int x, int SM)
{
T[x].sum = T[x].sz * SM;
if (SM > 0) T[x].lmax = T[x].rmax = T[x].midmax = T[x].sum; else T[x].lmax = T[x].rmax = T[x].midmax = SM;
}
void rev_opr(int x)
{
int c0 = T[x].c[0], c1 = T[x].c[1]; sc(x, c0, 1); sc(x, c1, 0);
int tmp = T[x].lmax; T[x].lmax = T[x].rmax; T[x].rmax = tmp;
}
void dm(int x)
{
int SM0 = T[x].sm;
if (SM0 != NOSM) {
T[x].v = T[T[x].c[0]].sm = T[T[x].c[1]].sm = SM0; T[x].sm = NOSM;
sm_opr(T[x].c[0], SM0); sm_opr(T[x].c[1], SM0);
}
if (T[x].rev) {
T[T[x].c[0]].rev = !T[T[x].c[0]].rev; T[T[x].c[1]].rev = !T[T[x].c[1]].rev; T[x].rev = 0;
rev_opr(T[x].c[0]); rev_opr(T[x].c[1]);
}
}
void upd(int x)
{
int c0 = T[x].c[0], c1 = T[x].c[1];
T[x].sz = T[c0].sz + T[c1].sz + 1;
T[x].sum = T[c0].sum + T[c1].sum + T[x].v;
T[x].lmax = max(T[c0].lmax, T[c0].sum + T[x].v + max(T[c1].lmax, 0));
T[x].rmax = max(T[c1].rmax, max(T[c0].rmax, 0) + T[x].v + T[c1].sum);
T[x].midmax = max(T[c0].midmax, T[c1].midmax, max(T[c0].rmax, 0) + T[x].v + max(T[c1].lmax, 0));
}
void rot(int x)
{
int y = T[x].p; bool d = T[x].d;
if (y == root) {root = x; T[root].p = 0;} else sc(T[y].p, x, T[y].d);
sc(y, T[x].c[!d], d); sc(x, y, !d); upd(y);
}
void splay(int x, int r)
{
int p; while ((p = T[x].p) != r) if (T[p].p == r) rot(x); else if (T[x].d == T[p].d) {rot(p); rot(x);} else {rot(x); rot(x);} upd(x);
}
int Find_Kth(int K)
{
int i = root, S0;
while (i) {
dm(i); S0 = T[T[i].c[0]].sz + 1;
if (K == S0) break; else if (K < S0) i = T[i].c[0]; else {K -= S0; i = T[i].c[1];}
}
return i;
}
int mkt(int l, int r)
{
if (l > r) return 0;
int n0 = Q[front], mid = l + r >> 1; if (front == MAXN) front = 1; else front++;
newnode(n0, a[mid]); int l_r = mkt(l, mid - 1), r_r = mkt(mid + 1, r);
sc(n0, l_r, 0); sc(n0, r_r, 1); upd(n0); return n0;
}
void ins(int pos)
{
int P0 = Find_Kth(pos); splay(P0, 0); int P1 = Find_Kth(pos + 1); splay(P1, root); sc(P1, mkt(0, len - 1), 0); upd(P1); upd(P0);
}
void era(int x)
{
if (!x) return;
if (rear == MAXN) rear = 1; else rear++; Q[rear] = x;
era(T[x].c[0]); era(T[x].c[1]);
}
void del(int l, int r)
{
int P0 = Find_Kth(l - 1); splay(P0, 0); int P1 = Find_Kth(r + 1); splay(P1, root);
int root0 = T[P1].c[0]; sc(P1, 0, 0); upd(P1); upd(P0); era(root0);
}
void mksame(int l, int r, int x)
{
int P0 = Find_Kth(l - 1); splay(P0, 0); int P1 = Find_Kth(r + 1); splay(P1, root);
int n = T[P1].c[0]; T[n].sm = x; sm_opr(n, x); upd(P1); upd(P0);
}
void reve(int l, int r)
{
int P0 = Find_Kth(l - 1); splay(P0, 0); int P1 = Find_Kth(r + 1); splay(P1, root);
int n = T[P1].c[0]; T[n].rev = !T[n].rev; rev_opr(n); upd(P1); upd(P0);
}
int get_sum(int l, int r)
{
int P0 = Find_Kth(l - 1); splay(P0, 0); int P1 = Find_Kth(r + 1); splay(P1, root);
int n = T[P1].c[0]; return T[n].sum;
}
int max_sum()
{
return T[root].midmax;
}
void prepare()
{
T[0].sz = T[0].sum = T[0].lmax = T[0].rmax = T[0].midmax = 0;
front = 3; rear = MAXN; re1(i, MAXN) Q[i] = i;
newnode(1, -INF); newnode(2, -INF); sc(1, 2, 1); root = 1; T[root].p = 0;
}
int main()
{
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
prepare();
int m, l, r, x;
scanf("%d%d", &len, &m); char ch = getchar(), str[1000];
re(i, len) scanf("%d", &a[i]); ins(1);
re(i, m) {
scanf("%s", str);
if (!strcmp(str, "INSERT")) {scanf("%d%d", &l, &len); re(i, len) scanf("%d", &a[i]); ins(++l);}
if (!strcmp(str, "DELETE")) {scanf("%d%d", &l, &r); r += l++; del(l, r);}
if (!strcmp(str, "MAKE-SAME")) {scanf("%d%d%d", &l, &r, &x); r += l++; mksame(l, r, x);}
if (!strcmp(str, "REVERSE")) {scanf("%d%d", &l, &r); r += l++; reve(l, r);}
if (!strcmp(str, "GET-SUM")) {scanf("%d%d", &l, &r); r += l++; printf("%d\n", get_sum(l, r));}
if (!strcmp(str, "MAX-SUM")) printf("%d\n", max_sum());
ch = getchar();
}
fclose(stdin); fclose(stdout);
return 0;
}
最后把我的這個代碼與BYVoid神犇的本題代碼進行測試比較,結果(BYVoid神犇的代碼見
這里):
BYVoid神犇的:

本沙茶的:

【相關論文】
運用伸展樹解決數列維護問題 by JZP
【感謝】
JZP神犇(提供論文)
BYVoid神犇(提供標程)