博客專欄

        EEPW首頁 > 博客 > ?注意力機(jī)制中的掩碼詳解

        ?注意力機(jī)制中的掩碼詳解

        發(fā)布人:數(shù)據(jù)派THU 時間:2023-07-17 來源:工程師 發(fā)布文章
        注意力機(jī)制的掩碼允許我們發(fā)送不同長度的批次數(shù)據(jù)一次性的發(fā)送到transformer中。在代碼中是通過將所有序列填充到相同的長度,然后使用“attention_mask”張量來識別哪些令牌是填充的來做到這一點(diǎn),本文將詳細(xì)介紹這個掩碼的原理和機(jī)制。


        圖片
        我們先介紹下如果不使用掩碼,是如何運(yùn)行的。這里用GPT-2每次使用一個序列來執(zhí)行推理,因為每次只有一個序列,所以速度很慢:

         from transformers import GPT2LMHeadModel, GPT2Tokenizer  tokenizer = GPT2Tokenizer.from_pretrained('gpt2') gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')  context = tokenizer('It will rain in the', return_tensors='pt')  prediction = gpt2.generate(**context, max_length=10) tokenizer.decode(prediction[0]) # prints 'It will rain in the morning, and the rain'


        在顯存允許的情況下,使用批處理輸入的速度更快,因為我們在一次推理的過程可以同時處理多個序列。對許多樣本執(zhí)行推理要快得多,但也稍微復(fù)雜一些,下面是使用transformer庫進(jìn)行推理的代碼:

         tokenizer.padding_side = "left" tokenizer.pad_token = tokenizer.eos_token  sentences = ["It will rain in the",            "I want to eat a big bowl of",            "My dog is"] inputs = tokenizer(sentences, return_tensors="pt", padding=True)  output_sequences = gpt2.generate(**inputs)  for seq in output_sequences:    print(tokenizer.decode(seq))


        transformer庫幫我們處理了很多細(xì)節(jié),我們現(xiàn)在詳細(xì)的介紹它里面到底做了什么。
        我們將令牌輸入到語言模型中,如GPT-2和BERT,作為張量進(jìn)行推理。張量就像一個python列表,但有一些額外的特征和限制。比如說,對于一個2+維的張量,該維中的所有向量必須是相同的長度。例如,

         from torch import tensor  tensor([[1,2], [3,4]]) # ok tensor([[1,2], [3]])   # error!


        當(dāng)我們對輸入進(jìn)行標(biāo)記時,它將被轉(zhuǎn)換為序列的張量,每個整數(shù)對應(yīng)于模型詞表中的一個項。以下是GPT-2中的標(biāo)記化示例:
        圖片
        如果我們想在輸入中包含第二個序列:
        圖片
        因為這兩個序列有不同的長度,所以不能把它們組合成一個張量。這時就需要用虛擬標(biāo)記填充較短的序列,以便每個序列具有相同的長度。因為我們想讓模型繼續(xù)向序列的右側(cè)添加,我們將填充較短序列的左側(cè)。
        圖片
        這就是注意力掩碼的一個應(yīng)用。注意力掩碼告訴模型哪些令牌是填充的,在填充令牌的位置放置0,在實際令牌的位置放置1。現(xiàn)在我們理解了這一點(diǎn),讓我們逐行查看代碼。

         tokenizer.padding_side = "left"


        這一行告訴標(biāo)記器從左邊開始填充(默認(rèn)是右邊),因為最右邊標(biāo)記的logits將用于預(yù)測未來的標(biāo)記。

         tokenizer.pad_token = tokenizer.eos_token


        這一行指定將使用哪個令牌進(jìn)行填充。選擇哪一個并不重要,這里我們選擇的是“序列結(jié)束”標(biāo)記。

         sentences = ["It will rain in the",            "I want to eat a big bowl of",            "My dog is"]


        上面這三個序列在標(biāo)記時都有不同的長度,我們使用下面的方法填充:

         inputs = tokenizer(sentences, return_tensors="pt", padding=True)


        在進(jìn)行表計劃和添加填充后,得到了以下的結(jié)果:

         {'input_ids': tensor([    [50256, 50256, 50256, 1026,   481, 6290,   287,   262],    [   40,   765,   284, 4483,   257, 1263, 9396,   286],    [50256, 50256, 50256, 50256, 50256, 3666, 3290,   318]  ]), 'attention_mask': tensor([    [0, 0, 0, 1, 1, 1, 1, 1],    [1, 1, 1, 1, 1, 1, 1, 1],    [0, 0, 0, 0, 0, 1, 1, 1]  ])}


        可以看到,第一個和第三個序列在開始時進(jìn)行了填充,并且attention_mask參數(shù)標(biāo)記了這個填充的位置。
        現(xiàn)在讓我們將這個輸入傳遞給模型來生成新的文本:

         output_sequences = gpt2.generate(**inputs)


        如果你不熟悉函數(shù)調(diào)用的**kwargs語法,它是將輸入字典作為命名參數(shù)傳入,使用鍵作為參數(shù)名,并使用值作為相應(yīng)的實參值。
        我們只需要循環(huán)遍歷每個生成的序列并以人類可讀的形式打印出結(jié)果,使用decode()函數(shù)將令牌id轉(zhuǎn)換為字符串。

         for seq in output_sequences:    print(tokenizer.decode(seq))


        在注意力掩碼中,我們的輸入是0和1,但是在最終的計算時,會將在將無效位置的注意力權(quán)重設(shè)置為一個很小的值,通常為負(fù)無窮(-inf),以便在計算注意力分?jǐn)?shù)時將其抑制為接近零的概率。
        這時因為,在計算注意力權(quán)重時,需要進(jìn)行Softmax的計算:
        Softmax函數(shù)的性質(zhì):注意力機(jī)制通常使用Softmax函數(shù)將注意力分?jǐn)?shù)轉(zhuǎn)化為注意力權(quán)重,Softmax函數(shù)對輸入值進(jìn)行指數(shù)運(yùn)算,然后進(jìn)行歸一化。當(dāng)輸入值非常小或負(fù)無窮時,經(jīng)過指數(shù)運(yùn)算后會接近零。因此,將掩碼設(shè)置為負(fù)無窮可以確保在Softmax函數(shù)計算時,對應(yīng)位置的注意力權(quán)重趨近于零。
        排除無效位置的影響:通過將無效位置的注意力權(quán)重設(shè)置為負(fù)無窮,可以有效地將這些位置的權(quán)重壓低。在計算注意力權(quán)重時,負(fù)無窮的權(quán)重會使對應(yīng)位置的注意力權(quán)重接近于零,從而模型會忽略無效位置的影響。這樣可以確保模型更好地關(guān)注有效的信息,提高模型的準(zhǔn)確性和泛化能力。
        但是負(fù)無窮并不是唯一的選擇。有時也可以選擇使用一個很大的負(fù)數(shù),以達(dá)到相似的效果。具體的選擇可以根據(jù)具體的任務(wù)和模型的需求來確定。



        *博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。



        關(guān)鍵詞: AI

        相關(guān)推薦

        技術(shù)專區(qū)

        關(guān)閉
        主站蜘蛛池模板: 社旗县| 乌拉特前旗| 衡水市| 将乐县| 凯里市| 贵南县| 弥勒县| 和顺县| 嵊泗县| 武宁县| 嘉善县| 信阳市| 保德县| 敦煌市| 云和县| 青龙| 高唐县| 唐海县| 漳州市| 昌平区| 万载县| 新营市| 乾安县| 绿春县| 财经| 鄂伦春自治旗| 苏尼特左旗| 祁连县| 呼伦贝尔市| 青川县| 丰都县| 明星| 哈巴河县| 二连浩特市| 侯马市| 桃园县| 琼海市| 宝兴县| 海安县| 霍山县| 石嘴山市|