TensorFlow搭建模型方式總結

来源:https://www.cnblogs.com/lovefisho/archive/2022/05/17/16279836.html
-Advertisement-
Play Games

大家好,這篇文章分享了C程式設計(譚浩強)第五版第四章課後題答案,所有程式已經測試能夠正常運行,如果小伙伴發現有錯誤的的地方,歡迎留言告訴我,我會及時改正!感謝大家的觀看!!! ...


引言

 TensorFlow提供了多種API,使得入門者和專家可以根據自己的需求選擇不同的API搭建模型。

基於Keras Sequential API搭建模型

Sequential適用於線性堆疊的方式搭建模型,即每層只有一個輸入和輸出。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

# 使用Sequential搭建模型
# 方式一
model = tf.keras.models.Sequential([

    # 加入CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
    tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)),

    # 加入激活函數
    tf.keras.layers.Activation('relu'),

    # 加入2X2池化層, 步長為2
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),

    # 把圖像數據平鋪
    tf.keras.layers.Flatten(),

    # 加入全連接層, 設置神經元為128個, 設置relu激活函數
    tf.keras.layers.Dense(128, activation='relu'),

    # 加入全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
    tf.keras.layers.Dense(10, activation='softmax')
])

# 方式二
model2 = tf.keras.models.Sequential()
model2.add(tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)))
model2.add(tf.keras.layers.Activation('relu'))
model2.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
model2.add(tf.keras.layers.Flatten())
model2.add(tf.keras.layers.Dense(128, activation='relu'))
model2.add(tf.keras.layers.Dense(10, activation='softmax'))

# 模型概覽
model.summary()

"""
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 3)         30        

 activation (Activation)     (None, 26, 26, 3)         0         

 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               

 flatten (Flatten)           (None, 507)               0         

 dense (Dense)               (None, 128)               65024     

 dense_1 (Dense)             (None, 10)                1290      

=================================================================
Total params: 66,344
Trainable params: 66,344
"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

基於Keras 函數API搭建模型

由於Sequential是線性堆疊的,只有一個輸入和輸出,但是當我們需要搭建多輸入模型時,如輸入圖片、文本描述等,這幾類信息可能需要分別使用CNN,RNN模型提取信息,然後彙總信息到最後的神經網路中預測輸出。或者是多輸出任務,如根據音樂預測音樂類型和發行時間。亦或是一些非線性的拓撲網路結構模型,如使用殘差鏈接、Inception等。上述這些情況的網路都不是線性搭建,要搭建如此複雜的網路,需要使用函數API來搭建。

 

簡單實例

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

input_tensor = tf.keras.layers.Input(shape=(28, 28, 1))

# CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
x = tf.keras.layers.Conv2D(3, kernel_size=3, strides=1)(input_tensor)

# 激活函數
x = tf.keras.layers.Activation('relu')(x)

# 2X2池化層, 步長為2
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)(x)

# 把圖像數據平鋪
x = tf.keras.layers.Flatten()(x)

# 全連接層, 設置神經元為128個, 設置relu激活函數
x = tf.keras.layers.Dense(128, activation='relu')(x)

# 全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
output = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.models.Model(input_tensor, output)

# 模型概覽
model.summary()

"""
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 3)         30        
                                                                 
 activation (Activation)     (None, 26, 26, 3)         0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 507)               0         
                                                                 
 dense (Dense)               (None, 128)               65024     
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 66,344
Trainable params: 66,344
Non-trainable params: 0
_________________________________________________________________

"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

 

多輸入實例

import tensorflow as tf

# 輸入1
input_tensor1 = tf.keras.layers.Input(shape=(28,))
x1 = tf.keras.layers.Dense(16, activation='relu')(input_tensor1)
output1 = tf.keras.layers.Dense(32, activation='relu')(x1)

# 輸入2
input_tensor2 = tf.keras.layers.Input(shape=(28,))
x2 = tf.keras.layers.Dense(16, activation='relu')(input_tensor2)
output2 = tf.keras.layers.Dense(32, activation='relu')(x2)

# 合併輸入1和輸入2
concat = tf.keras.layers.concatenate([output1, output2])

