前向傳播(張量)- 實戰

来源:https://www.cnblogs.com/nickchen121/archive/2019/05/11/10849484.html
-Advertisement-
Play Games

[TOC] 手寫數字識別流程 MNIST手寫數字集7000 10張圖片 60k張圖片訓練,10k張圖片測試 每張圖片是28\ 28,如果是彩色圖片是28\ 28\ 3 0 255表示圖片的灰度值,0表示純白,255表示純黑 打平28 28的矩陣,得到28\ 28=784的向量 對於b張圖片得到[b, ...


目錄

手寫數字識別流程

  • MNIST手寫數字集7000*10張圖片
  • 60k張圖片訓練,10k張圖片測試
  • 每張圖片是28*28,如果是彩色圖片是28*28*3
  • 0-255表示圖片的灰度值,0表示純白,255表示純黑
  • 打平28*28的矩陣,得到28*28=784的向量
  • 對於b張圖片得到[b,784];然後對於b張圖片可以給定編碼
  • 把上述的普通編碼給定成獨熱編碼,但是獨熱編碼都是概率值,並且概率值相加為1,類似於softmax回歸
  • 套用線性回歸公式
  • X[b,784] W[784,10] b[10] 得到 [b,10]
  • 高維圖片實現非常複雜,一個線性模型無法完成,因此可以添加非線性因數
  • f(X@W+b),使用激活函數讓其非線性化,引出relu函數
  • 用了激活函數,模型還是太簡單
  • 使用工廠
    • H1 =relu(X@W1+b1)
    • H2 = relu(h1@W2+b2)
    • Out = relu(h2@W3+b3)
  • 第一步,把[1,784]變成[1,512]變成[1,256]變成[1,10]
  • 得到[1,10]後將結果進行獨熱編碼
  • 使用歐氏距離或者使用mse進行誤差度量
  • [1,784]通過三層網路輸出一個[1,10]

前向傳播(張量)- 實戰

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import os
# do not print irrelevant information
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# x: [60k,28,28]
# y: [60k]
(x, y), _ = datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 1s 0us/step
# transform Tensor
# x: [0~255] ==》 [0~1.]
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
f'x.shape: {x.shape}, y.shape: {y.shape}, x.dtype: {x.dtype}, y.dtype: {y.dtype}'
"x.shape: (60000, 28, 28), y.shape: (60000,), x.dtype: <dtype: 'float32'>, y.dtype: <dtype: 'int32'>"
f'min_x: {tf.reduce_min(x)}, max_x: {tf.reduce_max(x)}'
'min_x: 0.0, max_x: 1.0'
f'min_y: {tf.reduce_min(y)}, max_y: {tf.reduce_max(y)}'
'min_y: 0, max_y: 9'
# batch of 128
train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(128)
train_iter = iter(train_db)
sample = next(train_iter)
f'batch: {sample[0].shape,sample[1].shape}'
'batch: (TensorShape([128, 28, 28]), TensorShape([128]))'
# [b,784] ==> [b,256] ==> [b,128] ==> [b,10]
# [dim_in,dim_out],[dim_out]
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))
# learning rate
lr = 1e-3
for epoch in range(10):  # iterate db for 10
    # tranin every train_db
    for step, (x, y) in enumerate(train_db):
        # x: [128,28,28]
        # y: [128]

        # [b,28,28] ==> [b,28*28]
        x = tf.reshape(x, [-1, 28*28])

        with tf.GradientTape() as tape:  # only data types of tf.variable are logged
            # x: [b,28*28]
            # h1 = x@w1 + b1
            # [b,784]@[784,256]+[256] ==> [b,256] + [256] ==> [b,256] + [b,256]
            h1 = x @ w1 + tf.broadcast_to(b1, [x.shape[0], 256])
            h1 = tf.nn.relu(h1)
            # [b,256] ==> [b,128]
            # h2 = x@w2 + b2  # b2 can broadcast automatic
            h2 = h1 @ w2 + b2
            h2 = tf.nn.relu(h2)
            # [b,128] ==> [b,10]
            out = h2 @ w3 + b3

            # compute loss
            # out: [b,10]
            # y:[b] ==> [b,10]
            y_onehot = tf.one_hot(y, depth=10)

            # mse = mean(sum(y-out)^2)
            # [b,10]
            loss = tf.square(y_onehot - out)
            # mean:scalar
            loss = tf.reduce_mean(loss)

        # compute gradients
        grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
        # w1 = w1 - lr * w1_grad
        # w1 = w1 - lr * grads[0]  # not in situ update
        # in situ update
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])
        w2.assign_sub(lr * grads[2])
        b2.assign_sub(lr * grads[3])
        w3.assign_sub(lr * grads[4])
        b3.assign_sub(lr * grads[5])

        if step % 100 == 0:
            print(f'epoch:{epoch}, step: {step}, loss:{float(loss)}')
