TensorFlow框架(3)之MNIST機器學習入門

来源:http://www.cnblogs.com/huliangwen/archive/2017/08/30/7455382.html
-Advertisement-
Play Games

1. MNIST數據集 1.1 概述 Tensorflow框架載tensorflow.contrib.learn.python.learn.datasets包中提供多個機器學習的數據集。本節介紹的是MNIST數據集,其功能都定義在mnist.py模塊中。 MNIST是一個入門級的電腦視覺數據集,它 ...


1. MNIST數據集

1.1 概述

  Tensorflow框架載tensorflow.contrib.learn.python.learn.datasets包中提供多個機器學習的數據集。本節介紹的是MNIST數據集,其功能都定義在mnist.py模塊中。

MNIST是一個入門級的電腦視覺數據集,它包含各種手寫數字圖片:

圖 11

  它也包含每一張圖片對應的標簽,告訴我們這個是數字幾。比如,上面這四張圖片的標簽分別是5,0,4,1

1.2 載入

  有兩種方式可以獲取MNIST數據集:

1) 自動下載

  TensorFlow框架提供了一個函數:read_data_sets,該函數能夠實現自動下載的功能。如下所示的程式,就能夠自動下載數據集。

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

//由於input_data只是對read_data_sets進行了包裝,其什麼也沒有做,所以我們可以直接使用read_data_sets.

From tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

Mnist=read_data_sets("MNIST_data",one_hot=True)

2) 手動下載

  用戶也能夠手動下載數據集,然後向read_data_sets函數傳遞所在的本地目錄,如下所示:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/MNIST_data/", False, one_hot=True)

PS:

    MNIST數據集可以Yann LeCun's website進行下載,如所示是下載後的目錄:/tmp/MNIST_data/

圖 12

1.3 結構

1) 數據分類

  自動下載方式的數據集被分成如表 11所示的三部分。這樣的切分很重要,在機器學習模型設計時必須有一個單獨的測試數據集不用於訓練,而是用來評估這個模型的性能,從而更加容易把設計的模型推廣到其他數據集上(泛化)。

表 11

數據集

目的

mnist.train

55000 圖片和標簽, 用於訓練。

mnist.test

10000 圖片和標簽, 用於最終測試訓練的準確性。

mnist.validation

5000 圖片和標簽, 用於迭代驗證訓練的準確性。

PS:

  若是手動下載則只有兩部分,即表中的train和test兩部分。

2) 數據展開

  正如前面提到的一樣,每一個MNIST數據單元有兩部分組成:一張包含手寫數字的圖片和一個對應的標簽。我們把這些圖片設為"X",把這些標簽設為"Y"。訓練數據集和測試數據集都包含X和Y,比如訓練數據集的圖片是 mnist.train.images ,訓練數據集的標簽是 mnist.train.labels。

  其中每一張圖片包含28像素*28像素。我們可以用一個數字數組來表示這張圖片:

圖 13

  TensorFlow把這個數組展開成一個向量(數組),長度是 28x28 = 784,即TensorFlow將一個二維的數組展開成一個一維的數組,從[28, 28]數組轉換為[784]數組。因此,在MNIST訓練數據集中,mnist.train.images 是一個形狀為 [60000, 784] 的張量,第一個維度數字用來索引圖片,第二個維度數字用來索引每張圖片中的像素點。在此張量里的每一個元素,都表示某張圖片里的某個像素的灰度值,值介於0和1之間,如圖 14所示的二維結構。

圖 14

  相對應的MNIST數據集的標簽是介於0到9的數字,用來描述給定圖片里表示的數字。為了用於這個教程,我們使標簽數據是"one-hot vectors"。一個one-hot向量除了某一位的數字是1以外其餘各維度數字都是0。所以在此教程中,數字n將表示成一個只有在第n維度(從0開始)數字為1的10維向量。比如,標簽0將表示成([1,0,0,0,0,0,0,0,0,0,0])。因此,mnist.train.labels 是一個 [60000, 10] 的數字矩陣。

圖 15

2. MNIST分類學習

2.1 實現理論

