參考書 《TensorFlow:實戰Google深度學習框架》(第2版) 例子:從一個張量創建一個數據集,遍歷這個數據集,並對每個輸入輸出y = x^2 的值。 運行結果: 數據是文本文件:創建數據集。 運行結果: 數據是TFRecord文件:創建TFRecord測試文件。 運行結果: ...
參考書
《TensorFlow:實戰Google深度學習框架》(第2版)
例子:從一個張量創建一個數據集,遍歷這個數據集,並對每個輸入輸出y = x^2 的值。
#!/usr/bin/env python # -*- coding: UTF-8 -*- # coding=utf-8 """ @author: Li Tian @contact: [email protected] @software: pycharm @file: dataset_test1.py @time: 2019/2/10 10:52 @desc: 例子:從一個張量創建一個數據集,遍歷這個數據集,並對每個輸入輸出y = x^2 的值。 """ import tensorflow as tf # 從一個數組創建數據集。 input_data = [1, 2, 3, 5, 8] dataset = tf.data.Dataset.from_tensor_slices(input_data) # 定義一個迭代器用於遍曆數據集。因為上面定義的數據集沒有用placeholder作為輸入參數 # 所以這裡可以使用最簡單的one_shot_iterator iterator = dataset.make_one_shot_iterator() # get_next() 返回代表一個輸入數據的張量,類似於隊列的dequeue()。 x = iterator.get_next() y = x * x with tf.Session() as sess: for i in range(len(input_data)): print(sess.run(y))
運行結果:
數據是文本文件:創建數據集。
#!/usr/bin/env python # -*- coding: UTF-8 -*- # coding=utf-8 """ @author: Li Tian @contact: [email protected] @software: pycharm @file: dataset_test2.py @time: 2019/2/10 11:03 @desc: 數據是文本文件 """ import tensorflow as tf # 從文本文件創建數據集。假定每行文字是一個訓練例子。註意這裡可以提供多個文件。 input_files = ['./input_file11', './input_file22'] dataset = tf.data.TextLineDataset(input_files) # 定義迭代器用於遍曆數據集 iterator = dataset.make_one_shot_iterator() # 這裡get_next()返回一個字元串類型的張量,代表文件中的一行。 x = iterator.get_next() with tf.Session() as sess: for i in range(4): print(sess.run(x))
運行結果:
數據是TFRecord文件:創建TFRecord測試文件。
#!/usr/bin/env python # -*- coding: UTF-8 -*- # coding=utf-8 """ @author: Li Tian @contact: [email protected] @software: pycharm @file: dataset_createdata.py @time: 2019/2/10 13:59 @desc: 創建樣例文件 """ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np import time # 生成整數型的屬性。 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) # 生成字元串型的屬性。 def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) a = [11, 21, 31, 41, 51] b = [22, 33, 44, 55, 66] # 輸出TFRecord文件的地址 filename = './input_file2' # 創建一個writer來寫TFRecord文件 writer = tf.python_io.TFRecordWriter(filename) for index in range(len(a)): aa = a[index] bb = b[index] # 將一個樣例轉化為Example Protocol Buffer,並將所有的信息寫入這個數據結構。 example = tf.train.Example(features=tf.train.Features(feature={ 'feat1': _int64_feature(aa), 'feat2': _int64_feature(bb) })) # 將一個Example寫入TFRecord文件中。 writer.write(example.SerializeToString()) writer.close()
運行結果:
數據是TFRecord文件:創建數據集。(使用最簡單的one_hot_iterator來遍曆數據集)
#!/usr/bin/env python # -*- coding: UTF-8 -*- # coding=utf-8 """ @author: Li Tian @contact: [email protected] @software: pycharm @file: dataset_test3.py @time: 2019/2/10 13:16 @desc: 數據是TFRecord文件 """ import tensorflow as tf # 解析一個TFRecord的方法。record是從文件中讀取的一個樣例。前面介紹瞭如何解析TFRecord樣例。 def parser(record): # 解析讀入的一個樣例 features = tf.parse_single_example( record, features={ 'feat1': tf.FixedLenFeature([], tf.int64), 'feat2': tf.FixedLenFeature([], tf.int64), } ) return features['feat1'], features['feat2'] # 從TFRecord文件創建數據集。 input_files = ['./input_file1', './input_file2'] dataset = tf.data.TFRecordDataset(input_files) # map()函數表示對數據集中的每一條數據進行調用相應方法。使用TFRecordDataset讀出的是二進位的數據。 # 這裡需要通過map()函數來調用parser()對二進位數據進行解析。類似的,map()函數也可以用來完成其他的數據預處理工作。 dataset = dataset.map(parser) # 定義遍曆數據集的迭代器 iterator = dataset.make_one_shot_iterator() # feat1, feat2是parser()返回的一維int64型張量,可以作為輸入用於進一步的計算。 feat1, feat2 = iterator.get_next() with tf.Session() as sess: for i in range(10): f1, f2 = sess.run([feat1, feat2]) print(f1, f2)
運行結果:
數據是TFRecord文件:創建數據集。(使用placeholder和initializable_iterator來動態初始化數據集)
#!/usr/bin/env python # -*- coding: UTF-8 -*- # coding=utf-8 """ @author: Li Tian @contact: [email protected] @software: pycharm @file: dataset_test4.py @time: 2019/2/10 13:44 @desc: 用initializable_iterator來動態初始化數據集的例子 """ import tensorflow as tf from figuredata_deal.dataset_test3 import parser # 解析一個TFRecord的方法。與上面的例子相同不再重覆。 # 從TFRecord文件創建數據集,具體文件路徑是一個placeholder,稍後再提供具體路徑。 input_files = tf.placeholder(tf.string) dataset = tf.data.TFRecordDataset(input_files) dataset = dataset.map(parser) # 定義遍歷dataset的initializable_iterator iterator = dataset.make_initializable_iterator() feat1, feat2 = iterator.get_next() with tf.Session() as sess: # 首先初始化iterator,並給出input_files的值。 sess.run(iterator.initializer, feed_dict={input_files: ['./input_file1', './input_file2']}) # 遍歷所有數據一個epoch,當遍歷結束時,程式會拋出OutOfRangeError while True: try: sess.run([feat1, feat2]) except tf.errors.OutOfRangeError: break
運行結果: