學習筆記TF016:CNN實現、數據集、TFRecord、載入圖像、模型、訓練、調試

来源:http://www.cnblogs.com/libinggen/archive/2017/05/30/6921285.html
-Advertisement-
Play Games

AlexNet(Alex Krizhevsky,ILSVRC2012冠軍)適合做圖像分類。層自左向右、自上向下讀取,關聯層分為一組,高度、寬度減小,深度增加。深度增加減少網路計算量。 訓練模型數據集 Stanford電腦視覺站點Stanford Dogs http://vision.stanfor ...


AlexNet(Alex Krizhevsky,ILSVRC2012冠軍)適合做圖像分類。層自左向右、自上向下讀取,關聯層分為一組,高度、寬度減小,深度增加。深度增加減少網路計算量。

訓練模型數據集 Stanford電腦視覺站點Stanford Dogs http://vision.stanford.edu/aditya86/ImageNetDogs/ 。數據下載解壓到模型代碼同一路徑imagenet-dogs目錄下。包含的120種狗圖像。80%訓練,20%測試。產品模型需要預留原始數據交叉驗證。每幅圖像JPEG格式(RGB),尺寸不一。

圖像轉TFRecord文件,有助加速訓練,簡化圖像標簽匹配,圖像分離利用檢查點文件對模型進行不間斷測試。轉換圖像格式把顏色空間轉灰度,圖像修改統一尺寸,標簽除上每幅圖像。訓練前只進行一次預處理,時間較長。

glob.glob 枚舉指定路徑目錄,顯示數據集文件結構。“*”通配符可以實現模糊查找。文件名中8個數字對應ImageNet類別WordNetID。ImageNet網站可用WordNetID查圖像細節: http://www.image-net.org/synset?wnid=n02085620 。

文件名分解為品種和相應的文件名,品種對應文件夾名稱。依據品種對圖像分組。枚舉每個品種圖像,20%圖像劃入測試集。檢查每個品種測試圖像是否至少有全部圖像的18%。目錄和圖像組織到兩個與每個品種相關的字典,包含各品種所有圖像。分類圖像組織到字典中,簡化選擇分類圖像及歸類過程。

預處理階段,依次遍歷所有分類圖像,打開列表中文件。用dataset圖像填充TFRecord文件,把類別包含進去。dataset鍵值對應文件列表標簽。record_location 存儲TFRecord輸出路徑。枚舉dataset,當前索引用於文件劃分,每隔100m幅圖像,訓練樣本信息寫入新的TFRecord文件,加快寫操作進程。無法被TensorFlow識別為JPEG圖像,用try/catch忽略。轉為灰度圖減少計算量和記憶體占用。tf.cast把RGB值轉換到[0,1)區間內。標簽按字元串存儲較高效,最好轉換為整數索引或獨熱編碼秩1張量。

打開每幅圖像,轉換為灰度圖,調整尺寸,添加到TFRecord文件。tf.image.resize_images函數把所有圖像調整為相同尺寸,不考慮長寬比,有扭曲。裁剪、邊界填充能保持圖像長寬比。

按照TFRecord文件讀取圖像,每次載入少量圖像及標簽。修改圖像形狀有助訓練和輸出可視化。匹配所有在訓練集目錄下TFRecord文件載入訓練圖像。每個TFRecord文件包含多幅圖像。tf.parse_single_example只從文件提取單個樣本。批運算可同時訓練多幅圖像或單幅圖像,需要足夠系統記憶體。

圖像轉灰度值為[0,1)浮點類型,匹配convolution2d期望輸入。捲積輸出第1維和最後一維不改變,中間兩維發生變化。tf.contrib.layers.convolution2d創建模型第1層。weights_initializer設置正態隨機值,第一組濾波器填充正態分佈隨機數。濾波器設置trainable,信息輸入網路,權值調整,提高模型準確率。
max_pool把輸出降採樣。ksize、strides ([1,2,2,1]),捲積輸出形狀減半。輸出形狀減小,不改變濾波器數量(輸出通道)或圖像批數據尺寸。減少分量,與圖像(濾波器)高度、寬度有關。更多輸出通道,濾波器數量增加,2倍於第一層。多個捲積和池化層減少輸入高度、寬度,增加深度。很多架構,捲積層和池化層超過5層。訓練調試時間更長,能匹配更多更複雜模式。
圖像每個點與輸出神經元建立全連接。softmax,全連接層需要二階張量。第1維區分圖像,第2維輸入張量秩1張量。tf.reshape 指示和使用其餘所有維,-1把最後池化層調整為巨大秩1張量。
池化層展開,網路當前狀態與預測全連接層整合。weights_initializer接收可調用參數,lambda表達式返回截斷正態分佈,指定分佈標準差。dropout 削減模型中神經元重要性。tf.contrib.layers.fully_connected 輸出前面所有層與訓練中分類的全連接。每個像素與分類關聯。網路每一步將輸入圖像轉化為濾波減小尺寸。濾波器與標簽匹配。減少訓練、測試網路計算量,輸出更具一般性。

