博客專欄

        EEPW首頁 > 博客 > 用Transformer思想的分類器進行小樣本分割

        用Transformer思想的分類器進行小樣本分割

        發布人:計算機視覺工坊 時間:2022-07-23 來源:工程師 發布文章

        來源丨 GiantPandaCV
        文章目錄

        • 1 前言
        • 2 CWT-for-FSS 整體架構
        • 3 求解方法
        • 4 實驗結果分析
        • 5 代碼和可視化
        • 6 總結
        • 7 參考鏈接

          1 前言

        圖片

        之前寫了幾篇醫學圖像分割相關的論文閱讀筆記,這次打算開個小樣本語義分割的新坑。這篇閱讀筆記中介紹的論文也是很久之前讀過的,接受在 ICCV 上,思路值得借鑒。代碼也已經跑過了,但是一直沒來得及整理,arXiv:https://arxiv.org/pdf/2108.03032.pdf 。

        針對小樣本語義分割問題,這篇論文提出一種更加簡潔的元學習范式,即只對分類器進行元學習,對特征編碼****采用常規分割模型訓練方式。也就是說只對 Classifier Weight Transformer(后面都簡稱 CWT)進行元學習的訓練,使得 CWT 可以動態地適應測試樣本,從而提高分割準確率。

        先來介紹下背景,傳統的語義分割通常由三部分組成:一個 CNN 編碼器,一個 CNN ****和一個區分前景像素與背景像素的簡單的分類器。

        當模型學習識別一個沒見過的新類時,需要分別訓練這三個部分進行元學習,如果新類別中圖象太少,那么同時訓練三個模塊就十分困難。

        在這篇文論文中, 提出一種新的訓練方法,在面對新類時只關注模型中最簡單的分類器。就像文中假設一個學習了大量圖片和信息的傳統分割網絡已經能夠從任何一張圖片中捕捉到充分的,有利于區分背景和前景的信息,無論訓練時是否遇到了同類的圖。那么面對少樣本的新類時,只要對分類器進行元學習即可。

        這篇閱讀筆記首先概述了 CWT-for-FSS 的整體結構,再介紹了訓練方法,然后分析了實驗結果,最后對代碼訓練做了簡單的指南。

        2 CWT-for-FSS 整體架構

        一個小樣本分類系統一般由三部分構成:編碼器,****和分類器。

        其中,前兩個模塊模型比較復雜,最后一個分類器結構簡單。小樣本分類方法通常在元學習的過程中更新所有模塊或者除編碼器外的模塊,而所利用更新模塊的數據僅僅有幾個樣本。

        在這樣的情況下,模型更新的參數量相比于數據提供的信息量過多,從而不足以優化模型參數。基于此分析,文章中提出了一個全新的元學習訓練范式,即只對分類器進行元學習。兩種方式的對比,如下圖:

        圖片

        值得注意的是,我們知道在 Support set 上迭代的模型往往不能很好地作用在 Query set 上,因為同類別的圖像也可能存在差異。

        利用 CWT 來解決這個問題,就是這篇論文的重點。也就是說,可以動態地利用 Query set 的特征信息來進一步更新分類器,來提高分割的精準度。整體架構如下圖:

        圖片

        借助 Transformer 的思想,將分類器權重轉化為 Query,將 Query set 提取出來的特征轉化為 Key 和 Value,然后根據這三個值調整分類器權重,最后通過殘差連接,與原分類器參數求和。

        3 求解方法

        首先,對網絡進行預訓練,這里就不再贅述。然后就是對 CWT 進行元學習,分兩步,第一步是內循環,和預訓練一樣,根據支持集上的圖片和 mask 進行訓練,不過只修改分類器參數。

        當新類樣本數夠大時,只使用外循環,即只更新分類器,就能匹敵 SOTA,但是當面對小樣本時,表現就不盡如人意。第二步是外循環,根據每一副查詢圖片,微調分類器參數。

        微調后的參數只針對這一張查詢圖片,不能用于其他查詢圖象,也不覆蓋修改原分類器參數。

        假設一張查詢圖像,提取出的特征為F,形狀為 n × d,n為單通道的像素數,d為通道數,則全連接分類器參數 w 形狀為  2 × d。參照 Transformer,令 Query = w × Wq, Key = F × Wk, Value = F × Wv,其中 Wq、Wk 和Wv 均為可學習的 d × da 矩陣,d 為維度數,da 為人為規定的隱藏層維度數,本文將其設置為了 2048。根據這三個數,以及殘差鏈接,可求得新分類器權重為:

        圖片

        其中,Ψ 是一個線性層,輸入維度為 da,輸出維度為 d。softmax 針對的維度為行。求出每一張查詢集對應的權重后,只需要把特征 F 塞進 w* 就好。

        4 實驗結果分析

        這部分展示論文中的實驗結果,在兩個標準小樣本分割數據集 PASCAL 和 COCO 上,文中的方法在大多數情況下取得了最優的結果。

        圖片圖片

        此外,文中實驗在一種跨數據集的情景下測試了模型的性能,可以看出 CWT-for-FSS 方法具有了很好的魯棒性。

        圖片

        最后,可視化結果如下:

        圖片5 代碼和可視化

        代碼已經開開源在 https://github.com/lixiang007666/CWT-for-FSS 上,最后我們簡單看下如何使用。倉庫提供了訓練腳本:

        sh scripts/train.sh pascal 0 [0] 50 1

        后面幾個參數依次為數據集指定、split 數、gpus、layers 和 k-shots。如果需要多卡訓練,gpus 為[0,1,3,4,5,6,7],layers 除了 50 還可以指定為 101,說明 backbone 為 resnet101。對應的,測試的腳本為 scripts/test.sh。

        此外,倉庫中的代碼并沒有提供可視化腳本。如果需要可視化分割結果,可以參考下面的代碼。首先將以下內容插入主 test.py 腳本(在 classes.append() 下方):

                        logits_q[i] = pred_q.detach()
                        gt_q[i, 0] = q_label
                        classes.append([class_.item() for class_ in subcls])
                        # Insert visualization routine here 
                        if args.visualize:
                            output = {}
                            output['query'], output['support'] = {}, {}
                            output['query']['gt'], output['query']['pred'] =     vis_res(qry_oris[0][0],      qry_oris[1], F.interpolate(pred_q, size=q_label.size()[1:], mode='bilinear', align_corners=True).squeeze().detach().cpu().numpy())
                            spprt_label = torch.cat(spprt_oris[1], 0)
                            output['support']['gt'], output['support']['pred'] = vis_res(spprt_oris[0][0][0],spprt_label, output_support.squeeze().detach().cpu().numpy())

                            save_image = np.concatenate((output['support']['gt'], output['query']['gt'], output['query']['pred']), 1)
                            cv2.imwrite('./analysis/' + qry_oris[0][0].split('/')[-1] ,   save_image)

        主要可視化函數vis_res如下:

        def resize_image_label(image, label, size = 473):
            import cv2
            def find_new_hw(ori_h, ori_w, test_size):
                if ori_h >= ori_w:
                    ratio = test_size * 1.0 / ori_h
                    new_h = test_size
                    new_w = int(ori_w * ratio)
                elif ori_w > ori_h:
                    ratio = test_size * 1.0 / ori_w
                    new_h = int(ori_h * ratio)
                    new_w = test_size

                if new_h % 8 != 0:
                    new_h = (int(new_h / 8)) * 8
                else:
                    new_h = new_h
                if new_w % 8 != 0:
                    new_w = (int(new_w / 8)) * 8
                else:
                    new_w = new_w
                return new_h, new_w

            # Step 1: resize while keeping the h/w ratio. The largest side (i.e height or width) is reduced to $size.
            #                                             The other is reduced accordingly
            test_size = size
            new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size)

            image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)),
                                    interpolation=cv2.INTER_LINEAR)

            # Step 2: Pad wtih 0 whatever needs to be padded to get a ($size, $size) image
            back_crop = np.zeros((test_size, test_size, 3))

            back_crop[:new_h, :new_w, :] = image_crop
            image = back_crop

            # Step 3: Do the same for the label (the padding is 255)
            s_mask = label
            new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size)
            s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),
                                interpolation=cv2.INTER_NEAREST)
            back_crop_s_mask = np.ones((test_size, test_size)) * 255
            back_crop_s_mask[:new_h, :new_w] = s_mask
            label = back_crop_s_mask

            return image, label
        def vis_res(image_path, label, pred):

            import cv2
            def read_image(path):
                image = cv2.imread(path, cv2.IMREAD_COLOR)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = np.float32(image)
                return image

            def label_to_image(label):
                label = label == 1.
                label = np.float32(label) * 255.
                placeholder = np.zeros_like(label)
                label = np.concatenate((label, placeholder), 0)
                label = np.concatenate((label, placeholder), 0)
                label = np.transpose(label, (1,2,0))
                return label

            def blend_image_label(image, label):
                result = 0.5 * image + 0.5 * label
                result = np.float32(result)
                result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)

                return result

            def pred_to_image(label):
                label = np.float32(label) * 255.
                placeholder = np.zeros_like(label)
                placeholder = np.concatenate((placeholder, placeholder), 0)
                label = np.concatenate((placeholder, label), 0)
                label = np.transpose(label, (1,2,0))
                return label

            image = read_image(image_path)
            label = label.squeeze().detach().cpu().numpy()
            image, label = resize_image_label(image, label)
            label = label_to_image(np.expand_dims(label, 0))
            out_image_gt = blend_image_label(image, label)
            #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_gt.jpg',   out_image)

            pred  = np.argmax(pred, 0)
            pred = np.expand_dims(pred, 0)
            pred = pred_to_image(pred)
            out_image_pred = blend_image_label(image, pred)
            #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_pred.jpg',   out_image)

            return out_image_gt, out_image_pred

        注意,是在每次測試迭代結束時可視化分割結果。

        6 總結

        這篇閱讀筆記介紹了一種新的元學習訓練范式來解決小樣本語義分割問題。相比于現有的方法,這種方法更加簡潔有效,只對分類器進行元學習。

        重要的是,為了解決類內差異問題,提出 Classifier Weight Transformer 利用 Query 特征信息來迭代訓練分類器,從而獲得更加魯棒和精準的分割效果。

        7 參考鏈接

        • https://github.com/zhiheLu/CWT-for-FSS
        • https://arxiv.org/pdf/2108.03032.pdf


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 仁寿县| 新津县| 三台县| 白城市| 高邮市| 开阳县| 浙江省| 台东县| 周口市| 沙雅县| 开化县| 遂平县| 恩平市| 龙门县| 盐亭县| 平远县| 济源市| 宁德市| 海伦市| 秦皇岛市| 广德县| 平陆县| 临西县| 崇信县| 綦江县| 上饶县| 建宁县| 康保县| 临夏县| 庐江县| 门源| 天门市| 河南省| 疏附县| 克东县| 赣州市| 姚安县| 岗巴县| 枣强县| 诸暨市| 永年县|