博客專欄

        EEPW首頁 > 博客 > 從零自制深度學習推理框架: 計算圖中的表達式講解(2)

        從零自制深度學習推理框架: 計算圖中的表達式講解(2)

        發布人:計算機視覺工坊 時間:2023-04-23 來源:工程師 發布文章
        語法解析

        當得到token數組之后,我們對語法進行分析,并得到最終產物抽象語法樹(不懂的請自己百度,這是編譯原理中的概念).語法解析的過程是遞歸向下的,定義在Generate_函數中.

        struct TokenNode {
          int32_t num_index = -1;
          std::shared_ptr<TokenNode> left = nullptr;
          std::shared_ptr<TokenNode> right = nullptr;
          TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right);
          TokenNode() = default;
        };

        抽象語法樹由一個二叉樹組成,其中存儲它的左子節點和右子節點以及對應的操作編號num_indexnum_index為正, 則表明是輸入的編號,例如@0,@1中的num_index依次為1和2.  如果num_index為負數則表明當前的節點是mul或者add等operator.

        std::shared_ptr<TokenNode> ExpressionParser::Generate_(int32_t &index) {
          CHECK(index < this->tokens_.size());
          const auto current_token = this->tokens_.at(index);
          CHECK(current_token.token_type == TokenType::TokenInputNumber
             || current_token.token_type == TokenType::TokenAdd || current_token.token_type == TokenType::TokenMul);

        因為是一個遞歸函數,所以index指向token數組中的當前處理位置.current_token表示當前處理的token,它作為當前遞歸層的第一個Token, 必須是以下類型的一種.

        TokenInputNumber = 0,
        TokenAdd = 2,
        TokenMul = 3,

        如果當前token類型是輸入數字類型, 則直接返回一個操作數token作為一個葉子節點,不再向下遞歸, 也就是在add(@0,@1)中的@0@1,它們在前面的詞法分析中被歸類為TokenInputNumber類型.

          if (current_token.token_type == TokenType::TokenInputNumber) {
            uint32_t start_pos = current_token.start_pos + 1;
            uint32_t end_pos = current_token.end_pos;
            CHECK(end_pos > start_pos);
            CHECK(end_pos <= this->statement_.length());
            const std::string &str_number =
                std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);
            return std::make_shared<TokenNode>(std::stoi(str_number), nullptrnullptr);

          } 
        else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd) {
            std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();
            current_node->num_index = -int(current_token.token_type);

            index += 1;
            CHECK(index < this->tokens_.size());
            // 判斷add之后是否有( left bracket
            CHECK(this->tokens_.at(index).token_type == TokenType::TokenLeftBracket);

            index += 1;
            CHECK(index < this->tokens_.size());
            const auto left_token = this->tokens_.at(index);
         // 判斷當前需要處理的left token是不是合法類型
            if (left_token.token_type == TokenType::TokenInputNumber
                || left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul) {
              // (之后進行向下遞歸得到@0
                current_node->left = Generate_(index);
            } else {
              LOG(FATAL) << "Unknown token type: " << int(left_token.token_type);
            }
         }

        如果當前Token類型是mul或者add. 那么我們需要向下遞歸構建對應的左子節點和右子節點.

        例如對于add(@1,@2),再遇到add之后,我們需要先判斷是否存在left bracket, 然后再向下遞歸得到@1, 但是@1所代表的 數字類型,不會再繼續向下遞歸.

        當左子樹構建完畢之后,我們將左子樹連接到current_nodeleft指針中,隨后我們開始構建右子樹.此處描繪的過程體現在current_node->left = Generate_(index);中.

            index += 1
         // 當前的index指向add(@1,@2)中的逗號
            CHECK(index < this->tokens_.size());
            // 判斷是否是逗號
            CHECK(this->tokens_.at(index).token_type == TokenType::TokenComma);

            index += 1;
            CHECK(index < this->tokens_.size());
            // current_node->right = Generate_(index);構建右子樹
            const auto right_token = this->tokens_.at(index);
            if (right_token.token_type == TokenType::TokenInputNumber
                || right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul) {
              current_node->right = Generate_(index);
            } else {
              LOG(FATAL) << "Unknown token type: " << int(left_token.token_type);
            }

            index += 1;
            CHECK(index < this->tokens_.size());
            CHECK(this->tokens_.at(index).token_type == TokenType::TokenRightBracket);
            return current_node;

        例如對于add(@1,@2),index當前指向逗號的位置,所以我們需要先判斷是否存在comma, 隨后開始構建右子樹.右子樹中的向下遞歸分析中得到了@2. 當右子樹構建完畢后,我們將它(Generate_返回的節點,此處返回的是一個葉子節點,其中的數據是@2) 放到current_noderight指針中.

        串聯起來的例子

        簡單來說,我們復盤一下add(@0,@1)這個例子.輸入到Generate_函數中, 是一個token數組.

        • add
        • (
        • @0
        • ,
        • @1
        • )

        Generate_數組首先檢查第一個輸入是否為add,mul或者是input number中的一種.

        CHECK(current_token.token_type == TokenType::TokenInputNumber|| 
        current_token.token_type == TokenType::TokenAdd || current_token.token_type == TokenType::TokenMul);

        第一個輸入add,所以我們需要判斷其后是否是left bracket來判斷合法性, 如果合法則構建左子樹.

           else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd) {
            std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();
            current_node->num_index = -int(current_token.token_type);

            index += 1;
            CHECK(index < this->tokens_.size());
            CHECK(this->tokens_.at(index).token_type == TokenType::TokenLeftBracket);

            index += 1;
            CHECK(index < this->tokens_.size());
            const auto left_token = this->tokens_.at(index);

            if (left_token.token_type == TokenType::TokenInputNumber
                || left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul) {
              current_node->left = Generate_(index);
            }

        處理下一個token, 構建左子樹.

          if (current_token.token_type == TokenType::TokenInputNumber) {
            uint32_t start_pos = current_token.start_pos + 1;
            uint32_t end_pos = current_token.end_pos;
            CHECK(end_pos > start_pos);
            CHECK(end_pos <= this->statement_.length());
            const std::string &str_number =
                std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);
            return std::make_shared<TokenNode>(std::stoi(str_number), nullptrnullptr);

          } 

        遞歸進入左子樹后,判斷是TokenType::TokenInputNumber則返回一個新的TokenNode到add token成為左子樹.

        檢查下一個token是否為逗號,也就是在add(@0,@1)的@0是否為,

            CHECK(this->tokens_.at(index).token_type == TokenType::TokenComma);

            index += 1;
            CHECK(index < this->tokens_.size());

        下一步是構建add token的右子樹

            index += 1;
            CHECK(index < this->tokens_.size());
            const auto right_token = this->tokens_.at(index);
            if (right_token.token_type == TokenType::TokenInputNumber
                || right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul) {
              current_node->right = Generate_(index);
            } else {
              LOG(FATAL) << "Unknown token type: " << int(left_token.token_type);
            }

            index += 1;
            CHECK(index < this->tokens_.size());
            CHECK(this->tokens_.at(index).token_type == TokenType::TokenRightBracket);
            return current_node;
        current_node->right = Generate_(index); /// 構建add(@0,@1)中的右子樹

        Generate_(index)遞歸進入后遇到的token是@1 token,因為是Input Number類型所在構造TokenNode后返回.

          if (current_token.token_type == TokenType::TokenInputNumber) {
            uint32_t start_pos = current_token.start_pos + 1;
            uint32_t end_pos = current_token.end_pos;
            CHECK(end_pos > start_pos);
            CHECK(end_pos <= this->statement_.length());
            const std::string &str_number =
                std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);
            return std::make_shared<TokenNode>(std::stoi(str_number), nullptrnullptr);

          } 

        至此, add語句的抽象語法樹構建完成.

        struct TokenNode {
          int32_t num_index = -1;
          std::shared_ptr<TokenNode> left = nullptr;
          std::shared_ptr<TokenNode> right = nullptr;
          TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right);
          TokenNode() = default;
        };

        在上述結構中, left存放的是@0表示的節點, right存放的是@1表示的節點.


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 日喀则市| 湘乡市| 桓台县| 靖边县| 府谷县| 柘荣县| 洪江市| 融水| 兴国县| 额济纳旗| 南投县| 大田县| 凤山县| 商洛市| 灵台县| 天祝| 大连市| 和平区| 土默特右旗| 扶沟县| 三门县| 玉屏| 垫江县| 繁昌县| 姚安县| 丁青县| 绥芬河市| 娱乐| 新宁县| 广河县| 广昌县| 嵩明县| 古交市| 长垣县| 康马县| 加查县| 绥德县| 宽甸| 淮滨县| 新化县| 冷水江市|