TensorFlow框架(5)之機器學習實踐

来源:http://www.cnblogs.com/huliangwen/archive/2017/09/01/7464673.html
-Advertisement-
Play Games

1. Iris data set Iris數據集是常用的分類實驗數據集,由Fisher, 1936收集整理。Iris也稱鳶尾花卉數據集,是一類多重變數分析的數據集。數據集包含150個數據集,分為3類,每類50個數據,每個數據包含4個屬性。可通過花萼長度,花萼寬度,花瓣長度,花瓣寬度4個屬性預測鳶尾花 ...


1. Iris data set

  Iris數據集是常用的分類實驗數據集,由Fisher, 1936收集整理。Iris也稱鳶尾花卉數據集,是一類多重變數分析的數據集。數據集包含150個數據集,分為3類,每類50個數據,每個數據包含4個屬性。可通過花萼長度,花萼寬度,花瓣長度,花瓣寬度4個屬性預測鳶尾花卉屬於(Setosa,Versicolour,Virginica)三個種類中的哪一類。

該數據集包含了5個屬性:

  • Sepal.Length(花萼長度),單位是cm;
  • Sepal.Width(花萼寬度),單位是cm;
  • Petal.Length(花瓣長度),單位是cm;
  • Petal.Width(花瓣寬度),單位是cm;
  • species (種類)Iris Setosa(山鳶尾)、Iris Versicolour(雜色鳶尾),以及Iris Virginica(維吉尼亞鳶尾)。

 

如表 11所示的iris部分數據集。

表 11

6.4

2.8

5.6

2.2

2

5

2.3

3.3

1

1

4.9

2.5

4.5

1.7

2

4.9

3.1

1.5

0.1

0

5.7

3.8

1.7

0.3

0

4.4

3.2

1.3

0.2

0

5.4

3.4

1.5

0.4

0

6.9

3.1

5.1

2.3

2

6.7

3.1

4.4

1.4

1

5.1

3.7

1.5

0.4

0

5.2

2.7

3.9

1.4

1

6.9

3.1

4.9

1.5

1

5.8

4

1.2

0.2

0

5.4

3.9

1.7

0.4

0

7.7

3.8

6.7

2.2

2

6.3

3.3

4.7

1.6

1

2. Neural Network

2.1 Perform

  TensorFlow提供一個高水平的機器學習 API (tf.contrib.learn),使得容易配置(configure)、訓練(train)和評估(evaluate)各種機器學習模型。tf.contrib.learn庫的使用可以概括為五個步驟,如下所示:

  1) Load CSVs containing Iris training/test data into a TensorFlow Dataset

  2) Construct a neural network classifier

  3) Fit the model using the training data

  4) Evaluate the accuracy of the model

  5)Classify new samples

2.2 Code

  本節以對 Iris 數據集進行分類為例進行介紹,如下所示是完整的TensorFlow程式:

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

 

import os

import urllib

 

import numpy as np

import tensorflow as tf

 

# Data sets

IRIS_TRAINING = "iris_training.csv"

IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

 

IRIS_TEST = "iris_test.csv"

IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

 

def main():

# If the training and test sets aren't stored locally, download them.

if not os.path.exists(IRIS_TRAINING):

raw = urllib.urlopen(IRIS_TRAINING_URL).read()

with open(IRIS_TRAINING, "w") as f:

f.write(raw)

 

if not os.path.exists(IRIS_TEST):

raw = urllib.urlopen(IRIS_TEST_URL).read()

with open(IRIS_TEST, "w") as f:

f.write(raw)

 

# Load datasets.

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TRAINING,

target_dtype=np.int,

features_dtype=np.float32)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TEST,

target_dtype=np.int,

features_dtype=np.float32)

 

# Specify that all features have real-value data

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

 

# Build 3 layer DNN with 10, 20, 10 units respectively.

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,

hidden_units=[10, 20, 10],

n_classes=3,

model_dir="/tmp/iris_model")

# Define the training inputs

def get_train_inputs():

x = tf.constant(training_set.data)

y = tf.constant(training_set.target)

return x, y

 

# Fit model.

classifier.fit(input_fn=get_train_inputs, steps=2000)

 

# Define the test inputs

def get_test_inputs():

x = tf.constant(test_set.data)

y = tf.constant(test_set.target)

return x, y

 

# Evaluate accuracy.

accuracy_score = classifier.evaluate(input_fn=get_test_inputs,

steps=1)["accuracy"]

 

print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

 

