這篇被選為 NeurIPS 2018 最佳論文,他將連續的概念帶入了神經網路架構中,並且善用以往解微分方程的方法來做逼近,可以做到跟原方法(倒傳遞)一樣好的程度,而參數使用複雜度卻是常數,更短的訓練時間。
核心觀念
概念上來說,就是將神經網路離散的層觀念打破,將他貫通成為連續的層的網路架構。
連續和離散的差別來自於倒傳遞的過程:
$$
\mathbb{y}_{t+1} = \mathbb{y}_t - \eta \nabla \mathcal{L}
$$
其中 $\nabla \mathcal{L}$ 就是梯度的部份,是向量的,然而我們把他簡化成純量來看的話,他不過就是
$$
\frac{d \mathcal{L}}{dt}
$$
廣義上來說,一個函數的微分,如果是離散的版本就會是
$$
\frac{dy}{dt} = \frac{y(t + \Delta) - y(t)}{\Delta}
$$
如此一來,所形成的方程式就會是差分方程,然而連續的版本就是
$$
\frac{dy}{dt} = \lim_{\Delta \rightarrow 0} \frac{y(t + \Delta) - y(t)}{\Delta}
$$
這個所形成的會是微分方程。
從離散到連續
我們可以從離散的版本
$$
\frac{dy}{dt} = \frac{y(t + \Delta) - y(t)}{\Delta}
$$
把他轉成以下的樣貌
$$
y(t + \Delta) = y(t) + \Delta \frac{dy}{dt}
$$
要將他貫通的話,我們就得由從神經網路的基礎開始,如果是一般的前回饋網路(feed-forward network)當中的隱藏層是像下列這個樣子:
$$
h_{t+1} = f(h_t, \theta)
$$
我們可以發現像是 ResNet 這類的網路有 skip connection 的設置,所以跟一般的前回饋網路不同
$$
h_{t+1} = h_t + f(h_t, \theta)
$$
而 RNN 等等有序列概念的模型也有類似的結構,就是會是前一層的結果加上通過 $f$ 運算後的結果,成為下一層的結果。
這樣的形式跟我們前面提到的形式不謀而合
$$
y(t + \Delta) = y(t) + \Delta \frac{dy}{dt}
$$
只要我們把 $\Delta = 1$ 代入,就成了
$$
y(t+1) = y(t) + \frac{dy}{dt}
$$
以下給大家比對一下
$$
h_{t+1} = h_t + f(h_t, \theta) \\
y(t+1) = y(t) + \frac{dy}{dt}
$$
也就是,我們可以讓
$$
\frac{dy}{dt} = f(h_t, \theta)
$$
神奇的事情就發生了!神經網路 $f$ 就可以被我們拿來計算微分 $\frac{dy}{dt}$!
比較精確的說法是,把神經網路的層 $f$ 拿來逼近微分項,或是說梯度。這樣我們後面就可以用數值方法來逼近解。
$$
y(t + \Delta) = y(t) + \Delta \frac{dy}{dt} \\
\downarrow \\
y(t + \Delta) = y(t) + \Delta f(t, h(t), \theta_t)
$$
要拉成連續的還有一個重要的手段,就是將不同的層 $t$ 從離散的變成連續的,所以作者將 $t$ 做了參數化,將他變成 $f$ 的參數之一,如此一來,就可以在任意的層中放入資料做運算。
最重要的概念導出了這樣的式子
$$
h(t) \rightarrow \frac{dy(t)}{dt} = f(h(t), t, \theta) \rightarrow y(t)
$$
神經網路作為一個系統的微分形式
在傳統科學或是工程領域,我們會以微分式來表達以及建構一個系統。
$$
\nu = \frac{dx}{dt} = t + 1
$$
其實在這邊是一樣的道理,整體來說,我們是換成用神經網路去描述一個微分式,其實本質上就是這樣。
原本的層的概念就是用數學函數來建立的,而層與層之間傳遞著計算的結果。
$$
\mathbb{h_1} = \sigma(W_1 \mathbb{x} + \mathbb{b_1}) \\
\mathbb{y} = \sigma(W_2 \mathbb{h_1} + \mathbb{b_2})
$$
然而變成連續之後,我們等於是用神經網路中的層去建立跟描繪微分形式。
$$
\frac{d h(t)}{dt} = \sigma(W(t) \mathbb{x}(t) + \mathbb{b(t)}) \\
\frac{d y(t)}{dt} = \sigma(W(t) \mathbb{h}(t) + \mathbb{b(t)})
$$
是不是跟如出一轍呢?
$$
\frac{dy(t)}{dt} = f(h(t), t, \theta)
$$
向前傳遞解微分式
我們可以來計算看看隱藏層是長什麼樣子的。在隱藏層的微分式中,也是利用隱藏層去計算出來的。
$$
\frac{dh(t)}{dt} = f(h(t), t, \theta)
$$
基本上,我們只要對上式做積分就可以了。
$$
h(t) = \int f(h(t), t, \theta) dt
$$
這是一個怎樣的概念呢?我們可以來看看下圖。
我們做積分這件事其實是用 $h(t_0)$ 來推斷 $h(t_1)$ 的,這跟神經網路的向前傳遞是一樣的行為。
$$
h(t_1) = F(h(t), t, \theta) \bigg|_{t=t_0}
$$
這樣的積分動作,我們可以用 $t_0$ 時間點的資訊來解 $h(t_1)$。
這樣的解法在程式上就會交由 ODE Solver 去處理。
$$
h(t_1) = ODESolve(h(t_0), t_0, t_1, \theta, f)
$$
反向傳遞解函數
$$
\mathcal{L}(t_0, t, \theta) = \mathcal{L}(ODESolve(\cdot))
$$
$$
\frac{\partial \mathcal{L}}{\partial h(t)} = -a(t)
$$
adjoint state
$$
a(t) = \int -a(t)^T \frac{\partial f}{\partial h} dt = - \frac{\partial \mathcal{L}}{\partial h(t)}
$$
$$
a(t) = \int_{t_1}^{t_0} -a(t)^T \frac{\partial f(h(t), t, \theta)}{\partial h(t)} dt
$$
擴充狀態(augmented state)
$\frac{d \theta}{dt} = 0$
$\frac{dt}{dt} = 1$
let $\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix}$ be a augmented state
augmented state function:
$$
f_{aug}(\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix}) =
\begin{bmatrix}
f(h(t), t, \theta) \\
0 \\
1
\end{bmatrix}
$$
augmented state dynamics:
$$
\frac{d}{dt}
\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix}
f_{aug}(
\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix})
$$
augmented adjoint state:
$$
\begin{bmatrix}
a \\
a_{\theta} \\
a_t
\end{bmatrix}
$$
$a = \frac{\partial \mathcal{L}}{\partial h}$
$a_{\theta} = \frac{\partial \mathcal{L}}{\partial \theta}$
$a_t = \frac{\partial \mathcal{L}}{\partial t}$
$$
\frac{d a_{aug}}{dt} = -
\begin{bmatrix}
a \frac{\partial f}{\partial h} \\
a \frac{\partial f}{\partial \theta} \\
a \frac{\partial f}{\partial t}
\end{bmatrix}
$$