技術メモ

役に立てる技術的な何か、時々自分用の覚書。幅広く色々なことに興味があります。

3次スプライン補間を実装する

3次スプライン補間

補間とは?

補間とは、点群に対してその点群を通る連続な関数を当てはめること。
もっとも単純なものは線形補間で、いわゆる折れ線グラフのようなもの。
特に3次スプライン補間とは点群間を3次多項式で表し2階微分まで連続にするような補間方法を指す。
他にもラグランジュ補間などがあり、これはN個の点に対してN次関数で補間する方法を言う。

ちなみに、点群に異常値が含まれていたり点の数が多い場合、必ずしも点を通る関数である必要はない。
そのように必ずしも点を通らないN次関数を求めることは回帰と言ったりする。(100個の点に対して3次関数でフィッティングするような場合)

計算方法

3次スプライン補間はなぜ解けるか

そもそも、3次スプライン補間はなぜ解けるか。
結論から言うと、3次スプライン補間のアルゴリズムは単純な連立方程式に帰着する。
求める変数の数分だけの制約条件(方程式)があるから解が存在する。


各点
 (x_i, y_i),  i=[1,n+1]
区間を曲線
y=S_i(x), i=[1,n]
で補完するとする。
求めるべき変数は各曲線に対し4個あり、4N個
一方制約条件として以下がある。
1. S_i(x)は各点群を通る → 2N個の方程式
2. (両端を除く)各点で一階微分が連続 → N-1個の方程式
3. (両端を除く)各点で二階微分が連続 → N-1個の方程式

合計4N-2個の制約条件がある。
変数4N個に対し、方程式が4N-2個
このままでは解けないので、両端の点において2次の微分係数が0という制約を加える。
こうすることで連立方程式の問題に帰着される。

具体的なアルゴリズム

上の説明を数式で表現する。
変数a_{ij}を用いて各曲線を以下で表す
$$S_i(x) = \sum_{j=0}^{3}a_{ij}x^{j}$$
ここで、以下の制約条件を用いて
1. S_i(x)は各点群を通る
2. (両端を除く)各点で一階微分が連続
3. (両端を除く)各点で二階微分が連続
4. 両端の点において2次の微分係数が0

$$ y_i = S_i(x_i),i=[1,n]\\
y_{i+1} = S_i(x_{i+1}) ,i=[1,n]\\
S'_i(x_i) = S'_{i+1}(x_i) ,i=[1,n-1]\\
S''_i(x_i) = S''_{i+1}(x_i), i=[1,n-1] \\
S''_1(x_1) = 0 , S''_{n}(x_{n+1}) = 0 $$
を得る。
a_{ij}を用いて表すと
$$y_i = a_{i3}x^3_{i} + a_{i2}x^2_{i} + a_{i1}x_i + a_{i0}\\
y_{i+1} = a_{i3}x^3_{i+1} + a_{i2}x^2_{i+1} + a_{i1}x_{i+1} + a_{i0}\\
3a_{i3}x^2_{i+1} + 2a_{i2}x_{i+1} + a_{i1} = 3a_{(i+1)3}x^2_{i+1} + 2a_{(i+1)2}x_{i+1} + a_{(i+1)1}\\
6a_{i3}x_{i+1} + 2a_{i2} = 6a_{(i+1)3}x_{i+1} + 2a_{(i+1)2}\\
6a_{13}x_1 + 2a_{12} = 0 \\
6a_{n3}x_{n+1} + 2a_{n2} = 0$$
となる。
これで解くべき方程式が求まった。

あとは行列形式に変換し逆行列を求めることになる。
連立方程式を解くことは逆行列を求めることと同じ。

Pythonコード

以上のアルゴリズムPythonで実装してみた。
逆行列を求めるのには、逆行列を計算するのに便利なnumpyを使った。

import numpy as np
import sys
import matplotlib.pyplot as plt