2.1.1 M-P神經元模型

  傳統的M-P神經元模型中,每個神經元都接收來自n個其它神經元傳遞過來的輸入信號,這些輸入信號通過待權重的連接進行傳遞,神經元接收到的總輸入值將與神經元的閾值進行比較,然後通過"激活函數"處理以產生神經元的輸出。如圖 21所示的一個神經元模型。

圖 21

PS:

    圖中所示的值都為real value,並非為向量或矩陣。

2.1.2 softmax函數

  softmax函數與sigmoid函數類似都可以作為神經網路的激活函數。sigmoid將一個real value映射到(0,1)的區間(當然也可以是(-1,1)),這樣可以用來做二分類。

而softmax把一個k維的real value向量[a1,a2,a3,a4….]映射成一個[b1,b2,b3,b4….],其中bi是一個0-1的常數,然後可以根據bi的大小來進行多分類的任務,如取權重最大的一維。

所以對於MNIST分類任務是多分類類型,所以需要使用softmax函數作為神經網路的激活函數。

2.1.3 MNIST模型分析

  正如圖 14所分析的,輸入的訓練圖片或測試圖片為一個[60000, 784]的矩陣,每張圖片都是一個[784]的向量;輸出為一個[60000, 10]的矩陣,每張圖片都對應有一個[10]的向量標簽。所以對於每張圖片的輸入和每個標簽的輸出,其神經網路模型可表示為錯誤! 未找到引用源。所示的簡化版本,圖中所有值都為read value。

圖 22 前饋神經網路(一個帶有10個神經元的隱藏層)

 

  如果把它寫成一個等式,我們可以得到:

我們也可以用向量表示這個計算過程:用矩陣乘法和向量相加。這有助於提高計算效率。(也是一種更有效的思考方式)

更進一步,可以寫成更加緊湊的方式:

式中, B和Y都為一個[10]類型的向量,X為一個[784]類型的向量,W是一個[10,784]類型的矩陣。

 

2.2 TensorFlow實現

  對於機器學習中的監督學習任務可以分四個步驟完成,如下所示:

  1. 模型選擇:選擇一個estimator對象;
  2. 模型訓練:根據訓練數據集來訓練模型;
  3. 模型測試:測量模型的泛化能力,即對其評分;
  4. 模型應用:進行實際預測或應用。

2.2.1 模型選擇

  由於我們已經選擇神經網路為監督學習任務的模型,即式(3)所示的等式,我們可通過使用下標來表明等式中變數的維數,如下所示:

所以在TensorFlow中的實現,就需要定義相應的變數和等式。但是式(4)中的X是一個[784]的向量,而實際待訓練的輸入數據是一個[60000, 784] 的矩陣,所以需要對式(4)進行稍微的變形,使其滿足數據輸入和數據輸出的要求,即如下所示的等式:

  • X是輸入參數,為訓練數據,即多張圖像;
  • WB是未知參數,即通過神經網路來訓練的數據;
  • Y為輸出參數,為圖像標簽,將使用該值與已知標簽進行比較。

 

如下所示是TensorFlow的實現:

# Create the model

x = tf.placeholder("float", [None, 784])

W = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

2.2.2 模型訓練

1) 模型評估

  我們可以創建一個模型(model),但我們仍然不知道模型的好壞。為了評估一個TensorFlow模型的性能,我們可以提供一個期望值,然後比較模型產生值和期望值之差來進行評估。

傳統方法採用"均分誤差"法評估一個模型的性能:,首先提供一個期望向量,然後對產生值(f(x))和期望值(y)兩個向量的每個元素進行取平方差,然後求出每個元素的總和。

    由於傳遞神經網路採用梯度下降法來逐漸調整式(4)中的W和B參數,即逐步減少均分誤差的值;然而若以"均分誤差"為標準逐步調整參數,其歸約的速度非常慢。所以提出以"交叉熵"法為標準評估模型的值,如下所示:

如下所示的TensorFlow實現:

# Define loss and optimizer

y_ = tf.placeholder("float", [None,10])

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

 

2) 訓練過程

  TensorFlow提供多個優化器來逐步優化模型,即逐步優化未知參數。優化器以用戶指定的評估的誤差為優化目標,即最小化模型評估的誤差,或最大化模型評估的誤差。

