博客專欄

        EEPW首頁 > 博客 > 重參數新方法,ACNet的升級版DBB

        重參數新方法,ACNet的升級版DBB

        發布人:計算機視覺工坊 時間:2022-10-19 來源:工程師 發布文章
        作者丨ChaucerG

        來源丨手寫AI 

        1、開篇小記知識點1:

        并行多分支結構提取的特征具有更強的表征性;

        具體可以回憶參考DenseNet、VOVNet、Res2Net以及PeleeNet(后續均會有解讀)。
        知識點2:

        并行多分支結構會帶來更大別的計算開銷;

        具體可以參考CSPNet對此的分析。
        知識點3:

        使用 1×3 conv + 3×1 conv + 3×3 conv 代替原本一個的 3×3 conv的ACNet重參方法是有效的;

        具體可以參考ACNet的分析。
        知識點4:重參有沒有更好的呢?
        答:有,DBB可以說就是ACNet v2,全面升級!
        2、DBB 簡述

        Diverse Branch Block是繼ACNet的又一次對網絡結構重參數化的探索,即ACNet v2,DBB設計了一個類似Inception的模塊,以多分支的結構豐富卷積塊的特征空間,各分支結構包括平均池化,多尺度卷積等。最后在推理階段前,把多分支結構中進行重參數化,融合成一個主分支。加快推理速度的同時,順帶提升一下精度!圖片上圖給出了設計的DBB結構示意圖。類似Inception,它采用1×1、1×1?K×K、1×1?AVG等組合方式對原始K×K卷積進行增強。對于1×1?K×K分支,設置中間通道數等于輸入通道數并將1×1卷積初始化為Identity矩陣;其他分支則采用常規方式初始化。此外,在每個卷積后都添加BN層用于提供訓練時的非線性,這對于性能提升很有必要。

        3、DBB的實現

        以下是 DBB 的Pytorch實現:

        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        from dbb_transforms import *


        def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                    padding_mode='zeros')
        :

            conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                   stride=stride, padding=padding, dilation=dilation, groups=groups,
                                   bias=False, padding_mode=padding_mode)
            bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
            se = nn.Sequential()
            se.add_module('conv', conv_layer)
            se.add_module('bn', bn_layer)
            return se


        class IdentityBasedConv1x1(nn.Conv2d):
            def __init__(self, channels, groups=1):
                super(IdentityBasedConv1x1, self).__init__(in_channels=channels,
                                                           out_channels=channels,
                                                           kernel_size=1,
                                                           stride=1,
                                                           padding=0,
                                                           groups=groups,
                                                           bias=False)

                assert channels % groups == 0
                input_dim = channels // groups
                id_value = np.zeros((channels, input_dim, 11))
                for i in range(channels):
                    id_value[i, i % input_dim, 00] = 1
                self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
                nn.init.zeros_(self.weight)

            def forward(self, input):
                kernel = self.weight + self.id_tensor.to(self.weight.device)
                result = F.conv2d(input,
                                  kernel,
                                  None,
                                  stride=1,
                                  padding=0,
                                  dilation=self.dilation,
                                  groups=self.groups)
                return result

            def get_actual_kernel(self):
                return self.weight + self.id_tensor.to(self.weight.device)


        class BNAndPadLayer(nn.Module):
            def __init__(self,
                         pad_pixels,
                         num_features,
                         eps=1e-5,
                         momentum=0.1,
                         affine=True,
                         track_running_stats=True)
        :

                super(BNAndPadLayer, self).__init__()
                self.bn = nn.BatchNorm2d(num_features,
                                         eps,
                                         momentum,
                                         affine,
                                         track_running_stats)
                self.pad_pixels = pad_pixels

            def forward(self, input):
                output = self.bn(input)
                if self.pad_pixels > 0:
                    if self.bn.affine:
                        pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
                            self.bn.running_var + self.bn.eps)
                    else:
                        pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
                    output = F.pad(output, [self.pad_pixels] * 4)
                    pad_values = pad_values.view(1-111)
                    output[:, :, 0:self.pad_pixels, :] = pad_values
                    output[:, :, -self.pad_pixels:, :] = pad_values
                    output[:, :, :, 0:self.pad_pixels] = pad_values
                    output[:, :, :, -self.pad_pixels:] = pad_values
                return output

            @property
            def weight(self):
                return self.bn.weight

            @property
            def bias(self):
                return self.bn.bias

            @property
            def running_mean(self):
                return self.bn.running_mean

            @property
            def running_var(self):
                return self.bn.running_var

            @property
            def eps(self):
                return self.bn.eps


        class DiverseBranchBlock(nn.Module):

            def __init__(self,
                         in_channels,
                         out_channels,
                         kernel_size,
                         stride=1,
                         padding=0,
                         dilation=1,
                         groups=1,
                         internal_channels_1x1_3x3=None,
                         deploy=False,
                         nonlinear=None,
                         single_init=False)
        :

                super(DiverseBranchBlock, self).__init__()
                self.deploy = deploy

                if nonlinear is None:
                    self.nonlinear = nn.Identity()
                else:
                    self.nonlinear = nonlinear

                self.kernel_size = kernel_size
                self.out_channels = out_channels
                self.groups = groups
                assert padding == kernel_size // 2

                if deploy:
                    self.dbb_reparam = nn.Conv2d(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding,
                                                 dilation=dilation,
                                                 groups=groups,
                                                 bias=True)

                else:
                    self.dbb_origin = conv_bn(in_channels=in_channels,
                                              out_channels=out_channels,
                                              kernel_size=kernel_size,
                                              stride=stride,
                                              padding=padding,
                                              dilation=dilation,
                                              groups=groups)

                    self.dbb_avg = nn.Sequential()
                    if groups < out_channels:
                        self.dbb_avg.add_module('conv',
                                                nn.Conv2d(in_channels=in_channels,
                                                          out_channels=out_channels,
                                                          kernel_size=1,
                                                          stride=1,
                                                          padding=0,
                                                          groups=groups,
                                                          bias=False))

                        self.dbb_avg.add_module('bn',
                                                BNAndPadLayer(pad_pixels=padding,
                                                              num_features=out_channels))

                        self.dbb_avg.add_module('avg',
                                                nn.AvgPool2d(kernel_size=kernel_size,
                                                             stride=stride,
                                                             padding=0))

                        self.dbb_1x1 = conv_bn(in_channels=in_channels,
                                               out_channels=out_channels,
                                               kernel_size=1,
                                               stride=stride,
                                               padding=0,
                                               groups=groups)
                    else:
                        self.dbb_avg.add_module('avg',
                                                nn.AvgPool2d(kernel_size=kernel_size,
                                                             stride=stride,
                                                             padding=padding))

                    self.dbb_avg.add_module('avgbn',
                                            nn.BatchNorm2d(out_channels))

                    if internal_channels_1x1_3x3 is None:
                        # For mobilenet, it is better to have 2X internal channels
                        internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels

                    self.dbb_1x1_kxk = nn.Sequential()
                    if internal_channels_1x1_3x3 == in_channels:
                        self.dbb_1x1_kxk.add_module('idconv1',
                                                    IdentityBasedConv1x1(channels=in_channels, groups=groups))
                    else:
                        self.dbb_1x1_kxk.add_module('conv1',
                                                    nn.Conv2d(in_channels=in_channels,
                   out_channels=internal_channels_1x1_3x3,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0,
                                                    groups=groups,
                                                    bias=False))
                    self.dbb_1x1_kxk.add_module('bn1',
                                                BNAndPadLayer(pad_pixels=padding,
                       num_features=internal_channels_1x1_3x3,affine=True))
                    self.dbb_1x1_kxk.add_module('conv2',
                                                nn.Conv2d(in_channels=internal_channels_1x1_3x3, 
                      out_channels=out_channels,
                                                          kernel_size=kernel_size,
                                                          stride=stride,
                                                          padding=0,
                                                          groups=groups,
                                                          bias=False))
                    self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))

                #   The experiments reported in the paper used the default initialization of bn.weight (all as 1).
                #   But changing the initialization may be useful in some cases.
                if single_init:
                    #   Initialize the bn.weight of dbb_origin as 1 and others as 0.
                    #   This is not the default setting.
                    self.single_init()

            def get_equivalent_kernel_bias(self):
                k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
                                                   self.dbb_origin.bn)

                if hasattr(self, 'dbb_1x1'):
                    # 按照方式1進行conv+bn的融合
                    k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
                                                 self.dbb_1x1.bn)
                    # 按照方式方式6進行多尺度卷積的合并
                    k_1x1 = transVI_multiscale(k_1x1,
                                               self.kernel_size)
                else:
                    k_1x1, b_1x1 = 00

                if hasattr(self.dbb_1x1_kxk, 'idconv1'):
                    k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
                else:
                    k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
                # 按照方式1進行conv+bn的融合
                k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first,
                                                                 self.dbb_1x1_kxk.bn1)
                # 按照方式1進行conv+bn的融合
                k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
                                                                   self.dbb_1x1_kxk.bn2)
                # 按照方式3進行1x1卷積與kxk卷積的合并
                k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
                                                                      b_1x1_kxk_first,
                                                                      k_1x1_kxk_second,
                                                                      b_1x1_kxk_second,
                                                                      groups=self.groups)

                k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
                # 按照方式1進行conv+bn的融合
                k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
                                                                   self.dbb_avg.avgbn)
                if hasattr(self.dbb_avg, 'conv'):
                    # 按照方式1進行conv+bn的融合
                    k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight,
                                                                     self.dbb_avg.bn)
                    # 按照方式3進行1x1卷積與kxk卷積的合并
                    k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
                                                                          b_1x1_avg_first,
                                                                          k_1x1_avg_second,
                                                                          b_1x1_avg_second,
                                                                          groups=self.groups)
                else:
                    k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
                # 按照方式2進行分支的合并
                return transII_addbranch((k_origin,
                                          k_1x1,
                                          k_1x1_kxk_merged,
                                          k_1x1_avg_merged),
                                         (b_origin,
                                          b_1x1,
                                          b_1x1_kxk_merged,
                                          b_1x1_avg_merged))

            def switch_to_deploy(self):
                if hasattr(self, 'dbb_reparam'):
                    return
                kernel, bias = self.get_equivalent_kernel_bias()
                self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels,
                                             out_channels=self.dbb_origin.conv.out_channels,
                                             kernel_size=self.dbb_origin.conv.kernel_size,
                                             stride=self.dbb_origin.conv.stride,
                                             padding=self.dbb_origin.conv.padding,
                                             dilation=self.dbb_origin.conv.dilation,
                                             groups=self.dbb_origin.conv.groups, bias=True)
                self.dbb_reparam.weight.data = kernel
                self.dbb_reparam.bias.data = bias
                for para in self.parameters():
                    para.detach_()
                self.__delattr__('dbb_origin')
                self.__delattr__('dbb_avg')
                if hasattr(self, 'dbb_1x1'):
                    self.__delattr__('dbb_1x1')
                self.__delattr__('dbb_1x1_kxk')

            def forward(self, inputs):

                if hasattr(self, 'dbb_reparam'):
                    return self.nonlinear(self.dbb_reparam(inputs))

                out = self.dbb_origin(inputs)
                if hasattr(self, 'dbb_1x1'):
                    out += self.dbb_1x1(inputs)
                out += self.dbb_avg(inputs)
                out += self.dbb_1x1_kxk(inputs)
                return self.nonlinear(out)

            def init_gamma(self, gamma_value):
                if hasattr(self, "dbb_origin"):
                    torch.nn.init.constant_(self.dbb_origin.bn.weight,
                                            gamma_value)
                if hasattr(self, "dbb_1x1"):
                    torch.nn.init.constant_(self.dbb_1x1.bn.weight,
                                            gamma_value)
                if hasattr(self, "dbb_avg"):
                    torch.nn.init.constant_(self.dbb_avg.avgbn.weight,
                                            gamma_value)
                if hasattr(self, "dbb_1x1_kxk"):
                    torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight,
                                            gamma_value)

            def single_init(self):
                self.init_gamma(0.0)
                if hasattr(self, "dbb_origin"):
                    torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)

        話不多說,直接對比ONNX的輸出,就問你香不香!!!圖片

        4、參考

        [1].https://github.com/DingXiaoH/DiverseBranchBlock/blob/main/diversebranchblock.py

        本文僅做學術分享,如有侵權,請聯系刪文。


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 山丹县| 比如县| 理塘县| 寿光市| 镇江市| 田阳县| 武冈市| 吴堡县| 琼中| 玉环县| 扎赉特旗| 青田县| 闻喜县| 建德市| 新沂市| 黑龙江省| 黄浦区| 沭阳县| 叶城县| 贵港市| 乌拉特后旗| 万源市| 绥宁县| 阜南县| 临西县| 荔波县| 县级市| 长海县| 若尔盖县| 德清县| 内丘县| 资溪县| 静宁县| 抚松县| 广水市| 阿合奇县| 麻阳| 探索| 武乡县| 建昌县| 会昌县|