神經網路(BP)演算法Python實現及簡單應用

来源:https://www.cnblogs.com/xuyiqing/archive/2018/04/11/8797048.html
-Advertisement-
Play Games

首先用Python實現簡單地神經網路演算法: 使用自己定義的神經網路演算法實現一些簡單的功能: 小案例: X: Y 0 0 0 0 1 1 1 0 1 1 1 0 from NN.NeuralNetwork import NeuralNetwork import numpy as np nn = Neu ...


首先用Python實現簡單地神經網路演算法:

import numpy as np


# 定義tanh函數
def tanh(x):
    return np.tanh(x)


# tanh函數的導數
def tan_deriv(x):
    return 1.0 - np.tanh(x) * np.tan(x)


# sigmoid函數
def logistic(x):
    return 1 / (1 + np.exp(-x))


# sigmoid函數的導數
def logistic_derivative(x):
    return logistic(x) * (1 - logistic(x))


class NeuralNetwork:
    def __init__(self, layers, activation='tanh'):
        """
        神經網路演算法構造函數
        :param layers: 神經元層數
        :param activation: 使用的函數(預設tanh函數)
        :return:none
        """
        if activation == 'logistic':
            self.activation = logistic
            self.activation_deriv = logistic_derivative
        elif activation == 'tanh':
            self.activation = tanh
            self.activation_deriv = tan_deriv

        # 權重列表
        self.weights = []
        # 初始化權重(隨機)
        for i in range(1, len(layers) - 1):
            self.weights.append((2 * np.random.random((layers[i - 1] + 1, layers[i] + 1)) - 1) * 0.25)
            self.weights.append((2 * np.random.random((layers[i] + 1, layers[i + 1])) - 1) * 0.25)

    def fit(self, X, y, learning_rate=0.2, epochs=10000):
        """
        訓練神經網路
        :param X: 數據集(通常是二維)
        :param y: 分類標記
        :param learning_rate: 學習率(預設0.2)
        :param epochs: 訓練次數(最大迴圈次數,預設10000)
        :return: none
        """
        # 確保數據集是二維的
        X = np.atleast_2d(X)

        temp = np.ones([X.shape[0], X.shape[1] + 1])
        temp[:, 0: -1] = X
        X = temp
        y = np.array(y)

        for k in range(epochs):
            # 隨機抽取X的一行
            i = np.random.randint(X.shape[0])
            # 用隨機抽取的這一組數據對神經網路更新
            a = [X[i]]
            # 正向更新
            for l in range(len(self.weights)):
                a.append(self.activation(np.dot(a[l], self.weights[l])))
            error = y[i] - a[-1]
            deltas = [error * self.activation_deriv(a[-1])]

            # 反向更新
            for l in range(len(a) - 2, 0, -1):
                deltas.append(deltas[-1].dot(self.weights[l].T) * self.activation_deriv(a[l]))
                deltas.reverse()
            for i in range(len(self.weights)):
                layer = np.atleast_2d(a[i])
                delta = np.atleast_2d(deltas[i])
                self.weights[i] += learning_rate * layer.T.dot(delta)

    def predict(self, x):
        x = np.array(x)
        temp = np.ones(x.shape[0] + 1)
        temp[0:-1] = x
        a = temp
        for l in range(0, len(self.weights)):
            a = self.activation(np.dot(a, self.weights[l]))
        return a

 

 

 

使用自己定義的神經網路演算法實現一些簡單的功能:

 小案例:

X:                  Y 0 0                 0 0 1                 1 1 0                 1 1 1                 0  
from NN.NeuralNetwork import NeuralNetwork
import numpy as np

nn = NeuralNetwork([2, 2, 1], 'tanh')
temp = [[0, 0], [0, 1], [1, 0], [1, 1]]
X = np.array(temp)
y = np.array([0, 1, 1, 0])
nn.fit(X, y)
for i in temp:
    print(i, nn.predict(i))

 

發現結果基本機制,無限接近0或者無限接近1

 

