這篇被選為 NeurIPS 2018 最佳論文,他將連續的概念帶入了神經網路架構中,並且善用以往解微分方程的方法來做逼近,可以做到跟原方法(倒傳遞)一樣好的程度,而參數使用複雜度卻是常數,更短的訓練時間。

核心觀念

概念上來說,就是將神經網路離散的層觀念打破,將他貫通成為連續的層的網路架構。

連續和離散的差別來自於倒傳遞的過程:

$$
\mathbb{y}_{t+1} = \mathbb{y}_t - \eta \nabla \mathcal{L}
$$

其中 $\nabla \mathcal{L}$ 就是梯度的部份,是向量的,然而我們把他簡化成純量來看的話,他不過就是

$$
\frac{d \mathcal{L}}{dt}
$$

廣義上來說,一個函數的微分,如果是離散的版本就會是

$$
\frac{dy}{dt} = \frac{y(t + \Delta) - y(t)}{\Delta}
$$

如此一來,所形成的方程式就會是差分方程,然而連續的版本就是

$$
\frac{dy}{dt} = \lim_{\Delta \rightarrow 0} \frac{y(t + \Delta) - y(t)}{\Delta}
$$

這個所形成的會是微分方程。

從離散到連續

我們可以從離散的版本

$$
\frac{dy}{dt} = \frac{y(t + \Delta) - y(t)}{\Delta}
$$

把他轉成以下的樣貌

$$
y(t + \Delta) = y(t) + \Delta \frac{dy}{dt}
$$

要將他貫通的話,我們就得由從神經網路的基礎開始,如果是一般的前回饋網路(feed-forward network)當中的隱藏層是像下列這個樣子:

$$
h_{t+1} = f(h_t, \theta)
$$

我們可以發現像是 ResNet 這類的網路有 skip connection 的設置,所以跟一般的前回饋網路不同

$$
h_{t+1} = h_t + f(h_t, \theta)
$$

而 RNN 等等有序列概念的模型也有類似的結構,就是會是前一層的結果加上通過 $f$ 運算後的結果,成為下一層的結果。

這樣的形式跟我們前面提到的形式不謀而合

$$
y(t + \Delta) = y(t) + \Delta \frac{dy}{dt}
$$

只要我們把 $\Delta = 1$ 代入,就成了

$$
y(t+1) = y(t) + \frac{dy}{dt}
$$

以下給大家比對一下

$$
h_{t+1} = h_t + f(h_t, \theta) \\
y(t+1) = y(t) + \frac{dy}{dt}
$$

也就是,我們可以讓

$$
\frac{dy}{dt} = f(h_t, \theta)
$$

神奇的事情就發生了!神經網路 $f$ 就可以被我們拿來計算微分 $\frac{dy}{dt}$!

比較精確的說法是,把神經網路的層 $f$ 拿來逼近微分項,或是說梯度。這樣我們後面就可以用數值方法來逼近解。

$$
y(t + \Delta) = y(t) + \Delta \frac{dy}{dt} \\
\downarrow \\
y(t + \Delta) = y(t) + \Delta f(t, h(t), \theta_t)
$$

要拉成連續的還有一個重要的手段,就是將不同的層 $t$ 從離散的變成連續的,所以作者將 $t$ 做了參數化,將他變成 $f$ 的參數之一,如此一來,就可以在任意的層中放入資料做運算。

最重要的概念導出了這樣的式子

$$
h(t) \rightarrow \frac{dy(t)}{dt} = f(h(t), t, \theta) \rightarrow y(t)
$$

神經網路作為一個系統的微分形式

在傳統科學或是工程領域,我們會以微分式來表達以及建構一個系統。

$$
\nu = \frac{dx}{dt} = t + 1
$$

其實在這邊是一樣的道理,整體來說,我們是換成用神經網路去描述一個微分式,其實本質上就是這樣。

原本的層的概念就是用數學函數來建立的,而層與層之間傳遞著計算的結果。

$$
\mathbb{h_1} = \sigma(W_1 \mathbb{x} + \mathbb{b_1}) \\
\mathbb{y} = \sigma(W_2 \mathbb{h_1} + \mathbb{b_2})
$$

然而變成連續之後,我們等於是用神經網路中的層去建立跟描繪微分形式。

$$
\frac{d h(t)}{dt} = \sigma(W(t) \mathbb{x}(t) + \mathbb{b(t)}) \\
\frac{d y(t)}{dt} = \sigma(W(t) \mathbb{h}(t) + \mathbb{b(t)})
$$

是不是跟如出一轍呢?

$$
\frac{dy(t)}{dt} = f(h(t), t, \theta)
$$

向前傳遞解微分式

我們可以來計算看看隱藏層是長什麼樣子的。在隱藏層的微分式中,也是利用隱藏層去計算出來的。

$$
\frac{dh(t)}{dt} = f(h(t), t, \theta)
$$

基本上,我們只要對上式做積分就可以了。

$$
h(t) = \int f(h(t), t, \theta) dt
$$

這是一個怎樣的概念呢?我們可以來看看下圖。

我們做積分這件事其實是用 $h(t_0)$ 來推斷 $h(t_1)$ 的,這跟神經網路的向前傳遞是一樣的行為。

$$
h(t_1) = F(h(t), t, \theta) \bigg|_{t=t_0}
$$

這樣的積分動作,我們可以用 $t_0$ 時間點的資訊來解 $h(t_1)$。

這樣的解法在程式上就會交由 ODE Solver 去處理。

$$
h(t_1) = ODESolve(h(t_0), t_0, t_1, \theta, f)
$$

反向傳遞解函數

$$
\mathcal{L}(t_0, t, \theta) = \mathcal{L}(ODESolve(\cdot))
$$

$$
\frac{\partial \mathcal{L}}{\partial h(t)} = -a(t)
$$

adjoint state

$$
a(t) = \int -a(t)^T \frac{\partial f}{\partial h} dt = - \frac{\partial \mathcal{L}}{\partial h(t)}
$$

$$
a(t) = \int_{t_1}^{t_0} -a(t)^T \frac{\partial f(h(t), t, \theta)}{\partial h(t)} dt
$$

擴充狀態(augmented state)

$\frac{d \theta}{dt} = 0$

$\frac{dt}{dt} = 1$

let $\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix}$ be a augmented state

augmented state function:

$$
f_{aug}(\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix}) =
\begin{bmatrix}
f(h(t), t, \theta) \\
0 \\
1
\end{bmatrix}
$$

augmented state dynamics:

$$
\frac{d}{dt}
\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix}

f_{aug}(
\begin{bmatrix}
h \\
\theta \\
t
\end{bmatrix})
$$

augmented adjoint state:

$$
\begin{bmatrix}
a \\
a_{\theta} \\
a_t
\end{bmatrix}
$$

$a = \frac{\partial \mathcal{L}}{\partial h}$

$a_{\theta} = \frac{\partial \mathcal{L}}{\partial \theta}$

$a_t = \frac{\partial \mathcal{L}}{\partial t}$

$$
\frac{d a_{aug}}{dt} = -
\begin{bmatrix}
a \frac{\partial f}{\partial h} \\
a \frac{\partial f}{\partial \theta} \\
a \frac{\partial f}{\partial t}
\end{bmatrix}
$$