ShuffleNet總結

来源:http://www.cnblogs.com/heguanyou/archive/2017/12/22/8087422.html
-Advertisement-
Play Games

在2017年末,Face++發了一篇論文[ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices ](https://arxiv.org/abs/1707.01083)討論了一個極有效率且可... ...


在2017年末,Face++發了一篇論文ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices討論了一個極有效率且可以運行在手機等移動設備上的網路結構——ShuffleNet。這個英文名我更願意翻譯成“重組通道網路”,ShuffleNet通過分組捲積與\(1 \times 1\)的捲積核來降低計算量,通過重組通道來豐富各個通道的信息。這個論文的mxnet源碼的開源地址為:MXShuffleNet

分組捲積與核大小對計算量的影響

論文說中到“We propose using pointwise group convolutions to reduce computation complexity of 1 × 1 convolutions”,那麼為什麼用分組捲積與小的捲積核會減少計算的複雜度呢?先來看看捲積在編程中是如何實現的,Caffe與mxnet的CPU版本都是用差不多的方法實現的,但Caffe的計算代碼會更加簡潔。

不分組且只有一個樣本

在不分組與輸入的樣本量為1(batch_size=1)的條件下,輸出一個通道上的一個點是捲積核會與所有的通道捲積之積,如圖1所示:

conv1

圖1 輸入層(第一層)只有一個通道,那個第二層一個通道上的點是第一層通道相應區域與相應捲積核的捲積,第三層一個通道上的點是與第二層所有通道上相應區域與相應捲積核的捲積,而且對於輸出通道每個輸入通道對應的捲積核是不一樣的,不同的輸出通道也有不同的捲積核,所以說捲積核的參數量是$C_{out} \times C_{in} \times H_k \times W_k$

在Caffe的計算方法中,先要將輸入張量為\(n \times C_{in} \times H_{in} \times W_{in}\)(n是batch_size)轉化為一個$ \left(C_{in} \times H_k \times H_w\right) \times \left(H_{in} \times W_{in}\right)\(的矩陣,這個過程叫**im2col**。最後得到的輸出張量為\)n \times C_{out} \times H_{in} \times W_{in}$。

conv2

圖2 輸入的所有通道按捲積核的大小提取出來排列成一行,要註意的是在這隻是示意圖,在實際的程式中,一般會排成一列,因為在防問數據時會一個通道一個通道地訪問的。輸出一個點要輸出的數據有$C_{in} \times K_h \times K_w$個。

conv3

圖3 輸出一個通道就有$H_{out} \times W_{out}$個點,而且在程式中同一個通道(圖中的同一個顏色)的內容是按行排列的,所以說轉換出來的的矩陣是圖中$ \left(C_{in} \times H_k \times H_w\right) \times \left(H_{out} \times W_{out}\right)$矩陣的轉置。

conv4

圖4 同樣捲積核(Filter)也要Reshape成$C_{out} \times \left( C_{in} \times K_h \times K_w \right)$ 的矩陣

得到的兩個矩陣Feature與Filter相乘得到輸出矩陣Output,再Reshape成\(C_{out} \times C_{in} \times H_k \times W_k\)張量:
\[ Filter_{C_{out} \times \left( C_{in} \times K_h \times K_w \right)} \times Feature_{\left(C_{in} \times H_k \times H_w\right) \times \left(H_{out} \times W_{out}\right)} = Output_{C_{out} \times (H_{out} \times W_{out})} \tag{1.1} \]

現在的計算技術中,對方長度為\(n\)的方陣,計算量能從\(n^3\)代碼到\(n^{2.376}\),最小的複雜度現在仍然未知,本文為了方便計算量就以\(n^3\)為基準。所以式(1.1)的矩陣計算最普通的計算量\(Computation\)是:
\[ Computation=C_{out} \times H_{out} \times W_{out} \times \left( C_{in} \times K_h \times K_w \right)^2 \tag{1.2} \]
從式(1.2)中可以看出來,捲積核的大小對計算量影響是很大的,\(3 \times 3\)的捲積核比\(1 \times 1\)的計算量要大\(3^4=81\)倍。

分組且只有一個樣本

什麼叫做分組,就是將輸入與輸出的通道分成幾組,比如輸出與輸入的通道數都是4個且分成2組,那第1、2通道的輸出只使用第1、2通道的輸入,同樣那第3、4通道的輸出只使用第1、2通道的輸入。也就是說,不同組的輸出與輸入沒有關係了,減少聯繫必然會使計算量減小,但同時也會導致信息的丟失

當分成g組後,一層參數量的大小由\(Filter_{C_{out} \times \left( C_{in} \times K_h \times K_w \right)}\)變成\(Filter_{C_{out} \times \left( C_{in} \times K_h \times K_w / g \right)}\)。Feature Matrix的大小雖然沒發生變化,但是每一組的使用量是原來的$1/g,Filter也只用到所有參數的\(1/g\)\(。然後再迴圈計算\)g$次(同時FeatureMatrix與FilterMatrix要有地址偏移),那麼計算公式與計算量的大小為:
\[ Filter_{C_{out}/g \times \left( C_{in} \times K_h \times K_w /g \right)} \times Feature_{\left(C_{in} \times H_k \times H_w /g\right) \times \left(H_{out} \times W_{out}\right)} = Output_{C_{out}/g \times (H_{out} \times W_{out})} \tag{1.3} \]
\[ Computation=C_{out} \times H_{out} \times W_{out} \times \left( C_{in} \times K_h \times K_w /g \right)^2 \tag{1.4} \]

所以,分成\(g\)組可以使參數量變成原來的\(1/g\),計算量是原來的\(1/g^2\)

多個樣本輸入

為了節省記憶體,多個樣本輸入的時候,上述的所有過程都不會改變,而是每一個樣本都運行一次上述的過程。

以上只是最簡單、粗略的分析,實際上計算效率的提升並不會有上述這麼多,一方面因為im2col會消耗與矩陣運算差不多的時間,另一方面因為現代的blas庫優化了矩陣運算,複雜度並沒有上述分析的那麼多,還有計算過程for迴圈是比較耗時的指令,即使用openmp也不能優化捲積的計算過程。

交換通道(Shuffle Channels)

在上面我提到過,分組會導致信息的丟失,那麼有沒有辦法來解決這個問題呢?這個論文給出的方法就是交換通道,因為在同一組中不同的通道蘊含的信息可能是相同的,如果在不同的組之後交換一些通道,那麼就能交換信息,使得各個組的信息更豐富,能提取到的特征自然就更多,這樣是有利於得到更好的結果。

shufleChannels

圖5 分組交換通道的示意圖,a)是不交換通道但是分成3組了,要吧看到,不同的組是完全獨立的;b)每組內又分成3組,不分別交換到其它組中,這樣信息就發生了交換,c)這個是與b)是等價的。