# 頂層分類模型
output = tf.keras.layers.Dense(10, activation='relu')(concat)

model = tf.keras.models.Model([input_tensor1, input_tensor2], output)

# 編譯
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

 

多輸出實例

import tensorflow as tf

# 輸入
input_tensor = tf.keras.layers.Input(shape=(28,))
x = tf.keras.layers.Dense(16, activation='relu')(input_tensor)
output = tf.keras.layers.Dense(32, activation='relu')(x)


# 多個輸出
output1 = tf.keras.layers.Dense(10, activation='relu')(output)
output2 = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.models.Model(input_tensor, [output1, output2])

# 編譯
model.compile(
    optimizer='adam',
    loss=['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics=['accuracy']
)

 

子類化API

 相較於上述使用高階API,使用子類化API的方式來搭建模型,可以根據需求對模型中的任何一部分進行修改。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train / 255, x_test / 255

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=10).batch(32)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.hidden_layer1 = tf.keras.layers.Dense(16, activation='relu')
        self.hidden_layer2 = tf.keras.layers.Dense(10, activation='softmax')

    # 定義模型
    def call(self, x):
        h = self.flatten(x)
        h = self.hidden_layer1(h)
        y = self.hidden_layer2(h)
        return y


model = MyModel()

# 損失函數 和 優化器
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# 評估指標
train_loss = tf.keras.metrics.Mean()  # 一個epoch的loss
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()  # 一個epoch的準確率

test_loss = tf.keras.metrics.Mean()
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x)
        loss = loss_function(y, y_pre)
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))

    train_loss(loss)
    train_accuracy(y, y_pre)


@tf.function
def test_step(x, y):
    y_pre = model(x)
    te_loss = loss_function(y, y_pre)

    test_loss(te_loss)
    test_accuracy(y, y_pre)


epoch = 2

for i in range(epoch):

    # 重置評估指標
    train_loss.reset_states()
    train_accuracy.reset_states()

    # 按照batch size 進行訓練
    for x, y in train_data:
        train_step(x, y)

    print(f'epoch {i+1} train loss {train_loss.result()} train accuracy {train_accuracy.result()}')

 參考

TensorFlow官方文檔

 

本文來自博客園,作者:LoveFishO,轉載請註明原文鏈接:https://www.cnblogs.com/lovefisho/p/16279836.html


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 背景 在測試環境上遇到一個詭異的問題,部分業務邏輯會記錄用戶ID到資料庫,但記錄的數據會串,比如當前用戶的操作記錄會被其他用戶覆蓋, 而且這個現象是每次重啟後一小段時間內就正常 問題 線上程池內部使用了InheritableThreadLocal存放用戶登錄信息,再獲取用戶信息後,由於沒有及時rem ...
  • 今天收到一個工作4年的粉絲的面試題。 問題是: “Spring中有哪些方式可以把Bean註入到IOC容器”。 他說這道題是所有面試題裡面回答最好的,但是看面試官的表情,好像不太對。 我問他怎麼回答的,他說: “介面註入”、“Setter註入”、“構造器註入”。 為什麼不對?來看看普通人和高手的回答。 ...
  • 目錄 一.簡介 二.效果演示 三.源碼下載 四.猜你喜歡 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 基礎 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 轉場 零基礎 O ...
  • “==”和equals的區別 首先我們應該知道的是: “==”是運算符,如果是基本數據類型,則比較存儲的值;如果是引用數據類型,則比較所指向對象的地址值。 equals是Object的方法,比較的是所指向的對象的地址值,一般情況下,重寫之後比較的是對象的值。 一、對象類型不同 1、equals(): ...
  • 把destoon數據生成json,一般用於百度小程式、QQ小程式和微信小程式或者原生APP,由於系統是GB2312編碼,所以服務端編寫的時候我們進行了一些編碼轉換的處理,保證服務端訪問的編碼是UTF-8就可以。不多了,下麵乾貨來了。如果你是程式或此段代碼對你有幫助,希望收藏!! 代碼來了,在根目錄新 ...
  • 環境介紹: python 3.8 解釋器 pycharm 2021專業版 >>> 激活碼 編輯器 谷歌瀏覽器 谷歌驅動 selenium >>> 驅動 >>> 瀏覽器 模塊使用: 採集一個視頻 requests >>> pip install requests re 採集多個視頻 selenium ...
  • 1引言:這裡主要做三件事 1.1resources文件夾下創建spring-mvc.xml並配置:開啟註解驅動(mvc:annotation-driven),靜態資源過濾(mvc:default-servlet-handler),視圖解析器(InternalResourceViewResolver) ...
  • 在 JVM 進行垃圾回收之前,我們需要先判斷一個對象是否存活,判斷對象是否存活採用了兩種方法: 引用計數法 給對象中添加一個引用計數器,每引用這個對象一次,計數器 +1,當引用失效時,計數器 -1。當引用計數器為 0 時,則表示該對象可被回收。 Java 不適用原因:無法解決對象互相迴圈引用的問題 ...
