從神經網路視角看均方誤差與交叉熵作為損失函數時的共同點

来源:http://www.cnblogs.com/dengdan890730/archive/2016/05/31/5545616.html
-Advertisement-
Play Games

以均方誤差或交叉熵誤差作為loss function的NN, 其輸出神經元的敏感度是它的激活值與目標值的差值 ...


縮寫:

  • NN: neural network, 神經網路
  • MSE: Mean Squared Error, 均方誤差
  • CEE: Cross Entropy Error, 交叉熵誤差.(此縮寫不是一個conventional縮寫)

標記符號:

  • \(net\)\(net_i\), 凈輸出值, \(net = w^Tx\)
  • \(a\)\(a_i\), 神經元的激活函數輸出值: \(a = f(net)\)

本文所有的\(x\)都是增廣後的, 即\(x_0 = 1\).


Introduction

MSE與CEE是兩種常用的loss function, 它們在形式上很不一樣, 但在使用梯度下降演算法優化它們的loss function時, 會發現它們其實是殊途同歸.

很多機器學習演算法都可以轉換成淺層神經網路模型(本文中特指全連接的MLP). 而神經網路的BP演算法最核心的一步就是計算敏感度(見BP), 採用不同損失函數和激活函數的NN在BP演算法上的差異也主要存在於敏感度上. 所以將有監督機器學習演算法轉化為神經網路模型後, 只需要計算出輸出神經元的敏感度就可以看出MSE與CEE之間的很多異同點.

在利用mini-batch SGD訓練神經網路時, 通常是先計算批次中每一個樣本產生的梯度, 然後取平均值. 所以接下來的分析中, 只關註單個訓練樣本產生的loss. 根據這個loss計算敏感度.

使用MSE的典型代表是線性回歸, 使用CEE的代表則是邏輯回歸. 這兩個演算法的一些相同點與不同點可以參考blog.

