TensorFlow搭建模型方式總結

来源:https://www.cnblogs.com/lovefisho/archive/2022/05/17/16279836.html
-Advertisement-
Play Games

大家好,這篇文章分享了C程式設計(譚浩強)第五版第四章課後題答案,所有程式已經測試能夠正常運行,如果小伙伴發現有錯誤的的地方,歡迎留言告訴我,我會及時改正!感謝大家的觀看!!! ...


引言

 TensorFlow提供了多種API,使得入門者和專家可以根據自己的需求選擇不同的API搭建模型。

基於Keras Sequential API搭建模型

Sequential適用於線性堆疊的方式搭建模型,即每層只有一個輸入和輸出。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

# 使用Sequential搭建模型
# 方式一
model = tf.keras.models.Sequential([

    # 加入CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
    tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)),

    # 加入激活函數
    tf.keras.layers.Activation('relu'),

    # 加入2X2池化層, 步長為2
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),

    # 把圖像數據平鋪
    tf.keras.layers.Flatten(),

    # 加入全連接層, 設置神經元為128個, 設置relu激活函數
    tf.keras.layers.Dense(128, activation='relu'),

    # 加入全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
    tf.keras.layers.Dense(10, activation='softmax')
])

# 方式二
model2 = tf.keras.models.Sequential()
model2.add(tf.keras.layers.Conv2D(3, kernel_size=3, strides=1, input_shape=(28, 28, 1)))
model2.add(tf.keras.layers.Activation('relu'))
model2.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
model2.add(tf.keras.layers.Flatten())
model2.add(tf.keras.layers.Dense(128, activation='relu'))
model2.add(tf.keras.layers.Dense(10, activation='softmax'))

# 模型概覽
model.summary()

"""
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 3)         30        

 activation (Activation)     (None, 26, 26, 3)         0         

 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               

 flatten (Flatten)           (None, 507)               0         

 dense (Dense)               (None, 128)               65024     

 dense_1 (Dense)             (None, 10)                1290      

=================================================================
Total params: 66,344
Trainable params: 66,344
"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

基於Keras 函數API搭建模型

由於Sequential是線性堆疊的,只有一個輸入和輸出,但是當我們需要搭建多輸入模型時,如輸入圖片、文本描述等,這幾類信息可能需要分別使用CNN,RNN模型提取信息,然後彙總信息到最後的神經網路中預測輸出。或者是多輸出任務,如根據音樂預測音樂類型和發行時間。亦或是一些非線性的拓撲網路結構模型,如使用殘差鏈接、Inception等。上述這些情況的網路都不是線性搭建,要搭建如此複雜的網路,需要使用函數API來搭建。

 

簡單實例

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train/255, x_test/255

input_tensor = tf.keras.layers.Input(shape=(28, 28, 1))

# CNN層(2D), 使用了3個捲積核, 捲積核的尺寸為3X3, 步長為1, 輸入圖像的維度為28X28X1
x = tf.keras.layers.Conv2D(3, kernel_size=3, strides=1)(input_tensor)

# 激活函數
x = tf.keras.layers.Activation('relu')(x)

# 2X2池化層, 步長為2
x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)(x)

# 把圖像數據平鋪
x = tf.keras.layers.Flatten()(x)

# 全連接層, 設置神經元為128個, 設置relu激活函數
x = tf.keras.layers.Dense(128, activation='relu')(x)

# 全連接層(輸出層), 設置輸出數量為10, 設置softmax激活函數
output = tf.keras.layers.Dense(10, activation='softmax')(x)

model = tf.keras.models.Model(input_tensor, output)

# 模型概覽
model.summary()

"""
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 3)         30        
                                                                 
 activation (Activation)     (None, 26, 26, 3)         0         
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 3)        0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 507)               0         
                                                                 
 dense (Dense)               (None, 128)               65024     
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 66,344
Trainable params: 66,344
Non-trainable params: 0
_________________________________________________________________

