學習筆記TF024:TensorFlow實現Softmax Regression(回歸)識別手寫數字

来源:http://www.cnblogs.com/libinggen/archive/2017/07/09/7143669.html
-Advertisement-
Play Games

TensorFlow實現Softmax Regression(回歸)識別手寫數字。MNIST(Mixed National Institute of Standards and Technology database),簡單機器視覺數據集,28X28像素手寫數字,只有灰度值信息,空白部分為0,筆跡根 ...


TensorFlow實現Softmax Regression(回歸)識別手寫數字。MNIST(Mixed National Institute of Standards and Technology database),簡單機器視覺數據集,28X28像素手寫數字,只有灰度值信息,空白部分為0,筆跡根據顏色深淺取[0, 1], 784維,丟棄二維空間信息,目標分0~9共10類。數據載入,data.read_data_sets, 55000個樣本,測試集10000樣本,驗證集5000樣本。樣本標註信息,label,10維向量,10種類one-hot編碼。訓練集訓練模型,驗證集檢驗效果,測試集評測模型(準確率、召回率、F1-score)。

演算法設計,Softmax Regression訓練手寫數字識別分類模型,估算類別概率,取概率最大數字作模型輸出結果。類特征相加,判定類概率。模型學習訓練調整權值。softmax,各類特征計算exp函數,標準化(所有類別輸出概率值為1)。y = softmax(Wx+b)。

NumPy使用C、fortran,調用openblas、mkl矩陣運算庫。TensorFlow密集複雜運算在Python外執行。定義計算圖,運算操作不需要每次把運算完的數據傳回Python,全部在Python外面運行。

import tensor flow as tf,載入TensorFlow庫。less = tf.InteractiveSession(),創建InteractiveSession,註冊為預設session。不同session的數據、運算,相互獨立。x = tf.placeholder(tf.float32, [None,784]),創建Placeholder 接收輸入數據,第一參數數據類型,第二參數代表tensor shape 數據尺寸。None不限條數輸入,每條輸入為784維向量。

tensor存儲數據,一旦使用掉就會消失。Variable在模型訓練迭代中持久化,長期存在,每輪迭代更新。Softmax Regression模型的Variable對象weights、biases 初始化為0。模型訓練自動學習合適值。複雜網路,初始化方法重要。w = tf.Variable(tf.zeros([784, 10])),784特征維數,10類。Label,one-hot編碼後10維向量。

Softmax Regression演算法,y = tf.nn.softmax(tf.matmul(x, W) + b)。tf.nn包含大量神經網路組件。tf.matmul,矩陣乘法函數。TensorFlow將forward、backward內容自動實現,只要定義好loss,訓練自動求導梯度下降,完成Softmax Regression模型參數自動學習。

定義loss function描述問題模型分類精度。Loss越小,模型分類結果與真實值越小,越精確。模型初始參數全零,產生初始loss。訓練目標是減小loss,找到全局最優或局部最優解。cross-entropy,分類問題常用loss function。y預測概率分佈,y'真實概率分佈(Label one-hot編碼),判斷模型對真實概率分佈預測準確度。cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))。定義placeholder,輸入真實label。tf.reduce_sum求和,tf.reduce_mean每個batch數據結果求均值。

定義優化演算法,隨機梯度下降SGD(Stochastic Gradient Descent)。根據計算圖自動求導,根據反向傳播(Back Propagation)演算法訓練,每輪迭代更新參數減小loss。提供封裝優化器,每輪迭代feed數據,TensorFlow在後臺自動補充運算操作(Operation)實現反向傳播和梯度下降。train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)。調用tf.train.GradientDescentOptimizer,設置學習速度0.5,設定優化目標cross-entropy,得到訓練操作train_step。

tf.global_variables_initializer().run()。TensorFlow全局參數初始化器tf.golbal_variables_initializer。

