原創 | 谷歌JAX 助力科學計算
谷歌最新推出的JAX,官方定義為CPU、GPU和TPU上的NumPy。它具有出色的自動微分(differentiation)功能,是可用于高性能機器學習研究的python庫。Numpy在科學計算領域十分普及,但是在深度學習領域,由于它不支持自動微分和GPU加速,所以更多的是使用Tensorflow或Pytorch這樣的深度學習框架。然而谷歌之前推出的Tensorflow API有一些比較混亂的情況,在1.x的迭代中,就存在如原子op、layers等不同層次的API。面對不同類型的用戶,使用粒度不同的多層API本身并不是什么問題。但同層次的API也有多種競品,如slim和layers等實則提高了學習成本和遷移成本。而JAX使用 XLA 在諸如GPU和TPU的加速器上編譯和運行NumPy。它與 NumPy API 非常相似, numpy 完成的事情幾乎都可以用 jax.numpy 完成,從而避免了直接定義API這件事。
下面簡要介紹JAX的幾個特性,并同時給出一些示例讓讀者能夠快速入門上手。最后我們將結合科學計算的實例,展現google JAX在科學計算方面的巨大威力。
1.JAX特性
1)自動微分:
在深度學習領域,網絡參數的優化是通過基于梯度的反向傳播算法實現的。因此能夠實現任意數值函數的微分對于機器學習有著十分重要的意義。下面結合官方文檔的例子簡要介紹這一特性。
首先介紹最簡單的grad求一階微分:可以直接通過grad函數求某一函數在某位置的梯度值
import jax.numpy as jnpfrom jax import grad, jit, vmapgrad_tanh = grad(jnp.tanh)print(grad_tanh(2.0))[OUT]:0.070650816
當然如果想對雙切正弦函數繼續求二階,三階導數,也可以這樣做:
print(grad(grad(jnp.tanh))(2.0))print(grad(grad(grad(jnp.tanh)))(2.0))[OUT]:-0.136218680.25265405
除此之外,還可以利用hessian、jacfwd 和 jacrev 等方法實現函數轉換,它們的功能分別是求解海森矩陣,以及利用前向或反向模式求解雅克比矩陣。Jacfwd和jacrev可以得到一樣的結果,但是在不同的情形下求解效率不同,這是因為兩者背后對應的微分幾何中的push forward和pull back方法。而前面提到的grad則是基于反向模式。
在一些擬牛頓法的優化算法中,常常需要利用二階的海森矩陣。為了實現海森矩陣的求解。為了實現這一目標,我們可以使用jacfwd(jacrev(f))或者jacrev(jacfwd(f))。但是前者的效率更高,因為內層的雅克比矩陣計算是通過類似于一個1維損失函數對n維向量的求導,明顯使用反向模式更為合適。外層則通常是n維函數對n維向量的求導,正向模式更有優勢。
2)向量化
無論是科學計算或者機器學習的研究中,我們都會將定義的優化目標函數應用到大量數據中,例如在神經網絡中我們去計算每一個批次的損失函數值。JAX 通過 vmap 轉換實現自動向量化,簡化了這種形式的編程。
下面結合幾個例子,說明這一用法:
vmap有3個最重要的參數:
fun: 代表需要進行向量化操作的具體函數;
in_axes:輸入格式為元組,代表fun中每個輸入參數中,使用哪一個維度進行向量化;
out_axes: 經過fun計算后,每組輸出在哪個維度輸出。
我們先來看二維情況下的一些例子:
import jax.numpy as jnpimport numpy as npimport jax
(1)先定義a,b兩個二維數組(array)
a = np.array(([1,3],[23, 5]))print(a)[out]: [[ 1 3][23 5]]b = np.array(([11,7],[19,13]))print(b)[OUT]: [[11 7][19 13]]
(2)正常的兩個矩陣element-wise的相加
print(jnp.add(a,b))#[[1+11, 3+7]]# [[23+19, 5+13]][OUT]: [[12 10][42 18]]
(3)矩陣a的行 + 矩陣b的行,然后根據out_axes=0輸出,0表示行輸出
print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=0)(a,b))#[[1+11, 3+7]]#[[23+19, 5+13]][OUT]: [[12 10][42 18]]
(4)矩陣a的行 + 矩陣b的行,然后根據out_axes=1輸出,1表示列輸出
print(jax.vmap(jnp.add, in_axes=(0,0), out_axes=1)(a,b))# [[1+11, 3+7]]#[[23+19, 5+13]] 再以列轉置輸出[OUT]: [[12 42][10 18]]
理解了上面的例子之后,現在開始增加難度,換成三維的例子:
from jax.numpy import jnpA, B, C, D = 2, 3, 4, 5def foo(tree_arg):x, (y, z) = tree_argreturn jnp.dot(x, jnp.dot(y, z))from jax import vmapK = 6 # batch sizex = jnp.ones((K, A, B)) # batch axis in different locationsy = jnp.ones((B, K, C))z = jnp.ones((C, D, K))tree = (x, (y, z))vfoo = vmap(foo, in_axes=((0, (1, 2)),))print(vfoo(tree).shape)
你能夠計算最后的輸出嗎?
讓我們一起來分析一下。在這段代碼中分別定義了三個全1矩陣x,y,z,他們的維度分別是6*2*3,3*6*4,4*5*6。而tree則控制了foo函數中矩陣連續點積的順序。根據in_axes可知,y和z的點積最后結果為6個3*5的子矩陣,這是由于y和z此時相當于6個y的子矩陣(3*4維)和6個z的子矩陣(4*5維)點積。再與x點積,得到的最終結果為(6,2,5)。
3)JIT編譯
XLA是TensorFlow底層做JIT編譯優化的工具,XLA可以對計算圖做算子Fusion,將多個GPU Kernel合并成少量的GPU Kernel,用以減少調用次數,可以大量節省GPU Memory IO時間。Jax本身并沒有重新做執行引擎層面的東西,而是直接復用TensorFlow中的XLA Backend進行靜態編譯,以此實現加速。
jit的基本使用方法非常簡單,直接調用jax.jit()或使用@jax.jit裝飾函數即可:
import jax.numpy as jnpfrom jax import jitdef slow_f(x):# Element-wise ops see a large benefit from fusionreturn x * x + x * 2.0x = jnp.ones((5000, 5000))fast_f = jax.jit(slow_f) # 靜態編譯slow_f;%timeit -n10 -r3 fast_f(x)%timeit -n10 -r3 slow_f(x)10 loops, best of 3: 24.2 ms per loop10 loops, best of 3: 82.8 ms per loop
運行時間結果:fast_f(x)是slow_f(x) 在CPU上運行速度的3.5倍!靜態編譯大大加速了程序的運行速度。如圖1 所示。
圖 1 tensorflow和JAX中的XLA backend
2.JAX在科學計算中的應用
分子動力學是現代計算凝聚態物理的重要力量。它經常用于模擬材料。下面的實例將展現JAX在以分子動力學為代表的科學計算領域的巨大潛力。
首先簡單介紹一下分子動力學。分子動力學的基本任務就是獲得研究對象在不同時刻的位置和速度,然后基于統計力學的知識獲取想得到的物理量,解釋對象的行為和性質。
它的主要步驟包括:
第一步,設置研究對象組成粒子的初始位置和速度;第二步,基于粒子的位置計算每個粒子的合力,并基于牛頓第二定計算粒子的加速度。(這里可能有小伙伴會問,如何計算?我們下文的勢函數將為大家解釋);第三步,基于加速度算下一時刻粒子速度,根據速度計算下一時刻位置。
不斷循環2-3步,得到粒子的運動軌跡。
如需要獲得所有粒子的軌跡,根據牛頓運動方程,需要知道粒子的初始位置和速度,質量以及受力。粒子的受力是勢能函數的負梯度,所以在分子動力學模擬中,必須確定所有原子之間的勢能函數,即勢能關于兩個原子之間相對位置的函數,這個勢函數我們也稱之為力場。
在分子動力學中,復雜力場的優化是一類重要的問題。ReaxFF就是其中的代表。相比于傳統力場基于靜態化學鍵以及不隨化學環境改變的靜態電荷假設,ReaxFF引入鍵級勢的概念,這允許鍵在整個模擬過程里形成和斷開,并動態地為原子分配電荷。也正是由于這些特性的存在,反應力場的形式明顯比經典力場更為復雜。這使得我們將其計算的能量等值與密度泛函或者實驗值對比得到的損失函數進行反饋優化時更為困難,如圖2 所示。
圖2 反應力場的參數構成
各種全局優化方法,例如遺傳算法,模擬退火算法,進化算法以及粒子群優化算法等等往往沒有利用任何梯度信息,這使得這些搜索成本可能會非常昂貴。而JAX的出現為這一問題的解決帶來了可能。
JAX-REAXFF:
1)流程
圖3 Jax-ReaxFF流程
圖3是Jax-ReaxFF的任務流概述,可以將其大致分為兩個階段:聚類和主優化循環。而主優化循環則分別包括利用梯度信息的能量最小化和力場參數優化。
聚類只要是根據相互作用列表進行聚類,在內存中正確對齊,以確保有效的單指令多數據(SIMD)并行化提高效率。
而主優化循環中能量最小化的過程是尋找能量最低最穩定幾何構型的過程。它的具體做法是利用JAX求體系勢能對原子坐標的梯度,進行優化。力場參數的優化在原文中則分別使用了兩種擬牛頓優化方法——L-BFGS和SLSQP。這通scipy.optimize.minimize函數實現,其中向該函數直接傳入JAX求解梯度的方法以提高效率。能量最小化和力場參數優化迭代循環。
圖4 JAX-ReaxFF主循環優化
Github地址:https://github.com/cagrikymk/JAX-ReaxFF
2)效果
作者在多個數據集上分別實現了參數的優化,可以看到相比于其他算法,利用JAX梯度信息的優化具有明顯的速度優勢。
圖5 金屬鈷數據集結果
參考文獻:https://pubs.acs.org/doi/pdf/10.1021/acs.jctc.2c00363https://jax.readthedocs.io/en/latest/faq.htmlhttps://zhuanlan.zhihu.com/p/474724292https://arxiv.org/abs/2010.09063https://mp.weixin.qq.com/s/AoygUZK886RClDBnp1v3jw
*博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。