博客專欄

        EEPW首頁 > 博客 > 混合密度網絡(MDN)進行多元回歸詳解和代碼示例(1)

        混合密度網絡(MDN)進行多元回歸詳解和代碼示例(1)

        發布人:數據派THU 時間:2022-03-13 來源:工程師 發布文章
        來源:Deephub Imba


        回歸


        “回歸預測建模是逼近從輸入變量 (X) 到連續輸出變量 (y) 的映射函數 (f) [...] 回歸問題需要預測具體的數值。具有多個輸入變量的問題通常被稱為多元回歸問題 例如,預測房屋價值,可能在 100,000 美元到 200,000 美元之間
        這是另一個區分分類問題和回歸問題的視覺解釋如下:
        圖片
        另外一個例子

        圖片

        密度


        DENSITY “密度” 是什么意思?這是一個快速的通俗示例:
        假設正在為必勝客運送比薩。現在記錄剛剛進行的每次交付的時間(以分鐘為單位)。交付 1000 次后,將數據可視化以查看工作表現如何。這是結果:圖片
        這是披薩交付時間數據分布的“密度”。平均而言,每次交付需要 30 分鐘(圖中的峰值)。它還表示,在 95% 的情況下(2 個標準差2sd ),交付需要 20 到 40 分鐘才能完成。密度種類代表時間結果的“頻率”。“頻率”和“密度”的區別在于:

        • 頻率:如果你在這條曲線下繪制一個直方圖并對所有的 bin 進行計數,它將求和為任何整數(取決于數據集中捕獲的觀察總數)。

        • 密度:如果你在這條曲線下繪制一個直方圖并計算所有的 bin,它總和為 1。我們也可以將此曲線稱為概率密度函數 (pdf)。

        • 用統計術語來說,這是一個漂亮的正態/高斯分布。這個正態分布有兩個參數:


        均值


        • 標準差:“標準差是一個數字,用于說明一組測量值如何從平均值(平均值)或預期值中展開。低標準偏差意味著大多數數字接近平均值。高標準差意味著數字更加分散。“


        均值和標準差的變化會影響分布的形狀。例如:
        圖片
        有許多具有不同類型參數的各種不同分布類型。例如:

        圖片

        混合密度


        現在讓我們看看這 3 個分布:
        圖片
        如果我們采用這種雙峰分布(也稱為一般分布):

        圖片
        混合密度網絡使用這樣的假設,即任何像這種雙峰分布的一般分布都可以分解為正態分布的混合(該混合也可以與其他類型的分布一起定制 例如拉普拉斯):
        圖片

        網絡架構


        混合密度網絡也是一種人工神經網絡。這是神經網絡的經典示例:
        圖片
        輸入層(黃色)、隱藏層(綠色)和輸出層(紅色)。
        如果我們將神經網絡的目標定義為學習在給定一些輸入特征的情況下輸出連續值。在上面的例子中,給定年齡、性別、教育程度和其他特征,那么神經網絡就可以進行回歸的運算。
        圖片


        密度網絡


        圖片
        密度網絡也是神經網絡,其目標不是簡單地學習輸出單個連續值,而是學習在給定一些輸入特征的情況下輸出分布參數(此處為均值和標準差)。在上面的例子中,給定年齡、性別、教育程度等特征,神經網絡學習預測期望工資分布的均值和標準差。預測分布比預測單個值具有很多的優勢,例如能夠給出預測的不確定性邊界。這是解決回歸問題的“貝葉斯”方法。下面是預測每個預期連續值的分布的一個很好的例子:

        圖片
        下面的圖片向我們展示了每個預測實例的預期值分布:

        圖片


        混合密度網絡


        最后回到正題,混合密度網絡的目標是在給定特定輸入特征的情況下,學習輸出混合在一般分布中的所有分布的參數(此處為均值、標準差和 Pi)。新參數“Pi”是混合參數,它給出最終混合中給定分布的權重/概率。
        圖片
        最終結果如下:

        圖片

        示例1:單變量數據的 MDN 類


        上面的定義和理論基礎已經介紹完畢,下面我們開始代碼的演示:

        import numpy as np
        import pandas as pd

        from mdn_model import MDN

        from sklearn.datasets import make_moons
        from sklearn.preprocessing import StandardScaler
        import matplotlib.pyplot as plt
        import seaborn as sns

        from sklearn.linear_model import LinearRegression
        from sklearn.kernel_ridge import KernelRidge

        plt.style.use('ggplot')


        生成著名的“半月”型的數據集:

        X, y = make_moons(n_samples=2500, noise=0.03)
        y = X[:, 1].reshape(-1,1)
        X = X[:, 0].reshape(-1,1)

        x_scaler = StandardScaler()
        y_scaler = StandardScaler()

        X = x_scaler.fit_transform(X)
        y = y_scaler.fit_transform(y)

        plt.scatter(X, y, alpha = 0.3)

        圖片


        繪制目標值 (y) 的密度分布:

        sns.kdeplot(y.ravel(), shade=True)

        通過查看數據,我們可以看到有兩個重疊的簇:


        圖片


        這時一個很好的多模態分布(一般分布)。如果我們在這個數據集上嘗試一個標準的線性回歸來用 X 預測 y:

        model = LinearRegression()
        model.fit(X.reshape(-1,1), y.reshape(-1,1))
        y_pred = model.predict(X.reshape(-1,1))

        plt.scatter(X, y, alpha = 0.3)
        plt.scatter(X,y_pred)
        plt.title('Linear Regression')

        圖片

        sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'Linear Pred dist')      
        sns.kdeplot(y.ravel(), shade=True, label = 'True dist')

        圖片


        效果必須不好!現在讓嘗試一個非線性模型(徑向基函數核嶺回歸):


        model = KernelRidge(kernel = 'rbf')
        model.fit(X, y)
        y_pred = model.predict(X)


        plt.scatter(X, y, alpha = 0.3)
        plt.scatter(X,y_pred)
        plt.title('Non Linear Regression')

        圖片

        sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'NonLinear Pred dist')      
        sns.kdeplot(y.ravel(), shade=True, label = 'True dist')

        圖片
        雖然結果也不盡如人意,但是比上面的線性回歸要好很多了。


        *博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。



        關鍵詞: AI

        相關推薦

        技術專區

        關閉
        主站蜘蛛池模板: 寻甸| 东阿县| 延津县| 沧州市| 库尔勒市| 迁安市| 南丹县| 新干县| 凤冈县| 天台县| 三亚市| 景泰县| 盱眙县| 乐陵市| 凤阳县| 嘉祥县| 项城市| 广德县| 涞水县| 瓦房店市| 原平市| 寻甸| 合作市| 荥经县| 崇义县| 沂源县| 晋中市| 彝良县| 南丰县| 西乌| 蒙山县| 晋城| 宜都市| 两当县| 漠河县| 清远市| 五家渠市| 南阳市| 凯里市| 临汾市| 宁海县|