博客專欄

        EEPW首頁 > 博客 > pytorch可視化教程:訓練過程+網絡結構(2)

        pytorch可視化教程:訓練過程+網絡結構(2)

        發布人:計算機視覺工坊 時間:2022-09-26 來源:工程師 發布文章
        2.2 HiddenLayer可視化訓練過程

        tensorboard的圖像很華麗,但是使用過程相較于其他的工具包較為繁瑣,所以小網絡一般沒必要使用tensorboard。

         import hiddenlayer as hl
         import time
         
         # 記錄訓練過程的指標
         history = hl.History()
         # 使用canvas進行可視化
         canvas = hl.Canvas()
         
         # 獲取優化器和損失函數
         optimizer = torch.optim.Adam(MyConvNet.parameters(), lr=3e-4)
         loss_func = nn.CrossEntropyLoss()
         log_step_interval = 100      # 記錄的步數間隔
         
         for epoch in range(5):
             print("epoch:", epoch)
             # 每一輪都遍歷一遍數據加載器
             for step, (x, y) in enumerate(train_loader):
                 # 前向計算->計算損失函數->(從損失函數)反向傳播->更新網絡
                 predict = MyConvNet(x)
                 loss = loss_func(predict, y)
                 optimizer.zero_grad()   # 清空梯度(可以不寫)
                 loss.backward()     # 反向傳播計算梯度
                 optimizer.step()    # 更新網絡
                 global_iter_num = epoch * len(train_loader) + step + 1  # 計算當前是從訓練開始時的第幾步(全局迭代次數)
                 if global_iter_num % log_step_interval == 0:
                     # 控制臺輸出一下
                     print("global_step:{}, loss:{:.2}".format(global_iter_num, loss.item()))
                     # 在測試集上預測并計算正確率
                     test_predict = MyConvNet(test_data_x)
                     _, predict_idx = torch.max(test_predict, 1)  # 計算softmax后的最大值的索引,即預測結果
                     acc = accuracy_score(test_data_y, predict_idx)
         
                     # 以epoch和step為索引,創建日志字典
                     history.log((epoch, step),
                                 train_loss=loss,
                                 test_acc=acc,
                                 hidden_weight=MyConvNet.fc[2].weight)
         
                     # 可視化
                     with canvas:
                         canvas.draw_plot(history["train_loss"])
                         canvas.draw_plot(history["test_acc"])
                         canvas.draw_image(history["hidden_weight"])

        不同于tensorboard,hiddenlayer會在程序運行的過程中動態生成圖像,而不是模型訓練完后

        下面為模型訓練的某一時刻的截圖:

        圖片

        三、使用Visdom進行可視化

        Visdom是Facebook為pytorch開發的一塊可視化工具。類似于tensorboard,visdom也是通過在本地啟動前端服務器來實現可視化的,而在具體操作上,visdom又類似于matplotlib.pyplot。所以使用起來很靈活。

        首先先安裝visdom庫,然后補坑。由于啟動前端服務器需要大量依賴項,所以在第一次啟動時可能會很慢(需要下載前端三板斧的依賴項),解決方法請見這里。

        先導入需要的第三方庫:

         from visdom import Visdom
         from sklearn.datasets import  load_iris
         import torch
         import numpy as np
         from PIL import Image

        matplotlib里,用戶繪圖可以通過plt這個對象來繪圖,在visdom中,同樣需要一個繪圖對象,我們通過vis = Visdom()來獲取。具體繪制時,由于我們會一次畫好幾張圖,所以visdom要求用戶在繪制時指定當前繪制圖像的窗口名字(也就是win這個參數);除此之外,為了到時候顯示的分塊,用戶還需要指定繪圖環境env,這個參數相同的圖像,最后會顯示在同一張頁面上。

        繪制線圖(相當于matplotlib中的plt.plot)

         # 繪制圖像需要的數據
         iris_x, iris_y = load_iris(return_X_y=True)
         
         # 獲取繪圖對象,相當于plt
         vis = Visdom()
         
         # 添加折線圖
         x = torch.linspace(-66100).view([-11])
         sigmoid = torch.nn.Sigmoid()
         sigmoid_y = sigmoid(x)
         tanh = torch.nn.Tanh()
         tanh_y = tanh(x)
         relu = torch.nn.ReLU()
         relu_y = relu(x)
         # 連接三個張量
         plot_x = torch.cat([x, x, x], dim=1)
         plot_y = torch.cat([sigmoid_y, tanh_y, relu_y], dim=1)
         # 繪制線性圖
         vis.line(X=plot_x, Y=plot_y, win="line plot", env="main",
                  opts={
                      "dash" : np.array(["solid""dash""dashdot"]),
                      "legend" : ["Sigmoid""Tanh""ReLU"]
                  })

        繪制散點圖:

         # 繪制2D和3D散點圖
         # 參數Y用來指定點的分布,win指定圖像的窗口名稱,env指定圖像所在的環境,opts通過字典來指定一些樣式
         vis.scatter(iris_x[ : , 0 : 2], Y=iris_y+1, win="windows1", env="main")
         vis.scatter(iris_x[ : , 0 : 3], Y=iris_y+1, win="3D scatter", env="main",
                     opts={
                         "markersize" : 4,   # 點的大小
                         "xlabel" : "特征1",
                         "ylabel" : "特征2"
                     })

        繪制莖葉圖:

         # 添加莖葉圖
         x = torch.linspace(-66100).view([-11])
         y1 = torch.sin(x)
         y2 = torch.cos(x)
         
         # 連接張量
         plot_x = torch.cat([x, x], dim=1)
         plot_y = torch.cat([y1, y2], dim=1)
         # 繪制莖葉圖
         vis.stem(X=plot_x, Y=plot_y, win="stem plot", env="main",
                  opts={
                      "legend" : ["sin""cos"],
                      "title" : "莖葉圖"
                  })

        繪制熱力圖:

         # 計算鳶尾花數據集特征向量的相關系數矩陣
         iris_corr = torch.from_numpy(np.corrcoef(iris_x, rowvar=False))
         # 繪制熱力圖
         vis.heatmap(iris_corr, win="heatmap", env="main",
                     opts={
                         "rownames" : ["x1""x2""x3""x4"],
                         "columnnames" : ["x1""x2""x3""x4"],
                         "title" : "熱力圖"
                     })

        可視化圖片,這里我們使用自定義的env名MyPlotEnv

         # 可視化圖片
         img_Image = Image.open("./example.jpg")
         img_array = np.array(img_Image.convert("L"), dtype=np.float32)
         img_tensor = torch.from_numpy(img_array)
         print(img_tensor.shape)
         
         # 這次env自定義
         vis.image(img_tensor, win="one image", env="MyPlotEnv",
                   opts={
                       "title" : "一張圖像"
                   })

        可視化文本,同樣在MyPlotEnv中繪制:

         # 可視化文本
         text = "hello world"
         vis.text(text=text, win="text plot", env="MyPlotEnv",
                  opts={
                      "title" : "可視化文本"
                  })
         

        運行上述代碼,再通過在終端中輸入python3 -m visdom.server啟動服務器,然后根據終端返回的URL,在谷歌瀏覽器中訪問這個URL,就可以看到圖像了。

        圖片圖片

        在Environment中輸入不同的env參數可以看到我們在不同環境下繪制的圖片。對于分類圖集特別有用。

        在終端中按下Ctrl+C可以終止前端服務器。

        進一步

        需要注意,如果你的前端服務器停掉了,那么所有的圖片都會丟失,因為此時的圖像的數據都是駐留在內存中,而并沒有dump到本地磁盤。那么如何保存當前visdom中的可視化結果,并在將來復用呢?其實很簡單,比如我現在有一堆來之不易的Mel頻譜圖:

        圖片

        點擊Manage Views

        圖片

        點擊fork->save:(此處我只保存名為normal的env)

        圖片

        接著,在你的User目錄下(Windows是C:\Users\賬戶.visdom文件夾,Linux是在~.visdom文件夾下),可以看到保存好的env:

        圖片

        它是以json文件格式保存的,那么如果你保存完后再shut down當前的前端服務器,圖像數據便不會丟失。

        好的,現在在保存完你珍貴的數據后,請關閉你的visdom前端服務器。然后再啟動它。

        如何查看保存的數據呢?很簡答,下次打開visdom前端后,visdom會在.visdom文件夾下讀取所有的保存數據完成初始化,這意味著,你直接啟動visdom,其他什么也不用做就可以看到之前保存的數據啦!

        那么如何服用保存的數據呢?既然你都知道了visdom保存的數據在哪里,那么直接通過python的json包來讀取這個數據文件,然后做解析就可以了,這是方法一,演示如下:

        import json

        with open(r"...\.visdom\normal.json""r", encoding="utf-8"as f:
            dataset : dict = json.load(f)

        jsons : dict = dataset["jsons"]      # 這里存著你想要恢復的數據
        reload : dict = dataset["reload"]    # 這里存著有關窗口尺寸的數據 

        print(jsons.keys())     # 查看所有的win

        out:

        dict_keys(['jsons''reload'])
        dict_keys(['1.wav''2.wav''3.wav''4.wav''5.wav''6.wav''7.wav''8.wav''9.wav''10.wav''11.wav''12.wav''13.wav''14.wav'])

        但這么做不是很優雅,所以visdom封裝了第二種方法。你當然可以通過訪問文件夾.visdom來查看當前可用的env,但是也可以這么做:

        from visdom import Visdom

        vis = Visdom()
        print(vis.get_env_list())

        out:

        Setting up a new session...
        ['main''normal']

        在獲取了可用的環境名后,你可以通過get_window_data方法來獲取指定env、指定win下的圖像數據。請注意,該方法返回str,故需要通過json來解析:

        from visdom import Visdom
        import json

        vis = Visdom()

        window = vis.get_window_data(win="1.wav", env="normal")    
        window = json.loads(window)         # window 是 str,需要解析為字典

        content = window["content"]
        data = content["data"][0]
        print(data.keys())

        out:

        Setting up a new session...
        dict_keys(['z''x''y''zmin''zmax''type''colorscale'])

        通過索引這些keys,相信想復用原本的圖像數據并不困難。


        本文僅做學術分享,如有侵權,請聯系刪文。


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 疏附县| 虹口区| 铁岭市| 太仓市| 绥宁县| 邯郸市| 资溪县| 夹江县| 尚志市| 莲花县| 镇远县| 英超| 东丽区| 清丰县| 厦门市| 双辽市| 萝北县| 高雄县| 武冈市| 乾安县| 开江县| 夏津县| 彩票| 东明县| 新营市| 靖江市| 宁晋县| 濮阳县| 福泉市| 文安县| 将乐县| 凯里市| 大石桥市| 乐昌市| 高台县| 稷山县| 陈巴尔虎旗| 大厂| 永济市| 杭州市| 阳高县|