博客專欄

        EEPW首頁 > 博客 > 從零自制深度學(xué)習(xí)推理框架: 計(jì)算圖中的表達(dá)式講解(1)

        從零自制深度學(xué)習(xí)推理框架: 計(jì)算圖中的表達(dá)式講解(1)

        發(fā)布人:計(jì)算機(jī)視覺工坊 時(shí)間:2023-04-23 來源:工程師 發(fā)布文章
        項(xiàng)目主頁

        https://github.com/zjhellofss/KuiperInfer 感謝大家點(diǎn)贊和PR, 這是對我最大的鼓勵(lì), 謝謝.

        什么是表達(dá)式

        表達(dá)式就是一個(gè)計(jì)算過程,類似于如下:

        output_mid = input1 + input2
        output = output_mid * input3 

        用圖形來表達(dá)就是這樣的.

        圖片image-20230113160348886

        但是在PNNXExpession Layer中給出的是一種抽象表達(dá)式,會(huì)對計(jì)算過程進(jìn)行折疊,消除中間變量. 并且將具體的輸入張量替換為抽象輸入@0,@1等.對于上面的計(jì)算過程,PNNX生成的抽象表達(dá)式是這樣的.

        add(@0,mul(@1,@2)) 抽象的表達(dá)式重新變回到一個(gè)方便后端執(zhí)行的計(jì)算過程(抽象語法樹來表達(dá),在推理的時(shí)候我們會(huì)把它轉(zhuǎn)成逆波蘭式)。

        其中addmul表示我們上一節(jié)中說到的RuntimeOperator@0@1表示我們上一節(jié)課中說道的RuntimeOperand. 這個(gè)抽象表達(dá)式看起來比較簡單,但是實(shí)際上情況會(huì)非常復(fù)雜,我們給出一個(gè)復(fù)雜的例子:

        add(add(mul(@0,@1),mul(@2,add(add(add(@0,@2),@3),@4))),@5)

        這就要求我們需要一個(gè)魯棒的表達(dá)式解析和語法樹構(gòu)建功能.

        我們的工作詞法解析

        詞法解析的目的就是將add(@0,mul(@1,@2))拆分為多個(gè)token,token依次為add ( @0 , mul等.代碼如下:

        enum class TokenType {
          TokenUnknown = -1,
          TokenInputNumber = 0,
          TokenComma = 1,
          TokenAdd = 2,
          TokenMul = 3,
          TokenLeftBracket = 4,
          TokenRightBracket = 5,
        };

        struct Token {
          TokenType token_type = TokenType::TokenUnknown;
          int32_t start_pos = 0//詞語開始的位置
          int32_t end_pos = 0// 詞語結(jié)束的位置
          Token(TokenType token_type, int32_t start_pos, int32_t end_pos): token_type(token_type), start_pos(start_pos), end_pos(end_pos) {

          }
        };

        我們在TokenType中規(guī)定了Token的類型,類型有輸入、加法、乘法以及左右括號等.Token類中記錄了類型以及Token在字符串的起始和結(jié)束位置.

        如下的代碼是具體的解析過程,我們將輸入存放在statement_中,首先是判斷statement_是否為空, 隨后刪除表達(dá)式中的所有空格和制表符.

          if (!need_retoken && !this->tokens_.empty()) {
            return;
          }

          CHECK(!statement_.empty()) << "The input statement is empty!";
          statement_.erase(std::remove_if(statement_.begin(), statement_.end(), [](char c) {
            return std::isspace(c);
          }), statement_.end());
          CHECK(!statement_.empty()) << "The input statement is empty!";

        下面的代碼中,我們先遍歷表達(dá)式輸入

         for (int32_t i = 0; i < statement_.size();) {
            char c = statement_.at(i);
            if (c == 'a') {
              CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'd')
                      << "Parse add token failed, illegal character: " << c;
              CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'd')
                      << "Parse add token failed, illegal character: " << c;
              Token token(TokenType::TokenAdd, i, i + 3);
              tokens_.push_back(token);
              std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);
              token_strs_.push_back(token_operation);
              i = i + 3;
            } 
         }

        char c是當(dāng)前的字符,當(dāng)c等于字符a的時(shí)候,我們的詞法規(guī)定在token中以a作為開始的情況只有add. 所以我們判斷接下來的兩個(gè)字符必須是d和 d.如果不是的話就報(bào)錯(cuò),如果是i的話就初始化一個(gè)新的token并進(jìn)行保存.

        舉個(gè)簡單的例子只有可能是add,沒有可能是axc之類的組合.

        else if (c == 'm') {
             CHECK(i + 1 < statement_.size() && statement_.at(i + 1) == 'u')
                      << "Parse add token failed, illegal character: " << c;
              CHECK(i + 2 < statement_.size() && statement_.at(i + 2) == 'l')
                      << "Parse add token failed, illegal character: " << c;
              Token token(TokenType::TokenMul, i, i + 3);
              tokens_.push_back(token);
              std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);
              token_strs_.push_back(token_operation);
              i = i + 3;

        同理當(dāng)c等于字符m的時(shí)候,我們的語法規(guī)定token中以m作為開始的情況只有mul. 所以我們判斷接下來的兩個(gè)字必須是ul. 如果不是的話,就報(bào)錯(cuò),是的話就初始化一個(gè)mul token進(jìn)行保存.

           } else if (c == '@') {
              CHECK(i + 1 < statement_.size() && std::isdigit(statement_.at(i + 1)))
                      << "Parse number token failed, illegal character: " << c;
              int32_t j = i + 1;
              for (; j < statement_.size(); ++j) {
                if (!std::isdigit(statement_.at(j))) {
                  break;
                }
              }
              Token token(TokenType::TokenInputNumber, i, j);
              CHECK(token.start_pos < token.end_pos);
              tokens_.push_back(token);
              std::string token_input_number = std::string(statement_.begin() + i, statement_.begin() + j);
              token_strs_.push_back(token_input_number);
              i = j;
            } else if (c == ',') {
              Token token(TokenType::TokenComma, i, i + 1);
              tokens_.push_back(token);
              std::string token_comma = std::string(statement_.begin() + i, statement_.begin() + i + 1);
              token_strs_.push_back(token_comma);
              i += 1;
            }

        當(dāng)輸入為ant時(shí)候,我們對ant之后的所有數(shù)字進(jìn)行讀取,如果其之后不是操作數(shù),則報(bào)錯(cuò).當(dāng)字符等于(或者,的時(shí)候就直接保存為對應(yīng)的token,不需要對往后的字符進(jìn)行探查, 直接保存為對應(yīng)類型的Token.


        *博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。



        關(guān)鍵詞: AI

        相關(guān)推薦

        技術(shù)專區(qū)

        關(guān)閉
        主站蜘蛛池模板: 波密县| 南木林县| 辽源市| 奉化市| 上高县| 兴安盟| 茌平县| 清新县| 晴隆县| 靖宇县| 绥化市| 新乡市| 阳朔县| 龙陵县| 桦南县| 奉化市| 萝北县| 大庆市| 恭城| 朝阳区| 富源县| 巴中市| 博白县| 柳河县| 甘德县| 织金县| 池州市| 娱乐| 合川市| 大田县| 庄浪县| 石台县| 桃园市| 山东省| 泾源县| 托克托县| 肥城市| 沾益县| 青河县| 岢岚县| 宝清县|