TensorFlow搭建模型方式總結

来源:https://www.cnblogs.com/lovefisho/archive/2022/05/17/16279836.html
-Advertisement-
Play Games

大家好,這篇文章分享了C程式設計(譚浩強)第五版第四章課後題答案,所有程式已經測試能夠正常運行,如果小伙伴發現有錯誤的的地方,歡迎留言告訴我,我會及時改正!感謝大家的觀看!!! ...


引言

 TensorFlow提供了多種API,使得入門者和專家可以根據自己的需求選擇不同的API搭建模型。

基於Keras Sequential API搭建模型

Sequential適用於線性堆疊的方式搭建模型,即每層只有一個輸入和輸出。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

# 使用Sequential搭建模型
# 方式一
model = tf.keras.models.Sequential([

    # 加入CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
    tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)),

    # 加入激活函數
    tf.keras.layers.Activation('relu'),

    # 加入2X2池化層, 步長為2
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),

    # 把圖像數據平鋪
    tf.keras.layers.Flatten(),

    # 加入全連接層, 設置神經元為128個, 設置relu激活函數
    tf.keras.layers.Dense(128, activation='relu'),

    # 加入全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
    tf.keras.layers.Dense(10, activation='softmax')
])

# 方式二
model2 = tf.keras.models.Sequential()
model2.add(tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)))
model2.add(tf.keras.layers.Activation('relu'))
model2.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
model2.add(tf.keras.layers.Flatten())
model2.add(tf.keras.layers.Dense(128, activation='relu'))
model2.add(tf.keras.layers.Dense(10, activation='softmax'))

# 模型概覽
model.summary()

"""
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 3)         30        

 activation (Activation)     (None, 26, 26, 3)         0         

 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               

 flatten (Flatten)           (None, 507)               0         

 dense (Dense)               (None, 128)               65024     

 dense_1 (Dense)             (None, 10)                1290      

=================================================================
Total params: 66,344
Trainable params: 66,344
"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

基於Keras 函數API搭建模型

由於Sequential是線性堆疊的,只有一個輸入和輸出,但是當我們需要搭建多輸入模型時,如輸入圖片、文本描述等,這幾類信息可能需要分別使用CNN,RNN模型提取信息,然後彙總信息到最後的神經網路中預測輸出。或者是多輸出任務,如根據音樂預測音樂類型和發行時間。亦或是一些非線性的拓撲網路結構模型,如使用殘差鏈接、Inception等。上述這些情況的網路都不是線性搭建,要搭建如此複雜的網路,需要使用函數API來搭建。

 

簡單實例

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

input_tensor = tf.keras.layers.Input(shape=(28, 28, 1))

# CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
x = tf.keras.layers.Conv2D(3, kernel_size=3, strides=1)(input_tensor)

# 激活函數
x = tf.keras.layers.Activation('relu')(x)

# 2X2池化層, 步長為2
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)(x)

# 把圖像數據平鋪
x = tf.keras.layers.Flatten()(x)

# 全連接層, 設置神經元為128個, 設置relu激活函數
x = tf.keras.layers.Dense(128, activation='relu')(x)

# 全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
output = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.models.Model(input_tensor, output)

# 模型概覽
model.summary()

"""
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 3)         30        
                                                                 
 activation (Activation)     (None, 26, 26, 3)         0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 507)               0         
                                                                 
 dense (Dense)               (None, 128)               65024     
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 66,344
Trainable params: 66,344
Non-trainable params: 0
_________________________________________________________________

"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

 

多輸入實例

import tensorflow as tf

# 輸入1
input_tensor1 = tf.keras.layers.Input(shape=(28,))
x1 = tf.keras.layers.Dense(16, activation='relu')(input_tensor1)
output1 = tf.keras.layers.Dense(32, activation='relu')(x1)

# 輸入2
input_tensor2 = tf.keras.layers.Input(shape=(28,))
x2 = tf.keras.layers.Dense(16, activation='relu')(input_tensor2)
output2 = tf.keras.layers.Dense(32, activation='relu')(x2)

# 合併輸入1和輸入2
concat = tf.keras.layers.concatenate([output1, output2])

# 頂層分類模型
output = tf.keras.layers.Dense(10, activation='relu')(concat)

model = tf.keras.models.Model([input_tensor1, input_tensor2], output)

# 編譯
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

 

多輸出實例

import tensorflow as tf

# 輸入
input_tensor = tf.keras.layers.Input(shape=(28,))
x = tf.keras.layers.Dense(16, activation='relu')(input_tensor)
output = tf.keras.layers.Dense(32, activation='relu')(x)


