当前位置:网站首页>007_SSSSS_ Neural Ordinary Differential Equtions
007_SSSSS_ Neural Ordinary Differential Equtions
2022-07-21 21:43:00 【Artificial Idiots】
Neural Ordinary Differential Equtions
本文是NeurIPS 2018 最佳文章, 作者的想法可以用《赤壁賦》中的一句話來微妙的體現:蓋將自其變者而觀之,則天地曾不能以一瞬, 自其不變者而觀之, 則物與我皆無盡也。
作者提出了一種新的深度神經網絡模型, 主要的思路是:原來的神經網絡通常需要依靠若幹多個殘差網絡堆疊起來, 而本文提出的NeuralODE可以僅用一個與之前殘差塊類似的網絡來達到若幹多個殘差塊堆疊的效果, 從而大大降低了網絡的參數量和並且可以看作只需要常數的存儲空間.
1. Introduction
NeuralODE的想法歸根結底來自於傳統的神經網絡。在傳統的的殘差網絡(Residual Networks), 循環神經網絡(Recurrent Neural Network), 標准化流(Normalizing Flow)等經典的網絡中, 數據都可以用一個式子來錶示:
h t + 1 = h t + f ( h t , θ t ) h_{t+1} = h_t + f(h_t, \theta_t) ht+1=ht+f(ht,θt)
其中 h t ∈ R D h_t \in R^D ht∈RD, t ∈ { 0 , 1 , . . . T } t \in \lbrace 0,1,...T \rbrace t∈{ 0,1,...T}. 這些堆疊起來的網絡塊可以看做是連續變換的歐拉離散化. 於是上式其實還可以看成:
h t + 1 = h t + f ( h t , θ t ) d t h_{t+1} = h_t + f(h_t, \theta_t) dt ht+1=ht+f(ht,θt)dt
只是這裏的 d t = 1 dt = 1 dt=1. 這兩種看法, 第一種是將網絡輸出看作是變化量, 而後一種則是看作變化率, 形式上相近, 但是意義卻有很大的不同.
那麼當 T T T 無限大, 即 d t dt dt 無線小, 也就是說有無窮多的網絡塊連接的時候, 上式就可以看成是一種由神經網絡决定的常微分方程的形式:
d h ( t ) d t = f ( h ( t ) , t , θ ) \frac{dh(t)}{dt} = f(h(t), t, \theta) dtdh(t)=f(h(t),t,θ)
這個式子也就是本文的根源. 這樣有了初始的狀態 h ( 0 ) h(0) h(0), 就可以通過求解常微分方程(Ordinary Differential Equation, ODE)來求解出任意時刻的輸出 h ( T ) h(T) h(T), 如果中間求解過程用的是歐拉離散化求解的話, 那麼離散化的步驟有多少步, 其實就可以看作是多少個塊的堆疊.
簡單一句話, 雖然並不全面, 但是可以幫助理解, NeuralODE就是殘差網絡的連續化.
作者的主要貢獻包括:
- 提出了新的深度神經網絡模型NeuralODE, 擁有 O ( 1 ) O(1) O(1) 的空間複雜度. 而且作者給出了NeuralODE反向傳播更新參數的方法.
- 作者將NeuralODE用在了Normalizing Flow中, 提出了一種連續的流模型(Continuous Normalizing Flow).
- 作者將NeuralODE用在了時間序列模型中.
2. NeuralODE反向傳播
現在的神經網絡基本都需要完成兩個內容, 首先是前向傳播得到結果, 然後通過得到的結果反向傳播來更新網絡參數. 那麼NeuralODE作為一種新的網絡結構, 也需要具備這兩個基本的功能.
NeuralODE的前向過程已經很明了, 就是求解上面已經給出的常微分方程, 求解的方法可以用現有的求解器ODE Solver, 目前已經有多種ODE Solver可以直接使用. 最簡單的就是歐拉離散化來求解, 也就是每次取很小的 d t dt dt 然後不斷的累加即可.
但是NeuralODE如何進行反向傳播?
首先來考慮優化通過損失函數 L ( ) L() L() 得到的損失向量, 損失函數的輸入便是從初始的輸入 z ( t 0 ) z(t_0) z(t0) 通過前向過程的ODE Solver得到的輸出結果 z ( t 1 ) z(t_1) z(t1):
那麼為了優化參數, 就需要求出損失 L L L 關於參數 θ \theta θ 的梯度, 也就是 d L d θ \frac{dL}{d\theta} dθdL.
這裏再强調一下, 傳統的殘差網絡, 每個殘差塊有自己的參數, 這些殘差塊串行的連接在一起, 前向的時候每個殘差塊只運行了一次, 所以反向傳播的時候, 可以直接串行的計算出損失 L L L 關於中間任意的第 t t t 層輸出結果 h ( t ) h(t) h(t) 的梯度 d L / d h ( t ) dL/dh(t) dL/dh(t) , 然後根據這個梯度來計算損失關於第 t t t 層的參數 θ t \theta_t θt 的梯度 d L / d θ t dL / d\theta_t dL/dθt. 也就是最簡單的鏈式法則.
但是NeuralODE只有一個類似於殘差網絡中的一個殘差塊的網絡, 而NeuralODE的這一個塊在前向的ODE Solver中被運行了多少次是未知的, 這個值通常很大, 那麼用傳統的鏈式法則求梯度顯然是不可行的, 於是就需要通過其他方法來求梯度.
(這個部分原文裏邊用 z ( t ) z(t) z(t) 來錶示連續情况下中間的結果, 用 h ( t ) h(t) h(t) 錶示離散情况下的中間結果, 意義是一樣的, 為方便理解可以將兩個符號互換來看.)
直接用鏈式法則不可行, 但是鏈式法則的這種思路是可以借鑒的. 也就是求出損失關於中間結果的梯度, 作者用式子來錶示 a ( t ) = d L / d z ( t ) a(t) = dL / dz(t) a(t)=dL/dz(t) 稱其為adjoint. 那麼這個adjoint也滿足一個ODE:
推導過程如下(也可以直接參考原文附錄), 也是利用了鏈式法則:
再利用一次鏈式法則, 就可以得到:
推導過程如下(也可以直接參考原文附錄)
如果對推導過程不感興趣, 可以直接記住結論, 也就是反向傳播的過程也需要求解一個ODE. 而這個ODE需要和伴隨狀態一塊求.
簡單的小結一下, NeuralODE的前向過程是用ODE從初始時刻輸入求解終止時刻的輸出, 反向過程求解一個反向的ODE, 從終止時刻開始到初始時刻.而且從以上內容可以看出, NeuralODE的網絡輸入和輸出的維度必須是固定的, 所以就不能像殘差網絡一樣可以在殘差塊之間加入上下采樣, 這也是NeuralODE的一個缺點.
3. 用NeuralODE代替殘差網絡
作者做了用NeuralODE代替殘差網絡的實驗, 當然NeuralODE的維度必須是固定的, 所以用NeuralODE替換的是原來的殘差網絡中上下采樣之後的塊.
作者在MNIST數據集上進行了簡單的實驗:
4. 連續標准化流Continuous Normalizing Flow
有一類稱作planar normalizing flow的流模型, 其錶示形如:
這種錶示與之前看到的殘差錶示相似, 於是也可以轉換成ODE的形式,
這裏本人對流模型了解並不多, 詳細內容請參考作者原文.
5. 時間序列模型time-series model
作者還提出了一種時間序列模型, 基本的思想就是: 給定已經觀測到的時間點 t 0 , t 1 , . . . , t N t_0, t_1, ..., t_N t0,t1,...,tN, 以及初始狀態 z t 0 z_{t_0} zt0 , 那麼可以用ODE來預測 N N N 之後的時間的狀態.
而初始的狀態 z t 0 z_{t_0} zt0 則需要利用一個RNN encoder來得到.
6. 小結
總結一下, 本文作者的三個貢獻, 其中最重要也是本文核心的就是NeuralODE及其反向傳播的過程, 不過一個弊端就是Neural ODE的輸入和輸出維度是固定的, 這在一些場合顯得不太靈活. 之後作者又提出了連續流和基於ODE的時序模型, 都可以看做是NeuralODE在不同模型上的變種, 其本質還是相同的. NeuralODE最大的優點是 O ( 1 ) O(1) O(1) 的空間複雜度, 這樣可以將這一個塊設計的比傳統殘差塊更複雜, 從而有更好的效果, 但是相應的NeuralODE可以看做是拿時間來換空間, 利用ODE Slover求解的時間代價比傳統殘差網絡要高.
边栏推荐
- Flutter实战-WidgetsFlutterBinding
- 【乐视云学习笔记】关于Letv乐视云点播的视频暂停之后,按home回到桌面后重新onResume回到Activity,视频自动播放的情况
- Multi camera fusion in mobile phones
- 高通 Camx debug log控制
- 阿里云机器学习平台PAI与华东师范大学论文入选SIGIR 2022
- Dart pragma annotation VM: entry point
- 交流角度看电源完整性
- CameraX extensions, easier to implement camera features
- YOLO7 姿势识别实例
- CameraX Extensions , 相机特性实现更简单了
猜你喜欢
Broadcasts
Kubernetes 资源编排系列之二: Helm 篇
Flutter实战-StatefulWidget
你的第一个 Jenkins 项目,从这里开始
Flutter实战-自定义键盘(一)
Kotlin入门
Introduction to three distribution strategies (hash, round_robin or replicated) often used for data in azure data warehouse tables
吴恩达机器学习系列课程汇总(视频+部分汉化+讲义+作业)
Data storage scheme (II) -sqlite database storage
CameraX extensions, easier to implement camera features
随机推荐
Common sense of cross platform framework fluent and RN
【机器学习入门】机器人养成记-边玩游戏边学机器学习
Impala-shell exports the more than 9 million level table on kudu (below)_ Complete transmission
CameraX extensions, easier to implement camera features
[SQLite3 database]
Actual combat of flutter - customized keyboard (I)
【乐视云学习笔记】关于Letv乐视云点播的视频暂停之后,按home回到桌面后重新onResume回到Activity,视频自动播放的情况
Dart pragma annotation VM: entry point
Do you know what are the schemes of the list paging interface?
Sketch map of shutter curves animation curve
Camera2 闪光灯梳理
一文搞懂静态库/动态库链接问题
Introduction to three distribution strategies (hash, round_robin or replicated) often used for data in azure data warehouse tables
Manually build APK process
关于Sensor和ISP,对输出图像做Crop和Downscale的注意事项
直流角度看电源完整性
[MySQL] the background of multi table connection in mysql, the error of Cartesian product and how to correctly query multiple tables
Understand the problem of static library / dynamic library link in one article
【RPG Maker MV】RPG游戏《机器人养成记》制作笔记 - 制作背景和引擎选择
[wechat applet learning notes] two positions of pop-up window