trax.optimizers

adafactor

Adafactor optimizer class.

class trax.optimizers.adafactor.Adafactor(learning_rate=0.05, factored=True, multiply_by_parameter_scale=True, do_clipping=True, do_momentum=False, momentum_in_bfloat16=False, beta1=0.0, decay_rate=0.8, clipping_threshold=1.0, weight_decay_rate=1e-05, weight_decay_n_steps=0, epsilon1=1e-16, epsilon2=0.001)

Bases: trax.optimizers.base.Optimizer

Adafactor optimizer, as described in https://arxiv.org/abs/1804.04235.

__init__(learning_rate=0.05, factored=True, multiply_by_parameter_scale=True, do_clipping=True, do_momentum=False, momentum_in_bfloat16=False, beta1=0.0, decay_rate=0.8, clipping_threshold=1.0, weight_decay_rate=1e-05, weight_decay_n_steps=0, epsilon1=1e-16, epsilon2=0.001)

Create the Adafactor optimizer.

Adafactor is described in https://arxiv.org/abs/1804.04235.

Parameters:
  • learning_rate – float: trax-provided learning rate.
  • factored – boolean: whether to use factored second-moment estimator for 2d variables.
  • multiply_by_parameter_scale – boolean: if True, then scale provided learning_rate by parameter norm. if False, provided learning_rate is absolute step size.
  • do_clipping – whether to clip gradients; if True, set clipping_theshold.
  • do_momentum – whether to use momentum; if True, set beta1.
  • momentum_in_bfloat16 – if True, store momentum in bfloat16 to save memory.
  • beta1 – a float value between 0 and 1, enables momentum and uses extra memory if nonzero! Off by default.
  • decay_rate – float: controls second-moment exponential decay schedule.
  • clipping_threshold – an optional float >= 1, if None no update clipping.
  • weight_decay_rate – rate at which to decay weights.
  • weight_decay_n_steps – for how many steps to decay weights (always if None)
  • epsilon1 – Regularization constant for squared gradient.
  • epsilon2 – Regularization constant for parameter scale.
init(weights)

Creates optimizer slots that fit the given weights.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes updated layer weights and optimizer slots for one training step.

Parameters:
  • step – Training step number.
  • grads – Gradient values for this node (from back-propagation during a training step).
  • weights – Current weight values for this node (i.e., layer weights).
  • slots – Current slot values for this node.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns:

Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

adam

Adam optimizer class.

class trax.optimizers.adam.Adam(learning_rate=0.0001, weight_decay_rate=1e-05, b1=0.9, b2=0.999, eps=1e-05, clip_grad_norm=None)

Bases: trax.optimizers.base.Optimizer

Adam optimizer; described in https://arxiv.org/abs/1412.6980.

The update rule for time step \(t\), given gradients \(g_t\) and “Stepsize” \(\alpha\), is:

\[\begin{split}\hat{m}_t &\leftarrow \big(\beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t\big)\ /\ (1 - \beta_1^t) \\ \hat{v}_t &\leftarrow \big(\beta_2 \cdot m_{t-1} + (1 - \beta_2) \cdot g_t^2\big)\ /\ (1 - \beta_2^t) \\ \theta_t &\leftarrow \theta_{t-1} -\ \alpha \cdot \hat{m}_t / \big(\sqrt{\hat{v}_t} + \epsilon\big)\end{split}\]
__init__(learning_rate=0.0001, weight_decay_rate=1e-05, b1=0.9, b2=0.999, eps=1e-05, clip_grad_norm=None)

Creates an Adam optimizer.

Parameters:
  • learning_rate – Initial (unadapted) learning rate \(\alpha\); original paper calls this Stepsize and suggests .001 as a generally good value.
  • weight_decay_rate – Fraction of prior weight values to subtract on each step; equivalent to multiplying each weight element by 1 - weight_decay_rate. (This is not part of the core Adam algorithm.)
  • b1 – Exponential decay rate \(\beta_1\) for first moment estimates.
  • b2 – Exponential decay rate \(\beta_2\) for second moment estimates.
  • eps – Small positive constant \(\epsilon\) for numerical stability.
  • clip_grad_norm – Threshold value above which gradient clipping occurs. (This is not part of the core Adam algorithm.)
init(weights)

Creates optimizer slots that fit the given weights.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes updated layer weights and optimizer slots for one training step.

Parameters:
  • step – Training step number.
  • grads – Gradient values for this node (from back-propagation during a training step).
  • weights – Current weight values for this node (i.e., layer weights).
  • slots – Current slot values for this node.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns:

Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

base

Trax base optimizer class.

class trax.optimizers.base.Optimizer(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)

Bases: object