一周排行
    -Advertisement-
    Play Games
  • Dapr Outbox 是1.12中的功能。 本文只介紹Dapr Outbox 執行流程,Dapr Outbox基本用法請閱讀官方文檔 。本文中appID=order-processor,topic=orders 本文前提知識:熟悉Dapr狀態管理、Dapr發佈訂閱和Outbox 模式。 Outbo ...
  • 引言 在前幾章我們深度講解了單元測試和集成測試的基礎知識,這一章我們來講解一下代碼覆蓋率,代碼覆蓋率是單元測試運行的度量值,覆蓋率通常以百分比表示,用於衡量代碼被測試覆蓋的程度,幫助開發人員評估測試用例的質量和代碼的健壯性。常見的覆蓋率包括語句覆蓋率(Line Coverage)、分支覆蓋率(Bra ...
  • 前言 本文介紹瞭如何使用S7.NET庫實現對西門子PLC DB塊數據的讀寫,記錄了使用電腦模擬,模擬PLC,自至完成測試的詳細流程,並重點介紹了在這個過程中的易錯點,供參考。 用到的軟體: 1.Windows環境下鏈路層網路訪問的行業標準工具(WinPcap_4_1_3.exe)下載鏈接:http ...
  • 從依賴倒置原則(Dependency Inversion Principle, DIP)到控制反轉(Inversion of Control, IoC)再到依賴註入(Dependency Injection, DI)的演進過程,我們可以理解為一種逐步抽象和解耦的設計思想。這種思想在C#等面向對象的編 ...
  • 關於Python中的私有屬性和私有方法 Python對於類的成員沒有嚴格的訪問控制限制,這與其他面相對對象語言有區別。關於私有屬性和私有方法,有如下要點: 1、通常我們約定,兩個下劃線開頭的屬性是私有的(private)。其他為公共的(public); 2、類內部可以訪問私有屬性(方法); 3、類外 ...
  • C++ 訪問說明符 訪問說明符是 C++ 中控制類成員(屬性和方法)可訪問性的關鍵字。它們用於封裝類數據並保護其免受意外修改或濫用。 三種訪問說明符: public:允許從類外部的任何地方訪問成員。 private:僅允許在類內部訪問成員。 protected:允許在類內部及其派生類中訪問成員。 示 ...
  • 寫這個隨筆說一下C++的static_cast和dynamic_cast用在子類與父類的指針轉換時的一些事宜。首先,【static_cast,dynamic_cast】【父類指針,子類指針】,兩兩一組,共有4種組合:用 static_cast 父類轉子類、用 static_cast 子類轉父類、使用 ...
  • /******************************************************************************************************** * * * 設計雙向鏈表的介面 * * * * Copyright (c) 2023-2 ...
  • 相信接觸過spring做開發的小伙伴們一定使用過@ComponentScan註解 @ComponentScan("com.wangm.lifecycle") public class AppConfig { } @ComponentScan指定basePackage,將包下的類按照一定規則註冊成Be ...
  • 操作系統 :CentOS 7.6_x64 opensips版本: 2.4.9 python版本:2.7.5 python作為腳本語言,使用起來很方便,查了下opensips的文檔,支持使用python腳本寫邏輯代碼。今天整理下CentOS7環境下opensips2.4.9的python模塊筆記及使用 ...