trax.optimizers

adafactor

Adafactor optimizer class.

class trax.optimizers.adafactor.Adafactor(learning_rate=0.05, factored=True, multiply_by_parameter_scale=True, do_clipping=True, do_momentum=False, beta1=0.0, decay_rate=0.8, clipping_threshold=1.0, weight_decay_rate=1e-05, epsilon1=1e-30, epsilon2=0.001)

Bases: trax.optimizers.base.Optimizer

Adafactor optimizer, as described in https://arxiv.org/abs/1804.04235.

__init__(learning_rate=0.05, factored=True, multiply_by_parameter_scale=True, do_clipping=True, do_momentum=False, beta1=0.0, decay_rate=0.8, clipping_threshold=1.0, weight_decay_rate=1e-05, epsilon1=1e-30, epsilon2=0.001)

Create the Adafactor optimizer.

Adafactor is described in https://arxiv.org/abs/1804.04235.

Parameters:
  • learning_rate – float: trax-provided learning rate.
  • factored – boolean: whether to use factored second-moment estimator for 2d variables.
  • multiply_by_parameter_scale – boolean: if True, then scale provided learning_rate by parameter norm. if False, provided learning_rate is absolute step size.
  • do_clipping – whether to clip gradients; if True, set clipping_theshold.
  • do_momentum – whether to use momentum; if True, set beta1.
  • beta1 – a float value between 0 and 1, enables momentum and uses extra memory if nonzero! Off by default.
  • decay_rate – float: controls second-moment exponential decay schedule.
  • clipping_threshold – an optional float >= 1, if None no update clipping.
  • weight_decay_rate – rate at which to decay weights.
  • epsilon1 – Regularization constant for squared gradient.
  • epsilon2 – Regularization constant for parameter scale.
init(weights)

Creates optimizer slots for the given parameters.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes one step’s worth of updates.

The update computes both new weights for the layer/node and new slot values for the optimizer.

Parameters:
  • step – Current step number in the training process.
  • grads – Gradients for the weights of the sublayer.
  • weights – Current weights for the sublayer.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple of (new_weights, new_slots).

adam

Adam optimizer class.

class trax.optimizers.adam.Adam(learning_rate=0.0001, weight_decay_rate=1e-05, b1=0.9, b2=0.999, eps=1e-05, clip_grad_norm=None)

Bases: trax.optimizers.base.Optimizer

Adam optimizer; described in https://arxiv.org/abs/1412.6980.

The update rule for time step \(t\), given gradients \(g_t\) and “Stepsize” \(\alpha\), is:

\[\begin{split}\hat{m}_t &\leftarrow \big(\beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t\big)\ /\ (1 - \beta_1^t) \\ \hat{v}_t &\leftarrow \big(\beta_2 \cdot m_{t-1} + (1 - \beta_2) \cdot g_t^2\big)\ /\ (1 - \beta_2^t) \\ \theta_t &\leftarrow \theta_{t-1} -\ \alpha \cdot \hat{m}_t / \big(\sqrt{\hat{v}_t} + \epsilon\big)\end{split}\]
__init__(learning_rate=0.0001, weight_decay_rate=1e-05, b1=0.9, b2=0.999, eps=1e-05, clip_grad_norm=None)

Creates an Adam optimizer.

Parameters:
  • learning_rate – Initial (unadapted) learning rate \(\alpha\); original paper calls this Stepsize and suggests .001 as a generally good value.
  • weight_decay_rate – Fraction of prior weight values to subtract on each step; equivalent to multiplying each weight element by 1 - weight_decay_rate. (This is not part of the core Adam algorithm.)
  • b1 – Exponential decay rate \(\beta_1\) for first moment estimates.
  • b2 – Exponential decay rate \(\beta_2\) for second moment estimates.
  • eps – Small positive constant \(\epsilon\) for numerical stability.
  • clip_grad_norm – Threshold value above which gradient clipping occurs. (This is not part of the core Adam algorithm.)
