學習筆記TF024:TensorFlow實現Softmax Regression(回歸)識別手寫數字

来源:http://www.cnblogs.com/libinggen/archive/2017/07/09/7143669.html
-Advertisement-
Play Games

TensorFlow實現Softmax Regression(回歸)識別手寫數字。MNIST(Mixed National Institute of Standards and Technology database),簡單機器視覺數據集,28X28像素手寫數字,只有灰度值信息,空白部分為0,筆跡根 ...


TensorFlow實現Softmax Regression(回歸)識別手寫數字。MNIST(Mixed National Institute of Standards and Technology database),簡單機器視覺數據集,28X28像素手寫數字,只有灰度值信息,空白部分為0,筆跡根據顏色深淺取[0, 1], 784維,丟棄二維空間信息,目標分0~9共10類。數據載入,data.read_data_sets, 55000個樣本,測試集10000樣本,驗證集5000樣本。樣本標註信息,label,10維向量,10種類one-hot編碼。訓練集訓練模型,驗證集檢驗效果,測試集評測模型(準確率、召回率、F1-score)。

演算法設計,Softmax Regression訓練手寫數字識別分類模型,估算類別概率,取概率最大數字作模型輸出結果。類特征相加,判定類概率。模型學習訓練調整權值。softmax,各類特征計算exp函數,標準化(所有類別輸出概率值為1)。y = softmax(Wx+b)。

NumPy使用C、fortran,調用openblas、mkl矩陣運算庫。TensorFlow密集複雜運算在Python外執行。定義計算圖,運算操作不需要每次把運算完的數據傳回Python,全部在Python外面運行。

import tensor flow as tf,載入TensorFlow庫。less = tf.InteractiveSession(),創建InteractiveSession,註冊為預設session。不同session的數據、運算,相互獨立。x = tf.placeholder(tf.float32, [None,784]),創建Placeholder 接收輸入數據,第一參數數據類型,第二參數代表tensor shape 數據尺寸。None不限條數輸入,每條輸入為784維向量。

tensor存儲數據,一旦使用掉就會消失。Variable在模型訓練迭代中持久化,長期存在,每輪迭代更新。Softmax Regression模型的Variable對象weights、biases 初始化為0。模型訓練自動學習合適值。複雜網路,初始化方法重要。w = tf.Variable(tf.zeros([784, 10])),784特征維數,10類。Label,one-hot編碼後10維向量。

Softmax Regression演算法,y = tf.nn.softmax(tf.matmul(x, W) + b)。tf.nn包含大量神經網路組件。tf.matmul,矩陣乘法函數。TensorFlow將forward、backward內容自動實現,只要定義好loss,訓練自動求導梯度下降,完成Softmax Regression模型參數自動學習。

定義loss function描述問題模型分類精度。Loss越小,模型分類結果與真實值越小,越精確。模型初始參數全零,產生初始loss。訓練目標是減小loss,找到全局最優或局部最優解。cross-entropy,分類問題常用loss function。y預測概率分佈,y'真實概率分佈(Label one-hot編碼),判斷模型對真實概率分佈預測準確度。cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))。定義placeholder,輸入真實label。tf.reduce_sum求和,tf.reduce_mean每個batch數據結果求均值。

定義優化演算法,隨機梯度下降SGD(Stochastic Gradient Descent)。根據計算圖自動求導,根據反向傳播(Back Propagation)演算法訓練,每輪迭代更新參數減小loss。提供封裝優化器,每輪迭代feed數據,TensorFlow在後臺自動補充運算操作(Operation)實現反向傳播和梯度下降。train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)。調用tf.train.GradientDescentOptimizer,設置學習速度0.5,設定優化目標cross-entropy,得到訓練操作train_step。

tf.global_variables_initializer().run()。TensorFlow全局參數初始化器tf.golbal_variables_initializer。

