博客專欄

        EEPW首頁 > 博客 > 地平線 3D 目標檢測 Bevformer 參考算法 V2.0

        地平線 3D 目標檢測 Bevformer 參考算法 V2.0

        發布人:地平線開發者 時間:2025-02-08 來源:工程師 發布文章

        該示例為參考算法,僅作為在 征程 6 上模型部署的設計參考,非量產算法

        簡介

        BEVFormer 是當前熱門的自動駕駛系統中的 3D 視覺感知任務模型。BEVFormer 是一個端到端的框架,BEVFormer 可以直接從原始圖像數據生成 BEV 特征,無需依賴于傳統的圖像處理流程。它通過利用 Transformer 架構和注意力機制,有效地從多攝像頭圖像中學習生成高質量的鳥瞰圖(Bird's-Eye-View, BEV)特征表示。相較于其他的 BEV 轉換方式:

        1. 時空注意力機制:模型結合了空間交叉注意力(Spatial Cross-Attention, SCA)和時間自注意力(Temporal Self-Attention, TSA),使網絡能夠同時考慮空間和時間維度上的信息。融合歷史 bev 特征來提升預設的 BEV 空間中的 query 的自學能力,得到 bev 特征。

        2. Deformable attn:通過對每個目標生成幾個采樣點和采樣點的 offset 來提取采樣點周圍的重要特征,即只關注和目標相關的特征,減少計算量。

        3. transformer 架構:能夠有效捕捉序列中的長期依賴關系,適用于處理圖像序列。

        性能精度指標

        模型參數:

        圖片


        性能精度表現:

        image.png

        模型介紹

        圖片

        ·公版 BEVFormer 模型主要可以分為以下幾個關鍵部分:

        1. Backbone 網絡:用于從多視角攝像頭圖像中提取特征,本文為 tiny 版本,因此為 ResNet50。

        2. 時空特征提取:BEVFormer 通過引入時間和空間特征來學習 BEV 特征。具體來說,模型包括:

        3. Temporal Self-Attention(時間自注意力):利用前一時刻的 BEV 特征作為歷史特征,通過自注意力機制來計算當前時刻的 BEV 特征。

        4. Spatial Cross-Attention(空間交叉注意力):進行空間特征注意力,融合多視角圖像特征。

        5. Deformable Attention(可變形注意力):BEVFormer 使用可變形注意力機制來加速運算,提高模型對不同視角圖像特征的適應性。

        6. BEV 特征生成:通過時空特征的融合,完成環視圖像特征向 BEV 特征的建模。

        7. Decoder:設計用于 3D 物體檢測的端到端網絡結構,基于 2D 檢測器 Deformable DETR 進行改進,以適應 3D 空間的檢測任務。

        地平線部署說明

        公版 bevformer 在 征程 6 上部署相比于 征程 5 來說更簡單了,需要考慮的因素更少。征程 6 對非 4 維的支持可以和 4 維的同等效率,因此 征程 6 支持公版的注意力實現,不再限制維度,因此無需對維度做 Reshape,可直接支持公版寫法。但需注意的是公版的 bev_mask 會導致動態 shape。征程 6 不支持動態輸入,因此 bev_mask 無法使用。在精度上,我們修復了公版的 bug 已獲得了精度上的提升,同時通過對關鍵層做 int16 的量化精度配置以保障 1%以內的量化精度損失。

        下面將部署優化對應的改動點以及量化配置依次說明。

        性能優化

        改動點 1:

        將 attention 層的 mean 替換為 conv 計算,使性能上獲得提升。

        /usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/attention.py

        self.query_reduce_mean = nn.Conv2d(
           self.num_bev_queue * self.reduce_align_num,
           self.reduce_align_num,
           1,
           bias=False,)

        # init query_reduce_mean weight
        query_reduce_mean_weight = torch.zeros(
           self.query_reduce_mean.weight.size(),      
           dtype=self.query_reduce_mean.weight.dtype,
        )
        for i in range(self.reduce_align_num):
           for j in range(self.num_bev_queue):
               query_reduce_mean_weight[i, j * self.reduce_align_num + i] = (
                   1 / self.num_bev_queue
               )
        self.query_reduce_mean.weight = torch.nn.Parameter(
           query_reduce_mean_weight, requires_grad=False
        )

        改動點 2:

        公版中,在 Encoder 的空間融合模塊,會根據 bev_mask 計算有效的 query 和 reference_points,輸出 queries_rebatch 和 reference_points_rebatch,作用為減少交互的數據量,提升模型運行性能。對于稀疏的 query 做 crossattn 后再將 query 放回到 bev_feature 中。

        以上提取稀疏 query 步驟的主要算子為 gather,放回 bev_feature 步驟的主要算子為 scatter。由于工具鏈對這兩個算子暫未支持(gather 算子 930 已支持)而且 bev_mask 為動態的,為了提升模型的運行性能,工具鏈提供了 gridsample 算子的替換方式,index 計算只與內外參有關,因此作為前處理,將計算好的 index 作為模型輸入即可。

        gather

        gather 為根據 bevmask 來提取稀疏 query,降低 cross attn 的數據量,提升運行效率。

        代碼路徑:<code>/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/<span style="caret-color: #000000; color: #000000; font-family: monospace; font-size: medium; font-style: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: auto; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: auto; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: #e8e8e8; text-decoration: none; display: inline !important; float: none;">view_transformer.py</span>

                reference_points_cam = torch.clamp(
                   reference_points_cam, min=-2.1, max=2.1
               )
               reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
               bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
               bev_mask_ori = bev_mask.clone()
               max_len = self.virtual_bev_h * self.virtual_bev_w
               queries_rebatch_grid = reference_points_cam.new_zeros(
                   [B * self.numcam, self.virtual_bev_h, self.virtual_bev_w, 2]
               )
               for camera_idx, mask_per_img_bs in enumerate(bev_mask):
                   for bs_id, mask_per_img in enumerate(mask_per_img_bs):
                       temp_grid = (
                           torch.zeros(
                               (max_len, 2),
                               device=queries_rebatch_grid.device,
                               dtype=torch.float32,
                           )
                           - 1.5
                       )
                       index_query_per_img = (
                           mask_per_img.sum(-1).nonzero().squeeze(-1)
                       )
                       num_bev_points = index_query_per_img.shape[0]
                       camera_idx_tensor_x = index_query_per_img % self.bev_w
                       camera_idx_tensor_y = index_query_per_img // self.bev_w
                       index_grid = torch.stack(
                           [
                               camera_idx_tensor_x / (self.bev_w - 1),
                               camera_idx_tensor_y / (self.bev_h - 1),
                           ],
                           dim=-1,
                       )
                       index_grid = index_grid * 2 - 1
                       temp_grid[:num_bev_points] = index_grid
                       temp_grid = temp_grid.reshape(
                           self.virtual_bev_h, self.virtual_bev_w, 2
                       )
                       queries_rebatch_grid[
                           bs_id * self.numcam + camera_idx
                       ] = temp_grid
               reference_points_rebatch = (
                   reference_points_cam.flatten(-2)
                   .permute(1, 0, 3, 2)
                   .flatten(0, 1)
                   .reshape(B * self.numcam, D * 2, self.bev_h, self.bev_w)
               )
               reference_points_rebatch = (
                   F.grid_sample(
                       reference_points_rebatch,
                       queries_rebatch_grid,
                       mode="nearest",
                       align_corners=True,
                   )
                   .flatten(-2)
                   .permute(0, 2, 1)
                   .reshape(B * self.numcam, max_len, D, 2)
               )

        query_rebatch

                queries_rebatch = (
                   query.unsqueeze(1)
                   .repeat(1, self.num_cams, 1, 1)
                   .reshape(
                       bs * self.num_cams, self.bev_h, self.bev_w, self.embed_dims
                   )
                   .permute(0, 3, 1, 2)
               )
               queries_rebatch = F.grid_sample(
                   queries_rebatch,
                   queries_rebatch_grid,
                   mode="nearest",
                   align_corners=True,
               )
               queries_rebatch = queries_rebatch.flatten(-2).permute(0, 2, 1)
               reference_points_rebatch = reference_points_rebatch.flatten(
                   -2
               ).unsqueeze(-2)

        scatter

        scatter 操作對經過 deformable_attention 后的 query 放入到 bevfeature 中,然后求平均。

        代碼路徑為:/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/attention.py

                slots = self.restore_outputs(
                   restore_bev_grid,
                   queries_out,
                   bev_pillar_counts,
                   bs,
                   queries_rebatch_grid,
               )
           def restore_outputs(
               self,
               restore_bev_grid: Tensor,
               queries_out: Tensor,
               counts: Tensor,
               bs: int,
               queries_rebatch_grid: Tensor,
           ):
               """Restore outputs to bev feature."""
               queries_out = queries_out.reshape(
                   bs, self.num_cams, self.embed_dims, -1
               )
               queries_out = queries_out.permute(0, 2, 1, 3)
               queries_out = queries_out.reshape(
                   bs,
                   self.embed_dims,
                   self.num_cams * queries_rebatch_grid.shape[1],
                   queries_rebatch_grid.shape[2],
               )
               bev_queries = F.grid_sample(
                   queries_out, restore_bev_grid, mode="nearest", align_corners=True
               )
               bev_queries = bev_queries.reshape(bs, -1, self.bev_h, self.bev_w)
               slots = self.query_reduce_sum(bev_queries).flatten(-2).permute(0, 2, 1)
               slots = self.mul_pillarweight.mul(slots, counts)
               return slots

        其中 restore_bev_grid,根據 bevmask 反算回 bev_feature 的位置:

                restore_bev_grid = (
                   reference_points_cam.new_zeros(
                       B, self.max_camoverlap_num * self.bev_h, self.bev_w, 2
                   )
                   - 1.5
               )
               for bs_id, bev_mask_ in enumerate(bev_mask):
                   bev_pillar_num_map = torch.zeros(
                       (self.bev_h, self.bev_w), device=bev_mask_.device
                   )
                   count = bev_mask_.sum(-1) > 0
                   camera_idxs, bev_pillar_idxs = torch.where(count)
                   camera_idx_offset = 0
                   for cam_id in range(self.numcam):
                       camera_idx = torch.where(camera_idxs == cam_id)
                       bev_pillar_idx_cam = bev_pillar_idxs[camera_idx[0]]
                       num_camera_idx = len(camera_idx[0])
                       camera_idx_tmp = camera_idx[0] - camera_idx_offset
                       camare_tmp_idx_x = camera_idx_tmp % self.virtual_bev_w
                       camare_tmp_idx_y = camera_idx_tmp // self.virtual_bev_w
                       grid_x = camare_tmp_idx_x
                       grid_y = cam_id * self.virtual_bev_h + camare_tmp_idx_y
                       bev_pillar_idx_cam_x = bev_pillar_idx_cam % self.bev_w
                       bev_pillar_idx_cam_y = bev_pillar_idx_cam // self.bev_w
                       bev_pillar_num_map_tmp = bev_pillar_num_map[
                           bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
                       ]
                       grid_h = (
                           bev_pillar_num_map_tmp * self.bev_h + bev_pillar_idx_cam_y
                       ).to(torch.int64)
                       grid_w = (bev_pillar_idx_cam_x).to(torch.int64)
                       restore_bev_grid[bs_id, grid_h, grid_w, 0] = grid_x / (
                           self.virtual_bev_w - 1
                       )
                       restore_bev_grid[bs_id, grid_h, grid_w, 1] = grid_y / (
                           self.numcam * self.virtual_bev_h - 1
                       )
                       bev_pillar_num_map[
                           bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
                       ] = (
                           bev_pillar_num_map[
                               bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
                           ]
                           + 1
                       )
                       camera_idx_offset = camera_idx_offset + num_camera_idx
               restore_bev_grid = restore_bev_grid * 2 - 1
        精度優化浮點精度

        改動點 3:

        公版通過 can_bus 初始化 ref 來做時序融合,然而這個時候 bev feat 并沒有對齊,在 attention 計算時不能簡單的 concat 起來。因此我們換了一種時序對齊的方式,通過前后兩幀的 ego2global 坐標系轉換矩陣將當前幀的 bev 特征和上一幀對齊,此時 ref 都是一樣的。(非 征程 6 不支持,為公版 bug),精度上獲得提升。

        /usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/view_transformer.py` `get_prev_bev` `get_fusion_ref
        pre_scene = prev_meta["scene_token"]
        for i in range(bs):
           if pre_scene[i] != cur_meta["meta"][i]["scene"]:
               prev_bev[i] = torch.zeros(
                   (self.bev_h * self.bev_w, self.embed_dims),
                   dtype=torch.float32,
                   device=device,
               )
        ##公版:
        shift_ref_2d = ref_2d.clone()
        shift_ref_2d += shift[:, None, None, :]
        bs, len_bev, num_bev_level, _ = ref_2d.shape
        hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
           bs*2, len_bev, num_bev_level, 2)
        ##地平線版本
        shift_ref_2d = ref_2d.clone()
        bs, len_bev, num_bev_level, _ = ref_2d.shape
        hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
           bs * 2, len_bev, num_bev_level, 2
        )


        改動點 4:

        修復了個 tsa 公版的 batchsize 不等于 1 的 bug。BEVFormer/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py at master · fundament

        量化精度

        為量化精度保證,我們將以下的算子配置為 int16 或 int32 輸出:

        view_transformer:輸入節點做 int16 量化:

                int16_models = [
                   self.quant_hybird_ref_2d,
                   self.quant_norm_coords,
                   self.quant_restore_bev_grid,
                   self.quant_reference_points_rebatch,
                   self.quant_queries_rebatch_grid,
               ]
               for m in int16_models:
                   m.qconfig = qconfig_manager.get_qconfig(
                       activation_qat_qkwargs={"dtype": qint16},
                       activation_calibration_qkwargs={
                           "dtype": qint16,
                       },
                       activation_calibration_observer="mix",
                   )

        attention 層:最后兩個 conv 和 add 開啟 int16

            def set_qconfig(self) -> None:
               """Set the quantization configuration."""
               from hat.utils import qconfig_manager
               int16_module = [
                   self.output_proj,
                   self.add_res,
               ]

        decoder 層:cls_branches、reg_branches 的 conv 配置為 int32 輸出;sigmoid 和 reference_points 配置為 int16

            def set_qconfig(self) -> None:
                """Set the quantization configuration."""
                from hat.utils import qconfig_manager
                for _, m in enumerate(self.cls_branches):
                    m[0].qconfig = qconfig_manager.get_qconfig(
                        activation_qat_qkwargs={"dtype": qint16},
                        activation_calibration_qkwargs={
                            "dtype": qint16,
                        },
                        activation_calibration_observer="mix",
                    )
                    m[3].qconfig = qconfig_manager.get_qconfig(
                        activation_qat_qkwargs={"dtype": qint16},
                        activation_calibration_qkwargs={
                            "dtype": qint16,
                        },
                        activation_calibration_observer="mix",
                    )
                    m[-1].qconfig = qconfig_manager.get_default_qat_out_qconfig()
                self.reg_branches[-1][
                    -1
                ].qconfig = qconfig_manager.get_default_qat_out_qconfig()
                self.query_embedding.qconfig = None
                int16_module = [
                    self.reference_points,
                    self.sigmoid,
                ]
                for m in int16_module:
                    m.qconfig = qconfig_manager.get_qconfig(
                        activation_qat_qkwargs={"dtype": qint16},
                        activation_calibration_qkwargs={
                            "dtype": qint16,
                        },
                        activation_calibration_observer="mix",
                    )
        總結與建議訓練建議
        • 浮點和公版一致即可

        • qat 訓練需要將 lr 降低,下降策略建議使用 StepDecayLrUpdater。

        部署建議
        • 建議 bev size 的選擇考慮性能影響。征程 6 相比于 征程 5 帶寬增大,但仍需注意 bevsize 過大導致訪存時間過長對性能的影響,建議考慮實際部署情況選擇合適的 bevsize 做性能驗證。

        • 使用 bevmask 來提升運行性能,可參考 4.1 章節使用 gridsample 替換不支持的 scatter。

        • 在注意力機制中存在一些 ElementWise 操作,對于導致性能瓶頸的可以考慮 conv 替換,對于造成量化風險的可以根據敏感度分析結果合理選擇更高的量化精度,以確保注意力機制的部署。


        本文通過對 Bevformer 在地平線征程 6 上量化部署的優化,使得模型在該計算方案上用低于 1%的量化精度損失,得到 latency 為 45.74ms 的部署性能,同時,通過 Bevformer 的部署經驗,可以推廣到其他模型部署優化,例如包含 MSDA 模型結構、transformer-based BEV 的部署。

        附錄
        1. 論文:https://arxiv.org/pdf/2203.17270

        2. 公版代碼:https://github.com/fundamentalvision/BEVFormer


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。




        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 彝良县| 广河县| 昌吉市| 闸北区| 屯门区| 博兴县| 东乌| 新乡市| 北川| 台安县| 临沭县| 电白县| 雅江县| 武功县| 平塘县| 融水| 兴海县| 台东市| 定兴县| 堆龙德庆县| 兰溪市| 满城县| 札达县| 洱源县| 漯河市| 南宫市| 长武县| 福清市| 巴楚县| 南川市| 冀州市| 喜德县| 内丘县| 临武县| 临城县| 桦南县| 太湖县| 泸定县| 吴旗县| 林周县| 玛沁县|