batch_xs,batch_ys = mnist.train.next_batch(100)。訓練操作train_step。每次隨機從訓練集抽取100條樣本構成mini-batch,feed給 placeholder,調用train_step訓練樣本。使用小部分樣本訓練,隨機梯度下降,收斂速度更快。每次訓練全部樣本,計算量大,不容易跳出局部最優。

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmzx(y_,1)),驗證模型準確率。tf.argmax從tensor尋找最大值序號,tf.argmax(y,1)求預測數字概率最大,tf.argmax(y_,1)找樣本真實數字類別。tf.equal判斷預測數字類別是否正確,返回計算分類操作是否正確。

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)),統計全部樣本預測正確度。tf.cast轉化correct_prediction輸出值類型。

print(accuracy.eval({x: mnist.test.images,y_: mnist.test.labels}))。測試數據特征、Label輸入評測流程,計算模型測試集準確率。Softmax Regression MNIST數據分類識別,測試集平均準確率92%左右。

TensorFlow 實現簡單機器演算法步驟:
1、定義演算法公式,神經網路forward計算。
2、定義loss,選定優化器,指定優化器優化loss。
3、迭代訓練數據。
4、測試集、驗證集評測準確率。

定義公式只是Computation Graph,只有調用run方法,feed數據,計算才執行。

    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    print(mnist.train.images.shape, mnist.train.labels.shape)
    print(mnist.test.images.shape, mnist.test.labels.shape)
    print(mnist.validation.images.shape, mnist.validation.labels.shape)
    import tensorflow as tf
    sess = tf.InteractiveSession()
    x = tf.placeholder(tf.float32, [None, 784])
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    y_ = tf.placeholder(tf.float32, [None, 10])
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    tf.global_variables_initializer().run()
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        train_step.run({x: batch_xs, y_: batch_ys})
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

 


參考資料:
《TensorFlow實踐》

歡迎付費咨詢(150元每小時),我的微信:qingxingfengzi


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 一、準備 組件化 + 隨著業務需求的增長,在單工程 MVC 模式下,app 代碼逐漸變得龐大,面對的高耦合的代碼和複雜的功能模塊,我們或許就需要進行重構了,以組件化的形式,將需要的組件以 pod 私有庫的形式安裝到最後的主工程中,組件間各自獨立、解耦,僅依賴中間件進行通信,這或許就是極好的架構形式。 ...
  • 1. 下載MySQL Yum Repository http://dev.mysql.com/downloads/repo/yum/ 2. 本地安裝MySQL Yum Repository sudo yum localinstall platform-and-version-specific-pac ...
  • MySQL的簡單使用 1. 使用MySQL命令行工具 Windows 用戶使用: MySQL Client, 輸入密碼 Linux: mysql u用戶名 p密碼 mysql uroot p 2. 顯示資料庫命令 show databases; 3. 創建資料庫命令 create database ...
  • 錯誤截圖如下: 步驟1. 打開瀏覽器,輸入http://www.adobe.com/cn/ 步驟2. 點擊菜單,打開下拉的列表,找到並點擊Adobe Flash Player 步驟3. 把可選程式的勾“√”去掉,否則會安裝可選程式,然後點擊立即安裝按鈕 步驟4. 上一步下載的文件還不是Adobe F ...
  • 使用nmcli命令配置網路 NetworkManager是管理和監控網路設置的守護進程,設備既就是網路介面,連接是對網路介面的配置,一個網路介面可以有多個連接配置,但同時只有一個連接配置生效。 1 配置主機名 CentOS6 之前主機配置文件:/etc/sysconfig/network CentO ...
  • 命令歷史、文件類查看工具、文件和目錄類管理工具、通配符、IO重定向 ...
  • 發佈遇到的問題1: HTTP 錯誤 404.17 - Not Found 請求的內容似乎是腳本,因而將無法由靜態文件處理程式來處理。 最終解決時IIS的設置情況: 1、應用程式池的高級設置中 啟用32位應用程式: True 托管管道模式: Classic 載入用戶配置文件: True 2、選中發佈的 ...
  • 主題: 需求: 用戶角色,講師\學員, 用戶登陸後根據角色不同,能做的事情不同,分別如下講師視圖 管理班級,可創建班級,根據學員qq號把學員加入班級 可創建指定班級的上課紀錄,註意一節上課紀錄對應多條學員的上課紀錄, 即每節課都有整班學員上, 為了紀錄每位學員的學習成績,需在創建每節上課紀錄是,同時 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...