【Python】heapqで優先度つきキューを理解する

先日のAtCoder Beginners Contest 141に参加したのだが、問題Dを解くことができなかった。
この問題を言い換えると、要は「数値型の要素で構成される配列において、最も値が大きいものを1/2にするという作業を繰り返し、最終的な要素の合計を求める」というものだった。そこで愚直に以下のようなコードを書いたら、あえなくTLEに。

N, M = list(map(int, input().split()))
price = list(map(int, input().split()))
total = sum(price)
for i in range(M):
    max_index = price.index(max(price))
    discount = max(price) // 2
    total -= discount
    price[max_index] -= discount
print(total)

max() で最も値が大きい要素を毎回探しているのだが、これだと配列の長さN * 処理の回数Mで、計算量がO(NM)になってしまう。
ここで使うべきだったのが、優先度つきキューである。

優先度つきキューとは

要素を挿入していってひとつずつ取り出せるところは通常のキューと同じだが、取り出す順序が先入先出や後入先出ではない。優先度つきキューでは、名前の通り各要素に優先度がついており、その優先度が高い順に要素を取り出すことができる。

ヒープの仕組み

優先度つきキューを実装したデータ構造のうち、最もよく使われているのがヒープである。ヒープは木構造の概念を持つ。木構造は木の末端のノードを除いて子ノードを持ち、親は子よりも必ず高い(低い)優先度を持っている。このようにすることで、優先度が最も高い(低い)要素を簡単に取り出すことができる。また、新たに要素を挿入するときは、親と自分を比較することを繰り返せば良い。
この木構造にも、二分探索木や平衡探索木など様々な実装がある。

Pythonのheapq

Pythonにもヒープを実装したheapqモジュールが用意されている。通常のPythonリストと同じ感覚で使えることを重視しているらしく、以下のような特徴がある。

  • indexが0始まり
  • heappop() するとヒープの中で値が最小の要素が返る

このため、Pythonのヒープaでは a[k] <= a[2*k+1] and a[k] <= a[2*k+2] が成り立ち、a[0]が最小の値を持つ。すなわち親が子よりも値の小さい要素を持つ2分探索木構造になっている。
あくまで2分探索木をリストの形で表しているだけなので、 a[len(a)-1] (リストの最後の要素)が最大の値を持つとは限らないことに注意*1

heapqの実装

heapqにおいて、新しく値を挿入するときは以下のような実装になっている。

def heappush(heap, item):
    """Push item onto heap, maintaining the heap invariant."""
    heap.append(item)
    _siftdown(heap, 0, len(heap)-1)

def _siftdown(heap, startpos, pos):
    newitem = heap[pos]
    # Follow the path to the root, moving parents down until finding a place
    # newitem fits.
    while pos > startpos:
        parentpos = (pos - 1) >> 1
        parent = heap[parentpos]
        if newitem < parent:
            heap[pos] = parent
            pos = parentpos
            continue
        break
    heap[pos] = newitem

新しい要素を挿入しうるpositionを末尾に設定し、その親にあたる要素と新しい要素を比較する。もし新しい要素の方が値が小さければ、親の要素を子階層に下ろし、親が元いた位置にpositionを移動させる、というのを繰り返している。比較の結果、新しい要素の方が値が小さいとなれば、その位置に新しい要素を挿入する。 parentpos = (pos - 1) >> 1 の部分はビット演算であり、インデックスを2進数で右に1ビットずらしている。こうすることで親のインデックスを求めることができる。

同様に、値を取り出すときは以下のような処理が行われる。

def heappop(heap):
    """Pop the smallest item off the heap, maintaining the heap invariant."""
    lastelt = heap.pop()    # raises appropriate IndexError if heap is empty
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        _siftup(heap, 0)
        return returnitem
    return lastelt

def _siftup(heap, pos):
    endpos = len(heap)
    startpos = pos
    newitem = heap[pos]
    # Bubble up the smaller child until hitting a leaf.
    childpos = 2*pos + 1    # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child.
        rightpos = childpos + 1
        if rightpos < endpos and not heap[childpos] < heap[rightpos]:
            childpos = rightpos
        # Move the smaller child up.
        heap[pos] = heap[childpos]
        pos = childpos
        childpos = 2*pos + 1
    # The leaf at pos is empty now.  Put newitem there, and bubble it up
    # to its final resting place (by sifting its parents down).
    heap[pos] = newitem
    _siftdown(heap, startpos, pos)

まず heap.pop() で最後の要素をpopさせる。次にheap[0]の値を記録したうえで、その2つの子同士を比較し、値が小さい方を親の階層へ上げるということを繰り返す。
この繰り返しの末に、最も下の階層に到達したら、値が親の階層へ上がって空((正確には空になるというのはあくまで概念にすぎず、実際には元々の値が残っている。 _siftup() は、まずheap[0]の位置に最後の要素を仮で挿入しておき、そのうえで「2つの子同士を比べて小さい方の値で親の要素を上書きする」ということを繰り返しているにすぎない。))になった場所を起点に _siftdown() を始め、最後の要素を入れる位置を確定させる。これでヒープ構造を保つことができる。あとはもともとのheap[0]の値を返せばよい。

また、通常のリストをヒープソートしてヒープに変換する heapify() というメソッドも用意されている。これの実装は以下の通り。

def heapify(x):
    n = len(x)
    for i in reversed(range(n//2)):
        _siftup(x, i)

forループによって指定されている n // 2 - 1 というのは、2分探索木において子要素を持つ要素のうち、インデックスが最大のものである。すなわち下から2番目の階層から、自分とその子要素だけによる部分的な2分探索木をヒープ構造になるようにするのを繰り返している。

heapqで問題を解く

冒頭の問題をheapqを使って解くと、以下のように簡単に書ける。heappop() は最小の値の要素を返すため、最大の値を返してもらうためには-1を掛けておく必要がある。

import heapq

N, M = list(map(int, input().split()))
price = list(map(lambda x: int(x) * (-1), input().split()))

heapq.heapify(price)

for i in range(M):
    max_price = -(heapq.heappop(price))
    new_price = max_price // 2
    heapq.heappush(price, -(max_price // 2))

print(-(sum(price)))

ヒープへの要素の挿入・削除の計算量はO(logN)であり、M回繰り返してもO(logN * M)なので、処理時刻は大幅に短くなる。 heapify() の計算量も最悪でもO(N * logN)である。

*1:例えば、空のヒープに3, 2, 4, 9, 7の順で値を入れると、ヒープの中身は[2, 3, 4, 9, 7]となる