博客專欄

        EEPW首頁 > 博客 > 熱文 | 卷積神經網絡入門案例,輕松實現花朵分類(1)

        熱文 | 卷積神經網絡入門案例,輕松實現花朵分類(1)

        發布人:AI科技大本營 時間:2021-05-15 來源:工程師 發布文章

        前言

        本文介紹卷積神經網絡的入門案例,通過搭建和訓練一個模型,來對幾種常見的花朵進行識別分類;使用到TF的花朵數據集,它包含5類,即:“雛菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”;共 3670 張彩色圖片;通過搭建和訓練卷積神經網絡模型,對圖像進行分類,能識別出圖像是“蒲公英”,或“玫瑰”,還是其它。

        1.png

        本篇文章主要的意義是帶大家熟悉卷積神經網絡的開發流程,包括數據集處理、搭建模型、訓練模型、使用模型等;更重要的是解在訓練模型時遇到“過擬合”,如何解決這個問題,從而得到“泛化”更好的模型。

        思路流程

        • 導入數據集

        • 探索集數據,并進行數據預處理

        • 構建模型(搭建神經網絡結構、編譯模型)

        • 訓練模型(把數據輸入模型、評估準確性、作出預測、驗證預測)  

        • 使用訓練好的模型

        • 優化模型、重新構建模型、訓練模型、使用模型

        目錄

        • 導入數據集

        • 探索集數據,并進行數據預處理

        • 構建模型

        • 訓練模型

        • 使用模型

        • 優化模型、重新構建模型、訓練模型、使用模型(過擬合、數據增強、正則化、重新編譯和訓練模型、預測新數據)

        導入數據集

        使用到TF的花朵數據集,它包含5類,即:“雛菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”;共 3670 張彩色圖片;數據集包含5個子目錄,每個子目錄種存放一個類別的花朵圖片。

        # 下載數據集
        import pathlib
        dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
        data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
        data_dir = pathlib.Path(data_dir)
        # 查看數據集圖片的總數量
        image_count = len(list(data_dir.glob('*/*.jpg')))
        print(image_count)

        探索集數據,并進行數據預處理

        查看一張郁金香的圖片: 

        # 查看郁金香tulips目錄下的第1張圖片;
        tulips = list(data_dir.glob('tulips/*'))
        PIL.Image.open(str(tulips[0]))

        2.png

        加載數據集的圖片,使用keras.preprocessing從磁盤上加載這些圖像。

        # 定義加載圖片的一些參數,包括:批量大小、圖像高度、圖像寬度
        batch_size = 32
        img_height = 180
        img_width = 180
        # 將80%的圖像用于訓練
        train_ds = tf.keras.preprocessing.image_dataset_from_directory(
          data_dir,
          validation_split=0.2,
          subset="training",
          seed=123,
          image_size=(img_height, img_width),
          batch_size=batch_size)
        # 將20%的圖像用于驗證
        val_ds = tf.keras.preprocessing.image_dataset_from_directory(
          data_dir,
          validation_split=0.2,
          subset="validation",
          seed=123,
          image_size=(img_height, img_width),
          batch_size=batch_size)
        # 打印數據集中花朵的類別名稱,字母順序對應于目錄名稱
        class_names = train_ds.class_names
        print(class_names)

        查看一下訓練數據集中的9張圖像

        # 查看一下訓練數據集中的9張圖像
        import matplotlib.pyplot as plt
        plt.figure(figsize=(10, 10))
        for images, labels in train_ds.take(1):
          for i in range(9):
            ax = plt.subplot(3, 3, i + 1)
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            plt.axis("off")

        圖像形狀

        傳遞這些數據集來訓練模型model.fit,可以手動遍歷數據集并檢索成批圖像:

        for image_batch, labels_batch in train_ds:
          print(image_batch.shape)
          print(labels_batch.shape)
          break

        能看到輸出:(32, 180, 180, 3)   (32,)

        image_batch是圖片形狀的張量(32, 180, 180, 3)。32是指批量大??;180,180分別表示圖像的高度、寬度,3是顏色通道RGB。32張圖片組成一個批次。

        label_batch是形狀的張量(32,),對應32張圖片的標簽。

        數據集預處理

        下面進行數據集預處理,將像素的值標準化至0到1的區間內:

        # 將像素的值標準化至0到1的區間內。
        normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

        為什么是除以255呢?由于圖片的像素范圍是0~255,我們把它變成0~1的范圍,于是每張圖像(訓練集、測試集)都除以255。

        標準化數據

        # 調用map將其應用于數據集:
        normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
        image_batch, labels_batch = next(iter(normalized_ds))
        first_image = image_batch[0]
        # Notice the pixels values are now in `[0,1]`.
        print(np.min(first_image), np.max(first_image))


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: 深度學習

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 蒙山县| 丰都县| 合山市| 凌源市| 淮北市| 泊头市| 涞源县| 郴州市| 桂阳县| 定结县| 抚顺市| 新密市| 凌云县| 阿坝县| 宁城县| 仪征市| 鄂州市| 莱州市| 即墨市| 万安县| 繁昌县| 平陆县| 黑河市| 股票| 湟源县| 迁西县| 顺义区| 疏勒县| 广州市| 双城市| 沙湾县| 平邑县| 衡水市| 弋阳县| 九龙坡区| 鸡泽县| 绥棱县| 龙门县| 雅江县| 丘北县| 名山县|