batch_xs,batch_ys = mnist.train.next_batch(100)。訓練操作train_step。每次隨機從訓練集抽取100條樣本構成mini-batch,feed給 placeholder,調用train_step訓練樣本。使用小部分樣本訓練,隨機梯度下降,收斂速度更快。每次訓練全部樣本,計算量大,不容易跳出局部最優。

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmzx(y_,1)),驗證模型準確率。tf.argmax從tensor尋找最大值序號,tf.argmax(y,1)求預測數字概率最大,tf.argmax(y_,1)找樣本真實數字類別。tf.equal判斷預測數字類別是否正確,返回計算分類操作是否正確。

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)),統計全部樣本預測正確度。tf.cast轉化correct_prediction輸出值類型。

print(accuracy.eval({x: mnist.test.images,y_: mnist.test.labels}))。測試數據特征、Label輸入評測流程,計算模型測試集準確率。Softmax Regression MNIST數據分類識別,測試集平均準確率92%左右。

TensorFlow 實現簡單機器演算法步驟:
1、定義演算法公式,神經網路forward計算。
2、定義loss,選定優化器,指定優化器優化loss。
3、迭代訓練數據。
4、測試集、驗證集評測準確率。

定義公式只是Computation Graph,只有調用run方法,feed數據,計算才執行。

    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    print(mnist.train.images.shape, mnist.train.labels.shape)
    print(mnist.test.images.shape, mnist.test.labels.shape)
    print(mnist.validation.images.shape, mnist.validation.labels.shape)
    import tensorflow as tf
    sess = tf.InteractiveSession()
    x = tf.placeholder(tf.float32, [None, 784])
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    y_ = tf.placeholder(tf.float32, [None, 10])
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    tf.global_variables_initializer().run()
    for i in range(1000):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        train_step.run({x: batch_xs, y_: batch_ys})
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

 


參考資料:
《TensorFlow實踐》

歡迎付費咨詢(150元每小時),我的微信:qingxingfengzi


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 一、準備 組件化 + 隨著業務需求的增長,在單工程 MVC 模式下,app 代碼逐漸變得龐大,面對的高耦合的代碼和複雜的功能模塊,我們或許就需要進行重構了,以組件化的形式,將需要的組件以 pod 私有庫的形式安裝到最後的主工程中,組件間各自獨立、解耦,僅依賴中間件進行通信,這或許就是極好的架構形式。 ...
  • 1. 下載MySQL Yum Repository http://dev.mysql.com/downloads/repo/yum/ 2. 本地安裝MySQL Yum Repository sudo yum localinstall platform-and-version-specific-pac ...
  • MySQL的簡單使用 1. 使用MySQL命令行工具 Windows 用戶使用: MySQL Client, 輸入密碼 Linux: mysql u用戶名 p密碼 mysql uroot p 2. 顯示資料庫命令 show databases; 3. 創建資料庫命令 create database ...
  • 錯誤截圖如下: 步驟1. 打開瀏覽器,輸入http://www.adobe.com/cn/ 步驟2. 點擊菜單,打開下拉的列表,找到並點擊Adobe Flash Player 步驟3. 把可選程式的勾“√”去掉,否則會安裝可選程式,然後點擊立即安裝按鈕 步驟4. 上一步下載的文件還不是Adobe F ...
  • 使用nmcli命令配置網路 NetworkManager是管理和監控網路設置的守護進程,設備既就是網路介面,連接是對網路介面的配置,一個網路介面可以有多個連接配置,但同時只有一個連接配置生效。 1 配置主機名 CentOS6 之前主機配置文件:/etc/sysconfig/network CentO ...
  • 命令歷史、文件類查看工具、文件和目錄類管理工具、通配符、IO重定向 ...
  • 發佈遇到的問題1: HTTP 錯誤 404.17 - Not Found 請求的內容似乎是腳本,因而將無法由靜態文件處理程式來處理。 最終解決時IIS的設置情況: 1、應用程式池的高級設置中 啟用32位應用程式: True 托管管道模式: Classic 載入用戶配置文件: True 2、選中發佈的 ...
  • 主題: 需求: 用戶角色,講師\學員, 用戶登陸後根據角色不同,能做的事情不同,分別如下講師視圖 管理班級,可創建班級,根據學員qq號把學員加入班級 可創建指定班級的上課紀錄,註意一節上課紀錄對應多條學員的上課紀錄, 即每節課都有整班學員上, 為了紀錄每位學員的學習成績,需在創建每節上課紀錄是,同時 ...
