【Python】heapqで優先度つきキューを理解する
先日のAtCoder Beginners Contest 141に参加したのだが、問題Dを解くことができなかった。
この問題を言い換えると、要は「数値型の要素で構成される配列において、最も値が大きいものを1/2にするという作業を繰り返し、最終的な要素の合計を求める」というものだった。そこで愚直に以下のようなコードを書いたら、あえなくTLEに。
N, M = list(map(int, input().split())) price = list(map(int, input().split())) total = sum(price) for i in range(M): max_index = price.index(max(price)) discount = max(price) // 2 total -= discount price[max_index] -= discount print(total)
max()
で最も値が大きい要素を毎回探しているのだが、これだと配列の長さN * 処理の回数Mで、計算量がO(NM)になってしまう。
ここで使うべきだったのが、優先度つきキューである。
優先度つきキューとは
要素を挿入していってひとつずつ取り出せるところは通常のキューと同じだが、取り出す順序が先入先出や後入先出ではない。優先度つきキューでは、名前の通り各要素に優先度がついており、その優先度が高い順に要素を取り出すことができる。
ヒープの仕組み
優先度つきキューを実装したデータ構造のうち、最もよく使われているのがヒープである。ヒープは木構造の概念を持つ。木構造は木の末端のノードを除いて子ノードを持ち、親は子よりも必ず高い(低い)優先度を持っている。このようにすることで、優先度が最も高い(低い)要素を簡単に取り出すことができる。また、新たに要素を挿入するときは、親と自分を比較することを繰り返せば良い。
この木構造にも、二分探索木や平衡探索木など様々な実装がある。
Pythonのheapq
Pythonにもヒープを実装したheapqモジュールが用意されている。通常のPythonリストと同じ感覚で使えることを重視しているらしく、以下のような特徴がある。
- indexが0始まり
heappop()
するとヒープの中で値が最小の要素が返る
このため、Pythonのヒープaでは a[k] <= a[2*k+1] and a[k] <= a[2*k+2]
が成り立ち、a[0]が最小の値を持つ。すなわち親が子よりも値の小さい要素を持つ2分探索木構造になっている。
あくまで2分探索木をリストの形で表しているだけなので、 a[len(a)-1] (リストの最後の要素)が最大の値を持つとは限らないことに注意。*1
heapqの実装
heapqにおいて、新しく値を挿入するときは以下のような実装になっている。
def heappush(heap, item): """Push item onto heap, maintaining the heap invariant.""" heap.append(item) _siftdown(heap, 0, len(heap)-1) def _siftdown(heap, startpos, pos): newitem = heap[pos] # Follow the path to the root, moving parents down until finding a place # newitem fits. while pos > startpos: parentpos = (pos - 1) >> 1 parent = heap[parentpos] if newitem < parent: heap[pos] = parent pos = parentpos continue break heap[pos] = newitem
新しい要素を挿入しうるpositionを末尾に設定し、その親にあたる要素と新しい要素を比較する。もし新しい要素の方が値が小さければ、親の要素を子階層に下ろし、親が元いた位置にpositionを移動させる、というのを繰り返している。比較の結果、新しい要素の方が値が小さいとなれば、その位置に新しい要素を挿入する。
parentpos = (pos - 1) >> 1
の部分はビット演算であり、インデックスを2進数で右に1ビットずらしている。こうすることで親のインデックスを求めることができる。
同様に、値を取り出すときは以下のような処理が行われる。
def heappop(heap): """Pop the smallest item off the heap, maintaining the heap invariant.""" lastelt = heap.pop() # raises appropriate IndexError if heap is empty if heap: returnitem = heap[0] heap[0] = lastelt _siftup(heap, 0) return returnitem return lastelt def _siftup(heap, pos): endpos = len(heap) startpos = pos newitem = heap[pos] # Bubble up the smaller child until hitting a leaf. childpos = 2*pos + 1 # leftmost child position while childpos < endpos: # Set childpos to index of smaller child. rightpos = childpos + 1 if rightpos < endpos and not heap[childpos] < heap[rightpos]: childpos = rightpos # Move the smaller child up. heap[pos] = heap[childpos] pos = childpos childpos = 2*pos + 1 # The leaf at pos is empty now. Put newitem there, and bubble it up # to its final resting place (by sifting its parents down). heap[pos] = newitem _siftdown(heap, startpos, pos)
まず heap.pop()
で最後の要素をpopさせる。次にheap[0]の値を記録したうえで、その2つの子同士を比較し、値が小さい方を親の階層へ上げるということを繰り返す。
この繰り返しの末に、最も下の階層に到達したら、値が親の階層へ上がって空((正確には空になるというのはあくまで概念にすぎず、実際には元々の値が残っている。 _siftup()
は、まずheap[0]の位置に最後の要素を仮で挿入しておき、そのうえで「2つの子同士を比べて小さい方の値で親の要素を上書きする」ということを繰り返しているにすぎない。))になった場所を起点に _siftdown()
を始め、最後の要素を入れる位置を確定させる。これでヒープ構造を保つことができる。あとはもともとのheap[0]の値を返せばよい。
また、通常のリストをヒープソートしてヒープに変換する heapify()
というメソッドも用意されている。これの実装は以下の通り。
def heapify(x): n = len(x) for i in reversed(range(n//2)): _siftup(x, i)
forループによって指定されている n // 2 - 1
というのは、2分探索木において子要素を持つ要素のうち、インデックスが最大のものである。すなわち下から2番目の階層から、自分とその子要素だけによる部分的な2分探索木をヒープ構造になるようにするのを繰り返している。
heapqで問題を解く
冒頭の問題をheapqを使って解くと、以下のように簡単に書ける。heappop()
は最小の値の要素を返すため、最大の値を返してもらうためには-1を掛けておく必要がある。
import heapq N, M = list(map(int, input().split())) price = list(map(lambda x: int(x) * (-1), input().split())) heapq.heapify(price) for i in range(M): max_price = -(heapq.heappop(price)) new_price = max_price // 2 heapq.heappush(price, -(max_price // 2)) print(-(sum(price)))
ヒープへの要素の挿入・削除の計算量はO(logN)であり、M回繰り返してもO(logN * M)なので、処理時刻は大幅に短くなる。 heapify()
の計算量も最悪でもO(N * logN)である。
*1:例えば、空のヒープに3, 2, 4, 9, 7の順で値を入れると、ヒープの中身は[2, 3, 4, 9, 7]となる