在筆者對(duì)于對(duì)比學(xué)習(xí)的認(rèn)識(shí)中,主要有2個(gè)維度的事情需要考慮:
如何選取合適的負(fù)樣本如何選取合適的損失函數(shù)以下結(jié)合一些訓(xùn)練經(jīng)驗(yàn),簡(jiǎn)要筆記下。
如何構(gòu)造合適的負(fù)樣本之前在[1]中簡(jiǎn)單介紹過一些構(gòu)造負(fù)樣本的方法,總體來說,基于用戶行為數(shù)據(jù)我們可以通過batch negative和無點(diǎn)數(shù)據(jù)進(jìn)行負(fù)樣本構(gòu)建。
batch negative在搜索過程中,用戶行為存在很大的隨機(jī)性,比如有展現(xiàn)但沒有點(diǎn)擊的數(shù)據(jù)并不一定就是負(fù)樣本,為了獲取更可靠的用戶數(shù)據(jù),我們可以選擇在用戶點(diǎn)擊過的Doc之間組成負(fù)樣本。沒錯(cuò),我們認(rèn)為用戶點(diǎn)擊過的行為是更為可靠的,雖然即便是點(diǎn)擊行為也可能只是因?yàn)橛脩舻暮闷嫘袨榛蛘哒`操作等等,但是對(duì)比于無點(diǎn)行為總歸是更為可靠的。假設(shè)用戶的query i ii和點(diǎn)擊過的Doc組成二元組,其中的C \mathcal{C}C表示所有有點(diǎn)行為的集合,那么我們認(rèn)為其負(fù)樣本就是
。當(dāng)數(shù)據(jù)足夠龐大是,有點(diǎn)數(shù)據(jù)
的規(guī)模也會(huì)非常龐大,我們無法一次將所有負(fù)樣本都列舉出來(同時(shí),也沒有必要),我們通常會(huì)在一個(gè)batch內(nèi)對(duì)所有用戶點(diǎn)擊二元組進(jìn)行組合。也即是將
的規(guī)模限制在一個(gè)batch內(nèi),如Fig 1.1所示,其中的對(duì)角線都是二元組正樣本,而其他元素都是負(fù)樣本。通過一個(gè)矩陣乘法,我們就可以實(shí)現(xiàn)這個(gè)操作。如式子(1.1)所示。
Fig 1.1 Batch Negative的方式從一個(gè)batch中構(gòu)造負(fù)樣本。
無點(diǎn)擊樣本無點(diǎn)數(shù)據(jù)也不是一無是處,在某些搜索產(chǎn)品中,如果排序到前面的結(jié)果本身就不夠好,那么用戶的點(diǎn)擊數(shù)據(jù)和無點(diǎn)擊數(shù)據(jù)就具有足夠的區(qū)分度,無點(diǎn)數(shù)據(jù)拿來視為負(fù)樣本就是合理的,這個(gè)和產(chǎn)品具體的設(shè)計(jì),或者呈現(xiàn)UI形式等等有關(guān),需要在實(shí)踐中才能實(shí)驗(yàn)出來。
在實(shí)踐中,通常還會(huì)去進(jìn)行batch negative和無點(diǎn)擊數(shù)據(jù)的混合以達(dá)到獲取足夠多的負(fù)樣本的目的。
使用何種損失函數(shù)常用在對(duì)比學(xué)習(xí)中的損失函數(shù)主要有兩種,hinge loss[2]和交叉熵?fù)p失。其中的hinge loss形式如(2.1)所示:
hinge loss和SVM一樣[3],存在一個(gè)margin,一旦正樣本和負(fù)樣本打分的差距超過這個(gè)margin,那么損失就變?yōu)?,通過這種手段可以讓hinge loss學(xué)習(xí)到正樣本和負(fù)樣本之間的表征區(qū)別,而且又可以更好地控制訓(xùn)練過程。而交叉熵?fù)p失是我們的老朋友了,如式子(2.2)所示
其中的N為樣本數(shù)量,M為分類類別數(shù)量,而則是預(yù)測(cè)的logit經(jīng)過softmax之后的概率分布。注意到,正如Fig 1.1所示,對(duì)于每個(gè)
而言,其每一行都有
個(gè)負(fù)樣本;對(duì)于每個(gè)
而言,其每一列都有
個(gè)負(fù)樣本,那么就可以組織雙向的損失函數(shù)計(jì)算。這種方式對(duì)于雙塔模型結(jié)構(gòu)來說特別地“劃算”,因?yàn)閷?duì)于雙塔模型而言只需要計(jì)算一次矩陣計(jì)算就可以得到
的打分矩陣,然后通過雙向計(jì)算損失,可以實(shí)現(xiàn)更高效地對(duì)模型的訓(xùn)練。
在hinge loss計(jì)算過程中,還可以通過在這每一行(或者每一列)的N − 1個(gè)負(fù)樣本中選擇一個(gè)最難的負(fù)樣本,也就是打分最高的負(fù)樣本。這一點(diǎn)很容易理解,負(fù)樣本的打分如果打得很高,那么就可以認(rèn)為模型很大程度地將其誤認(rèn)為正樣本了,如果能將其分開,那么模型的表征能力應(yīng)該是更上一層樓的,因此將最難負(fù)樣本作為式子(2.1)的進(jìn)行訓(xùn)練。
訓(xùn)練過程在對(duì)比學(xué)習(xí)訓(xùn)練過程中,我們暫時(shí)只考慮雙塔模型(因?yàn)榻换ナ侥P偷呢?fù)樣本選取策略不同),雖然理論上hinge loss這種基于pairwise樣本選取策略的損失,可以很好地對(duì)比正負(fù)樣本的表征區(qū)別,但是如果模型并沒有進(jìn)行很好地訓(xùn)練就拿去用hinge loss進(jìn)行訓(xùn)練,有可能因?yàn)樨?fù)樣本太難導(dǎo)致訓(xùn)練出現(xiàn)“損失坍縮”(loss collapse)的現(xiàn)象,此時(shí)模型對(duì)正樣本和負(fù)樣本沒有區(qū)分能力,因此對(duì)兩者的打分都極為相似,有,此時(shí)loss坍縮到margin并且恒等于margin不再變化,如Fig 3.1所示。我們可以認(rèn)為模型陷入了平凡解。
Fig 3.1 采用hinge loss導(dǎo)致?lián)p失坍縮的現(xiàn)象。圖省事就直接ipad上畫了,有點(diǎn)丑見諒:_)這個(gè)現(xiàn)象也不一定就會(huì)出現(xiàn),如果采用的模型已經(jīng)進(jìn)行過合適的初始化,就不一定會(huì)出現(xiàn)這個(gè)問題。另外,采用交叉熵?fù)p失進(jìn)行一開始的訓(xùn)練是一種比較穩(wěn)定的方法。在CLIP模型中[4],作者采用了batch size=32,768的配置,在進(jìn)行過allgather機(jī)制,對(duì)所有特征進(jìn)行匯聚后[5],甚至可以實(shí)現(xiàn)32768 × 32768 大小的打分矩陣,這意味著有著海量的負(fù)樣本可供學(xué)習(xí),這也同時(shí)意味著對(duì)模型學(xué)習(xí)的巨大挑戰(zhàn)。因此CLIP文章的作者沒有采用hinge loss訓(xùn)練,而是采用了雙向的交叉熵?fù)p失進(jìn)行訓(xùn)練。
然而在巨大的batch size中訓(xùn)練是有著非常大的誘惑的,在[1]中我們就曾經(jīng)討論過對(duì)于對(duì)比學(xué)習(xí)中,負(fù)樣本增多意味著表征詞典的詞表的增大,有著巨大的效果增益。那么要如何去訓(xùn)練這種超大規(guī)模的batch size下的對(duì)比學(xué)習(xí)任務(wù)呢?筆者個(gè)人認(rèn)為需要進(jìn)行階段式地訓(xùn)練,一步步提高batch size大小。筆者曾經(jīng)試驗(yàn)過,如果一開始就采用很大的batch size進(jìn)行訓(xùn)練,在hinge loss的情況下,將會(huì)非常不穩(wěn)定,很容易出現(xiàn)損失坍塌的現(xiàn)象。而如果循序漸進(jìn)則不會(huì)出現(xiàn)這個(gè)問題,那么是否可以通過這種方法將batch size增加到很大呢(不考慮硬件的約束情況),這個(gè)筆者也還在實(shí)驗(yàn),希望后續(xù)能有個(gè)比較正向的結(jié)論。
同時(shí),在超大規(guī)模的對(duì)比學(xué)習(xí)過程中,如何結(jié)合交叉熵?fù)p失和hinge loss損失也是一個(gè)值得思考的問題。交叉熵?fù)p失穩(wěn)定,但是學(xué)習(xí)速度較慢(筆者實(shí)驗(yàn)發(fā)現(xiàn),不一定準(zhǔn)確),hinge loss不穩(wěn)定,但是學(xué)習(xí)速度更快,如何進(jìn)行平衡是一個(gè)值得嘗試的方向。對(duì)比學(xué)習(xí)在大規(guī)模數(shù)據(jù)上的訓(xùn)練的確還有很多值得探索的呢,公開的論文提供的細(xì)節(jié)也不多。
Reference
[1]. https://fesian.blog.csdn.net/article/details/119515146
[2]. https://blog.csdn.net/LoseInVain/article/details/103995962
[3]. https://blog.csdn.net/LoseInVain/article/details/78636176
[4]. https://fesian.blog.csdn.net/article/details/119516894
[5]. https://medium.com/@cresclux/example-on-torch-distributed-gather-7b5921092