Base class for optimizers that work hand in hand with Trax layers.

To define an optimizer subclass, specify its behavior with respect to a single node in the network (e.g., a single dense layer):

  • init: how to create/initialize optimizer-internal parameters (“slots”),
    as a function of the node’s weights.
  • update: how to use gradient information to update node weights and
    optimizer slots.

The Trax runtime combines these node-local computations into layer weight updates and optimizer slot updates for the whole tree of layers in the model.

__init__(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)

Sets initial hyperparameter values for this optimizer.

Takes optimizer hyperparameters as keyword arguments. These values can change over time (training steps), e.g., for learning rate schedules.

To expose subclass hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.

Parameters:
  • learning_rate – Learning rate for the optimizer. This can change during training by means of a training rate schedule.
  • clip_grad_norm – If specified, this scalar value is used to limit gradient size – all gradient elements in a training step are treated as if they belonged to a single vector and then scaled back if needed so that such a vector’s L2 norm does not exceed clip_grad_norm. If None, no clipping happens.
  • **init_opt_params – Initial values of any additional optimizer parameters.
init(weights)

Creates optimizer slots that fit the given weights.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes updated layer weights and optimizer slots for one training step.

Parameters:
  • step – Training step number.
  • grads – Gradient values for this node (from back-propagation during a training step).
  • weights – Current weight values for this node (i.e., layer weights).
  • slots – Current slot values for this node.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns:

Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

slots
opt_params
tree_init(weight_tree)

Assembles node-local initializations into full-tree initialization.

Parameters:weight_tree – Weights for an entire model, in a tree that matches the model’s layer structure.
Returns:Tuple (slots, opt_params), where slots are the initialized optimizer slot values and opt_params are optimizer hyperparameters (e.g., learning rate, momentum).
tree_update(step, grad_tree, weight_tree, slots, opt_params, store_slots=True)

Assembles node-local weight and slot updates for the full layer tree.

Parameters:
  • step – Current step number in the training process.
  • grad_tree – Gradients for the entire model, in a tree that matches the model’s layer structure.
  • weight_tree – Current weights for the entire model, in a tree that matches the model’s layer structure.
  • slots – Optimizer slots.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
  • store_slots – Boolean; if True, stores resulting slots in this object; when set to False, this becomes a pure function.
Returns:

Tuple (weights, slots), where weights are the optimizer-updated weights for the whole model (in a tree matching the model’s layer structure) and slots are the updated optimizer slot values.

class trax.optimizers.base.SGD(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)

Bases: trax.optimizers.base.Optimizer

Stochastic gradient descent (SGD) optimizer.

init(weights)

Creates optimizer slots that fit the given weights.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, slots, opt_params)

Computes updated layer weights and optimizer slots for one training step.

Parameters:
  • step – Training step number.
  • grads – Gradient values for this node (from back-propagation during a training step).
  • weights – Current weight values for this node (i.e., layer weights).
  • slots – Current slot values for this node.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns:

Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

trax.optimizers.base.l2_norm(tree)

Returns an L2 norm computed over all elements of all tensors in tree.

Parameters:tree – Tree-structured collection of tensors, e.g., model weights matching the model’s layer structure.
Returns:A scalar value computed as if all the tensors in tree were combined and flattened into a single vector, and then the L2 norm of that vector was calculated.
trax.optimizers.base.clip_grads(grad_tree, max_norm)

Proportionally reduces each gradient value to respect an aggregate limit.

Parameters:
  • grad_tree – Gradient values structured as a tree of tensors matching the model’s layer structure.
  • max_norm – The aggregate limit on gradient values. All gradient elements in grad_tree are treated as if they belonged to a single vector and that vector is shortened if needed so that its L2 norm does not exceed clip_grad_norm.
Returns:

A new tree of tensors matching the structure of grad_tree, but with element values proportionally rescaled as needed to respect the max_norm limit.

momentum

Nesterov momentum optimizer (also known as Nesterov Accelerated Gradient).

class trax.optimizers.momentum.Momentum(learning_rate=0.01, mass=0.9, weight_decay_rate=1e-05, nesterov=True)

Bases: trax.optimizers.base.Optimizer

A momentum optimizer.