init(weights)

Creates optimizer slots for the given parameters.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes one step’s worth of updates.

The update computes both new weights for the layer/node and new slot values for the optimizer.

Parameters:
  • step – Current step number in the training process.
  • grads – Gradients for the weights of the sublayer.
  • weights – Current weights for the sublayer.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple of (new_weights, new_slots).

base

Trax base optimizer class.

class trax.optimizers.base.Optimizer(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)

Bases: object

Base class for optimizers that work hand in hand with Trax layers.

To define an optimizer subclass, specify its behavior with respect to a single level/node in the network (e.g., a single dense layer):

  • init: how to create/initialize optimizer-internal weights (“slots”)
    whose shape matches the node’s weight shape.
  • update: how to use gradient information to update node weights and
    optimizer slots.

The Trax runtime combines these node-local computations into weight updates and slot updates for the whole tree of layers in the model.

__init__(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)

Sets initial hyperparameter values for this optimizer.

Takes initial optimizer parameters as keyword arguments. These values can be changed between training steps, e.g., for learning rate schedules.

If you want your subclass to expose hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.

Parameters:
  • learning_rate – The initial learning rate.
  • clip_grad_norm – float; the value to which gradients will be clipped.
  • **init_opt_params – Initial values of any additional optimizer parameters.
init(weights)

Creates optimizer slots for the given parameters.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes one step’s worth of updates.

The update computes both new weights for the layer/node and new slot values for the optimizer.

Parameters:
  • step – Current step number in the training process.
  • grads – Gradients for the weights of the sublayer.
  • weights – Current weights for the sublayer.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple of (new_weights, new_slots).

slots
opt_params
tree_init(weight_tree)

Assembles node-local initializations into full-tree initialization.

Parameters:weight_tree – Weights for an entire model, in a tree that matches the model’s layer structure.
Returns:Tuple (slots, opt_params), where slots are the initialized optimizer slot values and opt_params are optimizer hyperparameters (e.g., learning rate, momentum).
tree_update(step, grad_tree, weight_tree, slots, opt_params)

Assembles node-local weight and slot updates for the full layer tree.

Parameters:
  • step – Current step number in the training process.
  • grad_tree – Gradients for the entire model, in a tree that matches the model’s layer structure.
  • weight_tree – Current weights for the entire model, in a tree that matches the model’s layer structure.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple (weights, slots), where weights are the optimizer-updated weights for the whole model (in a tree matching the model’s layer structure) and slots are the updated optimizer slot values.

class trax.optimizers.base.SGD(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)

Bases: trax.optimizers.base.Optimizer

Stochastic gradient descent (SGD) optimizer.

A simple optimizer with no weights (“slots”) of its own.

init(weights)

Creates optimizer slots for the given parameters.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes one step’s worth of updates.

The update computes both new weights for the layer/node and new slot values for the optimizer.

Parameters:
  • step – Current step number in the training process.
  • grads – Gradients for the weights of the sublayer.
  • weights – Current weights for the sublayer.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple of (new_weights, new_slots).

trax.optimizers.base.l2_norm(tree)

Compute the l2 norm of a pytree of arrays. Useful for weight decay.

trax.optimizers.base.clip_grads(grad_tree, max_norm)

Clip gradients stored as a pytree of arrays to maximum norm max_norm.

momentum

Nesterov momentum optimizer (also known as Nesterov Accelerated Gradient).

class trax.optimizers.momentum.Momentum(learning_rate=0.01, mass=0.9, weight_decay_rate=1e-05, nesterov=True)

Bases: trax.optimizers.base.Optimizer

A momentum optimizer.