# Classify two new flower samples.

def new_samples():

return np.array(

[[6.4, 3.2, 4.5, 1.5],

[5.8, 3.1, 5.0, 1.7]], dtype=np.float32)

 

predictions = list(classifier.predict(input_fn=new_samples))

 

print(

"New Samples, Class Predictions: {}\n"

.format(predictions))

 

if __name__ == "__main__":

main()

3. Analysis

3.1 Load data

  對於本文的程式,Iris數據集被分為兩部分:

  • 訓練集:有120個樣例,保存在iris_training.csv文件中;
  • 測試集:有30個樣例,保存在iris_test.csv文件中。
1) import module

  首先程式引入必要module,然後定義了數據集的本地路徑和網路路徑;

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

 

import os

import urllib

 

import tensorflow as tf

import numpy as np

 

IRIS_TRAINING = "iris_training.csv"

IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

 

IRIS_TEST = "iris_test.csv"

IRIS_TEST_URL = http://download.tensorflow.org/data/iris_test.csv

2) Open File

  若本地路徑上不存在數據集指定的文件,則通過網上下載。

if not os.path.exists(IRIS_TRAINING):

raw = urllib.urlopen(IRIS_TRAINING_URL).read()

with open(IRIS_TRAINING,'w') as f:

f.write(raw)

 

if not os.path.exists(IRIS_TEST):

raw = urllib.urlopen(IRIS_TEST_URL).read()

with open(IRIS_TEST,'w') as f:

f.write(raw)

3) load Dataset

  接著將Iris數據集載入到TensorFlow框架中,使其TensorFlow能夠直接使用。這其中使用了learn.datasets.base模塊的load_csv_with_header()函數。該方法有三個參數:

  • filename:指定了CSV文件的名字;
  • target_dtype:指定了數據集中目標數據類型,其為numpy datatype類型;
  • features_dtype:指定了數據集中特征向量的數據類型,其為numpy datatype類型。

如表 11所示,Iris數據中的目標值為:0~2,所以可以定義為整型數據就可以了,即np.int,如下所示:

# Load datasets.

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TRAINING,

target_dtype=np.int,

features_dtype=np.float32)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TEST,

target_dtype=np.int,

features_dtype=np.float32)

 

  由於tf.contrib.learn中的數據類型(Datasets)是以元祖類型定義的,所以用戶可以通過data 和 target兩個域屬性訪問特征向量數據和目標數據。即training_set.data 和 training_set.target為訓練數據集中的特征向量和目標數據。

3.2 Construct Estimator

  tf.contrib.learn預定義了許多模型,稱為:Estimators。用戶以黑箱模型使用Estimator來訓練和評估數據。本節使用tf.contrib.learn.DNNClassifier來訓練數據,如下所示:

# Specify that all features have real-value data

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

 

# Build 3 layer DNN with 10, 20, 10 units respectively.

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,

hidden_units=[10, 20, 10],

n_classes=3,

model_dir="/tmp/iris_model")

  首先程式定義了模型的feature columns,其指定了數據集中特征向量的數據類型。每種類型都有一個名字,由於本節的數據是實數型,所以這裡使用.real_valued_column類型。該類型第一個參數指定了列名字,第二個參數指定了列的數量。其中所有的特征類型都定義在:tensorflow/contrib/layers/python/layers/feature_column.py.

    然後程式創建了DNNClassifier模型,

  • feature_columns=feature_columns:指定所創建的特征向量類型;
  • hidden_units=[10, 20, 10]:設置隱藏層的層數,並指定每層神經元的數據量;
  • n_classes=3:指定目標類型的數量,Iris數據有三類,所以這裡為3
  • model_dir=/tmp/iris_model:指定模型在訓練期間保存的路徑。

3.3 Describe pipeline

  TensorFlow框架的數據都是以Tensor對象存在,即要麼是constant、placeholder或Variable類型。通常訓練數據是以placeholder類型定義,然後用戶訓練時,傳遞所有的數據。本節則將訓練數據存儲在constant類型中。如下所示:

# Define the training inputs

def get_train_inputs():

x = tf.constant(training_set.data)

y = tf.constant(training_set.target)

 

return x, y

 

3.4 Fit DNNClassifier

  創建分類器後,就可以調用神經網路中DNNClassifier模型的fit()函數來訓練模型了,如下所示:

# Fit model.

classifier.fit(input_fn=get_train_inputs, steps=2000)

通過向fit傳遞get_train_inputs函數返回的訓練數據,並指定訓練的步數為2000步。

