基於tensorflow實現mnist手寫識別 (多層神經網路)

来源:https://www.cnblogs.com/imae/archive/2019/04/01/10634192.html
-Advertisement-
Play Games

標題黨其實也不多,一個輸入層,三個隱藏層,一個輸出層 老樣子先上代碼 導入mnist的路徑很長,現在還記不住 設置輸入層,X為樣本數據,y是標簽值 X 784是因為28*28,None是因為不知道需要用多少樣本 Y 10是因為 0~9的預測輸出,None理由同上 3層這樣寫有點啰嗦 下一版有個用函數 ...


標題黨其實也不多,一個輸入層,三個隱藏層,一個輸出層

老樣子先上代碼

導入mnist的路徑很長,現在還記不住

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
from time import time
mnist =  input_data.read_data_sets("data/",one_hot = True)
#導入Tensorflwo和mnist數據集等 常用庫

  

設置輸入層,X為樣本數據,y是標簽值

X 784是因為28*28,None是因為不知道需要用多少樣本

Y 10是因為 0~9的預測輸出,None理由同上

x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y')

3層這樣寫有點啰嗦 下一版有個用函數實現的比這個好。

tf.truncated_normal([784,H1_NN],stddev = 0.1)以截斷正態分佈的隨機初始化,數學原理不解釋(budong),反正大小控制在stddev裡面 方便後面訓練
H1_NN = 256 #第一層神經元節點數
H2_NN = 64 #第二層神經元節點數
H3_NN = 32 #第三層神經元節點數
#第一層
W1 = tf.Variable(tf.truncated_normal([784,H1_NN],stddev = 0.1))
b1 = tf.Variable(tf.zeros(H1_NN))
#第二層
W2 = tf.Variable(tf.truncated_normal([H1_NN,H2_NN],stddev = 0.1))
b2 = tf.Variable(tf.zeros(H2_NN))
#第三層
W3 = tf.Variable(tf.truncated_normal([H2_NN,H3_NN],stddev = 0.1))
b3 = tf.Variable(tf.zeros(H3_NN))
#輸出層
W4 = tf.Variable(tf.truncated_normal([H3_NN,10],stddev = 0.1)) 
b4 = tf.Variable(tf.zeros(10))

輸出 不多講了 前三層使用了Relu,最後輸出因為是10分類所有使用了softmax

(今天寫的時候記錯了pred輸出使用了loss函數的softmax計算導致程式報錯,先記下來)

Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #使用Relu當作激活函數
Y2 = tf.nn.relu(tf.matmul(Y1,W2)+b2)#使用Relu當作激活函數
Y3 = tf.nn.relu(tf.matmul(Y2,W3)+b3)#使用Relu當作激活函數
forward = tf.matmul(Y3,W4)+b4 
pred = tf.nn.softmax(forward)#輸出層分類應用使用softmax當作激活函數

沒錯上面說的就是這個tf.nn.sofmax_cross_entropy_with_logits,不使用這個使用第一版的

#損失函數使用交叉熵
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = forward,labels = y))

會因為log為0導致梯度爆炸 數學原理不太懂 以後補一下,會了再來填充

loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),
                                              reduction_indices=1))

超參數設置啥的沒啥好說,值得一提total_batch 好像是類似一個洗牌的函數

#設置訓練參數
train_epochs = 50
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size) #隨機抽取樣本
learning_rate = 0.01
display_step = 1

優化器,(反向傳播?)不確定 反正用來調整最優的w和b

#優化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)

利用argmax對比預測結果和標簽值,方便後面統計準確率

#定義準確率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

開始訓練,後面會補一個保存和調用模型的代碼不然以後模型大了 不保存都要程式跑一次才能用太費時間,這裡print我把用format的刪了 因為不太會用

#開始訓練
sess = tf.Session()
init = tf.global_variables_initializer()
startTime = time()
sess.run(init)
for epochs in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)#讀取批次數據
        sess.run(opimizer,feed_dict={x:xs,y:ys})#執行批次數據訓練
    
    #total_batch個批次訓練完成後,使用驗證數據計算誤差與準確率
    loss,acc =  sess.run([loss_function,accuracy],
                        feed_dict={
                            x:mnist.validation.images,
                            y:mnist.validation.labels})
    #輸出訓練情況
    if(epochs+1) % display_step == 0:
        epochs += 1 
        print("Train Epoch:",epochs,
               "Loss=",loss,"Accuracy=",acc)
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#顯示預測耗時

最後50輪訓練後準確率是0.97左右 已經收斂了

使用測試集評估模型

#評估模型
accu_test =  sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("model accuracy:",accu_test)

  準確率0.9714,還行

