TensorFlow搭建模型方式總結

来源:https://www.cnblogs.com/lovefisho/archive/2022/05/17/16279836.html
-Advertisement-
Play Games

大家好,這篇文章分享了C程式設計(譚浩強)第五版第四章課後題答案,所有程式已經測試能夠正常運行,如果小伙伴發現有錯誤的的地方,歡迎留言告訴我,我會及時改正!感謝大家的觀看!!! ...


引言

 TensorFlow提供了多種API,使得入門者和專家可以根據自己的需求選擇不同的API搭建模型。

基於Keras Sequential API搭建模型

Sequential適用於線性堆疊的方式搭建模型,即每層只有一個輸入和輸出。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

# 使用Sequential搭建模型
# 方式一
model = tf.keras.models.Sequential([

    # 加入CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
    tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)),

    # 加入激活函數
    tf.keras.layers.Activation('relu'),

    # 加入2X2池化層, 步長為2
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),

    # 把圖像數據平鋪
    tf.keras.layers.Flatten(),

    # 加入全連接層, 設置神經元為128個, 設置relu激活函數
    tf.keras.layers.Dense(128, activation='relu'),

    # 加入全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
    tf.keras.layers.Dense(10, activation='softmax')
])

# 方式二
model2 = tf.keras.models.Sequential()
model2.add(tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)))
model2.add(tf.keras.layers.Activation('relu'))
model2.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
model2.add(tf.keras.layers.Flatten())
model2.add(tf.keras.layers.Dense(128, activation='relu'))
model2.add(tf.keras.layers.Dense(10, activation='softmax'))

# 模型概覽
model.summary()

"""
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 3)         30        

 activation (Activation)     (None, 26, 26, 3)         0         

 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               

 flatten (Flatten)           (None, 507)               0         

 dense (Dense)               (None, 128)               65024     

 dense_1 (Dense)             (None, 10)                1290      

=================================================================
Total params: 66,344
Trainable params: 66,344
"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

基於Keras 函數API搭建模型

由於Sequential是線性堆疊的,只有一個輸入和輸出,但是當我們需要搭建多輸入模型時,如輸入圖片、文本描述等,這幾類信息可能需要分別使用CNN,RNN模型提取信息,然後彙總信息到最後的神經網路中預測輸出。或者是多輸出任務,如根據音樂預測音樂類型和發行時間。亦或是一些非線性的拓撲網路結構模型,如使用殘差鏈接、Inception等。上述這些情況的網路都不是線性搭建,要搭建如此複雜的網路,需要使用函數API來搭建。

 

簡單實例

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

input_tensor = tf.keras.layers.Input(shape=(28, 28, 1))

# CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
x = tf.keras.layers.Conv2D(3, kernel_size=3, strides=1)(input_tensor)

# 激活函數
x = tf.keras.layers.Activation('relu')(x)

# 2X2池化層, 步長為2
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)(x)

# 把圖像數據平鋪
x = tf.keras.layers.Flatten()(x)

# 全連接層, 設置神經元為128個, 設置relu激活函數
x = tf.keras.layers.Dense(128, activation='relu')(x)

# 全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
output = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.models.Model(input_tensor, output)

# 模型概覽
model.summary()

"""
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 3)         30        
                                                                 
 activation (Activation)     (None, 26, 26, 3)         0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 507)               0         
                                                                 
 dense (Dense)               (None, 128)               65024     
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 66,344
Trainable params: 66,344
Non-trainable params: 0
_________________________________________________________________

"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

 

多輸入實例

import tensorflow as tf

# 輸入1
input_tensor1 = tf.keras.layers.Input(shape=(28,))
x1 = tf.keras.layers.Dense(16, activation='relu')(input_tensor1)
output1 = tf.keras.layers.Dense(32, activation='relu')(x1)

# 輸入2
input_tensor2 = tf.keras.layers.Input(shape=(28,))
x2 = tf.keras.layers.Dense(16, activation='relu')(input_tensor2)
output2 = tf.keras.layers.Dense(32, activation='relu')(x2)

# 合併輸入1和輸入2
concat = tf.keras.layers.concatenate([output1, output2])

# 頂層分類模型
output = tf.keras.layers.Dense(10, activation='relu')(concat)

model = tf.keras.models.Model([input_tensor1, input_tensor2], output)

# 編譯
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

 

多輸出實例