3.5 Evaluate Model

  訓練模型後,就可以通過evaluate()函數來評估模型的泛化能力了。與fit函數類似,evaluate函數的輸入數據也需為Tensor類型,所以定義了get_test_inputs()函數來轉換數據。

# Define the test inputs

def get_test_inputs():

x = tf.constant(test_set.data)

y = tf.constant(test_set.target)

return x, y

 

# Evaluate accuracy.

accuracy_score = classifier.evaluate(input_fn=get_test_inputs, steps=1)["accuracy"]

 

print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

註意:

    由於evaluate函數的返回值是一個Map類型(即dict類型),所以直接根據"accuracy"鍵獲取值:accuracy_score。

3.6 Classify Samples

  在訓練模型後,就可以使用estimator模型的predict()函數來預測樣例。如表 31有所示的兩個樣例,希望預測其為什麼類型。

表 31

Sepal Length

Sepal Width

Petal Length

Petal Width

6.4

3.2

4.5

1.5

5.8

3.1

5

1.7

 

如下所示的程式:

# Classify two new flower samples.

def new_samples():

return np.array(

[[6.4, 3.2, 4.5, 1.5],

[5.8, 3.1, 5.0, 1.7]], dtype=np.float32)

 

predictions = list(classifier.predict(input_fn=new_samples))

 

print(

"New Samples, Class Predictions: {}\n"

.format(predictions))

輸出:

New Samples, Class Predictions: [1 2]

註意:

    由於predict()函數執行的返回結果類型是generator。所以上述程式將其轉換為一個list對象。

4. Logging and Monitoring

  由於TensorFlow的機器學習Estimator是黑箱學習,用戶無法瞭解模型執行發生了什麼,以及模型什麼時候收斂。所以tf.contrib.learn提供的一個Monitor API,可以幫助用戶記錄和評估模型。

4.1 Default ValidationMonitor

  預設使用fit()函數訓練Estimator模型時,TensorFlow會產生一些summary數據到fit()函數指定的路徑中。用戶可以使用Tensorborad來展示更詳細的信息。如圖 1所示,執行上述程式DNNClassifier的fit()和evaluate()函數後,預設在TensorBoard頁面顯示的常量信息。

圖 1

 

4.2 Monitors

  為了讓用戶更直觀地瞭解模型訓練過程的細節,tf.contrib.learn提供了一些高級Monitors,使得用戶在調用fit()函數時,可以使用Monitors來記錄和跟蹤模型的執行細節。如表 41所示是fitt()函數支持的Monitors類型:

表 41

Monitor

Description

CaptureVariable

每執行n步訓練,就將保存指定的變數值到一個集合(collection)

PrintTensor

每執行n步訓練,記錄指定的Tensor

SummarySaver

每執行n步訓練,使用tf.summary.FileWriter函數保存tf.Summary 緩存

ValidationMonitor

每執行n步訓練,記錄一批評估metrics,同時可設置停止條件

 

如\tensorflow\examples\tutorials\monitors\ iris_monitors.py所示的程式:

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

 

import os

 

import numpy as np

import tensorflow as tf

 

tf.logging.set_verbosity(tf.logging.INFO)

 

# Data sets

IRIS_TRAINING = os.path.join(os.path.dirname(__file__), "iris_training.csv")

IRIS_TEST = os.path.join(os.path.dirname(__file__), "iris_test.csv")

 

 

def main(unused_argv):

# Load datasets.

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float)

test_set = tf.contrib.learn.datasets.base.load_csv_with_header(

filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float)

 

validation_metrics = {

"accuracy":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_accuracy,

prediction_key="classes"),

"precision":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_precision,

prediction_key="classes"),

"recall":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_recall,

prediction_key="classes"),

"mean":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_mean,

prediction_key="classes")

}

validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(

test_set.data,

test_set.target,

every_n_steps=50,

metrics=validation_metrics,

early_stopping_metric="loss",

early_stopping_metric_minimize=True,

early_stopping_rounds=200)

 

# Specify that all features have real-value data

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

 

# Build 3 layer DNN with 10, 20, 10 units respectively.

classifier = tf.contrib.learn.DNNClassifier(

feature_columns=feature_columns,

hidden_units=[10, 20, 10],

n_classes=3,

model_dir="/tmp/iris_model",

config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))

 

# Fit model.

classifier.fit(x=training_set.data,

y=training_set.target,

steps=2000,

monitors=[validation_monitor])

 

