學習筆記GAN004:DCGAN main.py

来源:http://www.cnblogs.com/libinggen/archive/2017/09/14/7518769.html
-Advertisement-
Play Games

Scipy 高端科學計算:http://blog.chinaunix.net/uid-21633169-id-4437868.html import os #引用操作系統函數文件 import scipy.misc #引用scipy包misc模塊 圖像形式存取數組 import numpy as n ...


Scipy 高端科學計算:http://blog.chinaunix.net/uid-21633169-id-4437868.html

import os #引用操作系統函數文件
import scipy.misc #引用scipy包misc模塊 圖像形式存取數組
import numpy as np #引用numpy包 矩陣計算
from model import DCGAN #引用model文件DCGAN類
from utils import pp, visualize, to_json, show_all_variables #引用utils文件pp對象,visualize, to_json, show_all_variables方法
import tensorflow as tf #引用tensorflow
flags = tf.app.flags #接受命令行傳遞參數,相當於接受argv。第一個是參數名稱,第二個參數是預設值,第三個是參數描述
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") #訓練輪數 25
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") #adam優化器 學習速率 0.0002
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") #adam優化器 動量(參數移動平均數) 0.5
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") #訓練畫像尺寸,預設無限大正數
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") #圖像批大小 64
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") #輸入圖像高度 108 均衡的縮放圖像(保持圖像原始比例),使圖片的兩個坐標(寬、高)都大於等於 相應的視圖坐標(負的內邊距)。圖像則位於視圖的中央。
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") #輸入圖像寬度,None與高度相同
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") #輸出圖像高度 64
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") #輸出圖像寬度,None與高度相同
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") #數據集名稱 celebA mnist lsun
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") #圖片文件名的搜索擴展名
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") #檢查點目錄名
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") #圖片樣本保存目錄名
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") #訓練流程開關
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") #訓練流程開關
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") #可視化開關
FLAGS = flags.FLAGS
def main(_): #主程式
  pp.pprint(flags.FLAGS.__flags) #列印命令行參數
  if FLAGS.input_width is None: #如果沒有配置輸入圖像寬度
    FLAGS.input_width = FLAGS.input_height #把輸入圖像高度作為寬度
  if FLAGS.output_width is None: #如果沒有配置輸出圖像寬度
    FLAGS.output_width = FLAGS.output_height #把輸出圖像高度作為寬度
  if not os.path.exists(FLAGS.checkpoint_dir): #如果檢查點目錄不存在
    os.makedirs(FLAGS.checkpoint_dir) #創建檢查點目錄
  if not os.path.exists(FLAGS.sample_dir): #如果樣本目錄不存在
    os.makedirs(FLAGS.sample_dir) #創建樣本目錄
  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) #設置GPU顯存占用比例
  run_config = tf.ConfigProto() #獲取配置對象
  run_config.gpu_options.allow_growth = True #GPU顯存占用按需增加
  with tf.Session(config=run_config) as sess: #指定配置構建會話
    if FLAGS.dataset == 'mnist': #如果指定數據集為mnist
      dcgan = DCGAN( #構建DCGAN
          sess, #提定會話
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10, #標簽維度為10
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    else:
      dcgan = DCGAN( #構建DCGAN,不指定標簽維度
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir)
    show_all_variables() #顯示所有參數
    if FLAGS.train: #如果是訓練
      dcgan.train(FLAGS) #指定參數執行構建DCGAN 訓練方法
    else: #如果是測試
      if not dcgan.load(FLAGS.checkpoint_dir)[0]: #在檢查點目錄沒有檢查點文件,即沒有已訓練好的模型
        raise Exception("[!] Train a model first, then run test mode") #拋出異常:請先訓練模型再執行測試
      
    # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0], #JSON格式化:w,b,gbn
    #                 [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
    #                 [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
    #                 [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
    #                 [dcgan.h4_w, dcgan.h4_b, None])
    # Below is codes for visualization
    OPTION = 1
    visualize(sess, dcgan, FLAGS, OPTION) #執行可視化方法,傳入會話、DCGAN、配置參數,選項
if __name__ == '__main__': #如果直接執行本腳本文件,運行以下代碼,一般作調試用。如果作為其它腳本模塊引入,則不執行以下代碼
  tf.app.run() #運行APP.run 解析FLAGS,執行main方法

歡迎付費咨詢(150元每小時),我的微信:qingxingfengzi

我創建GAN日報群,以每天各報各的進度為主。把正在研究GAN的人聚在一起,互相鼓勵,一起前進。加我微信拉群,請註明:加入GAN日報群。


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 在這之前,我已經分享過一個webpack的全系列,相對於webpack, gulp使用和配置起來非常的簡單. gulp是什麼? gulp 是基於 node 實現 Web 前端自動化開發的工具,利用它能夠極大的提高開發效率。在 Web 前端開發工作中有很多“重覆工作”,比如壓縮CSS/JS文件。而這些 ...
  • 本文介紹如何獲取視頻中某個時間點的數據 調用以下方法即可,特別註意,在獲取圖片時的參數單位為微秒,不是毫秒 如果錯用了毫秒會一直獲取第一幀的畫面 ...
  • UIScrollViewDelegate - (void)scrollViewDidScroll:(UIScrollView *)scrollView;//scrollview 滾動的時候調用該方法,任何 offset 值改變都會調用該方法. - (void)scrollViewDidZoom:(U ...
  • 最近閑來無事,整理一下UICollectionView的相關方法以備使用 UICollectionViewFlowLayout和UICollectionViewLayout UICollectionViewFlowLayout是UICollectionViewLayout是一個子類,我們通常用的比較 ...
  • 作為從安卓的的入門選手,第一次看到還以為是個第三方呢,從github下來之後感覺不對啊,這麼多東西,後來一搜原來是個插件,而且不用從github上下載。 安裝的方法很簡單。 第一步:打開安卓studio的配置,找到Plugins,在右邊搜索ButterKnife ,你就會看到下麵這個界面。沒有錯,這 ...
  • 最近閑來無事,總結一下 UITableViewDataSource和 UITableViewDelegate方法 UITableViewDataSource @required - (NSInteger)tableView:(UITableView *)tableView numberOfRowsI ...
  • 作為安卓入門選手,在導入第三方的時候才發現居然有兩個build.gradle,我說咋不對啊,原來是導錯了(可能是因為我沒有看安卓培訓的視頻吧)。 那麼就說一下這兩個的作用(一個Project的,一個Module的): 簡單一點來說Project中的gradle是聲明的資源包括依賴項、第三方插件、ma ...
  • Android記憶體泄漏是一個經常要遇到的問題,程式在記憶體泄漏的時候很容易導致OOM的發生。那麼如何查找記憶體泄漏和避免記憶體泄漏就是需要知曉的一個問題,首先我們需要知道一些基礎知識。 Java的四種引用 強引用: 強引用是Java中最普通的引用,隨意創建一個對象然後在其他的地方引用一下,就是強引用,強引 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...