import tensorflow as tf

# 輸入
input_tensor = tf.keras.layers.Input(shape=(28,))
x = tf.keras.layers.Dense(16, activation='relu')(input_tensor)
output = tf.keras.layers.Dense(32, activation='relu')(x)


# 多個輸出
output1 = tf.keras.layers.Dense(10, activation='relu')(output)
output2 = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.models.Model(input_tensor, [output1, output2])

# 編譯
model.compile(
    optimizer='adam',
    loss=['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics=['accuracy']
)

 

子類化API

 相較於上述使用高階API,使用子類化API的方式來搭建模型,可以根據需求對模型中的任何一部分進行修改。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train / 255, x_test / 255

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=10).batch(32)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.hidden_layer1 = tf.keras.layers.Dense(16, activation='relu')
        self.hidden_layer2 = tf.keras.layers.Dense(10, activation='softmax')

    # 定義模型
    def call(self, x):
        h = self.flatten(x)
        h = self.hidden_layer1(h)
        y = self.hidden_layer2(h)
        return y


model = MyModel()

# 損失函數 和 優化器
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# 評估指標
train_loss = tf.keras.metrics.Mean()  # 一個epoch的loss
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()  # 一個epoch的準確率

test_loss = tf.keras.metrics.Mean()
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x)
        loss = loss_function(y, y_pre)
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))

    train_loss(loss)
    train_accuracy(y, y_pre)


@tf.function
def test_step(x, y):
    y_pre = model(x)
    te_loss = loss_function(y, y_pre)

    test_loss(te_loss)
    test_accuracy(y, y_pre)


epoch = 2

for i in range(epoch):

    # 重置評估指標
    train_loss.reset_states()
    train_accuracy.reset_states()

    # 按照batch size 進行訓練
    for x, y in train_data:
        train_step(x, y)

    print(f'epoch {i+1} train loss {train_loss.result()} train accuracy {train_accuracy.result()}')

 參考

TensorFlow官方文檔

 

本文來自博客園,作者:LoveFishO,轉載請註明原文鏈接:https://www.cnblogs.com/lovefisho/p/16279836.html


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 背景 在測試環境上遇到一個詭異的問題,部分業務邏輯會記錄用戶ID到資料庫,但記錄的數據會串,比如當前用戶的操作記錄會被其他用戶覆蓋, 而且這個現象是每次重啟後一小段時間內就正常 問題 線上程池內部使用了InheritableThreadLocal存放用戶登錄信息,再獲取用戶信息後,由於沒有及時rem ...
  • 今天收到一個工作4年的粉絲的面試題。 問題是: “Spring中有哪些方式可以把Bean註入到IOC容器”。 他說這道題是所有面試題裡面回答最好的,但是看面試官的表情,好像不太對。 我問他怎麼回答的,他說: “介面註入”、“Setter註入”、“構造器註入”。 為什麼不對?來看看普通人和高手的回答。 ...
  • 目錄 一.簡介 二.效果演示 三.源碼下載 四.猜你喜歡 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 基礎 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 轉場 零基礎 O ...
  • “==”和equals的區別 首先我們應該知道的是: “==”是運算符,如果是基本數據類型,則比較存儲的值;如果是引用數據類型,則比較所指向對象的地址值。 equals是Object的方法,比較的是所指向的對象的地址值,一般情況下,重寫之後比較的是對象的值。 一、對象類型不同 1、equals(): ...
  • 把destoon數據生成json,一般用於百度小程式、QQ小程式和微信小程式或者原生APP,由於系統是GB2312編碼,所以服務端編寫的時候我們進行了一些編碼轉換的處理,保證服務端訪問的編碼是UTF-8就可以。不多了,下麵乾貨來了。如果你是程式或此段代碼對你有幫助,希望收藏!! 代碼來了,在根目錄新 ...
  • 環境介紹: python 3.8 解釋器 pycharm 2021專業版 >>> 激活碼 編輯器 谷歌瀏覽器 谷歌驅動 selenium >>> 驅動 >>> 瀏覽器 模塊使用: 採集一個視頻 requests >>> pip install requests re 採集多個視頻 selenium ...
  • 1引言:這裡主要做三件事 1.1resources文件夾下創建spring-mvc.xml並配置:開啟註解驅動(mvc:annotation-driven),靜態資源過濾(mvc:default-servlet-handler),視圖解析器(InternalResourceViewResolver) ...
  • 在 JVM 進行垃圾回收之前,我們需要先判斷一個對象是否存活,判斷對象是否存活採用了兩種方法: 引用計數法 給對象中添加一個引用計數器,每引用這個對象一次,計數器 +1,當引用失效時,計數器 -1。當引用計數器為 0 時,則表示該對象可被回收。 Java 不適用原因:無法解決對象互相迴圈引用的問題 ...
