博客專欄

        EEPW首頁 > 博客 > 用 Java 訓練深度學習模型,原來這么簡單

        用 Java 訓練深度學習模型,原來這么簡單

        發布人:AI科技大本營 時間:2020-11-10 來源:工程師 發布文章

        以下文章來源于HelloGitHub ,作者Keerthan&Lanking

        前言

        很長時間以來,Java 都是一個很受企業歡迎的編程語言。得益于豐富的生態以及完善維護的包和框架,Java 擁有著龐大的開發者社區。盡管深度學習應用的不斷演進和落地,提供給 Java 開發者的框架和庫卻十分短缺。現今主要流行的深度學習模型都是用 Python 編譯和訓練的。對于 Java 開發者而言,如果要進軍深度學習界,就需要重新學習并接受一門新的編程語言同時還要學習深度學習的復雜知識。這使得大部分 Java 開發者學習和轉型深度學習開發變得困難重重。

        為了減少 Java 開發者學習深度學習的成本,AWS 構建了 Deep Java Library (DJL),一個為 Java 開發者定制的開源深度學習框架。它為 Java 開發者對接主流深度學習框架提供了一個橋梁。

        1604994462100236.jpg

        在這篇文章中,我們會嘗試用 DJL 構建一個深度學習模型并用它訓練 MNIST 手寫數字識別任務。

        什么是深度學習?

        在我們正式開始之前,我們先來了解一下機器學習和深度學習的基本概念。

        機器學習是一個通過利用統計學知識,將數據輸入到計算機中進行訓練并完成特定目標任務的過程。這種歸納學習的方法可以讓計算機學習一些特征并進行一系列復雜的任務,比如識別照片中的物體。由于需要寫復雜的邏輯以及測量標準,這些任務在傳統計算科學領域中很難實現。

        深度學習是機器學習的一個分支,主要側重于對于人工神經網絡的開發。人工神經網絡是通過研究人腦如何學習和實現目標的過程中歸納而得出一套計算邏輯。它通過模擬部分人腦神經間信息傳遞的過程,從而實現各類復雜的任務。深度學習中的“深度”來源于我們會在人工神經網絡中編織構建出許多層(layer)從而進一步對數據信息進行更深層的傳導。深度學習技術應用范圍十分廣泛,現在被用來做目標檢測、動作識別、機器翻譯、語意分析等各類現實應用中。

        訓練 MNIST 手寫數字識別

        3.1 項目配置

        你可以用如下的 gradle 配置來引入依賴項。在這個案例中,我們用 DJL 的 api 包 (核心 DJL 組件) 和 basicdataset 包 (DJL 數據集) 來構建神經網絡和數據集。這個案例中我們使用了 MXNet 作為深度學習引擎,所以我們會引入 mxnet-engine 和 mxnet-native-auto 兩個包。這個案例也可以運行在 PyTorch 引擎下,只需要替換成對應的軟件包即可。

        plugins {

        id 'java'

        }

        repositories {                           

        jcenter()

        }

        dependencies {

        implementation platform("ai.djl:bom:0.8.0")

        implementation "ai.djl:api"

        implementation "ai.djl:basicdataset"

        // MXNet

        runtimeOnly "ai.djl.mxnet:mxnet-engine"

        runtimeOnly "ai.djl.mxnet:mxnet-native-auto"

        }

        3.2 NDArray 和 NDManager

        NDArray 是 DJL 存儲數據結構和數學運算的基本結構。一個 NDArray 表達了一個定長的多維數組。NDArray 的使用方法類似于 Python 中的 numpy.ndarray。

        NDManager 是 NDArray 的老板。它負責管理 NDArray 的產生和回收過程,這樣可以幫助我們更好的對 Java 內存進行優化。每一個 NDArray 都會是由一個 NDManager 創造出來,同時它們會在 NDManager 關閉時一同關閉。NDManager 和 NDArray 都是由 Java 的 AutoClosable 構建,這樣可以確保在運行結束時及時進行回收。

        Model

        在 DJL 中,訓練和推理都是從 Model class 開始構建的。我們在這里主要講訓練過程中的構建方法。下面我們為 Model 創建一個新的目標。因為 Model 也是繼承了 AutoClosable 結構體,我們會用一個 try block 實現:

        try (Model model = Model.newInstance()) {

            ...

        // 主體訓練代碼

            ...

        }

        準備數據

        MNIST(Modified National Institute of Standards and Technology)數據庫包含大量手寫數字的圖,通常被用來訓練圖像處理系統。DJL 已經將 MNIST 的數據集收錄到了 basicdataset 數據集里,每個 MNIST 的圖的大小是 28 x 28。如果你有自己的數據集,你也可以通過 DJL 數據集導入教程來導入數據集到你的訓練任務中。

        數據集導入教程: 

        http://docs.djl.ai/docs/development/how_to_use_dataset.html#how-to-create-your-own-dataset

        int batchSize = 32; // 批大小

        Mnist trainingDataset = Mnist.builder()

                .optUsage(Usage.TRAIN) // 訓練集

                .setSampling(batchSize, true)

                .build();

        Mnist validationDataset = Mnist.builder()

                .optUsage(Usage.TEST) // 驗證集

                .setSampling(batchSize, true)

                .build();

        這段代碼分別制作出了訓練和驗證集。同時我們也隨機排列了數據集從而更好的訓練。除了這些配置以外,你也可以添加對于圖片的進一步處理,比如設置圖片大小,對圖片進行歸一化等處理。

        制作 model(建立 Block)

        當你的數據集準備就緒后,我們就可以構建神經網絡了。在 DJL 中,神經網絡是由 Block(代碼塊)構成的。一個 Block 是一個具備多種神經網絡特性的結構。它們可以代表 一個操作, 神經網絡的一部分,甚至是一個完整的神經網絡。然后 Block 可以順序執行或者并行。同時 Block 本身也可以帶參數和子 Block。這種嵌套結構可以幫助我們構造一個復雜但又不失維護性的神經網絡。在訓練過程中,每個 Block 中附帶的參數會被實時更新,同時也包括它們的各個子 Block。這種遞歸更新的過程可以確保整個神經網絡得到充分訓練。

        當我們構建這些 Block 的過程中,最簡單的方式就是將它們一個一個的嵌套起來。直接使用準備好 DJL 的 Block 種類,我們就可以快速制作出各類神經網絡。

        根據幾種基本的神經網絡工作模式,我們提供了幾種 Block 的變體。SequentialBlock 是為了應對順序執行每一個子 Block 構造而成的。它會將前一個子 Block 的輸出作為下一個 Block 的輸入 繼續執行到底。與之對應的,是 ParallelBlock 它用于將一個輸入并行輸入到每一個子 Block 中,同時將輸出結果根據特定的合并方程合并起來。最后我們說一下 LambdaBlock,它是幫助用戶進行快速操作的一個 Block,其中并不具備任何參數,所以也沒有任何部分在訓練過程中更新。

        1604994548329472.jpg

        我們來嘗試創建一個基本的 多層感知機(MLP)神經網絡吧。多層感知機是一個簡單的前向型神經網絡,它只包含了幾個全連接層 (LinearBlock)。那么構建這個網絡,我們可以直接使用 SequentialBlock。

        int input = 28 * 28; // 輸入層大小

        int output = 10; // 輸出層大小

        int[] hidden = new int[] {128, 64}; // 隱藏層大小

        SequentialBlock sequentialBlock = new SequentialBlock();

        sequentialBlock.add(Blocks.batchFlattenBlock(input));

        for (int hiddenSize : hidden) {

        // 全連接層

            sequentialBlock.add(Linear.builder().setUnits(hiddenSize).build());

        // 激活函數

            sequentialBlock.add(activation);

        }

        sequentialBlock.add(Linear.builder().setUnits(output).build());

        當然 DJL 也提供了直接就可以拿來用的 MLP Block :

        Block block = new Mlp(

                Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,

                Mnist.NUM_CLASSES,

        new int[] {128, 64});

        訓練

        當我們準備好數據集和神經網絡之后,就可以開始訓練模型了。在深度學習中,一般會由下面幾步來完成一個訓練過程:

        1.jpg

        初始化:我們會對每一個 Block 的參數進行初始化,初始化每個參數的函數都是由 設定的 Initializer 決定的。

        前向傳播:這一步將輸入數據在神經網絡中逐層傳遞,然后產生輸出數據。

        計算損失:我們會根據特定的損失函數 Loss 來計算輸出和標記結果的偏差。

        反向傳播:在這一步中,你可以利用損失反向求導算出每一個參數的梯度。

        更新權重:我們會根據選擇的優化器(Optimizer)更新每一個在 Block 上參數的值。

        DJL 利用了 Trainer 結構體精簡了整個過程。開發者只需要創建 Trainer 并指定對應的 Initializer、Loss 和 Optimizer 即可。這些參數都是由 TrainingConfig 設定的。下面我們來看一下具體的參數設置:

        TrainingListener:這個是對訓練過程設定的監聽器。它可以實時反饋每個階段的訓練結果。這些結果可以用于記錄訓練過程或者幫助 debug 神經網絡訓練過程中的問題。用戶也可以定制自己的 TrainingListener 來對訓練過程進行監聽。

        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())

            .addEvaluator(new Accuracy())

            .addTrainingListeners(TrainingListener.Defaults.logging());

        try (Trainer trainer = model.newTrainer(config)){

        // 訓練代碼

        }

        當訓練器產生后,我們可以定義輸入的 Shape。之后就可以調用 fit 函數來進行訓練。fit 函數會對輸入數據,訓練多個 epoch 是并最終將結果存儲在本地目錄下。

        /*

         * MNIST 包含 28x28 灰度圖片并導入成 28 * 28 NDArray。

         * 第一個維度是批大小, 在這里我們設置批大小為 1 用于初始化。

         */

        Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);

        int numEpoch = 5;

        String outputDir = "/build/model";

        // 用輸入初始化 trainer

        trainer.initialize(inputShape);

        TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");

        這就是訓練過程的全部流程了!用 DJL 訓練是不是還是很輕松的?之后看一下輸出每一步的訓練結果。如果你用了我們默認的監聽器,那么輸出是類似于下圖:

        [INFO ] - Downloading libmxnet.dylib ...

        [INFO ] - Training on: cpu().

        [INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms.

        Training:    100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/sec

        Validating:  100% |████████████████████████████████████████|

        [INFO ] - Epoch 1 finished.

        [INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24

        [INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14

        Training:    100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/sec

        Validating:  100% |████████████████████████████████████████|

        [INFO ] - Epoch 2 finished.NG [1m 41s]

        [INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10

        [INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09

        [INFO ] - train P50: 12.756 ms, P90: 21.044 ms

        [INFO ] - forward P50: 0.375 ms, P90: 0.607 ms

        [INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms

        [INFO ] - backward P50: 0.608 ms, P90: 0.973 ms

        [INFO ] - step P50: 0.543 ms, P90: 0.869 ms

        [INFO ] - epoch P50: 35.989 s, P90: 35.989 s

        當訓練結果完成后,我們可以用剛才的模型進行推理來識別手寫數字。如果剛才的內容哪里有不是很清楚的,可以參照下面兩個鏈接直接嘗試訓練。

        手寫數據集訓練:

        https://docs.djl.ai/examples/docs/train_mnist_mlp.html

        手寫數據集推理:

        https://docs.djl.ai/jupyter/tutorial/03_image_classification_with_your_model.html

        最后

        在這個文章中,我們介紹了深度學習的基本概念,同時還有如何優雅的利用 DJL 構建深度學習模型并進行訓練。DJL 也提供了更加多樣的數據集和神經網絡。

        Deep Java Library(DJL)是一個基于 Java 的深度學習框架,同時支持訓練以及推理。DJL 博取眾長,構建在多個深度學習框架之上 (TenserFlow、PyTorch、MXNet 等) 也同時具備多個框架的優良特性。你可以輕松使用 DJL 來進行訓練然后部署你的模型。

        它同時擁有著強大的模型庫支持:只需一行便可以輕松讀取各種預訓練的模型。現在 DJL 的模型庫同時支持高達 70 多個來自 GluonCV、 HuggingFace、TorchHub 以及 Keras 的模型。

        項目地址:https://github.com/awslabs/djl/

        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞:

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 瓮安县| 会宁县| 油尖旺区| 四川省| 香格里拉县| 宁陕县| 安达市| 六盘水市| 东光县| 黄山市| 奎屯市| 都昌县| 民权县| 神木县| 灵石县| 洛宁县| 大邑县| 乐昌市| 张家港市| 辰溪县| 霞浦县| 泰来县| 古交市| 中卫市| 天柱县| 绥宁县| 北辰区| 叶城县| 自贡市| 金溪县| 重庆市| 大余县| 绥德县| 汾西县| 连平县| 什邡市| 肃南| 河曲县| 孙吴县| 永兴县| 茌平县|