博客專欄

        EEPW首頁 > 博客 > 像教女朋友一樣的Deformable DETR論文精度+代碼詳解(2)

        像教女朋友一樣的Deformable DETR論文精度+代碼詳解(2)

        發(fā)布人:計算機視覺工坊 時間:2023-04-23 來源:工程師 發(fā)布文章
        4.4、Decoder

        詳細(xì)代碼注釋如下,這里要控制是否使用iterative bounding box refinement和two stage技巧。iterative bounding box refinement其實就是對參考點的位置進行微調(diào)。two stage方法其實就是通過參考點直接生成anchor但是只取最高置信度的前幾個,然后再送入decoder進行調(diào)整。intermediate數(shù)組是一個trick,每層Decoder都是可以輸出bbox和分類信息的,如果都利用起來算損失則成為auxiliary loss。

        class DeformableTransformerDecoderLayer(nn.Module):
           def __init__(self, d_model=256, d_ffn=1024,
                        dropout=0.1, activation="relu",
                        n_levels=4, n_heads=8, n_points=4):
               super().__init__()

               # cross attention
               self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
               self.dropout1 = nn.Dropout(dropout)
               self.norm1 = nn.LayerNorm(d_model)

               # self attention
               self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
               self.dropout2 = nn.Dropout(dropout)
               self.norm2 = nn.LayerNorm(d_model)

               # ffn
               self.linear1 = nn.Linear(d_model, d_ffn)
               self.activation = _get_activation_fn(activation)
               self.dropout3 = nn.Dropout(dropout)
               self.linear2 = nn.Linear(d_ffn, d_model)
               self.dropout4 = nn.Dropout(dropout)
               self.norm3 = nn.LayerNorm(d_model)

           @staticmethod
           def with_pos_embed(tensor, pos):
               return tensor if pos is None else tensor + pos

           def forward_ffn(self, tgt):
               tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
               tgt = tgt + self.dropout4(tgt2)
               tgt = self.norm3(tgt)
               return tgt

           def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
               # self attention
               q = k = self.with_pos_embed(tgt, query_pos)
               tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
               tgt = tgt + self.dropout2(tgt2)
               tgt = self.norm2(tgt)

               # cross attention
               tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
                                      reference_points,
                                      src, src_spatial_shapes, level_start_index, src_padding_mask)
               tgt = tgt + self.dropout1(tgt2)
               tgt = self.norm1(tgt)

               # ffn
               tgt = self.forward_ffn(tgt)

               return tgt


        class DeformableTransformerDecoder(nn.Module):
           def __init__(self, decoder_layer, num_layers, return_intermediate=False):
               super().__init__()
               self.layers = _get_clones(decoder_layer, num_layers)
               self.num_layers = num_layers
               self.return_intermediate = return_intermediate
               # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
               self.bbox_embed = None
               self.class_embed = None

           def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
                       query_pos=None, src_padding_mask=None):
               output = tgt

               # 用來存儲中間decoder輸出的 可以考慮是否用auxiliary loss
               intermediate = []
               intermediate_reference_points = []
               for lid, layer in enumerate(self.layers):
                   if reference_points.shape[-1] == 4:
                       reference_points_input = reference_points[:, :, None] \
                                                * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
                   else:
                       assert reference_points.shape[-1] == 2
                       reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
                   output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)

                   # hack implementation for iterative bounding box refinement
                   # iterative refinement是對decoder中的參考點進行微調(diào),類似cascade rcnn思想
                   if self.bbox_embed is not None:
                       tmp = self.bbox_embed[lid](output)
                       if reference_points.shape[-1] == 4:
                           new_reference_points = tmp + inverse_sigmoid(reference_points)
                           new_reference_points = new_reference_points.sigmoid()
                       else:
                           assert reference_points.shape[-1] == 2
                           new_reference_points = tmp
                           new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
                           new_reference_points = new_reference_points.sigmoid()
                       reference_points = new_reference_points.detach()

                   if self.return_intermediate:
                       intermediate.append(output)
                       intermediate_reference_points.append(reference_points)

               if self.return_intermediate:
                   return torch.stack(intermediate), torch.stack(intermediate_reference_points)

               return output, reference_points
        4.5、Deformable Transformer

        綜合模塊代碼如下

        class DeformableTransformer(nn.Module):
           def __init__(self, d_model=256, nhead=8,
                        num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
                        activation="relu", return_intermediate_dec=False,
                        num_feature_levels=4, dec_n_points=4,  enc_n_points=4,
                        two_stage=False, two_stage_num_proposals=300):
               super().__init__()

               self.d_model = d_model
               self.nhead = nhead
               self.two_stage = two_stage
               self.two_stage_num_proposals = two_stage_num_proposals

               encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
                                                                 dropout, activation,
                                                                 num_feature_levels, nhead, enc_n_points)
               self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)

               decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
                                                                 dropout, activation,
                                                                 num_feature_levels, nhead, dec_n_points)
               self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)

               self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))

               if two_stage:
                   self.enc_output = nn.Linear(d_model, d_model)
                   self.enc_output_norm = nn.LayerNorm(d_model)
                   self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
                   self.pos_trans_norm = nn.LayerNorm(d_model * 2)
               else:
                   self.reference_points = nn.Linear(d_model, 2)

               self._reset_parameters()

           def _reset_parameters(self):
               for p in self.parameters():
                   if p.dim() > 1:
                       nn.init.xavier_uniform_(p)
               for m in self.modules():
                   if isinstance(m, MSDeformAttn):
                       m._reset_parameters()
               if not self.two_stage:
                   xavier_uniform_(self.reference_points.weight.data, gain=1.0)
                   constant_(self.reference_points.bias.data, 0.)
               normal_(self.level_embed)

           def get_proposal_pos_embed(self, proposals):
               num_pos_feats = 128
               temperature = 10000
               scale = 2 * math.pi

               dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
               dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
               # N, L, 4
               proposals = proposals.sigmoid() * scale
               # N, L, 4, 128
               pos = proposals[:, :, :, None] / dim_t
               # N, L, 4, 64, 2
               pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
               return pos

           def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
               N_, S_, C_ = memory.shape
               base_scale = 4.0
               proposals = []
               _cur = 0
               for lvl, (H_, W_) in enumerate(spatial_shapes):
                   mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
                   valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
                   valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

                   grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
                                                   torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
                   grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

                   scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
                   grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
                   wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
                   proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
                   proposals.append(proposal)
                   _cur += (H_ * W_)
               output_proposals = torch.cat(proposals, 1)
               output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
               output_proposals = torch.log(output_proposals / (1 - output_proposals))
               output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
               output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))

               output_memory = memory
               output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
               output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
               output_memory = self.enc_output_norm(self.enc_output(output_memory))
               return output_memory, output_proposals

           def get_valid_ratio(self, mask):
               _, H, W = mask.shape
               valid_H = torch.sum(~mask[:, :, 0], 1)
               valid_W = torch.sum(~mask[:, 0, :], 1)
               valid_ratio_h = valid_H.float() / H
               valid_ratio_w = valid_W.float() / W
               valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
               return valid_ratio

           def forward(self, srcs, masks, pos_embeds, query_embed=None):
               assert self.two_stage or query_embed is not None

               # prepare input for encoder
               src_flatten = []
               mask_flatten = []
               lvl_pos_embed_flatten = []
               spatial_shapes = []
               for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
                   # 得到每一層feature map的batch size 通道數(shù)量 高寬
                   bs, c, h, w = src.shape
                   spatial_shape = (h, w)
                   spatial_shapes.append(spatial_shape)
                   # 將每層的feature map、mask、位置編碼拉平,并且加入到相關(guān)數(shù)組中
                   src = src.flatten(2).transpose(1, 2)
                   mask = mask.flatten(1)
                   pos_embed = pos_embed.flatten(2).transpose(1, 2)
                   # 位置編碼和可學(xué)習(xí)的每層編碼相加,表征類似 3D position
                   lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
                   lvl_pos_embed_flatten.append(lvl_pos_embed)
                   src_flatten.append(src)
                   mask_flatten.append(mask)
               # 在hidden_dim維度上進行拼接,也就是number token數(shù)量一樣的那個維度
               src_flatten = torch.cat(src_flatten, 1)
               mask_flatten = torch.cat(mask_flatten, 1)
               lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
               spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
               # 記錄每個level開始的索引以及有效的長寬(因為有mask存在,raw image的分辨率可能不統(tǒng)一) 具體查看get_valid_ratio函數(shù)
               # prod(1)計算h*w,cumsum(0)計算前綴和
               level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
               valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)

               # encoder
               memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)

               # prepare input for decoder
               bs, _, c = memory.shape
               # 是否使用兩階段模式
               if self.two_stage:
                   output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)

                   # hack implementation for two-stage Deformable DETR
                   enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
                   enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals

                   topk = self.two_stage_num_proposals
                   topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
                   topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
                   topk_coords_unact = topk_coords_unact.detach()
                   reference_points = topk_coords_unact.sigmoid()
                   init_reference_out = reference_points
                   pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
                   query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
               else:
                   # 這是非雙階段版本的Deformable DETR
                   # 將query_embed劃分為query_embed和tgt兩部分
                   query_embed, tgt = torch.split(query_embed, c, dim=1)
                   # 復(fù)制bs份
                   query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
                   tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
                   # nn.Linear得到每個object queries對應(yīng)的reference point, 這是decoder參考點的方法!!!
                   reference_points = self.reference_points(query_embed).sigmoid()
                   init_reference_out = reference_points

               # decoder
               hs, inter_references = self.decoder(tgt, reference_points, memory,
                                                   spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)

               inter_references_out = inter_references
               if self.two_stage:
                   return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
               return hs, init_reference_out, inter_references_out, None, None
        5、Experiment

        圖片圖4. Deformable DETR性能對比

        圖4可知,Deformable DETR不僅收斂速率比DETR快并且小目標(biāo)精度也高了許多。

        6、Conclusion

        Deformable DETR效率高并且收斂快,核心是Multi-Scale Deformable Attention Module。解決了DETR中收斂慢以及小目標(biāo)性能低的問題。

        Reference

        Deformable DETR:https://arxiv.org/pdf/2010.04159v4

        官方代碼倉庫:https://github.com/fundamentalvision/Deformable-DETR

        DCNv2:https://arxiv.org/pdf/2008.13535v2.pdf


        *博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。



        關(guān)鍵詞: AI

        相關(guān)推薦

        技術(shù)專區(qū)

        關(guān)閉
        主站蜘蛛池模板: 清水县| 西乌| 仁布县| 荥阳市| 衡南县| 娱乐| 会东县| 东宁县| 华坪县| 荥阳市| 米脂县| 台北市| 乌拉特中旗| 天全县| 瑞丽市| 屏山县| 青川县| 黄石市| 博爱县| 林甸县| 平远县| 拉萨市| 茶陵县| 海宁市| 玉龙| 伊通| 曲阜市| 通州市| 赞皇县| 偃师市| 容城县| 合山市| 汪清县| 息烽县| 延庆县| 凉城县| 普安县| 和政县| 阿尔山市| 政和县| 竹山县|