Source code for gluonnlp.optimizer.lamb
# coding: utf-8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""LAMB optimizer"""
from mxnet.optimizer import Optimizer, register
from mxnet.ndarray import zeros, NDArray
from mxnet.ndarray import square, power, sqrt, maximum, minimum, clip
__all__ = ['LAMB']
[docs]@register
class LAMB(Optimizer):
"""The LAMB optimizer proposed in
`Reducing BERT Pre-Training Time from 3 Days to 76 Minutes <https://arxiv.org/abs/1904.00962>`_.
If bias_correction is set to False, updates are applied by::
grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
r1 = min(max(w.norm(), lower_bound), upper_bound)
g = m / (sqrt(v_hat) + epsilon) + wd * w
r2 = g.norm()
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr = r * lr
w = w - lr * g
Otherwise, updates are applied by::
grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
m_hat = m / (1 - power(beta1, t))
v_hat = m / (1 - power(beta2, t))
r1 = w.norm()
g = m_hat / (sqrt(v_hat + epsilon)) + wd * w
r2 = g.norm()
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr = r * lr
w = w - lr * g
Parameters
----------
beta1 : float, optional, default is 0.9
Exponential decay rate for the first moment estimates.
beta2 : float, optional, default is 0.999
Exponential decay rate for the second moment estimates.
epsilon : float, optional, default is 1e-6
Small value to avoid division by 0.
lower_bound : float, optional, default is 1e-3
Lower limit of norm of weight
upper_bound : float, optional, default is 10.0
Upper limit of norm of weight
bias_correction : bool, optional, default is False
Whether to use bias correction, in the latest version of the lamb,
the bias correction was removed and some simple changes were made.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction
[docs] def create_state(self, index, weight):
stype = weight.stype
return (zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=stype), # mean
zeros(weight.shape, weight.context, dtype=weight.dtype,
stype=stype)) # variance
[docs] def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]
# preprocess grad
grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
mean, var = state
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
var[:] = self.beta2 * var + (1. - self.beta2) * square(grad)
r1 = weight.norm()
if not self.bias_correction:
r1 = minimum(maximum(r1, self.lower_bound), self.upper_bound)
g = mean / (sqrt(var) + self.epsilon) + wd * weight
else:
# execution bias correction
mean_hat = mean / (1. - power(self.beta1, t))
var_hat = var / (1. - power(self.beta2, t))
g = mean_hat / sqrt(var_hat + self.epsilon) + wd * weight
r2 = g.norm()
# calculate lamb_trust_ratio
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr *= r
# update weight
weight[:] -= lr * g