trax.supervised

decoding

Decoding with Trax models.

trax.supervised.decoding.autoregressive_sample_stream(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, accelerate=True)

Yields samples from model, in autoregressive language model fashion.

This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and further calls to autoregressive_sample_stream repeat the process for successive positions indefinitely.

Inputs and outputs always come in batches, even if size 1. If inputs is present, it must have shape (batch_size, inputs_sequence_length), and each output in the stream has shape (batch_size, 1).

Parameters:
  • model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as an autoregressive one-sample-at-a-time predictor (e.g., trax.models.TransformerLM).
  • inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model generates the first output based on just the start symbol.
  • batch_size – Number of sequences to generate in parallel as a batch.
  • temperature – Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability).
  • start_id – Integer representing the start symbol for the autoregressive process, or array of shape (batch_size, 1) of such integers.
  • accelerate – If True, create an accelerated version of model and use it for generating outputs.
Yields:

Tensor of integers with shape (batch_size, 1), representing the batch of outputs for the next position in the stream.

trax.supervised.decoding.autoregressive_sample(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, eos_id=1, max_length=100, accelerate=True)

Returns a batch of sequences created by autoregressive sampling.

This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and this loop repeats until either the model outputs the eos_id value or the output sequence reaches max_length items.

Parameters:
  • model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as autoregressive one-sample-at-a-time predictor (e.g., trax.models.TransformerLM).
  • inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model must generate the first output with no input to guide it.
  • batch_size – Number of sequences to generate in parallel as a batch.
  • temperature – Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability).
  • start_id – The start symbol (ID/integer) for the autoregressive process, or array of shape (batch_size, 1) of such integers.
  • eos_id – The end-of-sequence symbol (ID/integer) for the autoregressive process.
  • max_length – Maximum length for generated sequences.
  • accelerate – If True, create an accelerated version of model and use it for generating outputs.
Returns:

Tensor of integers with shape (batch_size, output_length) representing a batch of output sequences. output_length is the maximum length of the output sequences, where each sequence can be no longer than max_length.

Returns a batch of n_beams-sequences created by beam search.

This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and this loop repeats until either the model outputs the eos_id value or the output sequence reaches max_length items – but keeping n_beams top beams.

Parameters:
  • model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as autoregressive one-sample-at-a-time predictor (e.g., trax.models.TransformerLM).
  • inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model must generate the first output with no input to guide it.
  • batch_size – Number of sequences to generate in parallel as a batch.
  • n_beams – How many beams to consider at the same time.
  • start_id – The start symbol (ID/integer) for the autoregressive process, or array of shape (batch_size, 1) of such integers.
  • eos_id – The end-of-sequence symbol (ID/integer) for the autoregressive process.
  • max_length – Maximum length for generated sequences.
  • length_penalty – Factor alpha in calculating the length penalty for beams.
  • accelerate – If True, create an accelerated version of model and use it for generating outputs.
Returns:

Tensor of integers with shape (batch_size, n_beams, output_length) with a batch of output sequences. output_length is the maximum length of the output sequences, where each sequence can be no longer than max_length.

lr_schedules

Learning rate (LR) schedules.

In Trax a learning rate schedule is a function: \(\text{step} \mapsto \text{learning_rate}\). This module provides helpers for constructing such functions. For example:

constant(0.001)

returns a function that always returns 0.001.

trax.supervised.lr_schedules.constant(value)

Returns an LR schedule that is constant from time (step) 1 to infinity.

trax.supervised.lr_schedules.warmup(n_warmup_steps, max_value)

Returns an LR schedule with linear warm-up followed by constant value.

Parameters:
  • n_warmup_steps – Number of steps during which the learning rate rises on a line connecting (0, 0) and (n_warmup_steps, max_value).
  • max_value – Value for learning rate after warm-up has finished.
trax.supervised.lr_schedules.warmup_and_rsqrt_decay(n_warmup_steps, max_value)

Returns an LR schedule with warm-up + reciprocal square root decay.

trax.supervised.lr_schedules.multifactor(factors='constant * linear_warmup * rsqrt_decay', constant=0.1, warmup_steps=400, decay_factor=0.5, steps_per_decay=20000, steps_per_cycle=100000, second_constant=0.01, second_constant_step=10000, minimum=0)

Factor-based learning rate schedule.

