學習筆記TF051:生成式對抗網路

来源:http://www.cnblogs.com/libinggen/archive/2017/08/24/7421110.html
-Advertisement-
Play Games

生成式對抗網路(gennerative adversarial network,GAN),谷歌2014年提出網路模型。靈感自二人博弈的零和博弈,目前最火的非監督深度學習。GAN之父,Ian J.Goodfellow,公認人工智慧頂級專家。 原理。生成式對搞網路包含一個生成模型(generative ...


生成式對抗網路(gennerative adversarial network,GAN),谷歌2014年提出網路模型。靈感自二人博弈的零和博弈,目前最火的非監督深度學習。GAN之父,Ian J.Goodfellow,公認人工智慧頂級專家。

原理。
生成式對搞網路包含一個生成模型(generative model,G)和一個判別模型(discriminative model,D)。Ian J.Goodfellow、Jean Pouget-Abadie、Mehdi Mirza、Bing Xu、David Warde-Farley、Sherjil Ozair、Aaron Courville、Yoshua Bengio論文,《Generative Adversarial Network》,https://arxiv.org/abs/1406.2661 。
生成式對抗網路結構:
雜訊數據->生成模型->假圖片---|
|->判別模型->真/假
打亂訓練數據->訓練集->真圖片-|
生成式對抗網路主要解決如何從訓練樣本中學習出新樣本。生成模型負責訓練出樣本的分佈,如果訓練樣本是圖片就生成相似的圖片,如果訓練樣本是文章名子就生成相似的文章名子。判別模型是一個二分類器,用來判斷輸入樣本是真實數據還是訓練生成的樣本。
生成式對抗網路優化,是一個二元極小極大博弈(minimax two-player game)問題。使生成模型輸出在輸入給判別模型時,判斷模型秀難判斷是真實數據還是虛似數據。訓練好的生成模型,能把一個雜訊向量轉化成和訓練集類似的樣本。Argustus Odena、Christopher Olah、Jonathon Shlens論文《Coditional Image Synthesis with Auxiliary Classifier GANs》。
輔助分類器生成式對抗網路(auxiliary classifier GAN,AC-GAN)實現。

生成式對抗網路應用。生成數字,生成人臉圖像。

生成式對抗網路實現。https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py 。
Augustus Odena、Chistopher Olah和Jonathon Shlens 論文《Conditional Image Synthesis With Auxiliary Classifier GANs》。
通過雜訊,讓生成模型G生成虛假數據,和真實數據一起送到判別模型D,判別模型一方面輸出數據真/假,一方面輸出圖片分類。
首先定義生成模型,目的是生成一對(z,L)數據,z是雜訊向量,L是(1,28,28)的圖像空間。

def build_generator(latent_size):
cnn = Sequential()
cnn.add(Dense(1024, input_dim=latent_size, activation='relu'))
cnn.add(Dense(128 * 7 * 7, activation='relu'))
cnn.add(Reshape((128, 7, 7)))
#上採樣,圖你尺寸變為 14X14
cnn.add(UpSampling2D(size=(2,2)))
cnn.add(Convolution2D(256, 5, 5, border_mode='same', activation='relu', init='glorot_normal'))
#上採樣,圖像尺寸變為28X28
cnn.add(UpSampling2D(size=(2,2)))
cnn.add(Convolution2D(128, 5, 5, border_mode='same', activation='relu', init='glorot_normal'))
#規約到1個通道
cnn.add(Convolution2D(1, 2, 2, border_mode='same', activation='tanh', init='glorot_normal'))
#生成模型輸入層,特征向量
latent = Input(shape=(latent_size, ))
#生成模型輸入層,標記
image_class = Input(shape=(1,), dtype='int32')
cls = Flatten()(Embedding(10, latent_size, init='glorot_normal')(image_class))
h = merge([latent, cls], mode='mul')
fake_image = cnn(h) #輸出虛假圖片
return Model(input=[latent, image_class], output=fake_image)
定義判別模型,輸入(1,28,28)圖片,輸出兩個值,一個是判別模型認為這張圖片是否是虛假圖片,另一個是判別模型認為這第圖片所屬分類。

def build_discriminator();
#採用激活函數Leaky ReLU來替換標準的捲積神經網路中的激活函數
cnn = Wequential()
cnn.add(Convolution2D(32, 3, 3, border_mode='same', subsample=(2, 2), input_shape=(1, 28, 28)))
cnn.add(LeakyReLU())
cnn.add(Dropout(0.3))
cnn.add(Convolution2D(64, 3, 3, border_mode='same', subsample=(1, 1)))
cnn.add(LeakyReLU())
cnn.add(Dropout(0.3))
cnn.add(Convolution2D(128, 3, 3, border_mode='same', subsample=(1, 1)))
cnn.add(LeakyReLU())
cnn.add(Dropout(0.3))
cnn.add(Convolution2D(256, 3, 3, border_mode='same', subsample=(1, 1)))
cnn.add(LeakyReLU())
cnn.add(Dropout(0.3))
cnn.add(Flatten())
image = Input(shape=(1, 28, 28))
features = cnn(image)
#有兩個輸出
#輸出真假值,範圍在0~1
fake = Dense(1, activation='sigmoid',name='generation')(features)
#輔助分類器,輸出圖片分類
aux = Dense(10, activation='softmax', name='auxiliary')(features)
return Model(input=image, output=[fake, aux])
訓練過程,50輪(epoch),把權重保存,每輪把虛假數據生成圖處保存,觀察虛假數據演化過程。

if __name__ =='__main__':
#定義超參數
nb_epochs = 50
batch_size = 100
latent_size = 100
#優化器學習率
adam_lr = 0.0002
adam_beta_l = 0.5
#構建判別網路
discriminator = build_discriminator()
discriminator.compile(optimizer=adam(lr=adam_lr, beta_l=adam_beta_l), loss='binary_crossentropy')
latent = Input(shape=(lastent_size, ))
image_class = Input(shape-(1, ), dtype='int32')
#生成組合模型
discriminator.trainable = False
fake, aux = discriminator(fake)
combined = Model(input=[latent, image_class], output=[fake, aux])
combined.compile(optimizer=Adam(lr=adam_lr, beta_l=adam_beta_1), loss=['binary_crossentropy', 'sparse_categorical_crossentropy'])
#將mnist數據轉化為(...,1,28,28)維度,取值範圍為[-1,1]
(X_train,y_train),(X_test,y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=1)
X_test = (X_test.astype(np.float32) - 127.5) / 127.5
X_test = np.expand_dims(X_test, axis=1)
num_train, num_test = X_train.shape[0], X_test.shape[0]
train_history = defaultdict(list)
test_history = defaultdict(list)
for epoch in range(epochs):
print('Epoch {} of {}'.format(epoch + 1, epochs))
num_batches = int(X_train.shape[0] / batch_size)
progress_bar = Progbar(target=num_batches)
epoch_gen_loss = []
epoch_disc_loss = []
for index in range(num_batches):
progress_bar.update(index)
#產生一個批次的雜訊數據
noise = np.random.uniform(-1, 1, (batch_size, latent_size))
# 獲取一個批次的真實數據
image_batch = X_train[index * batch_size:(index + 1) * batch_size]
label_batch = y_train[index * batch_size:(index + 1) * batch_size]
# 生成一些雜訊標記
sampled_labels = np.random.randint(0, 10, batch_size)
# 產生一個批次的虛假圖片
generated_images = generator.predict(
[noise, sampled_labels.reshape((-1, 1))], verbose=0)
X = np.concatenate((image_batch, generated_images))
y = np.array([1] * batch_size + [0] * batch_size)
aux_y = np.concatenate((label_batch, sampled_labels), axis=0)
epoch_disc_loss.append(discriminator.train_on_batch(X, [y, aux_y]))
# 產生兩個批次雜訊和標記
noise = np.random.uniform(-1, 1, (2 * batch_size, latent_size))
sampled_labels = np.random.randint(0, 10, 2 * batch_size)
# 訓練生成模型來欺騙判別模型,輸出真/假都設為真
trick = np.ones(2 * batch_size)
epoch_gen_loss.append(combined.train_on_batch(
[noise, sampled_labels.reshape((-1, 1))],
[trick, sampled_labels]))
print('\nTesting for epoch {}:'.format(epoch + 1))
# 評估測試集,產生一個新批次雜訊數據
noise = np.random.uniform(-1, 1, (num_test, latent_size))
sampled_labels = np.random.randint(0, 10, num_test)
generated_images = generator.predict(
[noise, sampled_labels.reshape((-1, 1))], verbose=False)
X = np.concatenate((X_test, generated_images))
y = np.array([1] * num_test + [0] * num_test)
aux_y = np.concatenate((y_test, sampled_labels), axis=0)
# 判別模型是否能判別
discriminator_test_loss = discriminator.evaluate(
X, [y, aux_y], verbose=False)
discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)
# 創建兩個批次新雜訊數據
noise = np.random.uniform(-1, 1, (2 * num_test, latent_size))
sampled_labels = np.random.randint(0, 10, 2 * num_test)
trick = np.ones(2 * num_test)
generator_test_loss = combined.evaluate(
[noise, sampled_labels.reshape((-1, 1))],
[trick, sampled_labels], verbose=False)
generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)
# 損失值等性能指標記錄下來,並輸出
train_history['generator'].append(generator_train_loss)
train_history['discriminator'].append(discriminator_train_loss)
test_history['generator'].append(generator_test_loss)
test_history['discriminator'].append(discriminator_test_loss)
print('{0:<22s} | {1:4s} | {2:15s} | {3:5s}'.format(
'component', *discriminator.metrics_names))
print('-' * 65)
ROW_FMT = '{0:<22s} | {1:<4.2f} | {2:<15.2f} | {3:<5.2f}'
print(ROW_FMT.format('generator (train)',
*train_history['generator'][-1]))
print(ROW_FMT.format('generator (test)',
*test_history['generator'][-1]))
print(ROW_FMT.format('discriminator (train)',
*train_history['discriminator'][-1]))
print(ROW_FMT.format('discriminator (test)',
*test_history['discriminator'][-1]))
# 每個epoch保存一次權重
generator.save_weights(
'params_generator_epoch_{0:03d}.hdf5'.format(epoch), True)
discriminator.save_weights(
'params_discriminator_epoch_{0:03d}.hdf5'.format(epoch), True)
# 生成一些可視化虛假數字看演化過程
noise = np.random.uniform(-1, 1, (100, latent_size))
sampled_labels = np.array([
[i] * 10 for i in range(10)
]).reshape(-1, 1)
generated_images = generator.predict(
[noise, sampled_labels], verbose=0)
# 整理到一個方格
img = (np.concatenate([r.reshape(-1, 28)
for r in np.split(generated_images, 10)
], axis=-1) * 127.5 + 127.5).astype(np.uint8)
Image.fromarray(img).save(
'plot_epoch_{0:03d}_generated.png'.format(epoch))
pickle.dump({'train': train_history, 'test': test_history},
open('acgan-history.pkl', 'wb'))

