【煉丹技巧】指數(shù)移動(dòng)平均(EMA)的原理及PyTorch實(shí)現(xiàn) | 您所在的位置:網(wǎng)站首頁(yè) › 屬豬不能帶三樣?xùn)|西嗎為什么 › 【煉丹技巧】指數(shù)移動(dòng)平均(EMA)的原理及PyTorch實(shí)現(xiàn) |
來(lái)自 | 知乎 地址 | https://zhuanlan.zhihu.com/p/68748778 作者 | Nicolas 編輯 | 樸素人工智能 在深度學(xué)習(xí)中,經(jīng)常會(huì)使用EMA(指數(shù)移動(dòng)平均)這個(gè)方法對(duì)模型的參數(shù)做平均,以求提高測(cè)試指標(biāo)并增加模型魯棒。 今天瓦礫準(zhǔn)備介紹一下EMA以及它的Pytorch實(shí)現(xiàn)代碼。 EMA的定義![]() ![]() ![]() ![]() ![]() ![]() 瓦礫看了網(wǎng)上的一些實(shí)現(xiàn),使用起來(lái)都不是特別方便,所以自己寫了一個(gè)。 代碼語(yǔ)言:javascript復(fù)制class EMA(): def __init__(self, model, decay): self.model = model self.decay = decay self.shadow = {} self.backup = {} def register(self): for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] self.shadow[name] = new_average.clone() def apply_shadow(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow self.backup[name] = param.data param.data = self.shadow[name] def restore(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.backup param.data = self.backup[name] self.backup = {} # 初始化 ema = EMA(model, 0.999) ema.register() # 訓(xùn)練過(guò)程中,更新完參數(shù)后,同步update shadow weights def train(): optimizer.step() ema.update() # eval前,apply shadow weights;eval之后,恢復(fù)原來(lái)模型的參數(shù) def evaluate(): ema.apply_shadow() # evaluate ema.restore()References機(jī)器學(xué)習(xí)模型性能提升技巧: 指數(shù)加權(quán)平均(EMA), https://blog.csdn.net/mikelkl/article/details/85227053Exponential Weighted Average for Deep Neutal Networks, https://www.ashukumar27.io/exponentially-weighted-average/ |
今日新聞 |
推薦新聞 |
專題文章 |
CopyRight 2018-2019 實(shí)驗(yàn)室設(shè)備網(wǎng) 版權(quán)所有 |