epoch:0, step: 0, loss:0.5366693735122681
epoch:0, step: 100, loss:0.23276552557945251
epoch:0, step: 200, loss:0.19647717475891113
epoch:0, step: 300, loss:0.17389704287052155
epoch:0, step: 400, loss:0.1731622964143753
epoch:1, step: 0, loss:0.16157487034797668
epoch:1, step: 100, loss:0.16654588282108307
epoch:1, step: 200, loss:0.15311869978904724
epoch:1, step: 300, loss:0.14135733246803284
epoch:1, step: 400, loss:0.14423415064811707
epoch:2, step: 0, loss:0.13703864812850952
epoch:2, step: 100, loss:0.14255204796791077
epoch:2, step: 200, loss:0.1302051544189453
epoch:2, step: 300, loss:0.12224273383617401
epoch:2, step: 400, loss:0.12742099165916443
epoch:3, step: 0, loss:0.1219201311469078
epoch:3, step: 100, loss:0.12757658958435059
epoch:3, step: 200, loss:0.11587800830602646
epoch:3, step: 300, loss:0.10984969139099121
epoch:3, step: 400, loss:0.11641304194927216
epoch:4, step: 0, loss:0.11171815544366837
epoch:4, step: 100, loss:0.11717887222766876
epoch:4, step: 200, loss:0.10604140907526016
epoch:4, step: 300, loss:0.10111508518457413
epoch:4, step: 400, loss:0.10865814983844757
epoch:5, step: 0, loss:0.10434548556804657
epoch:5, step: 100, loss:0.10952303558588028
epoch:5, step: 200, loss:0.09875871241092682
epoch:5, step: 300, loss:0.09467941522598267
epoch:5, step: 400, loss:0.10282392799854279
epoch:6, step: 0, loss:0.09874211996793747
epoch:6, step: 100, loss:0.10355912148952484
epoch:6, step: 200, loss:0.09315416216850281
epoch:6, step: 300, loss:0.08971598744392395
epoch:6, step: 400, loss:0.0982089415192604
epoch:7, step: 0, loss:0.09428335726261139
epoch:7, step: 100, loss:0.09877124428749084
epoch:7, step: 200, loss:0.08866965025663376
epoch:7, step: 300, loss:0.08573523908853531
epoch:7, step: 400, loss:0.09440126270055771
epoch:8, step: 0, loss:0.09056715667247772
epoch:8, step: 100, loss:0.09483197331428528
epoch:8, step: 200, loss:0.0849832147359848
epoch:8, step: 300, loss:0.08246967941522598
epoch:8, step: 400, loss:0.09117519855499268
epoch:9, step: 0, loss:0.08741479367017746
epoch:9, step: 100, loss:0.09150294959545135
epoch:9, step: 200, loss:0.08185736835002899
epoch:9, step: 300, loss:0.07972464710474014
epoch:9, step: 400, loss:0.08842341601848602

您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 前段時間沒有好好準備,錯過了“金三銀四”,因此最近開始惡補各方面知識,決定先從JVM記憶體結構和GC開始。 JVM記憶體結構分為如下幾部分(前兩項為線程共用,後三項為線程私有的): 1、方法區:存儲已經被虛擬機載入的類信息、常量、JIT(及時編譯器Just In Time)編譯後的代碼以及類變數(sta ...
  • 1. 多進程與多線程 (1)背景:為何需要多進程或者多線程:在同一時間里,同一個電腦系統中如果允許兩個或者兩個以上的進程處於運行狀態,這便是多任務。多任務會帶來的好處例如用戶邊聽歌、邊上網、邊列印,而這些任務之間絲毫不會互相干擾。使用多進程技術,可大大提高電腦的運算速率。 (2)多進程與多線程的 ...
  • 1.項目結構 2.代碼展示 1.pom.xml 2.application.properties 3.實體類test 4.mapper層(介面和映射文件) 介面 映射文件 5.業務層 介面(TestService) 實現類(TestServiceImpl) 6.表示層(controller) 7.啟 ...
  • [學習筆記] 1.Eureca Server的Helloworld例子:做個普通的maven project,quickstart archetype。改成jdk.8。下麵Camden.SR1是版本名,springcloud的版本名稱很奇特,它是按照倫敦地鐵站的名稱命名的。馬 克-to-win@馬克 ...
  • 1.複習 2.匿名函數 3.作用域 4.函數式編程 4.map函數 5.filter函數 6.reduce函數 7.小結 8.內置函數 ...
  • 美食排行榜網站上線後,為了快速提升流量,需要製造一個引流機會。 我的想法是開闢一個專欄,按照菜品和地區,讓用戶自發投票給自己喜歡的餐館,最終形成一個年度/月度 等的美食排行榜 比如 成都川菜美食排行榜 這個頁面,目前是按照數據入庫的先後時間排序,並不是用戶真實的排行,怎麼才能做到真實排行呢? 這就需 ...
  • 基於flask的網頁聊天室(三) 前言 繼續上一次的內容,今天完成了csrf防禦的添加,用戶頭像的存儲以及用戶的登錄狀態 具體內容 首先是添加csrf的防禦,為整個app添加防禦: from flask_wtf.csrf import CSRFProtect CSRFProtect(app) 這個添 ...
  • 1.Equals 很多人對equals方法的用法有些模糊,這裡來為大家梳理下: 字元串中的equals方法,該方法用來判斷兩個字元串的內容是否相同。 例1: 從例1中我們可以看出,兩個字元串之間的比較,無論用”==”號還是equals來進行,只要內容相同,結果就為True,內容不同,結果就為Fals ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...