ShuffleUnit

ShuffleUnit的設計參考了ResNet,總有兩個基本單元,兩人個基本單元功能不一樣,將他們組合起來就可以得到ShuffleNet。這樣的設計可以在增加網路的深度(比mobilenet深約一倍)的同時,減少參數總量與計算量(本人運行Cifar10時,速度大約是molibenet的10倍)。

shufleunit

圖6 b)與c)是兩人個ShuffleNet的基本單元,這兩個單元是參考了a)的設計,單元b)輸出與輸入的Shape一致,只是豐富了每個通道的信息,單元c)增加了一倍的通道數且輸出的$H_{out}$、$W_{out}$ 比$H_{in}$、$W_{in}$減少了一倍

源碼解讀

def combine(residual, data, combine):
    if combine == 'add':
        return residual + data
    elif combine == 'concat':
        return mx.sym.concat(residual, data, dim=1)
    return None

add是代表圖6中的單元b),concat是代表圖6中的單元c)。

def channel_shuffle(data, groups):
    data = mx.sym.reshape(data, shape=(0, -4, groups, -1, -2))
    data = mx.sym.swapaxes(data, 1, 2)
    data = mx.sym.reshape(data, shape=(0, -3, -2))
    return data

這個函數就是交換通道的函數,函數的第一行data = mx.sym.reshape(data, shape=(0, -4, groups, -1, -2))是將輸入為\(n \times C_{in} \times H_{in} \times W_{in}\)reshape成\(n \times (C_{in}/g) \times g\times H_{in} \times W_{in}\),要註意的是mxnet中reshape不會改變張量在記憶體中的排列順序。至於要mxnet中的0,-1,-2,-3,-4的具體意義可以這樣看到:

import mxnet as mx
print(help(mx.sym.reshape))

可以看到輸出以下(只提取出一小部分,其餘的可用上述方法查看),這裡有各個參數的具體意義:

- ``0``  copy this dimension from the input to the output shape.
- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
- ``-2`` copy all/remainder of the input dimensions to the output shape.
- ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension.
- ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1).

函數的第二行是交換第一與第二個維度,那麼現在這個symbol的符號的shape就變成了\(n \times g \times (C_{in}/g) \times H_{in} \times W_{in}\)。這裡的第零個維度是\(n\)要註意的是交換維度改變了張量在記憶體中的排列順序,改變了記憶體中的順序實現上就是完成了圖5c)中的Channel Shuffle操作,不同的顏色代碼數據在原來記憶體中的位置。
函數的最後一行合併了第一與第二個維度,輸出的張量與輸入的張量shape都是\(n \times C_{in} \times H_{in} \times W_{in}\)