Interprets factors in the factors string which can consist of: * constant: interpreted as the constant value, * linear_warmup: interpreted as linear warmup until warmup_steps, * rsqrt_decay: divide by square root of max(step, warmup_steps) * decay_every: Every k steps decay the learning rate by decay_factor. * cosine_deay: Cyclic cosine decay, uses steps_per_cycle parameter. * two_constants: constant until second_constant_step, then switch to

second_constant.
Parameters:
  • factors – a string with factors separated by ‘*’ that defines the schedule.
  • constant – float, the starting constant for the learning rate schedule.
  • warmup_steps – how many steps to warm up for in the warmup schedule.
  • decay_factor – The amount to decay the learning rate by.
  • steps_per_decay – How often to decay the learning rate.
  • steps_per_cycle – Steps per cycle when using cosine decay.
  • second_constant – float, the second constant for the learning rate schedule.
  • second_constant_step – the step when the second_constant is triggered.
  • minimum – if the computed rate is below the minimum, then return the minimum.
Returns:

float -> {‘learning_rate’: float}, the step-dependent lr.

Return type:

a function learning_rate(step)

training

Simplified API (under development) for supervised learning/training in Trax.

This module will eventually replace trainer_lib.Trainer.

Key classes:

  • Loop: Core training loop for an n-step training session, starting from random initialization.
  • TrainTask: Labeled data + feedback mechanism (loss function w/ optimizer) for modifying a model’s weights.
  • Optimizer: How to compute model weight updates using loss-derived gradients. May contain state (“slots”, 1-1 with model weights) that accumulates across training steps. (This class is defined in the trax.optimizers.)
  • EvalTask: How and when to measure model performance as a function of training step number.
class trax.supervised.training.Loop(model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, checkpoint_low_metric=None, checkpoint_high_metric=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, callbacks=None)

Bases: object

Loop that can run for a given number of steps to train a supervised model.

Can train the model on multiple tasks by interleaving updates according to the which_task argument.

The typical supervised training process randomly initializes a model and updates its weights via feedback (loss-derived gradients) from a training task, by looping through batches of labeled data. A training loop can also be configured to run periodic evals and save intermediate checkpoints.

For speed, the implementation takes advantage of JAX’s composable function transformations (specifically, jit and grad). It creates JIT-compiled pure functions derived from variants of the core model; schematically:

  • training variant: jit(grad(pure_function(model+loss)))
  • evals variant: jit(pure_function(model+evals))

In training or during evals, these variants are called with explicit arguments for all relevant input data, model weights/state, optimizer slots, and random number seeds:

  • batch: labeled data
  • model weights/state: trainable weights and input-related state (e.g., as used by batch norm)
  • optimizer slots: weights in the optimizer that evolve during the training process
  • random number seeds: JAX PRNG keys that enable high-quality, distributed, repeatable generation of pseudo-random numbers
__init__(model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, checkpoint_low_metric=None, checkpoint_high_metric=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, callbacks=None)

Configures a training Loop, including a random initialization.

Parameters:
  • model – Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
  • tasks – List of TrainTask instances, which define the training data, loss function, and optimizer to be used in respective tasks in this training loop. It can also be a single TrainTask instance which is treated in the same way as a singleton list.
  • eval_model – Optional Trax layer, representing model used for evaluation, e.g., with dropout turned off. If None, the training model (model) will be used.
  • eval_tasks – List of EvalTask instances which define how to evaluate the model: which validation data to use and which metrics to report. Evaluation on each of the tasks and will run and be reported separately which allows to score a model on different subtasks. This argument can also be None, in which case no evals will be run, or a single EvalTask, which wil be treated in the same way as a singleton list.
  • output_dir – Path telling where to save outputs (evals and checkpoints). Can be None if both eval_task and checkpoint_at are None.
  • checkpoint_at – Function (integer –> boolean) telling, for step n, whether that step should have its checkpoint saved. If None, the default is periodic checkpointing at task.n_steps_per_checkpoint.
  • checkpoint_low_metric – Name of metric, or None. The metric name must be one of the metric names from the evals in eval_tasks. At checkpoint times determined by checkpoint_at, a separate specially named checkpoint will be saved (overwriting any previous version) if the designated metric reaches a value less than or equal to any previous recorded low value. No such checkpoint is saved if arg value is None.
  • checkpoint_high_metric – Name of metric, or None. The metric name must be one of the metric names from the evals in eval_tasks. At checkpoint times determined by checkpoint_at, a separate specially named checkpoint will be saved (overwriting any previous version) if the designated metric reaches a value greater than or equal to any previous recorded high value. No such checkpoint is saved if arg value is None.
  • permanent_checkpoint_at – Function (integer –> boolean) telling, for step n, whether that step should have its checkpoint saved permanently. If None, the default is periodic checkpointing at task.n_steps_per_permanent_checkpoint.
  • eval_at – Function (integer –> boolean) that says, for training step n, whether that step should run evals. If None, run evals on the first step and on every N’th step, as determined by the first training task.
  • which_task – Function (integer –> integer) indicating which task should be used at which training step. Can be set to None in single-task training.
  • n_devices – integer or None, the number of devices for this computation.
  • random_seed – the random seed to use; time/os dependent if None (default).
  • loss_chunk_size – int, if > 0 use chunks of this size to make loss computation more more memory-efficient.
  • use_memory_efficient_trainer – whether to use a special memory-efficient trainer; if set to 2, the memory efficiency if very aggressive
  • adasum – if True, use adaptive summation for multi-device gradients
  • callbacks – List of subclasses of StepCallback to call on training steps.