訓練數據真實標簽和模型預測結果,輸入到訓練優化器(優化每層權值)計算模型損失。數次迭代,每次提升模型準確率。大部分分類函數(tf.nn.softmax)要求數值類型標簽。每個標簽轉換代表包含所有分類列表索引整數。tf.map_fn 匹配每個標簽並返回類別列表索引。map依據目錄列表創建包含分類列表。tf.map_fn 可用指定函數對數據流圖張量映射,生成僅包含每個標簽在所有類標簽列表索引秩1張量。tf.nn.softmax用索引預測。

調試CNN,觀察濾波器(捲積核)每輪迭代變化。設計良好CNN,第一個捲積層工作,輸入權值被隨機初始化。權值通過圖像激活,激活函數輸出(特征圖)隨機。特征圖可視化,輸出外觀與原始圖相似,被施加靜力(static)。靜力由所有權值的隨機激發。經過多輪迭代,權值被調整擬合訓練反饋,濾波器趨於一致。網路收斂,濾波器與圖像不同細小模式類似。tf.image_summary得到訓練後的濾波器和特征圖簡單視圖。數據流圖圖像概要輸出(image summary output)從整體瞭解所使用的濾波器和輸入圖像特征圖。TensorDebugger,迭代中以GIF動畫查看濾波器變化。

文本輸入存儲在SparseTensor,大部分分量為0。CNN使用稠密輸入,每個值都重要,輸入大部分分量非0。

 

    import tensorflow as tf
    import glob
    from itertools import groupby
    from collections import defaultdict
    sess = tf.InteractiveSession()
    image_filenames = glob.glob("./imagenet-dogs/n02*/*.jpg")
    image_filenames[0:2]
    training_dataset = defaultdict(list)
    testing_dataset = defaultdict(list)
    image_filename_with_breed = map(lambda filename: (filename.split("/")[2], filename), image_filenames)
    for dog_breed, breed_images in groupby(image_filename_with_breed, lambda x: x[0]):
        for i, breed_image in enumerate(breed_images):
            if i % 5 == 0:
                testing_dataset[dog_breed].append(breed_image[1])
            else:
                training_dataset[dog_breed].append(breed_image[1])
        breed_training_count = len(training_dataset[dog_breed])
        breed_testing_count = len(testing_dataset[dog_breed])
        breed_training_count_float = float(breed_training_count)
        breed_testing_count_float = float(breed_testing_count)
        assert round(breed_testing_count_float / (breed_training_count_float + breed_testing_count_float), 2) > 0.18, "Not enough testing images."
    print "training_dataset testing_dataset END ------------------------------------------------------"
    def write_records_file(dataset, record_location):
        writer = None
        current_index = 0
        for breed, images_filenames in dataset.items():
            for image_filename in images_filenames:
                if current_index % 100 == 0:
                    if writer:
                        writer.close()
                    record_filename = "{record_location}-{current_index}.tfrecords".format(
                        record_location=record_location,
                        current_index=current_index)
                    writer = tf.python_io.TFRecordWriter(record_filename)
                    print record_filename + "------------------------------------------------------" 
                current_index += 1
                image_file = tf.read_file(image_filename)
                try:
                    image = tf.image.decode_jpeg(image_file)
                except:
                    print(image_filename)
                    continue
                grayscale_image = tf.image.rgb_to_grayscale(image)
                resized_image = tf.image.resize_images(grayscale_image, [250, 151])
                image_bytes = sess.run(tf.cast(resized_image, tf.uint8)).tobytes()
                image_label = breed.encode("utf-8")
                example = tf.train.Example(features=tf.train.Features(feature={
                    'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_label])),
                    'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
                }))
                writer.write(example.SerializeToString())
        writer.close()
    write_records_file(testing_dataset, "./output/testing-images/testing-image")
    write_records_file(training_dataset, "./output/training-images/training-image")
    print "write_records_file testing_dataset training_dataset END------------------------------------------------------"
    filename_queue = tf.train.string_input_producer(
    tf.train.match_filenames_once("./output/training-images/*.tfrecords"))
    reader = tf.TFRecordReader()
    _, serialized = reader.read(filename_queue)
    features = tf.parse_single_example(
    serialized,
        features={
            'label': tf.FixedLenFeature([], tf.string),
            'image': tf.FixedLenFeature([], tf.string),
        })
    record_image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(record_image, [250, 151, 1])
    label = tf.cast(features['label'], tf.string)
    min_after_dequeue = 10
    batch_size = 3
    capacity = min_after_dequeue + 3 * batch_size
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
    print "load image from TFRecord END------------------------------------------------------"
    float_image_batch = tf.image.convert_image_dtype(image_batch, tf.float32)
    conv2d_layer_one = tf.contrib.layers.convolution2d(
        float_image_batch,
        num_outputs=32,
        kernel_size=(5,5),
        activation_fn=tf.nn.relu,
        weights_initializer=tf.random_normal,
        stride=(2, 2),
        trainable=True)
    pool_layer_one = tf.nn.max_pool(conv2d_layer_one,
        ksize=[1, 2, 2, 1],
        strides=[1, 2, 2, 1],
        padding='SAME')
    conv2d_layer_one.get_shape(), pool_layer_one.get_shape()
    print "conv2d_layer_one pool_layer_one END------------------------------------------------------"
    conv2d_layer_two = tf.contrib.layers.convolution2d(
        pool_layer_one,
        num_outputs=64,
        kernel_size=(5,5),
        activation_fn=tf.nn.relu,
        weights_initializer=tf.random_normal,
        stride=(1, 1),
        trainable=True)
    pool_layer_two = tf.nn.max_pool(conv2d_layer_two,
        ksize=[1, 2, 2, 1],
        strides=[1, 2, 2, 1],
        padding='SAME')
    conv2d_layer_two.get_shape(), pool_layer_two.get_shape()
    print "conv2d_layer_two pool_layer_two END------------------------------------------------------"
    flattened_layer_two = tf.reshape(
        pool_layer_two,
        [
            batch_size,
            -1
        ])
    flattened_layer_two.get_shape()
    print "flattened_layer_two END------------------------------------------------------"
    hidden_layer_three = tf.contrib.layers.fully_connected(
        flattened_layer_two,
        512,
        weights_initializer=lambda i, dtype: tf.truncated_normal([38912, 512], stddev=0.1),
        activation_fn=tf.nn.relu
    )
    hidden_layer_three = tf.nn.dropout(hidden_layer_three, 0.1)
    final_fully_connected = tf.contrib.layers.fully_connected(
        hidden_layer_three,
        120,
        weights_initializer=lambda i, dtype: tf.truncated_normal([512, 120], stddev=0.1)
    )
    print "final_fully_connected END------------------------------------------------------"
    labels = list(map(lambda c: c.split("/")[-1], glob.glob("./imagenet-dogs/*")))
    train_labels = tf.map_fn(lambda l: tf.where(tf.equal(labels, l))[0,0:1][0], label_batch, dtype=tf.int64)
    loss = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            final_fully_connected, train_labels))
    batch = tf.Variable(0)
    learning_rate = tf.train.exponential_decay(
        0.01,
        batch * 3,
        120,
        0.95,
        staircase=True)
    optimizer = tf.train.AdamOptimizer(
        learning_rate, 0.9).minimize(
        loss, global_step=batch)
    train_prediction = tf.nn.softmax(final_fully_connected)
    print "train_prediction END------------------------------------------------------"
    filename_queue.close(cancel_pending_enqueues=True)
    coord.request_stop()
    coord.join(threads)
    print "END------------------------------------------------------"

 

