博客專欄

        EEPW首頁 > 博客 > 換臉火了,我用 python 快速入門生成模型

        換臉火了,我用 python 快速入門生成模型

        發布人:AI科技大本營 時間:2021-04-28 來源:工程師 發布文章

        引言:

        近幾年來,GAN生成對抗式應用十分火熱,不論是抖音上大火的“螞蟻牙黑”還是B站上的“復原老舊照片”以及換臉等功能,都是基于GAN生成對抗式的模型。但是GAN算法對于大多數而言上手較難,故今天我們將使用最少的代碼,簡單入門“生成對抗式網絡”,實現用GAN生成數字。

        其中生成的圖片效果如下可見:

        1.png

        01 模型建立

        1.1 環境要求

        本次環境使用的是python3.6.5+windows平臺

        主要用的庫有:

        • OS模塊用來對本地文件讀寫刪除、查找到等文件操作

        • numpy模塊用來矩陣和數據的運算處理,其中也包括和深度學習框架之間的交互等

        • Keras模塊是一個由Python編寫的開源人工神經網絡庫,可以作為Tensorflow、Microsoft-CNTK和Theano的高階應用程序接口,進行深度學習模型的設計、調試、評估、應用和可視化 。在這里我們用來搭建網絡層和直接讀取數據集操作,簡單方便

        • Matplotlib模塊用來可視化訓練效果等數據圖的制作

        1.2 GAN簡單介紹

        GAN 由生成器 (Generator)和判別器 (Discriminator) 兩個網絡模型組成,這兩個模型作用并不相同,而是相互對抗。我們可以很簡單的理解成,Generator是造假的的人,Discriminator是負責鑒寶的人。正是因為生成模型和對抗模型的相互對抗關系才稱之為生成對抗式。

        那我們為什么不適用VAE去生成模型呢,又怎么知道GAN生成的圖片會比VAE生成的更優呢?問題就在于VAE模型作用是使得生成效果越相似越好,但事實上僅僅是相似卻只是依葫蘆畫瓢。而 GAN 是通過 discriminator 來生成目標,而不是像 VAE線性般的學習。

        這個項目里我們目標是訓練神經網絡生成新的圖像,這些圖像與數據集中包含的圖像盡可能相近,而不是簡單的復制粘貼。神經網絡學習什么是圖像的“本質”,然后能夠從一個隨機的數字數組開始創建它。其主要思想是讓兩個獨立的神經網絡,一個產生器和一個鑒別器,相互競爭。生成器會創建與數據集中的圖片盡可能相似的新圖像。判別器試圖了解它們是原始圖片還是合成圖片。

        2.gif

        1.3 模型初始化

        在這里我們初始化需要使用到的變量,以及優化器、對抗式模型等。

        def __init__(self, width=28, height=28, channels=1):
            self.width = width
            self.height = height
            self.channels = channels
            self.shape = (self.width, self.height, self.channels)
            self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)
            self.G = self.__generator()
            self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)
            self.D = self.__discriminator()
            self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
            self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
            self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)

        1.4 生成器模型的搭建

        這里我們盡可能簡單的搭建一個生成器模型,3個完全連接的層,使用sequential標準化。神經元數分別是256,512,1024等:

         

        def __generator(self):
                """ Declare generator """
                model = Sequential()
                model.add(Dense(256, input_shape=(100,)))
                model.add(LeakyReLU(alpha=0.2))
                model.add(BatchNormalization(momentum=0.8))
                model.add(Dense(512))
                model.add(LeakyReLU(alpha=0.2))
                model.add(BatchNormalization(momentum=0.8))
                model.add(Dense(1024))
                model.add(LeakyReLU(alpha=0.2))
                model.add(BatchNormalization(momentum=0.8))
                model.add(Dense(self.width  * self.height * self.channels, activation='tanh'))
                model.add(Reshape((self.width, self.height, self.channels)))
                return model

        1.5 判別器模型的搭建

        在這里同樣簡單搭建判別器網絡層,和生成器模型類似:

        def __discriminator(self):
            """ Declare discriminator """
            model = Sequential()
            model.add(Flatten(input_shape=self.shape))
            model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(np.int64((self.width * self.height * self.channels)/2)))
            model.add(LeakyReLU(alpha=0.2))
            model.add(Dense(1, activation='sigmoid'))
            model.summary()
            return model

        1.6 對抗式模型的搭建

        這里是較為難理解的部分。讓我們創建一個對抗性模型,簡單來說這只是一個后面跟著一個鑒別器的生成器。注意,在這里鑒別器的權重被凍結了,所以當我們訓練這個模型時,生成器層將不受影響,只是向上傳遞梯度。代碼很簡單如下:

        def __stacked_generator_discriminator(self):
            self.D.trainable = False
            model = Sequential()
            model.add(self.G)
            model.add(self.D)
            return model

        02 模型的訓練使用

        2.1 模型的訓練

        在這里,我們并沒有直接去訓練生成器。而是通過對抗性模型間接地訓練它。我們將噪聲傳遞給了對抗模型,并將所有從數據庫中獲取的圖像標記為負標簽,而它們將由生成器生成。

        對真實圖像進行預先訓練的鑒別器把不能合成的圖像標記為真實圖像,所犯的錯誤將導致由損失函數計算出的損失越來越高。這就是反向傳播發揮作用的地方。由于鑒別器的參數是凍結的,在這種情況下,反向傳播不會影響它們。相反,它會影響生成器的參數。所以優化對抗性模型的損失函數意味著使生成的圖像盡可能的相似,鑒別器將識別為真實的。這既是生成對抗式的神奇之處!

        故訓練階段結束時,我們的目標是對抗性模型的損失值很小,而鑒別器的誤差盡可能高,這意味著它不再能夠分辨出差異。

        最終在我門的訓練結束時,鑒別器損失約為0.73。考慮到我們給它輸入了50%的真實圖像和50%的合成圖像,這意味著它有時無法識別假圖像。這是一個很好的結果,考慮到這個例子絕對不是優化的結果。要知道確切的百分比,我可以在編譯時添加一個精度指標,這樣它可能得到很多更好的結果實現更復雜的結構的生成器和判別器。

        代碼如下,這里legit_images是指原始訓練的圖像,而syntetic_images是生成的圖像。:

        def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
            for cnt in range(epochs):
                ## train discriminator
                random_index = np.random.randint(0, len(X_train) - np.int64(batch/2))
                legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels)
                gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100))
                syntetic_images = self.G.predict(gen_noise)
                x_combined_batch = np.concatenate((legit_images, syntetic_images))
                y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1))))
                d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
                # train generator
                noise = np.random.normal(0, 1, (batch, 100))
                y_mislabled = np.ones((batch, 1))
                g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
                print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))
                if cnt % save_interval == 0:
                    self.plot_images(save2file=True, step=cnt)

        2.2 可視化

        使用matplotlib來可視化模型訓練效果。

        def plot_images(self, save2file=False, samples=16, step=0):
            ''' Plot and generated images '''
            if not os.path.exists("./images"):
                os.makedirs("./images")
            filename = "./images/mnist_%d.png" % step
            noise = np.random.normal(0, 1, (samples, 100))
            images = self.G.predict(noise)
            plt.figure(figsize=(10, 10))
            for i in range(images.shape[0]):
                plt.subplot(4, 4, i+1)
                image = images[i, :, :, :]
                image = np.reshape(image, [self.height, self.width])
                plt.imshow(image, cmap='gray')
                plt.axis('off')
            plt.tight_layout()
            if save2file:
                plt.savefig(filename)
                plt.close('all')
            else:
                plt.show()

        3.png

        03 使用方法

        考慮到代碼較少,下述代碼復制粘貼即可運行。

        # -*- coding: utf-8 -*-
        import os
        import numpy as np
        from IPython.core.debugger import Tracer
        from keras.datasets import mnist
        from keras.layers import Input, Dense, Reshape, Flatten, Dropout
        from keras.layers import BatchNormalization
        from keras.layers.advanced_activations import LeakyReLU
        from keras.models import Sequential
        from keras.optimizers import Adam
        import matplotlib.pyplot as plt
        plt.switch_backend('agg')   # allows code to run without a system DISPLAY
        class GAN(object):
            """ Generative Adversarial Network class """
            def __init__(self, width=28, height=28, channels=1):
                self.width = width
                self.height = height
                self.channels = channels
                self.shape = (self.width, self.height, self.channels)
                self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)
                self.G = self.__generator()
                self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)
                self.D = self.__discriminator()
                self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
                self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
                self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)
            def __generator(self):
                """ Declare generator """
                model = Sequential()
                model.add(Dense(256, input_shape=(100,)))
                model.add(LeakyReLU(alpha=0.2))
                model.add(BatchNormalization(momentum=0.8))
                model.add(Dense(512))
                model.add(LeakyReLU(alpha=0.2))
                model.add(BatchNormalization(momentum=0.8))
                model.add(Dense(1024))
                model.add(LeakyReLU(alpha=0.2))
                model.add(BatchNormalization(momentum=0.8))
                model.add(Dense(self.width  * self.height * self.channels, activation='tanh'))
                model.add(Reshape((self.width, self.height, self.channels)))
                return model
            def __discriminator(self):
                """ Declare discriminator """
                model = Sequential()
                model.add(Flatten(input_shape=self.shape))
                model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape))
                model.add(LeakyReLU(alpha=0.2))
                model.add(Dense(np.int64((self.width * self.height * self.channels)/2)))
                model.add(LeakyReLU(alpha=0.2))
                model.add(Dense(1, activation='sigmoid'))
                model.summary()
                return model
            def __stacked_generator_discriminator(self):
                self.D.trainable = False
                model = Sequential()
                model.add(self.G)
                model.add(self.D)
                return model
            def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
                for cnt in range(epochs):
                    ## train discriminator
                    random_index = np.random.randint(0, len(X_train) - np.int64(batch/2))
                    legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels)
                    gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100))
                    syntetic_images = self.G.predict(gen_noise)
                    x_combined_batch = np.concatenate((legit_images, syntetic_images))
                    y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1))))
                    d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
                    # train generator
                    noise = np.random.normal(0, 1, (batch, 100))
                    y_mislabled = np.ones((batch, 1))
                    g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
                    print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))
                    if cnt % save_interval == 0:
                        self.plot_images(save2file=True, step=cnt)
            def plot_images(self, save2file=False, samples=16, step=0):
                ''' Plot and generated images '''
                if not os.path.exists("./images"):
                    os.makedirs("./images")
                filename = "./images/mnist_%d.png" % step
                noise = np.random.normal(0, 1, (samples, 100))
                images = self.G.predict(noise)
                plt.figure(figsize=(10, 10))
                for i in range(images.shape[0]):
                    plt.subplot(4, 4, i+1)
                    image = images[i, :, :, :]
                    image = np.reshape(image, [self.height, self.width])
                    plt.imshow(image, cmap='gray')
                    plt.axis('off')
                plt.tight_layout()
                if save2file:
                    plt.savefig(filename)
                    plt.close('all')
                else:
                    plt.show()
        if __name__ == '__main__':
            (X_train, _), (_, _) = mnist.load_data()
            # Rescale -1 to 1
            X_train = (X_train.astype(np.float32) - 127.5) / 127.5
            X_train = np.expand_dims(X_train, axis=3)
            gan = GAN()
            gan.train(X_train)


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: Python

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 当雄县| 尚义县| 灵台县| 临颍县| 合江县| 嫩江县| 虎林市| 泸西县| 浙江省| 枣阳市| 泸定县| 调兵山市| 石景山区| 临汾市| 衡山县| 百色市| 田东县| 霍邱县| 简阳市| 西华县| 聊城市| 上杭县| 同江市| 黔西| 阳朔县| 富阳市| 巴马| 丰都县| 卢湾区| 酒泉市| 宿迁市| 华亭县| 杭州市| 忻城县| 南平市| 陆丰市| 布尔津县| 乐清市| 东乡县| 西充县| 江川县|