一周排行
    -Advertisement-
    Play Games
  • 一:背景 準備開個系列來聊一下 PerfView 這款工具,熟悉我的朋友都知道我喜歡用 WinDbg,這東西雖然很牛,但也不是萬能的,也有一些場景他解決不了或者很難解決,這時候藉助一些其他的工具來輔助,是一個很不錯的主意。 很多朋友喜歡在項目中以記錄日誌的方式來監控項目的流轉情況,其實 CoreCL ...
  • 本來閑來無事,準備看看Dapper擴展的源碼學習學習其中的編程思想,同時整理一下自己代碼的單元測試,為以後的進一步改進打下基礎。 突然就發現問題了,源碼也不看了,開始改代碼,改了好久。 測試Dapper.LiteSql數據批量插入的時候,耗時20秒,感覺不正常,於是我測試了非Dapper版的Lite ...
  • 需求如下,在DEV框架項目中,需要在表格中增加一列顯示圖片,並且能編輯該列圖片,然後進行保存等操作,最終效果如下 這裡使用的是PictureEdit控制項來實現,打開DEV GridControl設計器,在ColumnEdit選擇PictureEdit: 綁定圖片代碼如下: DataTable dtO ...
  • 前兩天微軟偷偷更新了Visual Studio 2022 正式版版本 17.3 發佈,發佈摘要: MAUI 工作負荷 GA 生成 MAUI/Blazor CSS 熱重載支持 現在,你將能夠使用我們的新增功能在 Visual Studio 中使用每個更新試用一系列新功能。 選擇每個功能以瞭解有關特定功 ...
  • 航天和軍工領域的數字化轉型和建設正在積極推進,在與航天二院、航天三院、航天六院、航天九院、無線電廠、兵工廠等單位交流的過程中,用戶更聚焦試驗和生產過程中的痛點,迫切需要解決軟體平臺統一監測和控制設備及軟體與設備協同的問題。 ...
  • .NET 項目預設情況下 日誌是使用的 ILogger 介面,預設提供一下四種日誌記錄程式: 控制台 調試 EventSource EventLog 這四種記錄程式都是預設包含在 .NET 運行時庫中。關於這四種記錄程式的詳細介紹可以直接查看微軟的官方文檔 https://docs.microsof ...
  • 一:背景 上一篇我們聊到瞭如何去找 熱點函數,這一篇我們來看下當你的程式出現了 非托管記憶體泄漏 時如何去尋找可疑的代碼源頭,其實思路很簡單,就是在 HeapAlloc 或者 VirtualAlloc 時做 Hook 攔截,記錄它的調用棧以及分配的記憶體量, PerfView 會將這個 分配量 做成一個 ...
  • 背景 在 CI/CD 流程當中,測試是 CI 中很重要的部分。跟開發人員關係最大的就是單元測試,單元測試編寫完成之後,我們可以使用 IDE 或者 dot cover 等工具獲得單元測試對於業務代碼的覆蓋率。不過我們需要一個獨立的 CLI 工具,這樣我們才能夠在 Jenkins 的 CI 流程集成。 ...
  • 一、應用場景 大家在使用Mybatis進行開發的時候,經常會遇到一種情況:按照月份month將數據放在不同的表裡面,查詢數據的時候需要跟不同的月份month去查詢不同的表。 但是我們都知道,Mybatis是ORM持久層框架,即:實體關係映射,實體Object與資料庫表之間是存在一一對應的映射關係。比 ...
  • 我國目前並未出台專門針對網路爬蟲技術的法律規範,但在司法實踐中,相關判決已屢見不鮮,K 哥特設了“K哥爬蟲普法”專欄,本欄目通過對真實案例的分析,旨在提高廣大爬蟲工程師的法律意識,知曉如何合法合規利用爬蟲技術,警鐘長鳴,做一個守法、護法、有原則的技術人員。 案情介紹 深圳市快鴿互聯網科技有限公司 2 ...