訓練結束,創建3類文件。params_discriminator_epoch_{{epoch_number}}.hdf5,判別模型權重參數。params_generator_epoch_{{epoch_number}}.hdf5,生成模型權重參數。plot_epoch_{{epoch_number}}_generated.png 。

生成式對抗網路改進。生成式對抗網路(generative adversarial network,GAN)在無監督學習非常有效。常規生成式對抗網路判別器使用Sigmoid交叉熵損失函數,學習過程梯度消失。Wasserstein生成式對抗網路(Wasserstein generative adversarial network,WGAN),使用Wasserstein距離度量,而不是Jensen-Shannon散度(Jensen-Shannon divergence,JSD)。使用最小二乘生成式對抗網路(least squares generative adversarial network,LSGAN),判別模型用最小平方損失小函數(least squares loss function)。Sebastian Nowozin、Botond Cseke、Ryota Tomioka論文《f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization》。

參考資料:
《TensorFlow技術解析與實戰》

歡迎付費咨詢(150元每小時),我的微信:qingxingfengzi


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • select * from 表名 as of timestamp to_timestamp('2016-02-23 23:59:59','yyyy-mm-dd hh24:mi:ss') ; 該語句表示查詢某一個時間點時該表的數據,通過修改時間,你可以查詢刪除之前時的數據。篩選出來重新插入即可。 ...
  • 在MySQL中如何給普通用戶授予查看所有用戶線程/連接的許可權,當然,預設情況下show processlist是可以查看當前用戶的線程/連接的。 mysql> grant process on MyDB.* to test; ERROR 1221 (HY000): Incorrect usage o... ...
  • 具體報錯如下: Table '.\mysql\proc' is marked as crashed and should be repaired 我的解決辦法: 找到mysql的安裝目錄的bin/myisamchk工具,右擊【以管理員身份運行】修複下即可。 網上解決辦法: 找到mysql的安裝目錄的 ...
  • ALTER PROCEDURE [dbo].[sp_GetClassCountData] @BatchId NVARCHAR(50), @ExamId VARCHAR(100), @ClassId VARCHAR(100), @SubjectId NVARCHAR(50)ASBEGIN DECLAR ...
  • [20170824]11G備庫啟用DRCP連接.txt--//參考鏈接:http://blog.itpub.net/267265/viewspace-2099397/blogs.oracle.com/database4cn/adg%e5%a4%87%e5%ba%93%e7%9a%84drcp%e8% ...
  • 概述: 視圖即是虛擬表,也稱為派生表,因為它們的內容都派生自其它表的查詢結果。雖然視圖看起來感覺和基本表一樣,但是它們不是基本表。基本表的內容是持久的,而視圖的內容是在使用過程中動態產生的。——摘自《SQLite權威指南》 使用視圖的優點: 1.可靠的安全性 2.查詢性能提高 3.有效應對靈活性的功 ...
  • 創建序列 create sequence seq_student start with 6 increment by 1 maxvalue 500 nominvalue nocycle nocache; 創建觸發器 create or replace trigger trigger_student ... ...
  • HDFS ,Hadoop Distribute File System,hadoop分散式文件系統。 主從架構,分主節點NameNode,從節點DataNode.當然還有個SecondaryName,但這不是淺析里的點.這裡主要講下namenode和datanode的基本概念, 並描述下讀寫過程. ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...