一周排行
    -Advertisement-
    Play Games
  • 示例項目結構 在 Visual Studio 中創建一個 WinForms 應用程式後,項目結構如下所示: MyWinFormsApp/ │ ├───Properties/ │ └───Settings.settings │ ├───bin/ │ ├───Debug/ │ └───Release/ ...
  • [STAThread] 特性用於需要與 COM 組件交互的應用程式,尤其是依賴單線程模型(如 Windows Forms 應用程式)的組件。在 STA 模式下,線程擁有自己的消息迴圈,這對於處理用戶界面和某些 COM 組件是必要的。 [STAThread] static void Main(stri ...
  • 在WinForm中使用全局異常捕獲處理 在WinForm應用程式中,全局異常捕獲是確保程式穩定性的關鍵。通過在Program類的Main方法中設置全局異常處理,可以有效地捕獲並處理未預見的異常,從而避免程式崩潰。 註冊全局異常事件 [STAThread] static void Main() { / ...
  • 前言 給大家推薦一款開源的 Winform 控制項庫,可以幫助我們開發更加美觀、漂亮的 WinForm 界面。 項目介紹 SunnyUI.NET 是一個基於 .NET Framework 4.0+、.NET 6、.NET 7 和 .NET 8 的 WinForm 開源控制項庫,同時也提供了工具類庫、擴展 ...
  • 說明 該文章是屬於OverallAuth2.0系列文章,每周更新一篇該系列文章(從0到1完成系統開發)。 該系統文章,我會儘量說的非常詳細,做到不管新手、老手都能看懂。 說明:OverallAuth2.0 是一個簡單、易懂、功能強大的許可權+可視化流程管理系統。 有興趣的朋友,請關註我吧(*^▽^*) ...
  • 一、下載安裝 1.下載git 必須先下載並安裝git,再TortoiseGit下載安裝 git安裝參考教程:https://blog.csdn.net/mukes/article/details/115693833 2.TortoiseGit下載與安裝 TortoiseGit,Git客戶端,32/6 ...
  • 前言 在項目開發過程中,理解數據結構和演算法如同掌握蓋房子的秘訣。演算法不僅能幫助我們編寫高效、優質的代碼,還能解決項目中遇到的各種難題。 給大家推薦一個支持C#的開源免費、新手友好的數據結構與演算法入門教程:Hello演算法。 項目介紹 《Hello Algo》是一本開源免費、新手友好的數據結構與演算法入門 ...
  • 1.生成單個Proto.bat內容 @rem Copyright 2016, Google Inc. @rem All rights reserved. @rem @rem Redistribution and use in source and binary forms, with or with ...
  • 一:背景 1. 講故事 前段時間有位朋友找到我,說他的窗體程式在客戶這邊出現了卡死,讓我幫忙看下怎麼回事?dump也生成了,既然有dump了那就上 windbg 分析吧。 二:WinDbg 分析 1. 為什麼會卡死 窗體程式的卡死,入口門檻很低,後續往下分析就不一定了,不管怎麼說先用 !clrsta ...
  • 前言 人工智慧時代,人臉識別技術已成為安全驗證、身份識別和用戶交互的關鍵工具。 給大家推薦一款.NET 開源提供了強大的人臉識別 API,工具不僅易於集成,還具備高效處理能力。 本文將介紹一款如何利用這些API,為我們的項目添加智能識別的亮點。 項目介紹 GitHub 上擁有 1.2k 星標的 C# ...