技術メモ

役に立てる技術的な何か、時々自分用の覚書。幅広く色々なことに興味があります。

ダイクストラ法がやっと理解できた

ダイクストラ法というアルゴリズム
重み付きの経路の最短経路を求めるアルゴリズムとして有名である。
そのダイクストラ法というアルゴリズム、存在自体は知っていたが何度か自分で実装しようとしてはなぜそれが最短経路を求めるアルゴリズムになるのか腹落ちしないまま諦めかけていた。
だが恥ずかしながらようやくわかったので、わかった過程と実装を紹介する。

直感的に理解するには身近な問題に置き換えれば良い。
最短経路を求める問題は紐付きのビー玉を持ち上げる問題と同じなのだ。
そう考えればようやくわかってきた。
問題はこうだ

Q. それぞれが大小様々な紐でつながったビー玉がN個ある。(それぞれのビー玉は必ずしもお互いにつながっている必要はないが、どのビー玉にも繋がっていないものはないとする)
その中の1個を持ち上げることを考える。
この時、最後に持ち上がるビー玉の高さはどれくらいか?

このビー玉持ち上げ問題を考える時はこう考えるのがわかりやすい

ビー玉をそれぞれa0, a1, a2, ....aNとし、aiとajに繋がる紐の長さをLijとする。
最初にa0を持ち上げ始める。
すると次に地面から持ち上がるビー玉はなんだろう。
それは、a0に繋がっている紐のうち最小のものと繋がったビー玉aiになる。

ではさらにその次に地面から持ち上がるビー玉はなんだろう。
a0と紐でつながっているビー玉かaiと紐でつながっているビー玉のうち、L0j もしくは L0i + Lijが最小になるj、ajである。

ではさらににその次は...?
ここで仮にL0j > L0i + Lij だったとしよう。
すると今度は L0sもしくはL0i + LisもしくはL0i + Lij + Ljs が最小になるasになる。

さらにその次は...?

同じようにして考えいくとアルゴリズムの考え方の道筋が見えてくる。
そして最短経路はそれぞれのビー玉に対してどの紐によって持ち上がったかを記憶しておけば最短経路を辿ることができる。
(ここからはグラフの用語にしたがってノードと辺という言葉を使う。)

1. 基準となるノードから最小の辺に繋がるノードを選択し、「探索済み」としておく。探索済みのノードは同時にそれぞれ最小距離を保持しておく。
2. 探索済みのノードから繋がっている全てのノードに対して、最短距離を計算する。この時最短距離の計算は以下の通り。
  最短距離=全ての探索済みノードに対して「探索済みのノードが保持している最短距離+そのノードとの辺の重み」の最小値
3. 2で求めた最短距離が最小になるノードを選択肢「探索済み」としておく。
4. 全てのノードが探索済みになるまで2,3を繰り返す。


これをもう少し効率的に改変してみよう。
2の計算において全ての探索済みノードを毎回取り出してその度に最短距離を計算するのでは効率が悪い。
各ノードに対して「探索済みノードからの最短距離」を持っておいて、ノードが選択される毎にその最短距離を更新すれば良いのではないか。そうすればノードが選択されるたびに全探索済みノードに繋がっている辺を見に行かなくても選択されたノードから繋がっている辺を見るだけで済む。
ということで効率的にしたアルゴリズムはこのようになる

1. 各ノードに対して最短距離を表す配列Dを用意する。0番目を基準のノードとするとD=[0, ∞, ∞, ... ∞]とする。
2. Dの中で最小の値をとるノードiを選択し、iを探索済みとして記録する。
3. iから繋がる全ての辺Lijに対して、以下の規則でDを更新する。
  D[i] + Lij < D[j] ならば D[j] ← D[i] + Lij
4. 全てのノードが探索済みになるまで2,3を繰り返す。


これであれば実装できそうだ。
...ということでPythonで簡単に実装してみる。

# データ
links = {
    (0,1): 1.4,
    (0,2): 3.1,
    (0,3): 4.6,
    (1,2): 2.9,
    (2,3): 1.1,
    (1,3): 3.1,
    (1,4): 2.7,
    (4,1): 2.3,
    (2,4): 2.1,
    (3,4): 1.8,
}
nodes = 5
d = [0] + [1000] * (nodes -1)
prev = [0] * nodes
completed = [False] * nodes

# dの最小値の位置を求める
def argmin(d):
    idx = None
    min_val = 10000
    for i, j in enumerate(d):
        if j < min_val and completed[i] == False:
            min_val = j
            idx = i
    return idx

# 全てが探索済みになるまで繰り返す
while not all(completed):
    selectd = argmin(d)
    completed[selectd] = True
    for i, j in links:
        if i==selectd:
            if d[i] + links[(i,j)] < d[j]:
                d[j] = d[i] + links[(i,j)]
                prev[j] = i
        
print(d)
print(prev)