セグメント木の概念を、単純なセグメント木だけではなく双対セグメント木遅延伝搬セグメント木まで通じて整理したい

  • Pythonであることを生かして、二項演算や作用はハードコードせずに引数として渡す
  • 理解が目的なので速度のことは気にしない
  • 1-origin、サイズは2ベキ
  • Segment Tree のお勉強(2) | maspyのHPを参考にした
  • ソースコードはこちら
    • ここの解説とソースのdocが食い違っている時はソースの方が正しい、doctestで回帰テストしてるから。

下記ではセグメント木、双対セグメント木、遅延伝搬セグメント木、までを実装してAOJのテストケースでベリファイしているが、遅延伝搬セグメント木に関してはTLEしていて、柔軟な設計のせいでTLEしてるのかアルゴリズム的に何か間違えてるのかはまだ識別できてない

IDはこういう順番になっている python

>>> debugprint(range(SEGTREE_SIZE))
|                       1                       |
|           2           |           3           |
|     4     |     5     |     6     |     7     |
|  8  |  9  |  10 |  11 |  12 |  13 |  14 |  15 |
|16|17|18|19|20|21|22|23|24|25|26|27|28|29|30|31| 

配列が与えられたときに隣り合うものに二項演算をかけながらO(N)で構築できる

  • ここでは和を求めてる、この構築された木を使って範囲の和が高速に求められる python
>>> table = [None] * SEGTREE_SIZE
>>> set_items(table, range(16))
>>> full_up(table, lambda x, y: f"{x}+{y}")
>>> debugprint(table, 3)
|             0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15             |
|        0+1+2+3+4+5+6+7        |     8+9+10+11+12+13+14+15     |
|    0+1+2+3    |    4+5+6+7    |   8+9+10+11   |  12+13+14+15  |
|  0+1  |  2+3  |  4+5  |  6+7  |  8+9  | 10+11 | 12+13 | 14+15 |
| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15|

点更新

  • 末端の値に関数を掛ける(作用と呼んだりする)
    • 親に作用を伝搬していくことがlogNでできる
    • 例えば末端を+5すると、親も+5されるのでこれでセグメント木全体を正しく更新できる python
>>> point_update(table, 3, lambda x: f"f({x})")
>>> debugprint(table, 4)
|                    f(0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15)                   |
|           f(0+1+2+3+4+5+6+7)          |         8+9+10+11+12+13+14+15         |
|     f(0+1+2+3)    |      4+5+6+7      |     8+9+10+11     |    12+13+14+15    |
|   0+1   |  f(2+3) |   4+5   |   6+7   |   8+9   |  10+11  |  12+13  |  14+15  |
| 0  | 1  | 2  |f(3)| 4  | 5  | 6  | 7  | 8  | 9  | 10 | 11 | 12 | 13 | 14 | 15 |

別のパターン

  • 末端の値が、以前の値と無関係に設定される
    • この場合、親が2つの子の二項演算をやりながら上に伝搬していけばよい
  • 変化を作用で表現できないケースもこちら
    • 例えばmaxを取る演算で値が減少する時
    • 値が増加しかしない時は「新しい値とmaxを取る」という作用でできる
    • 値が減少する時max(x, y)を求めた後でxが0 になってもyの値が失われてるので戻せない python
>>> point_set(table, 3, "x", lambda x, y: f"{x}+{y}")
>>> debugprint(table, 3)
|             0+1+2+x+4+5+6+7+8+9+10+11+12+13+14+15             |
|        0+1+2+x+4+5+6+7        |     8+9+10+11+12+13+14+15     |
|    0+1+2+x    |    4+5+6+7    |   8+9+10+11   |  12+13+14+15  |
|  0+1  |  2+x  |  4+5  |  6+7  |  8+9  | 10+11 | 12+13 | 14+15 |
| 0 | 1 | 2 | x | 4 | 5 | 6 | 7 | 8 | 9 | 10| 11| 12| 13| 14| 15|

max二項演算を上に伝搬していくデモ python

