学习优化方法的基本框架
首先定义:待优化参数:w,目标函数:f(w),学习率α。
然后进行迭代优化,每个epoch t 的操作如下:
1.目标函数与当前参数的梯度:gt=∇f(wt)
2.由历史梯度按特定方法计算得出的一阶动量:mt=ϕ(g1,g2,⋯,gt)和二阶动量:Vt=ψ(g1,g2,⋯,gt)
3.计算当前参数的下降梯度:ηt=α⋅mt√Vt
4.对当前参数进行更新:wt+1=wt−ηt
关键在于第2步。第二步的变化,产生了不同的算法
SGD
Stochastic Gradient Decent:值考虑当前梯度,不考虑历史梯度,没有动量设置: wt+1=wt−ηt=wt−α⋅gt
SGD with Momentum
为了抑制震荡和加速收敛,从而加入对历史梯度的考虑,引入一阶动量: mt=β1⋅mt−1+(1−β1)⋅gt
SGD with Nesterov Acceleration
NAG中引入了一种试探步的想法,即计算当前位置按一阶动量走一步后位置的梯度。gt不是t时刻的参数wt的梯度,而是t时刻参数wt按动量一更新一次后的参数的梯度: gt=∇f(wt−α⋅mt−1√Vt−1)
wt+1=wt−ηt=wt−α⋅mt=wt−(β1⋅mt−1+(1−β1)⋅∇f(wt−α⋅mt−1√Vt−1))
AdaGrad
引入二阶动量来衡量参数的更新频率,以此来调整学习率。对于经常更新的参数(对应梯度的绝对值大,单个样本依赖强)学习率小一点,而偶尔更新的参数(对应梯度的绝对值小,单个样本依赖弱)学习率大一点,以此实现自适应学习率,对样本进行均衡化的依赖。
二阶动量为:参数对应维度上,所有梯度值的平方和: Vt=t∑τ=1g2τ
wt+1=wt−ηt=wt−α√Vt⋅mt
用上式可知学习率实际上是α√Vt。为了避免分母为0,分母会加上一个小的平滑项。参数更新越频繁,二阶动量越大,学习率就越小。同时分母是单调递增的,会使的学习率单调递减至0,导致训练提前结束,训练不够充分。
AdaDelta / RMSProp
AdaGrad二阶动量的计算过于激进,于是改变为:只关注过去一段时间窗口的下降梯度,而不是累积全部历史梯度,从而避免训练过早结束。改用指数平均的方式计算: Vt=β2∗Vt−1+(1−β2)g2t
wt+1=wt−ηt=wt−α√Vt⋅mt
Adam
把一阶动量和二阶动量都用起来,就是 Adam了——Adaptive + Momentum gt=∇f(wt)m1=g1,V1=g21mt=β1⋅mt−1+(1−β1)⋅gtVt=β2⋅Vt−1+(1−β2)⋅g2twt+1=wt−ηt=wt−α√Vt⋅mt
Nadam
Nadam是Nesterov和Adam的结合, Nadam=Nesterov + Adam gt=∇f(wt+α√Vt−1⋅mt−1)m1=g1,V1=g21mt=β1⋅mt−1+(1−β1)⋅gtVt=β2⋅Vt−1+(1−β2)⋅g2twt+1=wt−ηt=wt−α√Vt⋅mt
优化算法的常用tricks
最后,分享一些在优化算法的选择和使用方面的一些tricks。
- 首先,各大算法孰优孰劣并无定论。如果是刚入门,优先考虑 SGD+Nesterov Momentum 或者 Adam.(Standford 231n : The two recommended updates to use are either SGD+Nesterov Momentum or Adam)
- 选择你熟悉的算法——这样你可以更加熟练地利用你的经验进行调参
- 充分了解你的数据——如果数据是非常稀疏的,那么优先考虑自适应学习率的算法。
- 根据你的需求来选择——在模型设计实验过程中,要快速验证新模型的效果,可以先用Adam进行快速实验优化;在模型上线或者结果发布前,可以用精调的SGD进行模型的极致优化。
- 先用小数据集进行实验。论文 Stochastic Gradient Descent Tricks 指出,随机梯度下降算法的收敛速度和数据集的大小的关系不大。因此可以先用一个具有代表性的小数据集进行实验,测试一下最好的优化算法,并通过参数搜索来寻找最优的训练参数。
- 考虑不同算法的组合。先用Adam进行快速下降,而后再换到 SGD 进行充分的调优。切换策略可以参考本文介绍的方法。
- 数据集一定要充分的打散(shuffle)。这样在使用自适应学习率算法的时候,可以避免某些特征集中出现,而导致的有时学习过度、有时学习不足,使得下降方向出现偏差的问题。
- 训练过程中持续监控训练数据和验证数据上的目标函数值以及精度或者 AUC 等指标的变化情况。对训练数据的监控是要保证模型进行了充分的训练——下降方向正确,且学习率足够高;对验证数据的监控是为了避免出现过拟合。
- 制定一个合适的学习率衰减策略。可以使用定期衰减策略,比如每过多少个 epoch 就衰减一次;或者利用精度或者 AUC 等性能指标来监控,当测试集上的指标不变或者下跌时,就降低学习率。
reference
https://zhuanlan.zhihu.com/p/32230623