人工神經網路,支持任意數量隱藏層,多層隱藏層,python代碼分享

来源:http://www.cnblogs.com/bambipai/archive/2017/09/05/7474445.html
-Advertisement-
Play Games

程式以《集體智慧編程》第四章 “nn.py" 為原型和框架,可以建立多層隱藏層的感知器神經網路,並利用反向傳播法進行權值的修正。但該網路不能用於預測。 ...


  人工神經網路包含多種不同的神經網路,此處的代碼建立的是多層感知器網路,代碼以《集體智慧編程》第四章 “nn.py" 為原型和框架,可以指定隱藏網路的層數和每層的節點數,利用反向傳播法修正權值,並連接資料庫,保存每層每個節點的權值等信息。對網路進行訓練後,我們可以利用這個網路對輸入內容進行分類。代碼在演算法方面並沒有做出改進,結構上可能不是特別嚴謹和簡潔,只是為建立多層隱藏網路提供一個思路,可以對神經網路有更好的理解。

  新建一個文件(hiddens.py),併在其中新建一個類,取名為searchnet:

 1 from math import tanh
 2 import sqlite3 as sqlite
 3 import random
 4 class searchnet:
 5     def __init__(self,dbname,n,num):
 6       self.con=sqlite.connect(dbname)
 7       self.h=n#隱藏層的數量
 8       self.hiddennodes=num#每個隱藏層的節點數
11     def __del__(self):
12       self.con.close()
14     def maketables(self):
15       for i in range(self.h-1):
16         self.con.execute('create table hiddennode_%d(create_key,fromid,toid,strength)' % (i))
17       self.con.execute('create table wordhidden(fromid,toid,strength)')
18       self.con.execute('create table hiddenurl(fromid,toid,strength)')
19       self.con.commit()

其中,n和num分別是隱藏層的數量以及對應層數的節點數,然後我們建立了n-1張表存放隱藏層節點之間的權值,creat_key起到標示節點的作用,用以區別不同輸入形成的隱藏層節點,input和out分別是輸入內容和分類類別。

  接下來,我們來建立隱藏層以及節點之間的連接。

 1 def generatehiddennode(self,wordids,urls):
 2       #用以標示不同輸入產生的不同網路
 3       sorted_words=[id for id in wordids]
 4       sorted_words.sort()
 5       self.createkey='_'.join(sorted_words)
 6       
 7       #生成所有隱藏層節點並建立連接,creatkey標示了輸入的數據,每層每個節點的fromid和toid均不相同,代表了其層次和第幾個
 8       for i in range(self.h-1):
 9         for j in range(self.hiddennodes[i]):
10           for k in range(self.hiddennodes[i+1]):
11             table='hiddennode_%d' % i
12             fromid=str(i)+'_'+str(j)
13             toid=str(i+1)+'_'+str(k)
14             strn=random.random()
15             self.con.execute("insert into %s (create_key,fromid,toid,strength) values ('%s','%s','%s',%.2f)" % (table,self.createkey,fromid,toid,strn))
16    
17       #建立輸入和隱藏層的連接
18       table='wordhidden'
19       strength=0.1
20       for j in range(self.hiddennodes[0]):
21         hiddenid='0_'+str(j)
22         for wordid in wordids:
23           self.con.execute("insert into %s (fromid,toid,strength) values ('%s','%s',%f)" % (table,wordid,hiddenid,strength))
24           
25       #建立輸出和隱藏層的連接    
26       table='hiddenurl' 
27       strength=0.2
28       for j in range(self.hiddennodes[self.h-1]): 
29         hiddenid=str(self.h-1)+'_'+str(j)
30         for urlid in urls:
31             self.con.execute("insert into %s (fromid,toid,strength) values ('%s','%s',%f)" % (table,hiddenid,urlid,strength))
32       self.con.commit()

  首先連接輸入的內容作為標示不同輸入產生的不同隱藏層節點的標誌,然後迴圈建立隱藏層節點之間的連接(因為是隱藏層之間的連接,所以只需要n-1個表),除了create_key還有一個id來區分節點,即代碼中的fromid,x_y代表的是第x層隱藏層,第y個節點(x>=0,y>=0),隱藏層K層的toid就是K+1層的fromid,節點連接之間的權重在0-1之間隨機產生。然後建立第0層隱藏層和輸入層之間的連接,權值預設為0.1,第n-1層(最後一層)隱藏層和輸出層之間的連接,權值預設為0.2,並將這些信息存入表。

  我們可以運行看一下效果。

  

                

  另外我想說明的一點是,我們所建立的節點以及下麵要建立的網路都是抽象的,而資料庫中的表是具象的,但這並不是說表就是網路的具象化,它僅是存儲了網路中節點之間的連接,對於一個表中的fromid和toid來說,僅僅是一個名稱,並不代表真正抽象的節點,所以我們在建立隱藏層K層與K+1層節點之間的連接時,即便還並沒有生成及存儲K+1層的formid,我們仍然可以完成K層數據的生成與存儲,只要我們知道我們即將要生成的K+1層節點的名稱(fromid)即可。

  產生了隱藏層所有節點之後,可以開始建立網路了,利用資料庫中保存的信息,建立起包括所有當前權重值在內的相應網路。setupnetwork函數為searchnet類定義了多個實例變數,包括:輸入內容列表、隱藏層節點及分類分別,每個節點的數值輸出,節點之間的權重值(從資料庫中獲得)。

 1 def getallhiddenids(self,wordids,urlids):
 2       ll={}
 3       ll.setdefault(0,{})
 4       cur=self.con.execute("select toid from wordhidden where fromid='%s'" % wordids[0])
 5       for row in cur: ll[0].setdefault(row[0],1)
 6       res=row[0]
 7       for i in range(self.h-1):
 8         ll.setdefault(i+1,{})
 9         cur=self.con.execute("select toid from hiddennode_%d where create_key='%s' and fromid='%s' " % (i,res,self.createkey))