優化器基於梯度下降法自動修改神經網路的訓練參數,即W和b的值。

如下是以GradientDescentOptimizer優化器為示例的訓練過程:

# Train

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

 

2.2.3 模型測試

  為了評估模型的泛化性能,我們通過比較產生值(f(x))和期望值(y)之間的差異來進行評測性能。

由於本節的MNIST數據標簽(輸出值)是一個one-hot的便利,向量中的元素直郵一個為"1",所以使用特性的比較方式,如下所示是TensorFlow的實現:

# Test trained model

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

其中:

  • tf.argmax:能給出某個tensor對象在某一維上的其數據最大值所在的索引值。由於標簽向量是由0,1組成,因此最大值1所在的索引位置就是類別標簽。
  • tf.cast:類型轉換,將一個tensor對象的所有元素類型轉換為另一種類型。即上述將tf.equal方法生成的布爾值轉換成浮點數。
  • tf.reduce_mean:求矩陣或向量的平均值。若x=[[1., 1.] [2., 2.]],則tf.reduce_mean(x) ==> 1.5=1+1+2+2/4

 

上述三個小節的完整程式如下所示:

from __future__ import print_function

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

 

# Import data

mnist = input_data.read_data_sets("/tmp/MNIST_data/", False, one_hot=True)

 

# Create the model

x = tf.placeholder("float", [None, 784])

W = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

 

# Define loss and optimizer

y_ = tf.placeholder("float", [None,10])

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

 

init = tf.initialize_all_variables()

sess = tf.Session()

sess.run(init)

 

# Train

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100)

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

 

# Test trained model

#下述y的值是在上述訓練最後一步已經計算獲得,所以能夠與原始標簽y_進行比較

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

3. 參考文獻

  1. TensorFlow中文社區
  2. sigmoidsoftmax總結
  3. 交叉熵代價函數(作用及公式推導)
  4.  

您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 1. jar包下載 下載地址:http://ormlite.com/releases/,一般用core和android包即可。 如果使用的是android studio,也可以直接通過module settings加入依賴。 2. 實體類 使用OrmLite創建表不需要寫任何SQL語句,而是通過創建 ...
  • Android中,子線程使用主線程中的組件出現問題的解決方法 ...
  • 資源使用 Android 中支持三種格式的點陣圖文件:.png(首選), .jpg(可接受),.gif(不建議) 為什麼首推 PNG 呢? 官網的描述如下: 註:在構建過程中,可通過 aapt 工具自動優化點陣圖文件,對圖像進行無損壓縮。例如,不需要超過 256 色的真彩色 PNG 可通過調色板轉換為 ...
  • 在開始之前,我們需要創建一個DrawRectView 其初始代碼為 在ViewController中使用(尺寸為100x100並居中) 顯示效果如下(用紅色邊框顯示邊界) 修改DrawRectView.m代碼如下 其實就添加了下麵的繪圖代碼而已,繪製7條線條,每條線條的寬度為0.5 效果如下 將圖片 ...
  • 官方鏈接: https://developer.apple.com/app-store/review/guidelines/cn/ 1.條款和條件 1.1為App Store開發程式,開發者必須遵守Program License Agreement (PLA)、人機交互指南(HIG)以及開發者和蘋果 ...
  • 手機歸屬地查詢 效果圖: 分析: 1、傳遞多個參數,用一個類就好 2、打開資料庫 private SQLiteDatabase database; database=SQLiteDatabase.openOrCreateDatabase(file, null); file是資料庫的路徑 3、在邏輯中 ...
  • 首先,需要添加com.android.support:percent:24.1.1 包,版本隨意。 } 這個包給我們提供了PercentRelativeLayout以及PercentFrameLayout兩種佈局, 支持的屬性有layout_widthPercent、layout_heightPer ...
  • eclipse中打字中文突然變成繁體 在用eclipse做android項目的時候,發現打出來的字全部是繁體,而且QQ等其他位置又是簡體。 原因:eclipse的快捷點ctrl+alt+f(format代碼) 和搜狗裡面的切換簡繁體的快捷鍵一樣了。 所以也會導致在eclipse中ctrl+alt+f ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...