# Evaluate accuracy.

accuracy_score = classifier.evaluate(

x=test_set.data, y=test_set.target)["accuracy"]

print("Accuracy: {0:f}".format(accuracy_score))

 

# Classify two new flower samples.

new_samples = np.array(

[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)

y = list(classifier.predict(new_samples))

print("Predictions: {}".format(str(y)))

 

 

if __name__ == "__main__":

tf.app.run()

 

4.3 Configuring ValidationMonitor

  如圖 1所示,如果沒有指定任何evaluation metrics,那麼ValidationMonitor預設會記錄loss和accuracy信息。但用戶可以通過創建ValidationMonitor對象來自定義metrics信息。

即通過向ValidationMonitor構造函數傳遞一個metrics參數,該參數是一個Map類型(dist),其中的key是希望顯示的名字,value是一個MetricSpec對象。

其中tf.contrib.learn.MetricSpec類的構造函數有如下四個參數:

  1. metric_fn:是一個函數,TensorFlowtf.contrib.metrics模塊中預定義了一些函數,用戶可以直接使用;
  2. prediction_key:如果模型返回一個Tensor或與一個單一的入口,那麼這個參數可以被忽略;
  3. label_key:可選
  4. weights_key:可選

 

如下所示創建一個dist類型的對象:

validation_metrics = {

"accuracy":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_accuracy,

prediction_key="classes"),

"precision":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_precision,

prediction_key="classes"),

"recall":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_recall,

prediction_key="classes"),

"mean":

tf.contrib.learn.MetricSpec(

metric_fn=tf.contrib.metrics.streaming_mean,

prediction_key="classes")

}

validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(

test_set.data,

test_set.target,

every_n_steps=50,

metrics=validation_metrics,

early_stopping_metric="loss",

early_stopping_metric_minimize=True,

early_stopping_rounds=200)

註意:Python中的dist可以直接以一對"{}"初始化元素,如上validation_metrics對象創建所示。

5. 參考文獻

  [1]
您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 在iOS中常用的框架是Quartz 2D,它是Core Graphics框架的一部分,是一個強大的二維圖像繪製引擎。我們日常開發所用到的UIKit的組件都是由Core Graphics框架進行繪製的。當我們導入UIKit框架時,會自動導入Core Graphics框架。 在iOS中繪圖一般分為以下幾 ...
  • 前言 上個知識點介紹了OKHttp的基本使用,在Activity中寫了大量訪問網路的代碼,這種代碼寫起來很無聊,並且對技術沒什麼提升。在真實的企業開發中,肯定是把這些代碼封裝起來,做一個庫,給Activity調用。 封裝之前我們需要考慮以下這些問題: 封裝基本的公共方法給外部調用。get請求,Pos ...
  • 有的網頁中會使用"<meta name="viewport" content="width=device-width, initial-scale=1.0, minimum-scale=1.0, maximum-scale=1.0, user-scalable=no">"這個標簽來設置網頁的寬度,不 ...
  • + (CGSize)boundingALLRectWithSize:(NSString *)txt Font:(UIFont *)font Size:(CGSize)size { NSMutableAttributedString *attributedString = [[NSMutableAtt... ...
  • cordova-plugin-IFlyspeech 科大訊飛的語音聽說讀寫的cordova插件 Supported Platforms iOS android Installation 插件安裝命令:cordova plugin add https://github.com/Edc-zhang/co ...
  • 引言:交互的概念是很難用語言描述的,怎樣才能讓一個抽象的想法得到充分溝通和測試呢?一個原型工具就能回答這個問題。 原型是一個想法成為App或者網頁的旅途上的伴侶。一個建築師不會從一開始就去挖地下室,他會在繪製完草圖後,一步一步地設計電腦上的和真實的模型,並且反覆地測試和修訂。同樣平面設計師也會在UI ...
  • 最近在做一個Toolbar,setNavigationIcon()這個方法一直無效,說什麼的都有,什麼getSupportActionBar().setNavigationIcon()的,說設置style的,說放到setSupportActionBar()之後的。 其實沒有說全,還應該放到Drawe ...
  • 1. RNN迴圈神經網路 1.1 結構 迴圈神經網路(recurrent neural network,RNN)源自於1982年由Saratha Sathasivam 提出的霍普菲爾德網路。RNN的主要用途是處理和預測序列數據。全連接的前饋神經網路和捲積神經網路模型中,網路結構都是從輸入層到隱藏層再 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...