10         for row in cur: ll[i+1].setdefault(row[0],1)
11         res=row[0]
12       hn={}
13       for i in range(self.h):
14         node=sorted(ll[i].keys())
15         hn.setdefault(i,node)
16       return hn 
1 def getstrength(self,fromid,toid,layer):
2       if layer==-1: table='wordhidden'#-1層是輸入層
3       elif 0<=layer<self.h-1: table='hiddennode_%d' % layer
4       else: table='hiddenurl'
5       res=self.con.execute("select strength from %s fromid='%s' and toid='%s'" % (table,fromid,toid)).fetchone()
6       if res==None: 
7           if layer==-1: return -0.2
8           if 0<=layer<self.h: return 0
9       return res[0]

   在獲得隱藏層權值和節點時,要註意create_key這一項,如果沒有create_key這一項,儲存在同一張表中不同輸入所獲得的隱藏層節點id是相同的,那麼在獲得權值時就會產生錯誤,這是create_key這一項存在的重要意義。

  接下來建立網路,並構造前饋演算法,即演算法接受一系列輸入,按照神經網路的計算規則(y=Σ(x*w)),得到所有節點的輸出結果。

 1 def setupnetwork(self,wordids,urlids):
 2         # value lists
 3         self.wordids=wordids
 4         self.hns=self.getallhiddenids(wordids,urlids)#hns是個嵌套列表的字典
 5         self.urlids=urlids
 6  
 7         self.ah={}#ah是隱藏層的節點值,key是哪一層,value是節點值的列表
 8         self.w={}#w是權重值,key是該權重指向的哪一層,value是嵌套的列表,v[i][j]代表 i-j的weight
 9         # node outputs
10         self.ai = [1.0]*len(self.wordids)
11         for i in range(self.h):
12           self.ah.setdefault(i,[1.0]*len(self.hns[i]))
13         self.ao = [1.0]*len(self.urlids)
14         
15         # create weights matrix
16         self.w.setdefault(-1,[[self.getstrength(wordid,hiddenid,-1)for hiddenid in self.hns[0]] 
17                     for wordid in self.wordids])
18         for i in range(self.h-1):
19           self.w.setdefault(i,[[self.getstrength(fromid,toid,i)for toid in self.hns[i+1]] 
20                         for fromid in self.hns[i]])
21         self.w.setdefault('o',[[self.getstrength(hiddenid,urlid,1) for urlid in self.urlids] 
22                       for hiddenid in self.hns[self.h-1]])
setupnetwork

 

 1 def feedforward(self):
 2         # the only inputs are the query words
 3         for i in range(len(self.wordids)):
 4             self.ai[i] = 1.0
 5 
 6         # 首先利用輸入值得到第一層隱藏層的節點值
 7         for j in range(len(self.hns[0])):
 8             sum = 0.0
 9             for i in range(len(self.wordids)):