第二個例子:識別圖片中的數字

導入數據:

from sklearn.datasets import load_digits
import pylab as pl

digits = load_digits()
print(digits.data.shape)
pl.gray()
pl.matshow(digits.images[0])
pl.show()

 

觀察下:大小:(1797, 64)

數字0

 

接下來的代碼是識別它們:

import numpy as np
from sklearn.datasets import load_digits
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelBinarizer
from NN.NeuralNetwork import NeuralNetwork
from sklearn.cross_validation import train_test_split

# 載入數據集
digits = load_digits()
X = digits.data
y = digits.target
# 處理數據,使得數據處於0,1之間,滿足神經網路演算法的要求
X -= X.min()
X /= X.max()

# 層數:
# 輸出層10個數字
# 輸入層64因為圖片是8*8的,64像素
# 隱藏層假設100
nn = NeuralNetwork([64, 100, 10], 'logistic')
# 分隔訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y)

# 轉化成sklearn需要的二維數據類型
labels_train = LabelBinarizer().fit_transform(y_train)
labels_test = LabelBinarizer().fit_transform(y_test)
print("start fitting")
# 訓練3000次
nn.fit(X_train, labels_train, epochs=3000)
predictions = []
for i in range(X_test.shape[0]):
    o = nn.predict(X_test[i])
    # np.argmax:第幾個數對應最大概率值
    predictions.append(np.argmax(o))

# 列印預測相關信息
print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))

 

結果:

矩陣對角線代表預測正確的數量,發現正確率很多

 

這張表更直觀地顯示出預測正確率:

共450個案例,成功率94%

 


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 命令模式是我們能夠實現發送者和接收者之間的完全解耦,發送者是調用操作的對象,而接收者是接收請求並執行特定操作的對象。通過解耦,發送者無需瞭解接收者的介面。在這裡,請求的含義是需要被執行的命令。 作用 將一個請求封裝為一個對象,從而使你可用不同的請求對客戶進行參數化;對請求排隊或記錄請求日誌,以及支持 ...
  • 登錄鑒權:1. 用戶名+密碼 登錄請求2. 後臺接收登錄請求,生成ToKen(用戶名/密碼正確) 返回token3. 請求其他api 都帶上token,後臺校驗token是否存在/過期 後臺代碼如下:登錄/登出 @RestController@RequestMappingclass AuthCont ...
  • 基於spring boot 2.x + quartz 的CRUD任務管理系統,適用於中小項目。 基於spring boot +quartz 的CRUD任務管理系統: https://gitee.com/52itstyle/spring boot quartz 開發環境 JDK1.8、Maven、Ec ...
  • 學習目的: selenium目前版本已經到了3代目,你想加薪,就跟面試官扯這個,你贏了,工資就到位了,加上一個腳本的應用,結局你懂的 正式步驟 需求背景:抓取淘寶美食 Step1:流程分析 搜索關鍵字:利用selenium驅動瀏覽器搜索關鍵字,得到查詢後的商品列表 分析頁碼並翻頁:得到商品頁碼數,模 ...
  • 把系統分為各個功能不同的板塊,以電腦主機為例,高聚合就是指主板,cup等內的各種零件之間的緊密聯繫,松耦合就是指主板與cpu的連接,主板與顯卡的連接,主板與電源的連接。把顯卡,主板內的零件看作小號零件,這些小號零件組成了大號零件“顯卡”和“主板”,小號零件之間的連接相比大號零件之間的連接更緊密。緊密 ...
  • Spring的bean管理(註解) 註解 1.代碼裡面特殊標記(ep:@Test),使用註解完成一些相關功能 2.註解寫法 @註解名稱(屬性名稱=屬性值) 3.可以用在類,方法,屬性上都可以 4.Spring里替代部分配置文件,更方便 Spring註解開發的準備工作 導入 1.導入基本jar包 2. ...
  • 1.重置用戶信息 2.用戶登陸 ...
  • 二分法查找 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...