gluonnlp.optimizer

Gluonnlp provides some special optimizers for training in natural language processing.

BERTAdam Optimizer

The Adam optimizer with weight decay regularization for BERT.

BERTAdam

The Adam optimizer with weight decay regularization for BERT.

LAMB Optimizer

Implementation of the LAMB optimizer from the paper Reducing BERT Pre-Training Time from 3 Days to 76 Minutes.

In paper, the empirical results demonstrate the superior performance of LAMB for BERT and ResNet-50 training. By increasing the batch size to the memory limit of a TPUv3 pod, BERT training time can be reduced from 3 days to 76 minutes.

@inproceedings{You2019LargeBO,
  title={Large Batch Optimization for Deep Learning: Training BERT in 76 minutes},
  author={Yang You and Jing Li and Sashank J. Reddi and Jonathan Hseu and Sanjiv Kumar and Srinadh Bhojanapalli and Xiaodan Song and James Demmel and Cho-Jui Hsieh},
  year={2019}}

LAMB

The LAMB optimizer proposed in Reducing BERT Pre-Training Time from 3 Days to 76 Minutes.

API Reference

NLP optimizer.

class gluonnlp.optimizer.BERTAdam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-06, **kwargs)[source]

The Adam optimizer with weight decay regularization for BERT.

Updates are applied by:

rescaled_grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w)

Note that this is different from mxnet.optimizer.Adam, where L2 loss is added and accumulated in m and v. In BERTAdam, the weight decay term decoupled from gradient based update.

This is also slightly different from the AdamW optimizer described in Fixing Weight Decay Regularization in Adam, where the schedule multiplier and learning rate is decoupled, and the bias-correction terms are removed. The BERTAdam optimizer uses the same learning rate to apply gradients w.r.t. the loss and weight decay.

This optimizer accepts the following parameters in addition to those accepted by mxnet.optimizer.Optimizer.

Parameters
  • beta1 (float, optional, default is 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, optional, default is 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, optional, default is 1e-6) – Small value to avoid division by 0.

create_state(_, weight)[source]

state creation function.

create_state_multi_precision(index, weight)[source]

multi-precision state creation function.

update(index, weight, grad, state)[source]

update function

update_multi_precision(index, weight, grad, state)[source]

multi-precision update function

class gluonnlp.optimizer.LAMB(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-06, lower_bound=0.001, upper_bound=10.0, bias_correction=False, **kwargs)[source]

The LAMB optimizer proposed in Reducing BERT Pre-Training Time from 3 Days to 76 Minutes.

If bias_correction is set to False, updates are applied by:

grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
r1 = min(max(w.norm(), lower_bound), upper_bound)
g = m / (sqrt(v_hat) + epsilon) + wd * w
r2 = g.norm()
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr = r * lr
w = w - lr * g

Otherwise, updates are applied by:

grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
m_hat = m / (1 - power(beta1, t))
v_hat = m / (1 - power(beta2, t))
r1 = w.norm()
g = m_hat / (sqrt(v_hat + epsilon)) + wd * w
r2 = g.norm()
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr = r * lr
w = w - lr * g
Parameters
  • beta1 (float, optional, default is 0.9) – Exponential decay rate for the first moment estimates.

  • beta2 (float, optional, default is 0.999) – Exponential decay rate for the second moment estimates.

  • epsilon (float, optional, default is 1e-6) – Small value to avoid division by 0.

  • lower_bound (float, optional, default is 1e-3) – Lower limit of norm of weight

  • upper_bound (float, optional, default is 10.0) – Upper limit of norm of weight

  • bias_correction (bool, optional, default is False) – Whether to use bias correction, in the latest version of the lamb, the bias correction was removed and some simple changes were made.

create_state(index, weight)[source]

Creates auxiliary state for a given weight.

Some optimizers require additional states, e.g. as momentum, in addition to gradients in order to update weights. This function creates state for a given weight which will be used in update. This function is called only once for each weight.

Parameters
  • index (int) – An unique index to identify the weight.

  • weight (NDArray) – The weight.

Returns

state – The state associated with the weight.

Return type

any obj

update(index, weight, grad, state)[source]

Updates the given parameter using the corresponding gradient and state.

Parameters
  • index (int) – The unique index of the parameter into the individual learning rates and weight decays. Learning rates and weight decay may be set via set_lr_mult() and set_wd_mult(), respectively.

  • weight (NDArray) – The parameter to be updated.

  • grad (NDArray) – The gradient of the objective with respect to this parameter.

  • state (any obj) – The state returned by create_state().