• 再帰呼び出しで実装するのが自然な問題がある
    • 例えば、たどる順序が自明でない動的計画法をメモ化再帰呼び出しで実装するシチュエーション
  • Pythonは数値演算が重いので演算の多い動的計画法はTLEしやすい
  • PyPyだと関数呼び出しが遅いので再帰呼び出しを酷使するとTLEしやすい
  • 普段はNumbaでAOTコンパイルして高速化してるのだけど…
    • Numbaは関数内関数の再帰呼び出しをサポートしてない
    • 関数外に置いてもAOTではできない
    • JITならできるが実行フェーズにコンパイルするので速度上不利
  • Cythonだと関数を「Cからしか呼ばれない」と明示してコンパイルできるので有効打になり得る?
    • →現時点で最速

ターゲット問題としてグラフの最長パスを見つけるG - Longest Pathを使う

  • 速度の比較にはAtCoderのサーバ上でのテストケース1_17, 1_01の実行速度を使う
    • 1_01が一番大きいが、タイムアウトすると時間が分からないので一回り小さい1_17も併用した

実行時間まとめ | 1_17 | 1_01 | | — | — | — | — | | 653 ms | TLE | Code1 Python | 素朴なメモ化再帰 | | 422 ms | TLE | Code1 PyPy | | 735 ms | TLE | Code2 Python | メモ化 dict→list | | 378 ms | 485 ms | Code2 PyPy | | 434 ms | 565 ms | Code3 Python | 関数呼び出し前に計算済みかをチェック | | 498 ms | 352 ms | Code3 PyPy | | 169 ms | 223 ms | Code4 Cython | | 142 ms | 177 ms | Code5 Cython | メモ化 array→Cの配列 | | 148 ms | 164 ms (best) | Code6 Cython | 探索終了条件分岐を再帰の外で先に済ませる | | 1209 ms | 1225 ms | Code7 Numba | JIT | | 295 ms | 368 ms | Code8 Python | 深さ優先探索の帰りがけ順で処理 | | 253 ms | 256 ms | Code8 PyPy |

Code1: 素朴なメモ化再帰

from collections import defaultdict
import sys

sys.setrecursionlimit(10**6)


def solve(N, M, edges):
    longest = {}

    def get_longest(start):
        if start in longest:
            return longest[start]

        next_edges = edges.get(start)
        if not next_edges:
            ret = 0
        else:
            ret = max(get_longest(v) for v in edges[start]) + 1
        longest[start] = ret
        return ret

    return max(get_longest(v) for v in edges)


def main():
    N, M = map(int, input().split())
    edges = defaultdict(set)
    for i in range(M):
        v1, v2 = map(int, input().split())
        edges[v1].add(v2)

    print(solve(N, M, edges))


main()

Code2: メモ化にdictを使ってることを疑問に思う人がいるかもしれないのでlistに変えたバージョン

def solve(N, M, edges):
    longest = [-1] * (N + 1)
    for i in range(N + 1):
        if not edges[i]:
            longest[i] = 0

    def get_longest(start):
        ret = longest[start]
        if ret != -1:
            return ret

        next_edges = edges.get(start)
        if not next_edges:
            ret = 0
        else:
            ret = max(get_longest(v) for v in edges[start]) + 1
        longest[start] = ret
        return ret

    return max(get_longest(v) for v in edges)

Code3 関数呼び出し前に計算済みかどうかをチェックするバージョン

Code4 Cython cython

cdef get_longest(long start, dict edges, long[:] longest):
    if longest[start] != -1:
        return longest[start]

    cdef list next_edges
    next_edges = edges.get(start)
    if not next_edges:
        ret = 0
    else:
        #ret = max(get_longest(v, edges, longest) for v in edges[start]) + 1
        ret = 0
        for v in edges[start]:
            x = get_longest(v, edges, longest) + 1
            if x > ret:
                ret = x

    longest[start] = ret
    return ret


def solve(N, M, edges):
    cdef array.array longest = pyarray.array('l', [-1] * (N + 1))
    return max(get_longest(v, edges, longest) for v in edges)	 

Code5 longestをCの配列としてグローバルに置く

Code6 出て行く辺があるかどうかの条件分岐を再帰の外で先に済ませる

Numba

  • 速度はさておき何もしなくてもそのまま動くCythonと比べると、Numbaは動くようにするまでが大変
    • ret = max(get_longest(v) for v in edges[start]) + 1
      • The use of yield in a closure is unsupported.
    • https://gist.github.com/nishio/dd3013df3e88ef1afb0d41d5980a3882
      • Compilation is falling back to object mode WITH looplifting enabled because Function "get_longest" failed type inference due to: non-precise type pyobject
      • 単純にnumba.jitをつけるアプローチではうまくいかない、型推論ができるオブジェクトを引数にする必要がある
  • 隣接リスト形式のグラフの扱いが難関
    • Pythonで気楽に作るとdafaultdict(list)だが、dictもlistも適切でない
    • 可変長なので雑にnp.arrayにすると空間がN * M
    • 整数列として渡してNumba世界で連結リストにするのが一番マシかなと思う
      • 今回の問題ではグラフの辺の最大数が決まってるのでそのサイズのnp.array を確保する

Code7: Numba JIT

  • 1209 ms / 1225 ms https://atcoder.jp/contests/dp/submissions/14915733
  • 今までの中で格段に遅い、よく通ったもんだ
    • しかし二つの問題にかかった時間の差は16msecで、一番速いCython実装と同じぐらい
    • つまりJITであるせいで実行フェーズでコンパイルしてしまい時間を食っているが、コンパイル済みのものは高速ということ
    • 僕が普段NumbaをAOTで使うのもこれが理由。コンパイルフェーズでAOTコンパイルできるなら実行フェーズでJITコンパイルする必要は皆無
  • Numba AOTはコンパイル時にエラーになる
    • Untyped global name 'get_longest': cannot determine Numba type of <class 'function'>
      • これは再帰呼び出しだと人間が認識しているものが、Pythonにとっては「グローバル変数を取得してそれを呼ぶ」だから。
        • 関数定義の後に同じ名前で別のものに名前が再束縛されるかもしれないので、取得したグローバル変数の型が不明なのである。
        • とはいえこれは不便なので、例えばコンパイル時に「この関数内での同名変数のコールはこの関数自体のコールである」と人間が宣言することで再帰呼び出しを可能にするとかが、将来のNumbaには入るかもしれない、入るといいなぁ
        • JITではできてるんだからAOTできてもいいじゃんー
    • というわけで今のところ末尾再帰でない再帰関数をNumbaでAOTする方法は見つけられていない

Code8: 深さ優先探索の帰りがけ順で処理をする案

  • そもそも再帰呼び出しを使わなくてもコールツリーを深さ優先で探索して帰りに処理をすれば良い、という発想 python
def solve(N, M, edges):
    longest = {}

    stack = [v for v in edges]

    while stack:
        v = stack.pop()
        if v > 0:
            if v in longest:
                continue
            next_edges = edges.get(v)
            stack.append(-v)
            if next_edges:
                stack.extend(next_edges)
        else:
            next_edges = edges.get(-v)
            if not next_edges:
                ret = 0
            else:
                ret = max(longest[x] for x in next_edges) + 1
            longest[-v] = ret

    return max(longest[v] for v in edges)

Python 295 ms / 368 ms https://atcoder.jp/contests/dp/submissions/14916269 PyPy 253 ms / 256 ms https://atcoder.jp/contests/dp/submissions/14916290