10                 sum = sum + self.ai[i] * self.w[-1][i][j]
11             self.ah[0][j] = tanh(sum)
12             
13         #然後迴圈得到其他隱藏層的節點值   
14         for i in range(1,self.h):
15           for j in range(len(self.hns[i])):
16             sum=0.0
17             for k in range(len(self.hns[i-1])):
18               sum = sum + self.ah[i-1][k] * self.w[i-1][k][j]
19             self.ah[i][j] = tanh(sum)
20 
21         # 最後得到輸出層的
22         for k in range(len(self.urlids)):
23             sum = 0.0
24             for j in range(len(self.hns[self.h-1])):
25                 sum = sum + self.ah[self.h-1][j] * self.w['o'][j][k]
26             self.ao[k] = tanh(sum)
27 
28         return self.ao[:]
feedforward

   

  因為初始化輸入值是相同的,因此兩個節點的輸出值也均為相同。

  下麵利用反向傳播法調整權值,反向傳播,即讓誤差沿著網路反向傳播,所得到的值便是權值的修正量。反向傳播法是經典的權值修正演算法,但此處對於演算法不做具體說明,需要瞭解的童鞋自行查閱。下麵程式中,N是學習率,tanh(y)=1-y*y,因此,y為0時,tanh最大,y為1時,tanh最小,tanh(output)與誤差相乘,乘權值,乘輸入作為權值的修正量,這是因為,我們在訓練時,會指定一個輸出節點的輸出目標為1(或接近1),這樣,如果輸出節點的輸出接近於1,tanh(output)很小,權值的修正量就小,反之,權值的修正量就大。

 1 def backPropagate(self, targets, N=0.5):
 2         # calculate errors for output
 3         output_deltas = {}#每一層每個結點的誤差=error*dtanh(out),out是正向傳播時相對於這一層的輸出
 4         change={}#權值的改變值
 5         
 6         output_deltas.setdefault('o',[])
 7         for k in range(len(self.urlids)):
 8             error = targets[k]-self.ao[k]
 9             output_deltas['o'].append(dtanh(self.ao[k]) * error)
10         
11         hid=self.h-1
12         output_deltas.setdefault(hid,[])
13         for j in range(len(self.hns[hid])):
14             error = 0.0
15             for k in range(len(self.urlids)):
16                 error = error + output_deltas['o'][k]*self.w['o'][j][k]
17             output_deltas[hid].append(dtanh(self.ah[hid][j]) * error)
18 
19         # calculate errors for hidden layer
20         for i in range(hid-1,-1,-1):
21           output_deltas.setdefault(i,[])
22           for j in range(len(self.hns[i])):
23             error = 0.0
24             for k in range(len(self.hns[i+1])):
25                 error = error + output_deltas[i+1][k]*self.w[i][j][k]
26             output_deltas[i].append(dtanh(self.ah[i][j]) * error)
27 
28         # update output weights
29         for j in range(len(self.hns[hid])):
30             for k in range(len(self.urlids)):
31                 change = output_deltas['o'][k]*self.ah[hid][j]
32                 self.w['o'][j][k] = self.w['o'][j][k] + N*change
33 
34         # update input weights
35         for i in range(hid-1,-1,-1):
36           for k in range(len(self.hns[i])):
37               for j in range(len(self.hns[i+1])):
38                   change = output_deltas[i+1][j]*self.ah[i][k]
39                   self.w[i][k][j] = self.w[i][k][j] + N*change
40                   
41         for j in range(len(self.wordids)):
42             for k in range(len(self.hns[0])):
43                 change = output_deltas[0][k]*self.ai[j]
44                 self.w[-1][j][k] = self.w[-1][j][k] + N*change

 

    

  我們看到,在經過5次的權值修正後,輸出結果相對於最初的輸出結果相對更接近於目標值。

  我們可以把修正後的權值更新到資料庫,在這裡同樣註意“creat_key".

 1 def setstrength(self,fromid,toid,layer,strength):
 2       if layer==-1: 
 3         table='wordhidden'
 4         res=self.con.execute("select rowid from %s where fromid='%s' and toid='%s'" % (table,fromid,toid)).fetchone()
 5       elif 0<=layer<self.h-1: 
 6         table='hiddennode_%d' % layer
 7         res=self.con.execute("select rowid from %s where create_key='%s' and fromid='%s' and toid='%s'" % (table,self.createkey,fromid,toid)).fetchone()
 8       else: 
 9         table='hiddenurl'
10         res=self.con.execute("select rowid from %s where fromid='%s' and toid='%s'" % (table,fromid,toid)).fetchone()
11       rowid=res[0]
12       self.con.execute('update %s set strength=%f where rowid=%d' % (table,strength,rowid))
setstrength
 1 def updatedatabase(self):
 2       # set them to database values
 3       for i in range(len(self.wordids)):
 4           for j in range(len(self.hns[0])):
 5               self.setstrength(self.wordids[i],self.hns[0][j],-1,self.w[-1][i][j])
 6       for k in range(self.h-1):
 7         for i in range(len(self.hns[k])):
 8           for j in range(len(self.hns[k+1])):
 9             self.setstrength(self.hns[k][i],self.hns[k+1][j],k,self.w[k][i][j])
