新聞中心

        EEPW首頁 > 智能計(jì)算 > 設(shè)計(jì)應(yīng)用 > 掌握AI尚方寶劍:注意力機(jī)制

        掌握AI尚方寶劍:注意力機(jī)制

        作者:高煥堂 時(shí)間:2024-04-12 來源:EEPW 收藏


        本文引用地址:http://www.104case.com/article/202404/457480.htm

        1 前言

        經(jīng)過上一期的范例和解說,您對(duì)于的計(jì)算,已經(jīng)建立良好的基礎(chǔ)了。就可以輕易地來理解和掌握注意力(Attention) 機(jī)制。這項(xiàng)機(jī)制在許多大語言模型( 如ChatGPT、Gemma等) 里,都扮演了極為關(guān)鍵性的角色。再看看最近聲勢(shì)非常浩大的Sora,其關(guān)鍵技術(shù)——DiT(Diffusion Transformer) 的核心也是

        于是,本文就從上一期介紹的(Similarity) 基礎(chǔ),繼續(xù)延伸到。此外,更重要的是:此項(xiàng)機(jī)制也是可以學(xué)習(xí)的(Learnable),于是就來把它包裝于NN模型里,成為可以訓(xùn)練的注意力模型(Attention model)。

        典型的Attention 模型, 包括兩種: 交叉注意力(CrossAttention) 和自注意力(SelfAttention)。本文就先來說明SelfAttention 模型的計(jì)算邏輯,及其訓(xùn)練方法。

        2 以“企業(yè)經(jīng)營”來做比喻

        首先來做個(gè)比喻。例如,一個(gè)公司有三個(gè)部門,其投資額( 以X表示),經(jīng)過一年的經(jīng)營績效比率( 以W表示),其營收額( 以V 表示),如圖1所示。

        1712880714669803.png

        圖1

        這三部門投資額是:X=[10, 6, 2.5],其單位是---百萬元。經(jīng)過一年的經(jīng)營,其營收比率是:W=[2.0],就可以計(jì)算出營收金額是:V=[20, 12, 5]。

        接下來,公司的經(jīng)營團(tuán)隊(duì)開始規(guī)畫下一年度的投資方案,針對(duì)未來新的商業(yè)投資獲利注意點(diǎn),擬定一個(gè)投資預(yù)算分配表( 即注意力表),然后計(jì)算出新年度的投資預(yù)算金額( 單位:百萬元),如圖2所示。

        1712880832677517.png

        圖2

        其中的預(yù)算分配表,可以是矩陣(Similarity matrix),亦即經(jīng)由相似度的計(jì)算而來。現(xiàn)在,就來理解上圖的計(jì)算邏輯,請(qǐng)觀摩一個(gè)Python的實(shí)現(xiàn)代碼:

        # ax01.py

        import numpy as np

        import torch


        X = torch.tensor([[10.0],[6.0],[2.5]]) # 投資額

        W = torch.tensor([[2.0]])                 # 經(jīng)營績效

        V = X.matmul(W)                            # 計(jì)算營收

        A = torch.tensor(

              [[1.0, 0., 0.],

              [0.9, 0.1, 0.],

              [0.6, 0.3, 0.1]])         # 預(yù)算分配表

        Z = A.matmul(V)          # 計(jì)算分配額

        print(‘n 投資預(yù)算額Z:’)

        print(Z) #np.round(Z.detach().numpy()))

        #END

        接著,就執(zhí)行這個(gè)程序。此時(shí)就輸入X和W,計(jì)算出V值。然后輸入相似度表A,計(jì)算出新年度的投資預(yù)算額,并輸出如下:

        1712880962852160.png

        3 使用Attention計(jì)算公式

        在上一期里,已經(jīng)說明了,相似度矩陣是直接計(jì)算向量的點(diǎn)積(Dot-product),即將兩向量的對(duì)應(yīng)元素相乘再相加。然后,這相似度矩陣再除以它們的歐氏長度的乘積,將相似度的值正規(guī)化,就得到余弦(Cosine)相似度。而且,如果將上述的相似度矩陣,在經(jīng)由Softmax() 函數(shù)的運(yùn)算,就得到注意力矩陣(Attention weights) 了。例如,有兩個(gè)矩陣:Q 和K,就能計(jì)算出注意力矩陣,如圖3 所示。

        1712881038618246.png

        圖3

        那么,就可以繼續(xù)思考一個(gè)重要問題,就是:如何計(jì)算出Q和K矩陣呢? 答案是:可以由SelfAttention模型來預(yù)測(cè)出來。也就是,由輸入數(shù)據(jù)X來與SelfAttention模型的權(quán)重Wq相乘而得到Q。同時(shí),也由輸入數(shù)據(jù)X來與這模型的權(quán)重Wk 相乘而得到K,如圖4所示。

        1712881106266155.png

        圖4

        當(dāng)我們把上圖里的Wq、Wk和Wv權(quán)重都放入SelfAttention模型里, 就能進(jìn)行機(jī)器學(xué)習(xí)(Machine learning) 來找出最佳的權(quán)重值( 即Wq、Wk 和Wv),就能預(yù)測(cè)出Q、K 和V 了。并且可繼續(xù)計(jì)算出A 了。

        4 訓(xùn)練SelfAttention模型

        現(xiàn)在就把Wq、Wk 和Wv 都放入SelfAttention 模型里。請(qǐng)觀摩這個(gè)SelfAttention 模型的代碼范例,如下:

        # ax02.py

        import numpy as np

        import torch

        import torch.nn as nn

        import torch.nn.functional as F


        class SelfAttention(nn.Module): # 定義模型

           def __init__(self):

              super(SelfAttention, self).__init__()

              self.Wq = nn.Linear(1, 2, bias=False)

              self.Wk = nn.Linear(1, 2, bias=False)

              self.Wv = nn.Linear(1, 1, bias=False)

              def forward(self, x):

                  Q = self.Wq(x)

                  K = self.Wk(x)

                  V = self.Wv(x)

                  Scores = Q.matmul(K.T)

                  A = F.softmax(Scores, dim=-1) # Attention_weights

                  Z = A.matmul(V) # 計(jì)算Z

                   return Z, A, V


           model = SelfAttention() # 建立模型

           criterion = nn.MSELoss()

           optimizer = torch.optim.Adam(model.parameters(),lr=0.004)


           # 輸入X

           X = torch.tensor([[10.0],[6.0],[2.5]])

           # 設(shè)定Target Z

           target_attn = torch.tensor([[20.0],[19.0],[16.0]])


           print(‘展開訓(xùn)練1800 回合...’)

           for epoch in range(1800+1):

              Z, A, V = model(X) # 正向傳播

              loss = criterion(Z, target_attn) # 計(jì)算損失

              optimizer.zero_grad() # 反向傳播和優(yōu)化

              loss.backward()

              optimizer.step()

              if(epoch%600 == 0):

                 print(‘ep=’, epoch,‘loss=’, loss.item())


        # 進(jìn)行預(yù)測(cè)

        Z, A, V = model(X)

        print(‘n----- 預(yù)算分配表A -----’)

        print(np.round(A.detach().numpy(), 1))


        print(‘n----- 投資預(yù)算額Z -----’)

        print(np.round(Z.detach().numpy()))

        #END

        然后就執(zhí)行這個(gè)程序,此時(shí)會(huì)展開1800 回合的訓(xùn)練。在訓(xùn)練過程中,回持續(xù)修正模型里的權(quán)重( 即Wq、Wk和Wv),并且其損失(Loss) 值會(huì)持續(xù)下降,如下:

        1712881209642952.png

        一旦訓(xùn)練完成了,就可以展開預(yù)測(cè)(Prediction)。此時(shí),就計(jì)算出Q、K 和V,然后繼續(xù)計(jì)算出A 和Z 值。

        5 結(jié)束語

        本期基于相似度計(jì)算,繼續(xù)說明的計(jì)算邏輯,建立SelfAttention模型,并且訓(xùn)練1800 回合,然后進(jìn)行預(yù)測(cè)。

        從這范例中,可以領(lǐng)會(huì)到SelfAttention模型能順利捕捉到企業(yè)的經(jīng)營規(guī)律,并進(jìn)行準(zhǔn)確的預(yù)測(cè)。

        (本文來源于《EEPW》2024.4)



        評(píng)論


        相關(guān)推薦

        技術(shù)專區(qū)

        關(guān)閉
        主站蜘蛛池模板: 汽车| 丰宁| 正安县| 广灵县| 同德县| 白沙| 洞口县| 准格尔旗| 辽宁省| 长垣县| 托克逊县| 利津县| 酒泉市| 扶余县| 伊宁县| 汝城县| 资兴市| 汕头市| 乐至县| 古丈县| 湘潭市| 岳西县| 西乌珠穆沁旗| 清水河县| 历史| 万全县| 永平县| 长垣县| 卓资县| 会东县| 衡阳市| 临漳县| 丁青县| 东辽县| 丰都县| 锦州市| 崇仁县| 彭州市| 碌曲县| 清原| 东乌珠穆沁旗|