問題描述:

  • 給定:
    • 訓練集\(D = \{(x^{(1)}, y^{(1)}), \dots, (x^{(i)}, y^{(i)}), \dots, (x^{(N)}, y^{(N)})\}\), \(x^{(N)} \in \chi\), \(\chi : R^d\), \(y^{(i)}\in R\)
    • model family \(f(x)\)
  • 目標: 利用\(D\)學習一個具體的\(f(x)\)用於對新樣本\(x'\)進行預測:\(y' = f(x')\)

註意, 線性回歸的\(f(x)與y\)取的是連續值, 而邏輯回歸則是代表類別的離散值.

均方誤差---線性回歸

線性回歸使用均方誤差(Mean Squared Error, MSE)作為loss function.
將線性回歸問題\(f(x) = w^Tx\)轉換成神經網路模型:

  • 輸入層: \(d\)個神經元, \(d\)\(x\)的維度.
  • 輸出層: \(1\)個神經元, 激活函數為identical, 即\(a = net = w^Tx\).
  • 隱層: 無

在樣本\((x, y)\)上的損失:
\[ J(w) = \frac 12 (a - y)^2 = \frac 12 (net - y)^2 \]
輸出神經元的敏感度:
\[ \delta = \frac {\partial J}{\partial net} = a - y = net - y \]

交叉熵---邏輯回歸

邏輯回歸使用最大似然方法估計參數.

二分類邏輯回歸

先說二分類邏輯回歸, 即\(y = \{0, 1\}\). 將它轉換成神經網路模型, 拓撲結構與線性回歸一致. 不同的是輸入神經元的激活函數為\(a = sigmoid(net)\). 把\(a\)看作\(y=1\)的概率值: \(P(y =1 | x) = a\). 分類依據是根據選擇的閾值, 例如\(0.5\), 當\(a\)不小於它時\(y=1\), 否則\(y = 0\).
樣本\((x, y)\)出現的概率, 即likelihood function:
\[ l(w) = a^y(1-a)^(1-y) \]
log-likelihood:
\[ L(w) = ln l(w) = ylna + (1-y)ln(1-a) \]
最大化\(L(w)\)就是最小化\(-L(w)\), 所以它的loss為:
\[ J(w) = - L(w) = -ylna - (1 - y) ln(1 - a) \]
這實際上就是二分類問題的交叉熵loss. 如blog所示, 當\(a=0.5\)時, loss最大.
輸出神經元的敏感度:
\[ \delta = \frac {\partial J}{\partial net} = \frac {\partial J} {\partial a} \frac {\partial a} {\partial net} = \frac {a-y}{(1-a)a} (1-a)a = a - y \]
相信你已經看出來了, 線性回歸NN的敏感度\(net - y\)實際上也是激活值與目標值的差. 也就是說, 雖然邏輯回歸與線性回歸使用了不同的loss function, 但它們倆反向傳播的敏感度在形式上是一致的, 都是激活值\(a\)與目標值\(y\)的差值.

多分類邏輯回歸

先將多分類邏輯回歸轉換成神經網路模型:

  • 輸入層: 同上
  • 輸出層: 有多少種類別, 就有多少個輸出神經元. 用\(C\)來表示類別數目, 所以輸出層有\(C\)個神經元. 激活函數為softmax函數. 輸出值和二分類邏輯回歸一樣被當成概率作為分類依據.
  • 隱層: 無

依然只考慮單個樣本\((x, y)\).
\(y\)的預測值\(f(x)\)為輸出值最大的那個神經元代表的類別, 即:
\[ f(x) = arg\max_i a_i, \forall i \in \{1,\dots, C\} \]
而第\(i\)個輸出神經元的激活值為:
\[ a_i = \frac {e^{net_i}}{\sum_{j=1}^N e^{net^j}} \]
它代表\(x\)的類別為\(i\)的概率.
為方便寫出它的似然函數, 先對\(y\)變成一個向量:
\[ y \gets (y_1, \dots, y_i, \dots, y_C)^T \]
其中,
\[ y_i = \begin{cases} 1&, i = y \\ 0&, i \neq y \end{cases} \]
它實際上代表第\(i\)個神經元的目標值.
所以樣本\((x,y)\)出現的概率, 即它的似然函數為:
\[ l(W) = \prod_{i=1}^{C} a_i^{y_i} \]
註意, 這裡的權值\(W\)已經是一個\(C\times d\)的矩陣, 而不是一個列向量.
log似然函數:
\[ L(W) = ln l(W) = \sum_{i=1}^{C} y_i ln a_i \]
\(L(W)\)的長相也可以看出, 二分類的邏輯回歸只是多分類邏輯回歸的一種特殊形式. 也就是說, 二分類的邏輯回歸也可以轉換成有兩個輸出神經元的NN.
同樣的, 最大化\(L(w)\)就是最小化\(-L(w)\), 所以它的loss為:
\[ J(W) = -L(W) = - \sum_{j=1}^{C} y_j ln a_j \]
這是更一般化的交叉熵. 代入softmax函數, 即\(a_j = \frac {e^{net_j}}{\sum_{k=1}^C e^{net_k}}\), 得到:
\[ J(W) = \sum_{j=1}^{C} y_j (ln \sum_{k=1}^C e^{net_k} - net_j) \]
\(i\)個神經元的敏感度:
\[ \delta_i = \frac {\partial J}{\partial net_i} = \sum_{j=1}^C y_j (\frac {\sum_{k=1}^C e^{net_k} \frac {\partial net_k}{\partial net_i}}{\sum_{k=1}^C e^{net_k}} - \frac {\partial net_j}{\partial net_i}) = \sum_{j=1}^C y_j\frac{e^{net_i}}{\sum_{k=1}^C e^{net_k}} - \sum_{j=1}^C y_j \frac {\partial net_j}{\partial net_i} = a_i - y_i \]
很神奇的一幕又出現了. 上面說過, 把目標值向量化後, \(y_i = 0,1\)實際上代表第\(i\)個神經元的目標值. 所以, 在這裡, 輸出神經元的敏感度也是它的激活值與目標值的差值.

總結與討論

主要結論:

  • 以均方誤差或交叉熵誤差作為loss function的NN, 其輸出神經元的敏感度是它的激活值與目標值的差值

比較有用的by-product:

  • 很多機器學習演算法都可以轉換成淺層神經網路模型
  • softmax與sigmoid 函數的導數形式: \(s' = s(1-s)\)
  • 最大似然估計的loss function 是交叉熵
  • 深度學習中常用的softmax loss其實也是交叉熵.

您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 一,效果圖。 二,工程圖。 三,代碼。 RootViewController.h #import <UIKit/UIKit.h> @interface RootViewController : UIViewController <UITableViewDataSource,UITableViewDe ...
  • 命名有些錯誤,但功能實現,以後註意下命名規範 WJViewGroup.h #import <UIKit/UIKit.h> @interface WJViewGroup : UIView { NSInteger _width; NSInteger _height; } @property (nonat ...
  • 分區: (1).一種分區技術,可以在創建表時應用分區技術,將數據以分區形式保存。 (2).可以將巨型表或索引分割成相對較小的、可獨立管理的部分。 (3).表分區時必須為表中的每一條記錄指定所屬分區。 對錶進行分區優點: 增強可用性; 維護方便; 均衡I/O; 改善查詢性能。 創建分區表 分區方法:範 ...
  • db.集合名稱.remove({query}, justOne)query:過濾條件,可選justOne:是否只刪除查詢到的第一條數據,值為true或者1時,只刪除一條數據,預設為false,可選。 準備數據:把_id為1和2的age都變成28 1、使用兩個參數:刪除age=28的第一條數據 2、使 ...
  • 報錯內容是:SQL Server 阻止了對組件 'Ad Hoc Distributed Queries' 的 STATEMENT'OpenRowset/OpenDatasource' 的訪問,因為此組件已作為此伺服器安全配置的一部分而被關閉。系統管理員可以通過使用 sp_configure 啟用 ' ...
  • mariadb的查詢流程圖 select語句的從句分析順序:from(過濾表)-->where(過濾行)-->group by(分組)-->having(分組過濾)-->order by(排序)-- >select(選取欄位)-->limit(查詢限制)-->最終結果 DISTINCT: 數據去重 ...
  • 使用VBScript腳本從Excel文件中導入PowerDesigner的物理模型。 ...
  • 使用 mysqladmin 刪除資料庫 使用普通用戶登陸mysql伺服器,你可能需要特定的許可權來創建或者刪除 MySQL 資料庫。 所以我們這邊使用root用戶登錄,root用戶擁有最高許可權,可以使用 mysql mysqladmin 命令來創建資料庫。 在刪除資料庫過程中,務必要十分謹慎,因為在執 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...