class  cubicSpline:
    def __init__(self, x,y):
        # Xが昇順になるようにソートする
        points = sorted(zip(x, y))
        self.x, self.y = np.array(list(zip(*points)), dtype=float)
        # フィッティングのために各係数を求めておく
        self.__initialize(self.x, self.y)

    def __initialize(self,x,y):
        xlen = len(x) # 点の数
        N = xlen - 1 # 求めるべき変数の数(=方程式の数)

        # Xが一致する値を持つ場合例外を発生させる
        if xlen != len(set(x)): raise ValueError("x must be different values")

        matrix = np.zeros([4*N, 4*N])
        Y = np.zeros([4*N])

        equation = 0
        for i in range(N):
            for j in range(4):
                matrix[equation, 4*i+j] = pow(x[i], j)
            Y[equation] = y[i]
            equation += 1
        for i in range(N):
            for j in range(4):
                matrix[equation, 4*i+j] = pow(x[i+1], j)
            Y[equation] = y[i+1]
            equation += 1
        for i in range(N-1):
            for j in range(4):
                matrix[equation, 4*i+j] = j*pow(x[i+1], j-1)
                matrix[equation, 4*(i+1)+j] = -j*pow(x[i+1], j-1)
            equation += 1
        for i in range(N-1):
            matrix[equation, 4*i+3] = 3*x[i+1]
            matrix[equation, 4*i+2] = 1
            matrix[equation, 4*(i+1)+3] = -3*x[i+1]
            matrix[equation, 4*(i+1)+2] = -1
            equation += 1
        matrix[equation,3] = 3*x[0]
        matrix[equation,2] = 1
        equation += 1
        matrix[4*N-1,4*N-1] = 3*x[N]
        matrix[4*N-1,4*N-2] = 1

        # Wa=Y => a=W^(-1)Yとして変数の行列を求める
        # その際、逆行列を求めるのにnp.linalg.invを使う
        self.variables = np.dot(np.linalg.inv(matrix),Y)

    def fit(self, x):
        """
        引数xが該当する区間を見つけてきて補間後の値を返す
        """
        xlen = len(self.x)
        for index,j in enumerate(self.x):
            if x < j:
                index -= 1
                break
        if index == -1:
            index += 1
        elif index == xlen-1:
            index -= 1
        a3 = self.variables[4*index + 3]
        a2 =  self.variables[4*index + 2]
        a1 = self.variables[4*index + 1]
        a0 = self.variables[4*index + 0]

        result = a3*pow(x,3) + a2*pow(x,2) + a1*x + a0
        return result


if __name__=="__main__":
    X = np.array([0,2,5,7,10])
    y = np.array([2,10,3,2,4])
    cubic = cubicSpline(X, y)
    plt.scatter(X, y)

    x = np.arange(-0.5,10.5,0.01)
    plt.plot(x, list(map(cubic.fit,x)))
    plt.show()

f:id:swdrsker:20180806185505p:plain

補足:元々用意されている関数を使う

実際にPythonで3次スプライン補間するには、scipyでそのための関数が用意されているので便利。

from scipy.interpolate import interp1d
import numpy as np
import matplotlib.pyplot as plt

X = np.array([0,2,5,7,10])
y = np.array([2,10,3,2,4])
plt.scatter(X,y)

x = np.arange(0,10,0.01)
cubic = interp1d(X, y, kind="cubic")
plt.plot(x, cubic(x))

plt.show()

補足:回帰(最小二乗法)との比較

3次スプライン補間のような補間と最小二乗法による回帰との大きな違いは もともとの点群を通るかどうかだ。
したがってもともとの点自体が重要になるような場合は、補間を使う方が適していると言える。
しかし点の数が増えたり誤差や異常値が含まれている場合は回帰を使う方が良い。

from scipy.interpolate import interp1d
import numpy as np
import matplotlib.pyplot as plt

X = np.array([0,2,5,6,7,8,10])
y = np.array([0,4,32,33,45,70,98])
plt.scatter(X,y)

x = np.arange(0,10,0.01)
cubic = interp1d(X, y, kind="cubic")
plt.plot(x, cubic(x) , c="r", label="cubic spline")


fit_params = np.polyfit(X, y, 2)
plt.plot(x, np.poly1d(fit_params)(x), color="g", label="polyfit")

plt.legend()
plt.show()

f:id:swdrsker:20180806185508p:plain
2次関数にフィッティングする線形回帰と3次スプライン補間の比較。
誤差を認める場合は線形回帰を使う方が良い。