10       for i in range(len(self.hns[self.h-1])):
11         for j in range(len(self.urlids)):
12             self.setstrength(self.hns[self.h-1][i],self.urlids[j],self.h,self.w['o'][i][j])
13       self.con.commit()

   到此,整個網路的建立、訓練代碼就完成了,但是這個網路只是對輸入過的內容分類效果較好,並不能進行預測。

  以上代碼適用於Python3.4,python2.x需要稍作改變

 1 1 def feedforward(self):
 2  2         # the only inputs are the query words
 3  3         for i in range(len(self.wordids)):
 4  4             self.ai[i] = 1.0
 5  6         # 首先利用輸入值得到第一層隱藏層的節點值
 6  7         for j in range(len(self.hns[0])):
 7  8             sum = 0.0
 8  9             for i in range(len(self.wordids)):
 9 10                 sum = sum + self.ai[i] * self.w[-1][i][j]
10 11             self.ah[0][j] = tanh(sum)
11 13         #然後迴圈得到其他隱藏層的節點值   
12 14         for i in range(1,self.h):
13 15           for j in range(len(self.hns[i])):
14 16             sum=0.0
15 17             for k in range(len(self.hns[i-1])):
16 18               sum = sum + self.ah[i-1][k] * self.w[i-1][k][j]
17 19             self.ah[i][j] = tanh(sum)
18 21         # 最後得到輸出層的
19 22         for k in range(len(self.urlids)):
20 23             sum = 0.0
21 24             for j in range(len(self.hns[self.h-1])):
22 25                 sum = sum + self.ah[self.h-1][j] * self.w['o'][j][k]
23 26             self.ao[k] = tanh(sum)
24 28         retur

您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • 這篇文章是我之前寫的博文 資料庫方面的面試技巧,如何從建表方面展示自己能力 和 面試技巧,如何通過索引說資料庫優化能力,內容來自Java web輕量級開發麵試教程是一個系列的,通過面試官的視角和大家分享在資料庫方面的面試經驗,這些內容都來摘自 java web輕量級開發麵試教程。 之前的兩篇文章點擊 ...
  • 自然語言處理在很多APP中都有實際應用的場景,比如在電商軟體中,客服問答系統、評論情感分析、帶有語義識別的搜索、商品自動分類、用戶畫像等等。那麼本篇作為自然語言處理淺學的第一篇,就著重來講一下背景知識。 背景知識 自然語言處理,英文是natural language process, NLP,說白了 ...
  • Mybatis的 mapper.xml 中 update 語句使用 if 標簽判斷對像屬性是否為空值。 UserDTO是傳過來參數的類型,userDTO是在mapperDao介面中給更新方法的參數起的別名。 mapperDao.java <update id="updata" parameterTy ...
  • 測試場景下,使用的oralce遇到表空間的占用超大,可以採用如下的方式進行空間的清理 首先使用sqlplus連接資料庫sqlplus sys/password@orcl as sysdba 之類進行資料庫的連接沒然後進行如下的操作 ##創建表空間對於自己的測試庫和表等最好都建立自己的表空間,以方便清 ...
  • 一、expdp/impdp和exp/imp的區別 1、exp和imp是客戶端工具程式,它們既可以在客戶端使用,也可以在服務端使用。 2、expdp和impdp是服務端的工具程式,他們只能在oracle服務端使用,不能在客戶端使用。 3、imp只適用於exp導出的文件,不適用於expdp導出文件;im ...
  • Pentaho Data Integration (Kettle) 一套基於Java的開源ETL工具集,是商務智能套件Pentaho的一部分。 社區主頁:http://community.pentaho.com/projects/data-integration 幫助文檔:https://help. ...
  • 欄位為id,存取的值有1,1.0,1.09。。。。。現在統一取出結果為1.0 select left(cast(id as DECIMAL(18,6) ) , charindex('.',cast(id as DECIMAL(18,6) )) + 1) ...
  • 資料庫 一個mongodb中可以建立多個資料庫。 MongoDB的預設資料庫為"db",該資料庫存儲在data目錄中。 MongoDB的單個實例可以容納多個獨立的資料庫,每一個都有自己的集合和許可權,不同的資料庫也放置在不同的文件中。 "show dbs" 命令可以顯示所有數據的列表。 執行 "db" ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...