# Point update, range max

>>> table = [0] * SEGTREE_SIZE
>>> set_items(table, range(16))
>>> full_up(table, max)
>>> debugprint(table)
|                       15                      |
|           7           |           15          |
|     3     |     7     |     11    |     15    |
|  1  |  3  |  5  |  7  |  9  |  11 |  13 |  15 |
|0 |1 |2 |3 |4 |5 |6 |7 |8 |9 |10|11|12|13|14|15|

>>> set_item(table, 5, 99)
>>> debugprint(table)
|                       15                      |
|           7           |           15          |
|     3     |     7     |     11    |     15    |
|  1  |  3  |  5  |  7  |  9  |  11 |  13 |  15 |
|0 |1 |2 |3 |4 |99|6 |7 |8 |9 |10|11|12|13|14|15|

>>> up_propagate_from_leaf(table, 5, max)
>>> debugprint(table)
|                       99                      |
|           99          |           15          |
|     3     |     99    |     11    |     15    |
|  1  |  3  |  99 |  7  |  9  |  11 |  13 |  15 |
|0 |1 |2 |3 |4 |99|6 |7 |8 |9 |10|11|12|13|14|15|

範囲縮約

  • 範囲に対して二項演算で縮約できる
    • Pythonだとreduce
    • Haskellだとfold系、Rubyだとinjectとか
    • セグメント木文脈ではRange sumとかRange maxみたいに具体的な二項演算を特定した表現をよく見る
    • 「範囲内の複数の値に対して繰り返し二項演算を適用して1つの値にすること」を指す言葉が見つからなかったのでここでは範囲縮約と呼んでおく
      • メジャーな言い方ではなさげ
      • 「畳み込み」と呼ぶ人もいる(多義語…)
  • 二項演算は結合法則を満たす必要がある
    • 加算、乗算、maxはOK
  • 交換法則は満たさなくて良い
    • …ように実装できるが整数の加算、乗算、maxでは交換法則が成り立つので、成り立たなくても大丈夫な実装にしてない人もいるかもね
    • 交換法則を満たさないケース
      • 値が行列で二項演算が行列積の場合
      • ARC008_D
        • 二項演算がではなく なので交換できない python
>>> set_items(table, range(16))
>>> full_up(table, lambda x, y: f"{x}+{y}")
>>> range_reduce(table, 3, 11, lambda x, y: f"{x}+{y}", unity="0")
'0+3+4+5+6+7+8+9+10+0'

# Bin-op must be associative
>>> range_reduce(table, 3, 11, lambda x, y: f"({x}+{y})", unity="0")
'(((0+3)+4+5+6+7)+(8+9+(10+0)))'

点更新と範囲縮約の組み合わせで点更新と範囲和ができる python

>>> table = [0] * SEGTREE_SIZE
>>> set_items(table, range(16))
>>> full_up(table, lambda x, y: f"{x}+{y}")
>>> debugprint(table)
|     0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15     |
|    0+1+2+3+4+5+6+7    | 8+9+10+11+12+13+14+15 |
|  0+1+2+3  |  4+5+6+7  | 8+9+10+11 |12+13+14+15|
| 0+1 | 2+3 | 4+5 | 6+7 | 8+9 |10+11|12+13|14+15|
|0 |1 |2 |3 |4 |5 |6 |7 |8 |9 |10|11|12|13|14|15|

>>> point_update(table, 5, lambda x: f"{x}+99")
>>> debugprint(table)
|                    0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+99                   |
|           0+1+2+3+4+5+6+7+99          |         8+9+10+11+12+13+14+15         |
|      0+1+2+3      |     4+5+6+7+99    |     8+9+10+11     |    12+13+14+15    |
|   0+1   |   2+3   |  4+5+99 |   6+7   |   8+9   |  10+11  |  12+13  |  14+15  |
| 0  | 1  | 2  | 3  | 4  |5+99| 6  | 7  | 8  | 9  | 10 | 11 | 12 | 13 | 14 | 15 |