到這裡就結束了,最後把完整代碼放上來 方便以後看

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
from time import time
mnist =  input_data.read_data_sets("data/",one_hot = True)
#導入Tensorflwo和mnist數據集等 常用庫
x = tf.placeholder(tf.float32,[None,784],name='X')
y = tf.placeholder(tf.float32,[None,10],name='Y')
H1_NN = 256 #第一層神經元節點數
H2_NN = 64 #第二層神經元節點數
H3_NN = 32 #第三層神經元節點數
#第一層
W1 = tf.Variable(tf.truncated_normal([784,H1_NN],stddev = 0.1))
b1 = tf.Variable(tf.zeros(H1_NN))
#第二層
W2 = tf.Variable(tf.truncated_normal([H1_NN,H2_NN],stddev = 0.1))
b2 = tf.Variable(tf.zeros(H2_NN))
#第三層
W3 = tf.Variable(tf.truncated_normal([H2_NN,H3_NN],stddev = 0.1))
b3 = tf.Variable(tf.zeros(H3_NN))
#輸出層
W4 = tf.Variable(tf.truncated_normal([H3_NN,10],stddev = 0.1)) 
b4 = tf.Variable(tf.zeros(10))
#計算結果
Y1 = tf.nn.relu(tf.matmul(x,W1)+b1) #使用Relu當作激活函數
Y2 = tf.nn.relu(tf.matmul(Y1,W2)+b2)#使用Relu當作激活函數
Y3 = tf.nn.relu(tf.matmul(Y2,W3)+b3)#使用Relu當作激活函數
forward = tf.matmul(Y3,W4)+b4 
pred = tf.nn.softmax(forward)#輸出層分類應用使用softmax當作激活函數
#損失函數使用交叉熵
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = forward,labels = y))
#設置訓練參數
train_epochs = 50
batch_size = 50
total_batch = int(mnist.train.num_examples/batch_size) #隨機抽取樣本
learning_rate = 0.01
display_step = 1
#優化器
opimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss_function)
#定義準確率
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(pred,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#開始訓練
sess = tf.Session()
init = tf.global_variables_initializer()
startTime = time()
sess.run(init)
for epochs in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)#讀取批次數據
        sess.run(opimizer,feed_dict={x:xs,y:ys})#執行批次數據訓練
    
    #total_batch個批次訓練完成後,使用驗證數據計算誤差與準確率
    loss,acc =  sess.run([loss_function,accuracy],
                        feed_dict={
                            x:mnist.validation.images,
                            y:mnist.validation.labels})
    #輸出訓練情況
    if(epochs+1) % display_step == 0:
        epochs += 1 
        print("Train Epoch:",epochs,
               "Loss=",loss,"Accuracy=",acc)
duration = time()-startTime
print("Trian Finshed takes:","{:.2f}".format(duration))#顯示預測耗時
#評估模型
accu_test =  sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("model accuracy:",accu_test)
全部代碼

 


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 前言 開心一刻 一隻被二哈帶偏了的柴犬,我只想弄死隔壁的二哈 what:是什麼 BeanFactoryPostProcessor介面很簡單,只包含一個方法 推薦大家直接去讀它的源碼註釋,說的更詳細、更好理解 簡單來說,BeanFactoryPostProcessor是spring對外提供的介面,用來 ...
  • 題意 "題目鏈接" Sol 神仙題。。Orz yyb 考慮點分治,那麼每次我們只需要統計以當前點為$LCA$的點對之間的貢獻以及$LCA$到所有點的貢獻。 一個很神仙的思路是,對於任意兩個點對的路徑上的顏色,我們只統計里根最近的那個點的貢獻。 有了這個思路我們就可以瞎搞了,具體的細節很繁瑣,但是大概 ...
  • 測試代碼筆記如下: 附: ...
  • 1、棧和隊列 操作 增查改刪重點 插入刪除先進先出 -->隊列先進後出 -->棧2、鏈表 寫之前先畫圖存儲數據的方式 通過指針將所有的數據鏈在一起數據結構的目的 管理存儲數據 方便快速查找使用 鏈表定義 鏈式存儲的線性表 一對一的關係結構體 指針 函數 迴圈 結構體複習:struct 點運算符(結構 ...
  • SpringApplication SpringApplication類提供了一種方便的方法來引導從main()方法啟動的Spring應用程式 SpringBoot 包掃描註解源碼分析 我們來看下spring boot裡面是怎麼創建applicationContext的: 我們來看下webAppli ...
  • 什麼是析構函數 創建對對象時,系統會自動調用構造函數為我們進行初始化,同樣,在銷毀對象時也會自動調用一個函數為我們收尾,如釋放記憶體等,這個函數是析構函數。 析構函數也是一種特殊的成員函數。 特點 析構函數的名稱和類的名稱相同,在前面加 析構函數沒有返回值,無參數 析構函數只能在類中使用,且只有一個參 ...
  • Dubbo provider啟動原理: 當我們的dubbo啟動我們的spring容器時spring 初始化容器的時候會查找META-INF/spring.handles文件查找對應的NamespaceHandle,dubbo在其jar包下配置了DubboNamespaceHandle,該類下有以下配 ...
  • OCR,全稱Optical character recognition,或者optical character reader,中文譯名叫做光學文字識別。它是把圖像文件中的手寫文本,列印文本轉換為機器編碼文本的一種方法。 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...