run(n_steps=1)

Runs this training loop for n steps.

Optionally runs evals and saves checkpoints at specified points.

Parameters:n_steps – Stop training after completing n steps.
step

Returns current step number in this training session.

history

Returns history in this training session.

n_devices

Returns the number of devices to be used in this computation.

is_chief

Returns true if this Loop is the chief.

model

Returns the model that is training.

tasks

Returns the training tasks.

eval_model

Returns the model used for evaluation.

eval_tasks

Returns the evaluation tasks.

output_dir

Returns the output directory.

new_rng()

Returns a new single-use random number generator (JAX PRNG key).

update_weights_and_state(weights=None, state=None)

Updates the weights and state of the trained model.

Sends this data both to the singleton model accessible via Loop.model and to the replicated model on the accelerator.

Useful when the weights or state are modified outside of training, e.g. during data collection in RL agents.

Parameters:
  • weights – Model weights or None. If None, don’t set.
  • state – Model state or None. If None, don’t set.
run_evals(summary_writers=None)

Runs and records evals for this training session.

Parameters:summary_writers – List of per-task Jaxboard summary writers to log metrics.
log_summary(values, summary_writer, value_prefix, log_prefix, stdout=True)

Logs and saves provided metrics.

Parameters:
  • values – Dict from metric name to metric value.
  • summary_writer – Jaxboard summary writer.
  • value_prefix – String appended in front of summary_writer entries.
  • log_prefix – String appended in front of logs.
  • stdout – Boolean saying if logs should be logged to stdout as well.
save_checkpoint(basename)

Saves checkpoint (multiple files) to disk for the current training step.

Saving a checkpoint will overwrite any previous checkpoint saved with the same basename. Use differing basename values to save multiple checkpoints or multiple copies of the same checkpoint.

Parameters:basename – Basename for saving a checkpoint. Full file paths for the saved checkpoint will combine the output dir, basename, and relevant file extensions (e.g., .weights.npy.gz).
load_checkpoint(directory=None, filename=None)

Loads model weights and step from a checkpoint on disk.

Parameters:
  • directory – Directory with the checkpoint (self._output_dir by default).
  • filename – Checkpoint file name (model.pkl.gz by default).
trax.supervised.training.pickle_to_file(obj, file_path, gzip=False)

Pickle obj to file_path with gzipping and failure protection.

trax.supervised.training.unpickle_from_file(file_path, gzip=False)

Unpickle obj from file_path with gzipping.

trax.supervised.training.init_host_and_devices(n_devices=None, random_seed=None)

Initializes host and device attributes for this trainer.

Parameters:
  • n_devices – Number of devices this trainer will use. If None, get the number from the backend.
  • random_seed – Random seed as the starting point for all random numbers used by the trainer. If None, calculate one from system time and host id.
Returns:

True if this trainer has special chief responsibilities. host_count: Number of hosts in this computation. n_devices: The passed in value of n_devices or a computed default (for this

host).

random_seed: The passed in value of random_seed or a computed default.

Return type:

is_chief