參考資料:
《面向機器智能的TensorFlow實踐》

歡迎加我微信交流:qingxingfengzi
我的微信公眾號:qingxingfengzigz
我老婆張幸清的微信公眾號:qingqingfeifangz


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 可以編寫angularjs的自定義指令來實現驗證文本框填入的數值是來為小數。 airExpressApp.directive('validateDecimalCharacters', function () { var REQUIRED_PATTERNS = [ /^-?[0-9]\d*(\.\d+ ...
  • 前幾天Insus.NET有寫過一篇《angularjs自定義指令Directive》http://www.cnblogs.com/insus/p/6908815.html 僅是在程式中指定某些來值來匹配。為你的數據表準備一個存儲過程: 判斷是否已經存在此值。只需寫SELECT語句。如果沒有記錄返回, ...
  • // 使用原生js 封裝ajax // 相容xhr對象 function createXHR(){ if(typeof XMLHttpRequest != "undefined"){ // 非IE6瀏覽器 return new XMLHttpRequest(); }else if(typeof Ac... ...
  • SpringMVC中,如何處理請求是很重要的任務。請求映射都會使用@RequestMapping標註。其中,類上的標註相當於一個首碼,表示該處理器是處理同一類請求;方法上的標註則更加細化。如,類的標註可能是“user”,表示全部都是與用戶相關的操作;具體到方法可能有“create”“update”“ ...
  • 寫代碼的不要耍小聰明,認認真真的敲。 代碼是給人看的,只是順便給機器去執行。 ...
  • 歡迎大家每天前來打卡~ 訓練營規則 每天出一道練習題,請大家自己完成編碼 第二天的文章中會告訴大家一種或幾種經典解決方法 完成練習的同學,歡迎大家把代碼貼在留言中 如果有問題,也請留言,我會找機會集中解答 希望這種手把手的方式能夠幫助大家儘快掌握C語言編程。 1. 例題 今天我們先來講解一道C語言的 ...
  • 要處理XML文檔,就要先解析(parse)他,解析器時這樣一個程式,讀入一個文件,確認整個文件具有正確的格式,然後將其分解成各種元素,使得程式員能夠訪問這些元素,Java庫提供了兩種XML解析器: 像文檔對象模型(Document Object Model,DOM)解析器這樣的樹型解析器,他們將讀入... ...
  • CASE WHEN 條件 THEN 改變的值 END 1.簡單case函數,使用表達式確定返回值: 語法: CASE title WHEN expression1 THEN result1 WHEN expression2 THEN result2 ... WHEN expressionN THEN ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...