Scipy 高端科學計算:http://blog.chinaunix.net/uid-21633169-id-4437868.html import os #引用操作系統函數文件 import scipy.misc #引用scipy包misc模塊 圖像形式存取數組 import numpy as n ...
Scipy 高端科學計算:http://blog.chinaunix.net/uid-21633169-id-4437868.html
import os #引用操作系統函數文件
import scipy.misc #引用scipy包misc模塊 圖像形式存取數組
import numpy as np #引用numpy包 矩陣計算
from model import DCGAN #引用model文件DCGAN類
from utils import pp, visualize, to_json, show_all_variables #引用utils文件pp對象,visualize, to_json, show_all_variables方法
import tensorflow as tf #引用tensorflow
flags = tf.app.flags #接受命令行傳遞參數,相當於接受argv。第一個是參數名稱,第二個參數是預設值,第三個是參數描述
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") #訓練輪數 25
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") #adam優化器 學習速率 0.0002
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") #adam優化器 動量(參數移動平均數) 0.5
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") #訓練畫像尺寸,預設無限大正數
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") #圖像批大小 64
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") #輸入圖像高度 108 均衡的縮放圖像(保持圖像原始比例),使圖片的兩個坐標(寬、高)都大於等於 相應的視圖坐標(負的內邊距)。圖像則位於視圖的中央。
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") #輸入圖像寬度,None與高度相同
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") #輸出圖像高度 64
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") #輸出圖像寬度,None與高度相同
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") #數據集名稱 celebA mnist lsun
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") #圖片文件名的搜索擴展名
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") #檢查點目錄名
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") #圖片樣本保存目錄名
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") #訓練流程開關
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") #訓練流程開關
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") #可視化開關
FLAGS = flags.FLAGS
def main(_): #主程式
pp.pprint(flags.FLAGS.__flags) #列印命令行參數
if FLAGS.input_width is None: #如果沒有配置輸入圖像寬度
FLAGS.input_width = FLAGS.input_height #把輸入圖像高度作為寬度
if FLAGS.output_width is None: #如果沒有配置輸出圖像寬度
FLAGS.output_width = FLAGS.output_height #把輸出圖像高度作為寬度
if not os.path.exists(FLAGS.checkpoint_dir): #如果檢查點目錄不存在
os.makedirs(FLAGS.checkpoint_dir) #創建檢查點目錄
if not os.path.exists(FLAGS.sample_dir): #如果樣本目錄不存在
os.makedirs(FLAGS.sample_dir) #創建樣本目錄
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) #設置GPU顯存占用比例
run_config = tf.ConfigProto() #獲取配置對象
run_config.gpu_options.allow_growth = True #GPU顯存占用按需增加
with tf.Session(config=run_config) as sess: #指定配置構建會話
if FLAGS.dataset == 'mnist': #如果指定數據集為mnist
dcgan = DCGAN( #構建DCGAN
sess, #提定會話
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10, #標簽維度為10
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
else:
dcgan = DCGAN( #構建DCGAN,不指定標簽維度
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
show_all_variables() #顯示所有參數
if FLAGS.train: #如果是訓練
dcgan.train(FLAGS) #指定參數執行構建DCGAN 訓練方法
else: #如果是測試
if not dcgan.load(FLAGS.checkpoint_dir)[0]: #在檢查點目錄沒有檢查點文件,即沒有已訓練好的模型
raise Exception("[!] Train a model first, then run test mode") #拋出異常:請先訓練模型再執行測試
# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0], #JSON格式化:w,b,gbn
# [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
# [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2],
# [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3],
# [dcgan.h4_w, dcgan.h4_b, None])
# Below is codes for visualization
OPTION = 1
visualize(sess, dcgan, FLAGS, OPTION) #執行可視化方法,傳入會話、DCGAN、配置參數,選項
if __name__ == '__main__': #如果直接執行本腳本文件,運行以下代碼,一般作調試用。如果作為其它腳本模塊引入,則不執行以下代碼
tf.app.run() #運行APP.run 解析FLAGS,執行main方法
歡迎付費咨詢(150元每小時),我的微信:qingxingfengzi
我創建GAN日報群,以每天各報各的進度為主。把正在研究GAN的人聚在一起,互相鼓勵,一起前進。加我微信拉群,請註明:加入GAN日報群。