This class implements two variants of momentum stochastic gradient descent (SGD): with and without the Nesterov correction. The implementation of the Nesterov update is based on the concepts in Sutskever et al. (2013) [http://jmlr.org/proceedings/papers/v28/sutskever13.pdf], reformulated in Bengio et al. (2012) [https://arxiv.org/abs/1212.0901], to work well with backpropagation (equations 6 and 7):

\[\begin{split}v_t &= \mu_{t-1}v_{t-1} - \epsilon_{t-1}\nabla f(\Theta_{t-1}) \\ \Theta_t &= \Theta_{t-1} - \mu_{t-1} v_{t-1} + \mu_t v_t + v_t\end{split}\]

where \(\mu_{t-1}\) is the momentum (decay) coefficient at time step \(t-1\) and \(\epsilon_{t-1}\) is the learning rate at \(t-1\).

Note that the implementation below also includes a weight decay rate (\(\alpha\)) on the parameters, independent of the Nesterov momentum.

__init__(learning_rate=0.01, mass=0.9, weight_decay_rate=1e-05, nesterov=True)

Sets initial hyperparameter values for this optimizer.

Takes optimizer hyperparameters as keyword arguments. These values can change over time (training steps), e.g., for learning rate schedules.

To expose subclass hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.

Parameters:
  • learning_rate – Learning rate for the optimizer. This can change during training by means of a training rate schedule.
  • clip_grad_norm – If specified, this scalar value is used to limit gradient size – all gradient elements in a training step are treated as if they belonged to a single vector and then scaled back if needed so that such a vector’s L2 norm does not exceed clip_grad_norm. If None, no clipping happens.
  • **init_opt_params – Initial values of any additional optimizer parameters.
init(weights)

Creates optimizer slots that fit the given weights.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, velocity, opt_params)

Computes updated layer weights and optimizer slots for one training step.

Parameters:
  • step – Training step number.
  • grads – Gradient values for this node (from back-propagation during a training step).
  • weights – Current weight values for this node (i.e., layer weights).
  • slots – Current slot values for this node.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns:

Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

rms_prop

RMSProp optimizer class.

class trax.optimizers.rms_prop.RMSProp(learning_rate=0.001, gamma=0.9, eps=1e-08, clip_grad_norm=None)

Bases: trax.optimizers.base.Optimizer

RMSProp optimizer.

Uses optimizer weights (“slots”) to maintain a root-mean-square exponentially decaying average of gradients from prior training batches.

__init__(learning_rate=0.001, gamma=0.9, eps=1e-08, clip_grad_norm=None)

Sets initial hyperparameter values for this optimizer.

Takes optimizer hyperparameters as keyword arguments. These values can change over time (training steps), e.g., for learning rate schedules.

To expose subclass hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.

Parameters:
  • learning_rate – Learning rate for the optimizer. This can change during training by means of a training rate schedule.
  • clip_grad_norm – If specified, this scalar value is used to limit gradient size – all gradient elements in a training step are treated as if they belonged to a single vector and then scaled back if needed so that such a vector’s L2 norm does not exceed clip_grad_norm. If None, no clipping happens.
  • **init_opt_params – Initial values of any additional optimizer parameters.
init(weights)

Creates optimizer slots that fit the given weights.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, grads, weights, avg_sq_grad, opt_params)

Computes updated layer weights and optimizer slots for one training step.

Parameters:
  • step – Training step number.
  • grads – Gradient values for this node (from back-propagation during a training step).
  • weights – Current weight values for this node (i.e., layer weights).
  • slots – Current slot values for this node.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns:

Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

sm3

SM3 optimizer class.

class trax.optimizers.sm3.MomentumType

Bases: enum.IntEnum

An enumeration.

EMA = 1
HEAVY_BALL = 2
NESTEROV = 3
class trax.optimizers.sm3.SM3(learning_rate=0.01, momentum=0.9, second_moment_averaging=1.0, weight_decay=0.0, momentum_type=<MomentumType.EMA: 1>)

Bases: trax.optimizers.base.Optimizer

SM3 optimizer, as described in https://arxiv.org/abs/1901.11150.

__init__(learning_rate=0.01, momentum=0.9, second_moment_averaging=1.0, weight_decay=0.0, momentum_type=<MomentumType.EMA: 1>)

Create the SM3 optimizer.

Memory-Efficient Adaptive Optimization. https://arxiv.org/abs/1901.11150

Parameters:
  • learning_rate – a postitive scalar value for the initial learning rate.
  • momentum – optional, a positive scalar value for momentum
  • second_moment_averaging – averaging of second moments (if 1.0, adds from begining of time like AdaGrad).
  • weight_decay – Weight decay for regularizing the model.
  • momentum_type – Nestrov, Heavy-Ball or EMA (Default).
init(w)

Creates optimizer slots that fit the given weights.

Parameters:weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.
update(step, g, w, slots, opt_params)

Computes updated layer weights and optimizer slots for one training step.

Parameters:
  • step – Training step number.
  • grads – Gradient values for this node (from back-propagation during a training step).
  • weights – Current weight values for this node (i.e., layer weights).
  • slots – Current slot values for this node.
  • opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns:

Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.