博客專欄

        EEPW首頁 > 博客 > 征程 6EM 常見 QConfig 配置解讀與示例

        征程 6EM 常見 QConfig 配置解讀與示例

        發布人:地平線開發者 時間:2025-06-01 來源:工程師 發布文章

        一、引言

        在工具鏈用戶手冊《量化感知訓練(QAT)-開發指南-QConfig 詳解》章節專門介紹了在 J6EM 上 qconfig 是怎么回事,從經歷看,大家可能會存在看了依舊不懂,或懂了不知道怎么配置的情況,特別是一些 OE 包中示例沒有的配置,例如固定某節點 scale、配置 linear weight int16 等操作。

        qconfig 控制了模型所有節點的量化類型,例如是采用 int8 還是 int16 量化,是固定校準階段的 scale 去 qat 還是不固定 scale 去 qat。

        提供的模板可分為三類:基礎模板、敏感度模板、自定義模板。本文將常見配置通過示例方式進行呈現。

        二、基礎模板

        基礎模板中 calibration / qat / qat_fixed_act_scale 區別在于使用的 observer 類型和 scale 更新邏輯,分別用于校準,不固定 activation scaleqat 訓練,固定 activation scale qat 訓練。

        default 模板 ( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter ) 會做三件事:

        首先,將可以設置的高精度輸出都設置上,對于不支持高精度的輸出將給出提示;

        然后,從 grid sample 算子的 grid 輸入向前搜索,直到出現第一個 gemm 類算子或者 QuantStub,將中間的所有算子都設置為 int16。根據經驗這里的 grid 一般表達范圍較寬,int8 有較大可能不滿足精度需求;

        最后,將其余算子設置為 int8。

        int16 模板 ( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter ) 會做兩件事:

        首先,將可以設置的高精度輸出都設置上,對于不支持高精度的輸出將給出提示;

        其次,將其余算子設置為 int16。

        from horizon_plugin_pytorch.quantization.qconfig_template import (
           default_calibration_qconfig_setter,
           default_qat_qconfig_setter,
           default_qat_fixed_act_qconfig_setter,
           qat_8bit_weight_16bit_act_qconfig_setter,
           qat_8bit_weight_16bit_fixed_act_qconfig_setter,
           calibration_8bit_weight_16bit_act_qconfig_setter,
        )
        qat_or_calib_model = prepare(
           float_model,
           example_inputs=example_inputs,  # 用來感知圖結構
           qconfig_setter=(

               default_qat_qconfig_setter,    # 根據需要配置setter模板
           ),
        )

        三、敏感度模板

        敏感度模板有三個:

        sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter
        sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter
        sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter

        三者的區別和基礎模板中三者的區別類似,也是分別用于校準,不固定 activation scale qat 訓練,固定 activation scale qat 訓練。

        敏感度模板的第一個輸入是精度 debug 工具產生的敏感度結果,第二個參數可以指定 ratio 或 topk,敏感度模板會根據配置,將量化敏感度最高的 topk 個算子設置為 int16。搭配固定模板,可以實現混合精度調優。

        若模型有多個輸出,每個輸出都會產生一個敏感度表,您可以設置多個敏感度模版。示例如下:

        from horizon_plugin_pytorch.quantization.qconfig_template import (
           default_calibration_qconfig_setter,
           sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
           sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
           sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
        )

        # 這兩個pt文件是通過debug工具得到的
        table1 = torch.load("output_0-0_L1_sensitive_ops.pt")
        table2 = torch.load("output_0-1_L1_sensitive_ops.pt")

        calibration_model = prepare(
           float_model,
           example_inputs=example_input,
           qconfig_setter=(
               sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table1, ratio=0.2),
               sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table2, ratio=0.2),
               default_calibration_qconfig_setter,
           ),
        )

        四、自定義模板

        自定義模板為 ModuleNameQconfigSetter,需要傳入模塊名和對應自定義的 qconfig,一般用于設置 fixed scale、配置 linear weight int16 等特殊需求,可以和固定模板,敏感度模板搭配使用。示例如下:

        from horizon_plugin_pytorch.quantization.qconfig_template import (
           calibration_8bit_weight_16bit_act_qconfig_setter,
           ModuleNameQconfigSetter,
        )
        from horizon_plugin_pytorch.quantization.qconfig import (
           get_qconfig,
           MSEObserver,
           MinMaxObserver,
           FixedScaleObserver,
           QConfig,
        )
        from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize

        # 手動設置某個算子的輸出scale
        op_name_output_fix_scale_qconfig = QConfig(
           output=FakeQuantize.with_args(
               observer=FixedScaleObserver,
               dtype=qint16,
               scale=0.0625,
           )
        )

        # 設置某個算子weight與輸出activation的量化類型
        # 校準時用MSEObserver,qat時用MinMaxObserver
        # 沒有weight的算子,配置了weight_dtype也不會起作用
        calib_weight_act_both_int16_qconfig = get_qconfig(
           observer=MSEObserver,
           weight_dtype=qint16,
           out_dtype=qint16,
        )

        calib_weight_act_both_int8_qconfig = get_qconfig(
           observer=MSEObserver,
           weight_dtype=qint8,
           out_dtype=qint8,
        )

        qat_weight_act_both_int16_qconfig = get_qconfig(
           observer=MinMaxObserver,
           weight_dtype=qint16,
           out_dtype=qint16,
           fix_scale=True,    # 是否固定scale
        )

        放在一塊簡單示例如下:

        from horizon_plugin_pytorch.quantization.qconfig_template import (
           default_qat_qconfig_setter,
           sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
           ModuleNameQconfigSetter,
        )

        table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")

        # 自動替換生成的算子只能通過 ModuleNameQconfigSetter 配置自定義 qconfig。
        module_name_to_qconfig = {
           "_generated_add_0": op_name_output_fix_scale_qconfig ,
        }

        qat_model = prepare(
           float_model,
           example_inputs=example_input,
           qconfig_setter=(
               ModuleNameQconfigSetter(module_name_to_qconfig),
               sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
               default_qat_qconfig_setter,
           ),
        )

        五、可運行的示例

        將網絡中 linear2 的 weight 配置為 int16 量化、輸入配置為 int8 量化、輸出配置為 int16 量化,其他算子激活使用 int16 量化,weight 使用 int8 量化。

        import torch
        from horizon_plugin_pytorch import set_march, March
        set_march(March.NASH_M)
        from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
        from horizon_plugin_pytorch.quantization import QuantStub
        from horizon_plugin_pytorch.quantization.hbdk4 import export
        from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
        from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
        from horizon_plugin_pytorch.dtype import qint8, qint16
        from torch.quantization import DeQuantStub
        import torch.nn as nn


        # 定義網絡結構
        class SmallModel(nn.Module):
           def __init__(self):
               super(SmallModel, self).__init__()
               # 第一個 Linear: 輸入 [2, 100, 256] -> 輸出 [2, 100, 256]
               self.linear1 = nn.Linear(256, 256)
               self.layernorm = nn.LayerNorm(256)  # 對最后一維進行歸一化
               self.relu = nn.ReLU()
               # 第二個 Linear: 輸入 [2, 100, 256] -> 輸出 [2, 100, 60]
               self.linear2 = nn.Linear(256, 60)
               # 第三個 Linear: 輸入 [2, 100, 60] -> 輸出 [2, 100, 60]
               self.linear3 = nn.Linear(60, 60)
               self.quant = QuantStub()
               self.dequant = DeQuantStub()

           def forward(self, x):
               x = self.quant(x)
               # 第一個 Linear
               x = self.linear1(x)  # [2, 100, 256]
               x = self.layernorm(x)  # [2, 100, 256]
               x = self.relu(x)  # [2, 100, 256]
               # 第二個 Linear
               x = self.linear2(x)  # [2, 100, 60]
               # 第三個 Linear
               x = self.linear3(x)
               x = self.dequant(x)
               return x

        example_input = torch.randn(2, 100, 256)
        model = SmallModel()

        # 前向傳播
        output = model(example_input)
        print("輸出形狀:", output.shape)

        # A global march indicating the target hardware version must be setted before prepare qat.
        set_march(March.NASH_M)

        calib_weight_act_both_int16_qconfig = get_qconfig(
           observer=MSEObserver,
           weight_dtype=qint16,
           out_dtype=qint16,
        )

        # layernorm沒有weight,配置了weight_dtype也不會起作用
        calib_weight_act_both_int8_qconfig = get_qconfig(
           observer=MSEObserver,
           weight_dtype=qint8,
           out_dtype=qint8,
        )

        qat_weight_act_both_int16_qconfig = get_qconfig(
           observer=MinMaxObserver,
           weight_dtype=qint16,
           out_dtype=qint16,
           fix_scale=True,
        )
        # 節點名稱,可以從model_check_result.txt中獲取,也可以從敏感度文件中獲取
        module_name_to_qconfig = {
           "layernorm": calib_weight_act_both_int8_qconfig,
           "linear2": calib_weight_act_both_int16_qconfig,  
        }

        calib_model = prepare(model.eval(), example_input,
                             qconfig_setter=(
                                 ModuleNameQconfigSetter(module_name_to_qconfig),
                                 calibration_8bit_weight_16bit_act_qconfig_setter,
                                 ),
                             )

        calib_model.eval()
        set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
        calib_model(example_input)

        calib_model.eval()                            
        set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
        calib_out = calib_model(example_input)

        qat_bc = export(calib_model, example_input)

        配置 add 單算子輸入和輸出均使用固定 scale

        import torch
        from horizon_plugin_pytorch import set_march, March
        set_march(March.NASH_E)
        from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
        from horizon_plugin_pytorch.quantization import QuantStub
        from horizon_plugin_pytorch.quantization.hbdk4 import export
        from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
        from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver, FixedScaleObserver, QConfig
        from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
        from horizon_plugin_pytorch.dtype import qint8, qint16
        from torch.quantization import DeQuantStub
        import torch.nn as nn


        class AddNet(nn.Module):
           def __init__(self):
               super(AddNet, self).__init__()
               self.quant_x = QuantStub()
               self.quant_y = QuantStub()
               self.dequant = DeQuantStub()

           def forward(self, x, y):
               x = self.quant_x(x)
               y = self.quant_y(y)
               z = torch.add(x, y)
               z = self.dequant(z)
               return z

        # 創建模型
        model = AddNet()

        # 生成兩個相同形狀的輸入張量
        torch.manual_seed(42)
        x = torch.randn(1, 1, 2, 6)
        y = torch.randn(1, 2, 2, 6)
        example_input = (x,y)

        # 前向傳播
        output = model(example_input[0], example_input[1])
        print("float輸出數據:", output)
        print("輸入形狀:", example_input[0].shape)
        print("輸出形狀:", output.shape)

        # A global march indicating the target hardware version must be setted before prepare qat.
        set_march(March.NASH_E)

        add_input_fix_scale_qconfig = QConfig(
           output=FakeQuantize.with_args(
               observer=FixedScaleObserver,
               dtype=qint16,
               scale=0.03125,
           )
        )
        add_output_fix_scale_qconfig = QConfig(
           output=FakeQuantize.with_args(
               observer=FixedScaleObserver,
               dtype=qint16,
               scale=0.0625,
           )
        )

        # 節點名稱,可以從model_check_result.txt中獲取,也可以從敏感度文件中獲取
        module_name_to_qconfig = {
           "quant_x": add_input_fix_scale_qconfig,

           "quant_y": add_input_fix_scale_qconfig,

           "_generated_add_0": add_output_fix_scale_qconfig,
        }

        calib_model = prepare(model.eval(), example_input,
                             qconfig_setter=(
                                 ModuleNameQconfigSetter(module_name_to_qconfig),
                                 calibration_8bit_weight_16bit_act_qconfig_setter,
                                 ),
                             )

        calib_model.eval()
        set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
        calib_model(example_input[0], example_input[1])

        calib_model.eval()                            
        set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
        calib_out = calib_model(example_input[0], example_input[1])
        print("calib輸出數據:", calib_out)

        qat_bc = export(calib_model, example_input)

        六、凍結部分網絡結構 qat 的配置

        補充常見凍結網絡結構,去進行 qat 的做法

        from horizon_plugin_pytorch.quantization import (
           QuantStub,
           prepare,
           set_fake_quantize,
           FakeQuantState,
        )
        #prepare QAT模型
        qat_model = prepare(
           model,
           example_inputs=xxx,
           qconfig_setter=(
               xxx,
           )
        )
        #加載calib權重
        qat_model.load_state_dict(torch.load("calib-checkpoint.ckpt"))
        #QAT訓練
        qat_model.train()
        #固定backbone部分的權重,requires_grad不影響drop bn的行為,需要與eval聯合用
        for param in qat_model.backbone.parameters():
           param.requires_grad = False
        #固定backbone部分的scale,eval只影響drop bn的行為,如果發生了backward仍然會改變權重,需要與requires_grad聯合使用
        qat_model.backbone.eval()
        set_fake_quantize(qat_model.backbone, FakeQuantState.VALIDATION)
        #配置head的FakeQuant為QAT狀態
        set_fake_quantize(qat_model.head, FakeQuantState.QAT)


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。




        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 闽侯县| 江源县| 保德县| 仁怀市| 彭州市| 手游| 望江县| 成安县| 烟台市| 武鸣县| 仁布县| 金华市| 安阳县| 冷水江市| 新乡县| 延寿县| 大城县| 土默特左旗| 荔波县| 聊城市| 德安县| 镇坪县| 沅江市| 齐齐哈尔市| 新建县| 绿春县| 安陆市| 乌恰县| 靖西县| 石林| 福清市| 凤冈县| 称多县| 灌云县| 兴山县| 新乡市| 中牟县| 当涂县| 莆田市| 山阴县| 丹阳市|