trax.supervised¶
decoding¶
Decoding with Trax models.

trax.supervised.decoding.
autoregressive_sample_stream
(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, accelerate=True, eval_mode=False, eval_min_length=1)¶ Yields samples from model, in autoregressive language model fashion.
This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and further calls to autoregressive_sample_stream repeat the process for successive positions indefinitely.
Inputs and outputs always come in batches, even if size 1. If inputs is present, it must have shape (batch_size, inputs_sequence_length), and each output in the stream has shape (batch_size, 1).
Parameters:  model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as an autoregressive onesampleatatime predictor (e.g., trax.models.TransformerLM), except if eval_mode is set – any model can be sampled then, but the sampling process may be much slower.
 inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model generates the first output based on just the start symbol.
 batch_size – Number of sequences to generate in parallel as a batch.
 temperature – Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability).
 start_id – Integer representing the start symbol for the autoregressive process, or array of shape (batch_size, 1) of such integers.
 accelerate – If True, create an accelerated version of model and use it for generating outputs.
 eval_mode – If True, assume the model is created in eval mode and sample by collecting all previous outputs and passing the whole tensor.
 eval_min_length – If set, the minimum length to pad to in eval mode.
Yields: Tensor of integers with shape (batch_size, 1), representing the batch of outputs for the next position in the stream.

trax.supervised.decoding.
autoregressive_sample
(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, eos_id=1, max_length=100, accelerate=True, eval_mode=False, eval_min_length=1)¶ Returns a batch of sequences created by autoregressive sampling.
This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and this loop repeats until either the model outputs the eos_id value or the output sequence reaches max_length items.
Parameters:  model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as autoregressive onesampleatatime predictor (e.g., trax.models.TransformerLM), except if eval_mode is set – any model can be sampled then, but the sampling process may be much slower.
 inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model must generate the first output with no input to guide it.
 batch_size – Number of sequences to generate in parallel as a batch.
 temperature – Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability).
 start_id – The start symbol (ID/integer) for the autoregressive process, or array of shape (batch_size, 1) of such integers.
 eos_id – The endofsequence symbol (ID/integer) for the autoregressive process.
 max_length – Maximum length for generated sequences.
 accelerate – If True, create an accelerated version of model and use it for generating outputs.
 eval_mode – If True, assume the model is created in eval mode and sample by collecting all previous outputs and passing the whole tensor.
 eval_min_length – If set, the minimum length to pad to in eval mode.
Returns: Tensor of integers with shape (batch_size, output_length) representing a batch of output sequences. output_length is the maximum length of the output sequences, where each sequence can be no longer than max_length.

trax.supervised.decoding.
beam_search
(model, inputs=None, batch_size=1, n_beams=2, start_id=0, eos_id=1, max_length=100, length_penalty=1.0, accelerate=True)¶ Returns a batch of n_beamssequences created by beam search.
This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and this loop repeats until either the model outputs the eos_id value or the output sequence reaches max_length items – but keeping n_beams top beams.
Parameters:  model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as autoregressive onesampleatatime predictor (e.g., trax.models.TransformerLM).
 inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model must generate the first output with no input to guide it.
 batch_size – Number of sequences to generate in parallel as a batch.
 n_beams – How many beams to consider at the same time.
 start_id – The start symbol (ID/integer) for the autoregressive process, or array of shape (batch_size, 1) of such integers.
 eos_id – The endofsequence symbol (ID/integer) for the autoregressive process.
 max_length – Maximum length for generated sequences.
 length_penalty – Factor alpha in calculating the length penalty for beams.
 accelerate – If True, create an accelerated version of model and use it for generating outputs.
Returns: Tensor of integers with shape (batch_size, n_beams, output_length) with a batch of output sequences. output_length is the maximum length of the output sequences, where each sequence can be no longer than max_length.
lr_schedules¶
Learning rate (LR) schedules.
In Trax a learning rate schedule is a function: \(\text{step} \mapsto \text{learning_rate}\). This module provides helpers for constructing such functions. For example:
constant(0.001)
returns a function that always returns 0.001.

trax.supervised.lr_schedules.
constant
(value)¶ Returns an LR schedule that is constant from time (step) 1 to infinity.

trax.supervised.lr_schedules.
warmup
(n_warmup_steps, max_value)¶ Returns an LR schedule with linear warmup followed by constant value.
Parameters:  n_warmup_steps – Number of steps during which the learning rate rises on a line connecting (0, 0) and (n_warmup_steps, max_value).
 max_value – Value for learning rate after warmup has finished.

trax.supervised.lr_schedules.
warmup_and_rsqrt_decay
(n_warmup_steps, max_value)¶ Returns an LR schedule with warmup + reciprocal square root decay.

