原文地址: https://zhuanlan.zhihu.com/p/23309693 https://zhuanlan.zhihu.com/p/23293860 CTC:前向計算例子 這裡我們直接使用warp-ctc中的變數進行分析。我們定義T為RNN輸出的結果的維數,這個問題的最終輸出維度為al ...
原文地址:
https://zhuanlan.zhihu.com/p/23309693
https://zhuanlan.zhihu.com/p/23293860
CTC:前向計算例子
這裡我們直接使用warp-ctc中的變數進行分析。我們定義T為RNN輸出的結果的維數,這個問題的最終輸出維度為alphabet_size。而ground_truth的維數為L。也就是說,RNN輸出的結果為alphabet_size*T的結果,我們要將這個結果和1*L這個向量進行對比,求出最終的Loss。
我們要一步一步地揭開這個演算法的細節……當然這個演算法的實現代碼有點晦澀……
我們的第一步要順著test_cpu.cpp的路線來分析代碼。第一步我們就是要解析small_test()中的內容。也就是做前向計算,計算對於RNN結果來說,對應最終的ground_truth——t的label的概率。
這個計算過程可以用動態規劃的演算法求解。我們可以用一個變數來表示動態規劃的中間過程,它就是:
:表示在RNN計算的時間T時刻,這一時刻對應的ground_truth的label為第i個下標的值t[i]的概率。
這樣的表示有點抽象,我們用一個實際的例子來講解:
RNN結果:,這裡的每一個變數都對應一個列向量。
ground_truth:
那麼表示的結果對應著的概率,當然與此同時,前面的結果也都合理地對應完成。
從上面的結果我們可以看出,如果的結果對應著,那麼的結果也必然對應著。所以前面的結果是確定的。然而對於其他的一些情況來說,我們的轉換存在著一定的不確定性。
CTC:前向計算具體過程
我們還是按照上面的例子進行計算,我們把剛纔的例子搬過來:
RNN結果:,這裡的每一個變數都對應一個列向量。
ground_truth:
alphabet:
按照上面介紹的計算方法,第一步我們先做ground_truth的狀態擴展,於是我們就把長度從3擴展到了7,現在的ground_truth變成了:
我們的RNN結果長度為4,也就是說我們會從上面的7個ground_truth狀態中進行轉移,並最終轉移到最終狀態。理論上利用動態規劃的演算法,我們需要計算4*7=28個中間結果。好了,下麵我們用表示RNN的第T時刻狀態為ground_truth中是第i個位置的概率。
那麼我們就開始計算了:
T=1時,我們只能選擇和blank,所以這一輪我們終結狀態只可能落在0和1上。所以第一輪變成了:
T=2時,我們可以繼續選擇,我們同時也可以選擇,還可以選擇和之間的blank,所以我們可以進一步關註這三個位置的概率,於是我們將其他的位置的概率設為0。
T=3時,留給我們的時間已經不多了,我們還剩2步,要走完整個旅程,我們只能選擇,以及它們之間的空格。於是乎我們關心的位置又發生了變化:
是不是有點看暈了?沒關係,因為還剩最後一步了。下麵是最後一步,因為最後一步我們必須要到以及它後面的空格了,所以我們的概率最終計算也就變成了:
好吧,最終的結果我們求出來了,實際上這就是通過時間的推移不斷迭代求解出來的。關於迭代求解的公式這裡就不再贅述了。我們直接來看一張圖:
於是乎我們從這個計算過程中發現一些問題:
首先是一個相對簡單的問題,我們看到在計算過程中我們發現了大量的連乘。由於每一個數字都是浮點數,那麼這樣連乘下去,最終數字有可能非常小而導致underflow。所以我們要將這個計算過程轉到對數域上。這樣我們就將其中的乘法轉變成了加法。但是原本就是加法的計算呢?比方說我們現在計算了loga和logb,我們如何計算log(a+b)呢,這裡老司機給出瞭解決方案,我們假設兩個數中a>b,那麼有
這樣我們就利用了loga和logb計算出了log(a+b)來。
另外一個問題就是,我們發現在剛纔的計算過程當中,對於每一個時間段,我們實際上並不需要計算每一個ground-truth位置的概率信息,實際上只要計算滿足某個條件的某一部分就可以了。所以我們有沒有希望在計算前就規劃好這條路經,以保證我們只計算最相關的那些值呢?
如何控制計算的數量?
不得不說,這一部分warp-ctc寫得實在有點晦澀,當然也可能是我在這方面的理解比較渣。我們這裡主要關註兩個部分——一個是數據的準備,一個是最終的數據的使用。
在介紹數據準備之前,我們先簡單說一下這部分計算的大概思路。我們用兩個變數start和end表示我們需要計算的狀態的起止點,在每一個時間點,我們要更新start和end這兩個變數。然後我們更新start和end之間的概率信息。這裡我們先要考慮一個問題,start和end的更新有什麼規律?
為了簡化思考,我們先假設ground_truth中沒有重覆的label,我們的大腦瞬間得到瞭解放。好了,下麵我們就要給出代碼中的兩個變數——
T:表示RNN結果中的維度
S/2:ground_truth的維度(S表示了擴展blank之後的維度)
基本上具備一點常識,我們就可以知道T>=S/2。什麼?你覺得有可能出現T<S/2的情況?兄弟,這種見鬼的事情如果發生,你難道要我們把RNN的結果拆開給你用?臣妾不太能做得到啊……
好了,既然接受了上面的事實,那麼我們就來舉幾個例子看看:
我們假設T=3,S/2=3,那麼說白了,它們之間的對應關係是一一對應,說白了這就和blank位置沒啥關係了。在T=1時,我們要轉移到第一個結果,T=2,我們要轉移到第二個結果……
如何控制計算的數量?cont.
好,廢話少說我們書接上回。不明真相的小朋友先看這個:
下麵我們假設T=4,S/2=3,好玩的地方來了。T比S/2多一個,也就是說我們允許冗餘出現了,那麼我們可能的形式也就變多了。我們可以增加一個blank,我們也可以在沒有label位置原地打一輪醬油。選擇更多,歡樂更多。
雖然選擇變多,但是著並不意味著我們可以選擇任意一種狀態轉移的方式,至少:
- 在T=2時,我們至少要轉移到第一個結果
- 在T=3時,我們至少要轉移到第二個結果
- 在T=4時,兄弟我們準備下車了
這其實就是對start的限制。源代碼中有這樣一句話:
int remain = (S / 2) + repeats - (T - t);
這裡我們先忽略repeats,那麼remain這個變數其實是在計算label數量和剩餘時間的差。如果用這樣的語言來表達剛纔的那個問題,我們語言就變成這個樣子:
- 當時間還剩4輪時(包括第4輪),我們在哪都無所謂(實際上是從T=1開始計算的)
- 當時間還剩3輪時(包括第3輪),我們至少要轉移到第一個結果(index=1)
- 當時間還剩2輪時(包括第2輪),我們至少要轉移到第二個結果(index=3)
- 當時間還剩1輪時(包括第1輪),我們至少要轉移到第三個結果(index=5)
好了,這裡我們看出其中的含義了。我們再啰嗦一下,看看這些變數隨T的變化情況:
- T=1,remain=0,start+=1
- T=2,remain=1,start+=2
- T=3,remain=2,start+=2
現在我們已經十分清楚了,當remain>=0時,start都要向前走,限制我們計算前面狀態的概率,因為這些概率已經沒有意義了。下麵的代碼也是這樣描述的:
if(remain >= 0)
start += s_inc[remain];
那麼這個s_inc是什麼東西?它就是我們需要提前準備好的計算量。我們知道經過擴充的label序列中,所有的非空label都處在奇數的index上,而填充的blank都處在偶數的index上(我們是0-based的計算方法,matlab選手請退散……),所以對於上面的問題,當start=0時,下一步我們會從0跳到1,此後我們會從1到3,3到5,跳轉的步數都是2,所以基於這個思路,我們就可以把s_inc這個數組生成出來。當然,我們的前提是沒有重覆。下麵我們會說重覆的問題的。
我們上面說了這麼多,重點把start的變化介紹清楚了。下麵我們來看看end。其實end的原理也類似,我們還是用剛纔的廢話套路來介紹站在end視角的世界:
- 在T=1時,我們最多能到第一個結果
- 在T=2時,我們最多能轉移到第二個結果
- 在T=3時,我們最多能轉移到第三個結果
- 在T=4時,我們已經掌握了整個世界……oh yeah
好了,可以看出end的變化形式,每個時刻end都可以+2,直到到達最後一個非blank的label,end變成了+1,然後end就不用動了,等著start動就可以了……(怎麼感覺有點污?天哪……)
那麼end變化的條件是什麼呢?
if(t <= (S / 2) + repeats)
end += e_inc[t - 1];
我們還是忽略repeats,那麼就十分清楚了,如果當前時刻小於等於label數,那麼儘管前進,如果大於了,基本上也就到頭了,這時候end就不用動了。
好了,前面我們終於說完了簡單模式下start和end的移動規律,下麵我們來看看帶重覆模式下的變化方法。
重覆,重覆
重覆會帶來什麼樣的變化呢?說白瞭如果有重覆的label出現,那麼兩個連續重覆的label中間就要至少出現一個blank。換句話說,每出現一個重覆,我們的S/2就要加一,於是我們再看一眼這兩個計算公式:
int remain = (S / 2) + repeats - (T - t);
if(remain >= 0)
start += s_inc[remain];
if(t <= (S / 2) + repeats)
end += e_inc[t - 1];
我們把repeats和S/2歸到一起,這時候就能看明白了。
同理,在計算s_inc和e_inc的時候,由於有repeats的存在,它們從過去的+2變成了兩個+1。也就是說先從label跳到blank,再跳到下一個label。這樣就可以解釋s_inc和e_inc的初始化策略了:
int e_counter = 0;
int s_counter = 0;
s_inc[s_counter++] = 1;
int repeats = 0;
for (int i = 1; i < L; ++i) {
if (labels[i-1] == labels[i]) {
s_inc[s_counter++] = 1;
s_inc[s_counter++] = 1;
e_inc[e_counter++] = 1;
e_inc[e_counter++] = 1;
++repeats;
}
else {
s_inc[s_counter++] = 2;
e_inc[e_counter++] = 2;
}
}
e_inc[e_counter++] = 1;
好了,到此我們才算把CTC中compute ctc loss這部分介紹完了。教科書上的一個公式看著簡單,落實到代碼就似乎充滿了trick。希望看懂了這個計算的你大腦沒有陣亡。