def shuffleUnit(residual, in_channels, out_channels, combine_type, groups=3, grouped_conv=True):

    if combine_type == 'add':
        DWConv_stride = 1
    elif combine_type == 'concat':
        DWConv_stride = 2
        out_channels -= in_channels

    first_groups = groups if grouped_conv else 1

    bottleneck_channels = out_channels // 4

    data = mx.sym.Convolution(data=residual, num_filter=bottleneck_channels, 
                      kernel=(1, 1), stride=(1, 1), num_group=first_groups)
    data = mx.sym.BatchNorm(data=data)
    data = mx.sym.Activation(data=data, act_type='relu')

    data = channel_shuffle(data, groups)

    data = mx.sym.Convolution(data=data, num_filter=bottleneck_channels, kernel=(3, 3), 
                       pad=(1, 1), stride=(DWConv_stride, DWConv_stride), num_group=groups)
    data = mx.sym.BatchNorm(data=data)

    data = mx.sym.Convolution(data=data, num_filter=out_channels, 
                       kernel=(1, 1), stride=(1, 1), num_group=groups)
    data = mx.sym.BatchNorm(data=data)

    if combine_type == 'concat':
        residual = mx.sym.Pooling(data=residual, kernel=(3, 3), pool_type='avg', 
                              stride=(2, 2), pad=(1, 1))

    data = combine(residual, data, combine_type)

    return data

ShuffleUnit這個函數實現上是實現圖6的b)與c),add對應成b),comcat對應於c)。

def make_stage(data, stage, groups=3):
    stage_repeats = [3, 7, 3]

    grouped_conv = stage > 2

    if groups == 1:
        out_channels = [-1, 24, 144, 288, 567]
    elif groups == 2:
        out_channels = [-1, 24, 200, 400, 800]
    elif groups == 3:
        out_channels = [-1, 24, 240, 480, 960]
    elif groups == 4:
        out_channels = [-1, 24, 272, 544, 1088]
    elif groups == 8:
        out_channels = [-1, 24, 384, 768, 1536]
       
    data = shuffleUnit(data, out_channels[stage - 1], out_channels[stage], 
                       'concat', groups, grouped_conv)

    for i in range(stage_repeats[stage - 2]):
        data = shuffleUnit(data, out_channels[stage], out_channels[stage], 
                           'add', groups, True)

    return data

def get_shufflenet(num_classes=10):
    data = mx.sym.var('data')
    data = mx.sym.Convolution(data=data, num_filter=24, 
                              kernel=(3, 3), stride=(2, 2), pad=(1, 1))
    data = mx.sym.Pooling(data=data, kernel=(3, 3), pool_type='max', 
                          stride=(2, 2), pad=(1, 1))
    
    data = make_stage(data, 2)
    
    data = make_stage(data, 3)
    
    data = make_stage(data, 4)
     
    data = mx.sym.Pooling(data=data, kernel=(1, 1), global_pool=True, pool_type='avg')
    
    data = mx.sym.flatten(data=data)
    
    data = mx.sym.FullyConnected(data=data, num_hidden=num_classes)
    
    out = mx.sym.SoftmaxOutput(data=data, name='softmax')

    return out

這兩個函數可以直接得到作者在論文中的表:

table

圖7

結果比較

論文後面用了種實驗證明這兩個技術的有效性,且證實了ShuffleNet的優秀,這裡就不細說,看論文後面的表就能一目瞭然。


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • Supervisor是用Python開發的一套client/server架構的進程管理程式,能做到開機啟動,以daemon進程的方式運行程式,並可以監控進程狀態等等。 linux進程管理方式有傳統的rc.d、新興的upstart、systemd等,與這些相比,Supervisor有著自己的特點。 便 ...
  • 如題,具體效果及代碼如下: 1 print("你好你叫什麼名字?") 2 print("我叫張三") 3 print("你好,張三,我叫李四") 4 5 print("你好你叫什麼名字?",end="") 6 print("我叫張三!",end="") 7 print("你好,張三,我叫李四",en ...
  • 版本:3.4.10 問題:啟動 zkServer.cmd時報錯如下, 解決辦法: bin目錄下 zkEnv.cmd配置 修改為: 把雙引號放在外面。 此時運行zkServer.cmd 用zkCli.cmd連接成功: 成功! ...
  • 使用While迴圈時經常會犯的一些小錯誤。以及猜年齡程式的2種編寫方式。 ...
  • 為什麼要用插件 主要還是自動化的考慮,如果額外使用Dockerfile進行鏡像生成,可能會需要自己手動指定jar/war位置,並且打包和生成鏡像間不同步,帶來很多瑣碎的工作。 插件選擇 使用比較多的是spotify的插件:https://github.com/spotify/docker maven ...
  • 關於本文說明,本人原博客地址位於http://blog.csdn.net/qq_37608890,本文來自筆者於2017年12月06日 18:06:30所撰寫內容(http://blog.csdn.net/qq_37608890/article/details/78731169)。 本文根據最近學習 ...
  • Trie樹與AC自動機 作為現階段的學習中個人應有的常識,AC自動機形象的來講就是在Trie樹上跑的一個KMP。由此,我們就先來談一談Trie樹。(有圖) 1. Trie樹 又稱單詞查找樹,字典樹,一般用於字元串的匹配。它利用公共的字元串首碼進行查詢,減少了無謂的操作,是空間換時間的經典演算法。舉例: ...
  • 本文秉承著 你看不懂是你sb,我寫的代碼就要牛逼 的理念來介紹一些js的裝逼技巧。 下麵的技巧,後三個,請謹慎用於團隊項目中(主要考慮到可讀性的問題),不然,leader 乾你沒商量。 [圖片上傳失敗...(image-922e98-1513315809572)] image.png Boolean ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...