This class implements two variants of momentum stochastic gradient descent (SGD): with and without the Nesterov correction. The implementation of the Nesterov update is based on the concepts in Sutskever et al. (2013) [http://jmlr.org/proceedings/papers/v28/sutskever13.pdf], reformulated in Bengio et al. (2012) [https://arxiv.org/abs/1212.0901], to work well with backpropagation (equations 6 and 7):

\[\begin{split}v_t &= \mu_{t-1}v_{t-1} - \epsilon_{t-1}\nabla f(\Theta_{t-1}) \\ \Theta_t &= \Theta_{t-1} - \mu_{t-1} v_{t-1} + \mu_t v_t + v_t\end{split}\]

where \(\mu_{t-1}\) is the momentum (decay) coefficient at time step \(t-1\) and \(\epsilon_{t-1}\) is the learning rate at \(t-1\).

Note that the implementation below also includes a weight decay rate (\(\alpha\)) on the parameters, independent of the Nesterov momentum.

__init__(learning_rate=0.01, mass=0.9, weight_decay_rate=1e-05, nesterov=True)

Sets initial hyperparameter values for this optimizer.

Takes initial optimizer parameters as keyword arguments. These values can be changed between training steps, e.g., for learning rate schedules.

If you want your subclass to expose hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.

Parameters:
  • learning_rate – The initial learning rate.
  • clip_grad_norm – float; the value to which gradients will be clipped.
  • **init_opt_params – Initial values of any additional optimizer parameters.
init(weights)

Creates optimizer slots for the given parameters.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, velocity, opt_params)

Computes one step’s worth of updates.

The update computes both new weights for the layer/node and new slot values for the optimizer.

Parameters:
  • step – Current step number in the training process.
  • grads – Gradients for the weights of the sublayer.
  • weights – Current weights for the sublayer.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple of (new_weights, new_slots).

rms_prop

RMSProp optimizer class.

class trax.optimizers.rms_prop.RMSProp(learning_rate=0.001, gamma=0.9, eps=1e-08, clip_grad_norm=None)

Bases: trax.optimizers.base.Optimizer

RMSProp optimizer.

Uses optimizer weights (“slots”) to maintain a root-mean-square exponentially decaying average of gradients from prior training batches.

__init__(learning_rate=0.001, gamma=0.9, eps=1e-08, clip_grad_norm=None)

Sets initial hyperparameter values for this optimizer.

Takes initial optimizer parameters as keyword arguments. These values can be changed between training steps, e.g., for learning rate schedules.

If you want your subclass to expose hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.

Parameters:
  • learning_rate – The initial learning rate.
  • clip_grad_norm – float; the value to which gradients will be clipped.
  • **init_opt_params – Initial values of any additional optimizer parameters.
init(weights)

Creates optimizer slots for the given parameters.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, avg_sq_grad, opt_params)

Computes one step’s worth of updates.

The update computes both new weights for the layer/node and new slot values for the optimizer.

Parameters:
  • step – Current step number in the training process.
  • grads – Gradients for the weights of the sublayer.
  • weights – Current weights for the sublayer.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple of (new_weights, new_slots).

sm3

SM3 optimizer class.

class trax.optimizers.sm3.SM3(learning_rate=0.01, momentum=0.9)

Bases: trax.optimizers.base.Optimizer

SM3 optimizer, as described in https://arxiv.org/abs/1901.11150.

__init__(learning_rate=0.01, momentum=0.9)

Create the SM3 optimizer.

Memory-Efficient Adaptive Optimization for Large-Scale Learning. https://arxiv.org/abs/1901.11150

Parameters:
  • learning_rate – a postitive scalar value for the initial learning rate.
  • momentum – optional, a positive scalar value for momentum
init(weights)

Creates optimizer slots for the given parameters.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes one step’s worth of updates.

The update computes both new weights for the layer/node and new slot values for the optimizer.

Parameters:
  • step – Current step number in the training process.
  • grads – Gradients for the weights of the sublayer.
  • weights – Current weights for the sublayer.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
Returns:

Tuple of (new_weights, new_slots).