# 多個輸出
output1 = tf.keras.layers.Dense(10, activation='relu')(output)
output2 = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.models.Model(input_tensor, [output1, output2])

# 編譯
model.compile(
    optimizer='adam',
    loss=['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics=['accuracy']
)

 

子類化API

 相較於上述使用高階API,使用子類化API的方式來搭建模型,可以根據需求對模型中的任何一部分進行修改。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train / 255, x_test / 255

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=10).batch(32)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.hidden_layer1 = tf.keras.layers.Dense(16, activation='relu')
        self.hidden_layer2 = tf.keras.layers.Dense(10, activation='softmax')

    # 定義模型
    def call(self, x):
        h = self.flatten(x)
        h = self.hidden_layer1(h)
        y = self.hidden_layer2(h)
        return y


model = MyModel()

# 損失函數 和 優化器
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# 評估指標
train_loss = tf.keras.metrics.Mean()  # 一個epoch的loss
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()  # 一個epoch的準確率

test_loss = tf.keras.metrics.Mean()
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x)
        loss = loss_function(y, y_pre)
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))

    train_loss(loss)
    train_accuracy(y, y_pre)


@tf.function
def test_step(x, y):
    y_pre = model(x)
    te_loss = loss_function(y, y_pre)

    test_loss(te_loss)
    test_accuracy(y, y_pre)


epoch = 2

for i in range(epoch):

    # 重置評估指標
    train_loss.reset_states()
    train_accuracy.reset_states()

    # 按照batch size 進行訓練
    for x, y in train_data:
        train_step(x, y)

    print(f'epoch {i+1} train loss {train_loss.result()} train accuracy {train_accuracy.result()}')

 參考

TensorFlow官方文檔

 

本文來自博客園,作者:LoveFishO,轉載請註明原文鏈接:https://www.cnblogs.com/lovefisho/p/16279836.html


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 背景 在測試環境上遇到一個詭異的問題,部分業務邏輯會記錄用戶ID到資料庫,但記錄的數據會串,比如當前用戶的操作記錄會被其他用戶覆蓋, 而且這個現象是每次重啟後一小段時間內就正常 問題 線上程池內部使用了InheritableThreadLocal存放用戶登錄信息,再獲取用戶信息後,由於沒有及時rem ...
  • 今天收到一個工作4年的粉絲的面試題。 問題是: “Spring中有哪些方式可以把Bean註入到IOC容器”。 他說這道題是所有面試題裡面回答最好的,但是看面試官的表情,好像不太對。 我問他怎麼回答的,他說: “介面註入”、“Setter註入”、“構造器註入”。 為什麼不對?來看看普通人和高手的回答。 ...
  • 目錄 一.簡介 二.效果演示 三.源碼下載 四.猜你喜歡 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 基礎 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 轉場 零基礎 O ...
  • “==”和equals的區別 首先我們應該知道的是: “==”是運算符,如果是基本數據類型,則比較存儲的值;如果是引用數據類型,則比較所指向對象的地址值。 equals是Object的方法,比較的是所指向的對象的地址值,一般情況下,重寫之後比較的是對象的值。 一、對象類型不同 1、equals(): ...
  • 把destoon數據生成json,一般用於百度小程式、QQ小程式和微信小程式或者原生APP,由於系統是GB2312編碼,所以服務端編寫的時候我們進行了一些編碼轉換的處理,保證服務端訪問的編碼是UTF-8就可以。不多了,下麵乾貨來了。如果你是程式或此段代碼對你有幫助,希望收藏!! 代碼來了,在根目錄新 ...
  • 環境介紹: python 3.8 解釋器 pycharm 2021專業版 >>> 激活碼 編輯器 谷歌瀏覽器 谷歌驅動 selenium >>> 驅動 >>> 瀏覽器 模塊使用: 採集一個視頻 requests >>> pip install requests re 採集多個視頻 selenium ...
  • 1引言:這裡主要做三件事 1.1resources文件夾下創建spring-mvc.xml並配置:開啟註解驅動(mvc:annotation-driven),靜態資源過濾(mvc:default-servlet-handler),視圖解析器(InternalResourceViewResolver) ...
  • 在 JVM 進行垃圾回收之前,我們需要先判斷一個對象是否存活,判斷對象是否存活採用了兩種方法: 引用計數法 給對象中添加一個引用計數器,每引用這個對象一次,計數器 +1,當引用失效時,計數器 -1。當引用計數器為 0 時,則表示該對象可被回收。 Java 不適用原因:無法解決對象互相迴圈引用的問題 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...