>>> range_reduce(table, 3, 11, lambda x, y: f"{x}+{y}", "0")
'0+3+4+5+6+7+99+8+9+10+0'

双対セグメント木

  • 範囲更新、点取得にできる

範囲更新

  • left以上right未満の範囲に対して作用を掛ける
    • 末端の要素に対してかけないことでlogNオーダーになっている python
>>> table = [""] * SEGTREE_SIZE
>>> set_items(table, range(16))
>>> range_update(table, 1, 11, lambda x: f"f")
>>> debugprint(table)
|                                               |
|                       |                       |
|           |     f     |           |           |
|     |  f  |     |     |  f  |     |     |     |
|0 |f |2 |3 |4 |5 |6 |7 |8 |9 |f |11|12|13|14|15|

>>> range_update(table, 3, 15, lambda x: f"{x}g")
>>> debugprint(table)
|                                                               |
|                               |                               |
|               |       fg      |       g       |               |
|       |   f   |       |       |   f   |       |   g   |       |
| 0 | f | 2 | 3g| 4 | 5 | 6 | 7 | 8 | 9 | f | 11| 12| 13|14g| 15|
  • これを使って、範囲更新で(+1)と(+2)の作用を範囲に適用する python
>>> table = [0] * SEGTREE_SIZE
>>> range_update(table, 1, 11, lambda x: x + 1)
>>> debugprint(table, maxsize=4)
|               0               |
|       0       |       0       |
|   0   |   1   |   0   |   0   |
| 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
|0|1|0|0|0|0|0|0|0|0|1|0|0|0|0|0|

>>> range_update(table, 3, 15, lambda x: x + 2)
>>> debugprint(table, maxsize=4)
|               0               |
|       0       |       0       |
|   0   |   3   |   2   |   0   |
| 0 | 1 | 0 | 0 | 1 | 0 | 2 | 0 |
|0|1|0|2|0|0|0|0|0|0|1|0|0|0|2|0|
  • 5番目の値が欲しいとき、親からそこまで下向きに作用を伝搬させれば良い python
>>> down_propagate_to_leaf(table, 5, add, 0)
>>> debugprint(table, maxsize=4)
|               0               |
|       0       |       0       |
|   0   |   0   |   2   |   0   |
| 0 | 1 | 0 | 3 | 1 | 0 | 2 | 0 |
|0|1|0|2|3|3|0|0|0|0|1|0|0|0|2|0|

>>> down_propagate_to_leaf(table, 9, add, 0)
>>> debugprint(table, maxsize=4)
|               0               |
|       0       |       0       |
|   0   |   0   |   0   |   0   |
| 0 | 1 | 0 | 3 | 0 | 2 | 2 | 0 |
|0|1|0|2|3|3|0|0|3|3|1|0|0|0|2|0|

Verified

def main():
    N, Q = map(int, input().split())
    depth = N.bit_length() + 1
    set_depth(depth)
    table = [INF] * SEGTREE_SIZE

    for _q in range(Q):
        q, x, y = map(int, input().split())
        if q == 0:
            # update
            point_set(table, x, y, min)
        else:
            # find
            print(range_reduce(table, x, y + 1, min, INF))
def main():
    from operator import add
    N, Q = map(int, input().split())
    depth = N.bit_length() + 1
    set_depth(depth)
    table = [0] * SEGTREE_SIZE

    for _q in range(Q):
        q, x, y = map(int, input().split())
        if q == 0:
            # add
            point_update(table, x, lambda x: x + y)
        else:
            # sum
            print(range_reduce(table, x, y + 1, add, 0))
def main():
    from operator import add
    N, Q = map(int, input().split())
    depth = N.bit_length() + 1
    set_depth(depth)
    table = [0] * SEGTREE_SIZE

    for time in range(Q):
        q, *args = map(int, input().split())
        if q == 0:
            # update
            s, t, value = args
            range_update(table, s, t + 1, lambda x: x + value)

        else:
            # find
            print(down_propagate_to_leaf(
                table, args[0], add, 0))