"""

# 編譯 為模型加入優化器, 損失函數, 評估指標
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 訓練模型, 2個epoch, batch size為100
model.fit(x_train, y_train, epochs=2, batch_size=100)

 

多輸入實例

import tensorflow as tf

# 輸入1
input_tensor1 = tf.keras.layers.Input(shape=(28,))
x1 = tf.keras.layers.Dense(16, activation='relu')(input_tensor1)
output1 = tf.keras.layers.Dense(32, activation='relu')(x1)

# 輸入2
input_tensor2 = tf.keras.layers.Input(shape=(28,))
x2 = tf.keras.layers.Dense(16, activation='relu')(input_tensor2)
output2 = tf.keras.layers.Dense(32, activation='relu')(x2)

# 合併輸入1和輸入2
concat = tf.keras.layers.concatenate([output1, output2])

# 頂層分類模型
output = tf.keras.layers.Dense(10, activation='relu')(concat)

model = tf.keras.models.Model([input_tensor1, input_tensor2], output)

# 編譯
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

 

多輸出實例

import tensorflow as tf

# 輸入
input_tensor = tf.keras.layers.Input(shape=(28,))
x = tf.keras.layers.Dense(16, activation='relu')(input_tensor)
output = tf.keras.layers.Dense(32, activation='relu')(x)


# 多個輸出
output1 = tf.keras.layers.Dense(10, activation='relu')(output)
output2 = tf.keras.layers.Dense(1, activation='sigmoid')(output)

model = tf.keras.models.Model(input_tensor, [output1, output2])

# 編譯
model.compile(
    optimizer='adam',
    loss=['sparse_categorical_crossentropy', 'binary_crossentropy'],
    metrics=['accuracy']
)

 

子類化API

 相較於上述使用高階API,使用子類化API的方式來搭建模型,可以根據需求對模型中的任何一部分進行修改。

import tensorflow as tf

# 導入手寫數字數據集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 數據標準化
x_train, x_test = x_train / 255, x_test / 255

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=10).batch(32)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.hidden_layer1 = tf.keras.layers.Dense(16, activation='relu')
        self.hidden_layer2 = tf.keras.layers.Dense(10, activation='softmax')

    # 定義模型
    def call(self, x):
        h = self.flatten(x)
        h = self.hidden_layer1(h)
        y = self.hidden_layer2(h)
        return y


model = MyModel()

# 損失函數 和 優化器
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

# 評估指標
train_loss = tf.keras.metrics.Mean()  # 一個epoch的loss
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()  # 一個epoch的準確率

test_loss = tf.keras.metrics.Mean()
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x)
        loss = loss_function(y, y_pre)
    grad = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))

    train_loss(loss)
    train_accuracy(y, y_pre)


@tf.function
def test_step(x, y):
    y_pre = model(x)
    te_loss = loss_function(y, y_pre)

    test_loss(te_loss)
    test_accuracy(y, y_pre)


epoch = 2

for i in range(epoch):

    # 重置評估指標
    train_loss.reset_states()
    train_accuracy.reset_states()

    # 按照batch size 進行訓練
    for x, y in train_data:
        train_step(x, y)

    print(f'epoch {i+1} train loss {train_loss.result()} train accuracy {train_accuracy.result()}')

 參考

TensorFlow官方文檔

 

本文來自博客園,作者:LoveFishO,轉載請註明原文鏈接:https://www.cnblogs.com/lovefisho/p/16279836.html


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 背景 在測試環境上遇到一個詭異的問題,部分業務邏輯會記錄用戶ID到資料庫,但記錄的數據會串,比如當前用戶的操作記錄會被其他用戶覆蓋, 而且這個現象是每次重啟後一小段時間內就正常 問題 線上程池內部使用了InheritableThreadLocal存放用戶登錄信息,再獲取用戶信息後,由於沒有及時rem ...
  • 今天收到一個工作4年的粉絲的面試題。 問題是: “Spring中有哪些方式可以把Bean註入到IOC容器”。 他說這道題是所有面試題裡面回答最好的,但是看面試官的表情,好像不太對。 我問他怎麼回答的,他說: “介面註入”、“Setter註入”、“構造器註入”。 為什麼不對?來看看普通人和高手的回答。 ...
  • 目錄 一.簡介 二.效果演示 三.源碼下載 四.猜你喜歡 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 基礎 零基礎 OpenGL (ES) 學習路線推薦 : OpenGL (ES) 學習目錄 >> OpenGL ES 轉場 零基礎 O ...
  • “==”和equals的區別 首先我們應該知道的是: “==”是運算符,如果是基本數據類型,則比較存儲的值;如果是引用數據類型,則比較所指向對象的地址值。 equals是Object的方法,比較的是所指向的對象的地址值,一般情況下,重寫之後比較的是對象的值。 一、對象類型不同 1、equals(): ...
  • 把destoon數據生成json,一般用於百度小程式、QQ小程式和微信小程式或者原生APP,由於系統是GB2312編碼,所以服務端編寫的時候我們進行了一些編碼轉換的處理,保證服務端訪問的編碼是UTF-8就可以。不多了,下麵乾貨來了。如果你是程式或此段代碼對你有幫助,希望收藏!! 代碼來了,在根目錄新 ...
  • 環境介紹: python 3.8 解釋器 pycharm 2021專業版 >>> 激活碼 編輯器 谷歌瀏覽器 谷歌驅動 selenium >>> 驅動 >>> 瀏覽器 模塊使用: 採集一個視頻 requests >>> pip install requests re 採集多個視頻 selenium ...
  • 1引言:這裡主要做三件事 1.1resources文件夾下創建spring-mvc.xml並配置:開啟註解驅動(mvc:annotation-driven),靜態資源過濾(mvc:default-servlet-handler),視圖解析器(InternalResourceViewResolver) ...
  • 在 JVM 進行垃圾回收之前,我們需要先判斷一個對象是否存活,判斷對象是否存活採用了兩種方法: 引用計數法 給對象中添加一個引用計數器,每引用這個對象一次,計數器 +1,當引用失效時,計數器 -1。當引用計數器為 0 時,則表示該對象可被回收。 Java 不適用原因:無法解決對象互相迴圈引用的問題 ...
一周排行
    -Advertisement-
    Play Games
  • .Net8.0 Blazor Hybird 桌面端 (WPF/Winform) 實測可以完整運行在 win7sp1/win10/win11. 如果用其他工具打包,還可以運行在mac/linux下, 傳送門BlazorHybrid 發佈為無依賴包方式 安裝 WebView2Runtime 1.57 M ...
  • 目錄前言PostgreSql安裝測試額外Nuget安裝Person.cs模擬運行Navicate連postgresql解決方案Garnet為什麼要選擇Garnet而不是RedisRedis不再開源Windows版的Redis是由微軟維護的Windows Redis版本老舊,後續可能不再更新Garne ...
  • C#TMS系統代碼-聯表報表學習 領導被裁了之後很快就有人上任了,幾乎是無縫銜接,很難讓我不想到這早就決定好了。我的職責沒有任何變化。感受下來這個系統封裝程度很高,我只要會調用方法就行。這個系統交付之後不會有太多問題,更多應該是做小需求,有大的開發任務應該也是第二期的事,嗯?怎麼感覺我變成運維了?而 ...
  • 我在隨筆《EAV模型(實體-屬性-值)的設計和低代碼的處理方案(1)》中介紹了一些基本的EAV模型設計知識和基於Winform場景下低代碼(或者說無代碼)的一些實現思路,在本篇隨筆中,我們來分析一下這種針對通用業務,且只需定義就能構建業務模塊存儲和界面的解決方案,其中的數據查詢處理的操作。 ...
  • 對某個遠程伺服器啟用和設置NTP服務(Windows系統) 打開註冊表 HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Services\W32Time\TimeProviders\NtpServer 將 Enabled 的值設置為 1,這將啟用NTP伺服器功 ...
  • title: Django信號與擴展:深入理解與實踐 date: 2024/5/15 22:40:52 updated: 2024/5/15 22:40:52 categories: 後端開發 tags: Django 信號 松耦合 觀察者 擴展 安全 性能 第一部分:Django信號基礎 Djan ...
  • 使用xadmin2遇到的問題&解決 環境配置: 使用的模塊版本: 關聯的包 Django 3.2.15 mysqlclient 2.2.4 xadmin 2.0.1 django-crispy-forms >= 1.6.0 django-import-export >= 0.5.1 django-r ...
  • 今天我打算整點兒不一樣的內容,通過之前學習的TransformerMap和LazyMap鏈,想搞點不一樣的,所以我關註了另外一條鏈DefaultedMap鏈,主要調用鏈為: 調用鏈詳細描述: ObjectInputStream.readObject() DefaultedMap.readObject ...
  • 後端應用級開發者該如何擁抱 AI GC?就是在這樣的一個大的浪潮下,我們的傳統的應用級開發者。我們該如何選擇職業或者是如何去快速轉型,跟上這樣的一個行業的一個浪潮? 0 AI金字塔模型 越往上它的整個難度就是職業機會也好,或者說是整個的這個運作也好,它的難度會越大,然後越往下機會就會越多,所以這是一 ...
  • @Autowired是Spring框架提供的註解,@Resource是Java EE 5規範提供的註解。 @Autowired預設按照類型自動裝配,而@Resource預設按照名稱自動裝配。 @Autowired支持@Qualifier註解來指定裝配哪一個具有相同類型的bean,而@Resourc... ...