trax.supervised.lr_schedules.
multifactor
(factors='constant * linear_warmup * rsqrt_decay', constant=0.1, warmup_steps=400, decay_factor=0.5, steps_per_decay=20000, steps_per_cycle=100000, second_constant=0.01, second_constant_step=10000, minimum=0)¶ Factorbased learning rate schedule.
Interprets factors in the factors string which can consist of: * constant: interpreted as the constant value, * linear_warmup: interpreted as linear warmup until warmup_steps, * rsqrt_decay: divide by square root of max(step, warmup_steps) * decay_every: Every k steps decay the learning rate by decay_factor. * cosine_deay: Cyclic cosine decay, uses steps_per_cycle parameter. * two_constants: constant until second_constant_step, then switch to
second_constant.Parameters:  factors – a string with factors separated by ‘*’ that defines the schedule.
 constant – float, the starting constant for the learning rate schedule.
 warmup_steps – how many steps to warm up for in the warmup schedule.
 decay_factor – The amount to decay the learning rate by.
 steps_per_decay – How often to decay the learning rate.
 steps_per_cycle – Steps per cycle when using cosine decay.
 second_constant – float, the second constant for the learning rate schedule.
 second_constant_step – the step when the second_constant is triggered.
 minimum – if the computed rate is below the minimum, then return the minimum.
Returns: float > {‘learning_rate’: float}, the stepdependent lr.
Return type: a function learning_rate(step)
training¶
Simplified API (under development) for supervised learning/training in Trax.
This module will eventually replace trainer_lib.Trainer
.
Key classes:
Loop
: Core training loop for an nstep training session, starting from random initialization.TrainTask
: Labeled data + feedback mechanism (loss function w/ optimizer) for modifying a model’s weights.Optimizer
: How to compute model weight updates using lossderived gradients. May contain state (“slots”, 11 with model weights) that accumulates across training steps. (This class is defined in thetrax.optimizers
.)EvalTask
: How and when to measure model performance as a function of training step number.

class
trax.supervised.training.
Loop
(model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, checkpoint_low_metric=None, checkpoint_high_metric=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, callbacks=None)¶ Bases:
object
Loop that can run for a given number of steps to train a supervised model.
Can train the model on multiple tasks by interleaving updates according to the
which_task
argument.The typical supervised training process randomly initializes a model and updates its weights via feedback (lossderived gradients) from a training task, by looping through batches of labeled data. A training loop can also be configured to run periodic evals and save intermediate checkpoints.
For speed, the implementation takes advantage of JAX’s composable function transformations (specifically,
jit
andgrad
). It creates JITcompiled pure functions derived from variants of the core model; schematically: training variant: jit(grad(pure_function(model+loss)))
 evals variant: jit(pure_function(model+evals))
In training or during evals, these variants are called with explicit arguments for all relevant input data, model weights/state, optimizer slots, and random number seeds:
 batch: labeled data
 model weights/state: trainable weights and inputrelated state (e.g., as used by batch norm)
 optimizer slots: weights in the optimizer that evolve during the training process
 random number seeds: JAX PRNG keys that enable highquality, distributed, repeatable generation of pseudorandom numbers

__init__
(model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, checkpoint_low_metric=None, checkpoint_high_metric=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, callbacks=None)¶ Configures a training
Loop
, including a random initialization.Parameters:  model – Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 tasks – List of
TrainTask
instances, which define the training data, loss function, and optimizer to be used in respective tasks in this training loop. It can also be a singleTrainTask
instance which is treated in the same way as a singleton list.  eval_model – Optional Trax layer, representing model used for evaluation,
e.g., with dropout turned off. If
None
, the training model (model) will be used.  eval_tasks – List of
EvalTask
instances which define how to evaluate the model: which validation data to use and which metrics to report. Evaluation on each of the tasks and will run and be reported separately which allows to score a model on different subtasks. This argument can also beNone
, in which case no evals will be run, or a singleEvalTask
, which wil be treated in the same way as a singleton list.  output_dir – Path telling where to save outputs (evals and checkpoints).
Can be
None
if botheval_task
andcheckpoint_at
areNone
.  checkpoint_at – Function (integer –> boolean) telling, for step n, whether
that step should have its checkpoint saved. If
None
, the default is periodic checkpointing attask.n_steps_per_checkpoint
.  checkpoint_low_metric – Name of metric, or None. The metric name must
be one of the metric names from the evals in
eval_tasks
. At checkpoint times determined bycheckpoint_at
, a separate specially named checkpoint will be saved (overwriting any previous version) if the designated metric reaches a value less than or equal to any previous recorded low value. No such checkpoint is saved if arg value is None.  checkpoint_high_metric – Name of metric, or None. The metric name must
be one of the metric names from the evals in
eval_tasks
. At checkpoint times determined bycheckpoint_at
, a separate specially named checkpoint will be saved (overwriting any previous version) if the designated metric reaches a value greater than or equal to any previous recorded high value. No such checkpoint is saved if arg value is None.  permanent_checkpoint_at – Function (integer –> boolean) telling,
for step n, whether that step should have its checkpoint saved
permanently. If
None
, the default is periodic checkpointing attask.n_steps_per_permanent_checkpoint
.  eval_at – Function (integer –> boolean) that says, for training step n,
whether that step should run evals. If
None
, run evals on the first step and on every N’th step, as determined by the first training task.  which_task – Function (integer –> integer) indicating which task should be
used at which training step. Can be set to
None
in singletask training.  n_devices – integer or
None
, the number of devices for this computation.  random_seed – the random seed to use; time/os dependent if
None
(default).  loss_chunk_size – int, if > 0 use chunks of this size to make loss computation more more memoryefficient.
 use_memory_efficient_trainer – whether to use a special memoryefficient trainer; if set to 2, the memory efficiency if very aggressive
 adasum – if True, use adaptive summation for multidevice gradients
 callbacks – List of subclasses of StepCallback to call on training steps.

