ExponentialMovingAverage¶
用指数衰减计算参数的滑动平均值。给定参数 \(\theta\),它的指数滑动平均值 (exponential moving average, EMA) 为
\[\begin{split}\begin{align}\begin{aligned}\text{EMA}_0 & = 0\\\text{EMA}_t & = \text{decay} * \text{EMA}_{t-1} + (1 - \text{decay}) * \theta_t\end{aligned}\end{align}\end{split}\]
用 update()
方法计算出的平均结果将保存在由实例化对象创建和维护的临时变量中,并且可以通过调用 apply()
方法把结果应用于当前模型的参数。同时,可用 restore()
方法恢复原始参数。
偏置校正
所有的滑动平均均初始化为 \(0\),因此它们相对于零是有偏的,可以通过除以因子 \((1 - \text{decay}^t)\) 来校正,因此在调用 apply()
方法时,作用于参数的真实滑动平均值将为:
\[\widehat{\text{EMA}}_t = \frac{\text{EMA}_t}{1 - \text{decay}^t}\]
衰减率调节
一个非常接近于 1 的很大的衰减率将会导致平均值滑动得很慢。更优的策略是,开始时设置一个相对较小的衰减率。参数 thres_steps
允许用户传递一个变量以设置衰减率,在这种情况下, 真实的衰减率变为:
\[\min(\text{decay}, \frac{1 + \text{thres_steps}}{10 + \text{thres_steps}})\]
通常 thres_steps
可以是全局的训练迭代步数。
参数¶
decay (float,可选) – 指数衰减率,通常接近 1,如 0.999 ,0.9999 ,···。默认值为 0.999 。
thres_steps (Variable|None,可选) – 调节衰减率的阈值步数,默认值为 None 。
name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。
代码示例¶
import numpy
import paddle
import paddle.static as static
from paddle.static import ExponentialMovingAverage
paddle.enable_static()
data = static.data(name='x', shape=[-1, 5], dtype='float32')
hidden = static.nn.fc(x=data, size=10)
cost = paddle.mean(hidden)
test_program = static.default_main_program().clone(for_test=True)
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(cost)
ema = ExponentialMovingAverage(0.999)
ema.update()
place = paddle.CPUPlace()
exe = static.Executor(place)
exe.run(static.default_startup_program())
for pass_id in range(3):
for batch_id in range(6):
data = numpy.random.random(size=(10, 5)).astype('float32')
exe.run(program=static.default_main_program(),
feed={'x': data},
fetch_list=[cost.name])
# usage 1
with ema.apply(exe):
data = numpy.random.random(size=(10, 5)).astype('float32')
exe.run(program=test_program,
feed={'x': data},
fetch_list=[hidden.name])
# usage 2
with ema.apply(exe, need_restore=False):
data = numpy.random.random(size=(10, 5)).astype('float32')
exe.run(program=test_program,
feed={'x': data},
fetch_list=[hidden.name])
ema.restore(exe)