run
(n_steps=1)¶ Runs this training loop for n steps.
Optionally runs evals and saves checkpoints at specified points.
Parameters: n_steps – Stop training after completing n steps.

step
¶ Returns current step number in this training session.

history
¶ Returns history in this training session.

n_devices
¶ Returns the number of devices to be used in this computation.

is_chief
¶ Returns true if this Loop is the chief.

model
¶ Returns the model that is training.

tasks
¶ Returns the training tasks.

eval_model
¶ Returns the model used for evaluation.

eval_tasks
¶ Returns the evaluation tasks.

output_dir
¶ Returns the output directory.

new_rng
()¶ Returns a new singleuse random number generator (JAX PRNG key).

update_weights_and_state
(weights=None, state=None)¶ Updates the weights and state of the trained model.
Sends this data both to the singleton model accessible via Loop.model and to the replicated model on the accelerator.
Useful when the weights or state are modified outside of training, e.g. during data collection in RL agents.
Parameters:  weights – Model weights or
None
. IfNone
, don’t set.  state – Model state or
None
. IfNone
, don’t set.
 weights – Model weights or

run_evals
(summary_writers=None)¶ Runs and records evals for this training session.
Parameters: summary_writers – List of pertask Jaxboard summary writers to log metrics.

log_summary
(values, summary_writer, value_prefix, log_prefix, stdout=True)¶ Logs and saves provided metrics.
Parameters:  values – Dict from metric name to metric value.
 summary_writer – Jaxboard summary writer.
 value_prefix – String appended in front of summary_writer entries.
 log_prefix – String appended in front of logs.
 stdout – Boolean saying if logs should be logged to stdout as well.

save_checkpoint
(basename)¶ Saves checkpoint (multiple files) to disk for the current training step.
Saving a checkpoint will overwrite any previous checkpoint saved with the same
basename
. Use differingbasename
values to save multiple checkpoints or multiple copies of the same checkpoint.Parameters: basename – Basename for saving a checkpoint. Full file paths for the saved checkpoint will combine the output dir, basename, and relevant file extensions (e.g., .weights.npy.gz).

load_checkpoint
(directory=None, filename=None)¶ Loads model weights and step from a checkpoint on disk.
Parameters:  directory – Directory with the checkpoint (self._output_dir by default).
 filename – Checkpoint file name (model.pkl.gz by default).

trax.supervised.training.
pickle_to_file
(obj, file_path, gzip=False)¶ Pickle obj to file_path with gzipping and failure protection.

trax.supervised.training.
unpickle_from_file
(file_path, gzip=False)¶ Unpickle obj from file_path with gzipping.

trax.supervised.training.
init_host_and_devices
(n_devices=None, random_seed=None)¶ Initializes host and device attributes for this trainer.
Parameters:  n_devices – Number of devices this trainer will use. If
None
, get the number from the backend.  random_seed – Random seed as the starting point for all random numbers used
by the trainer. If
None
, calculate one from system time and host id.
Returns: True if this trainer has special chief responsibilities. host_count: Number of hosts in this computation. n_devices: The passed in value of n_devices or a computed default (for this
host).
random_seed: The passed in value of random_seed or a computed default.
Return type: is_chief
 n_devices – Number of devices this trainer will use. If