Trax Tutorials¶
Trax Quick Intro¶
Trax is an endtoend library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.
 Run a pretrained Transformer: create a translator in a few lines of code
 Features and resources: API docs, where to talk to us, how to open an issue and more
 Walkthrough: how Trax works, how to make new models and train on your own data
We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!
General Setup
Execute the following few cells (once) before running any of the code samples.
[1]:
#@title
# Copyright 2020 Google LLC.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
[2]:
#@title
# Import Trax
!pip install q U trax
import trax
/bin/sh: pip: command not found
1. Run a pretrained Transformer¶
Here is how you create an EnglighGerman translator in a few lines of code:
 create a Transformer model in Trax with trax.models.Transformer
 initialize it from a file with pretrained weights with model.init_from_file
 tokenize your input sentence to input into the model with trax.data.tokenize
 decode from the Transformer with trax.supervised.decoding.autoregressive_sample
 detokenize the decoded result to get the translation with trax.data.detokenize
[3]:
# Create a Transformer model.
# Pretrained model config in gs://traxml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
input_vocab_size=33300,
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=2048, mode='predict')
# Initialize using pretrained weights.
model.init_from_file('gs://traxml/models/translation/ende_wmt32k.pkl.gz',
weights_only=True)
# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.
vocab_dir='gs://traxml/vocabs/',
vocab_file='ende_32k.subword'))[0]
# Decode from the Transformer.
tokenized = tokenized[None, :] # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized, temperature=0.0) # Higher temperature: more diverse results.
# Detokenize,
tokenized_translation = tokenized_translation[0][:1] # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
vocab_dir='gs://traxml/vocabs/',
vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!
2. Features and resources¶
Trax includes basic models (like ResNet, LSTM, Transformer and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.
You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.
 API docs
 chat with us
 open an issue
 subscribe to traxdiscuss for news
3. Walkthrough¶
You can learn here how Trax works, how to create new models and how to train them on your own data.
Tensors and Fast Math¶
The basic units flowing through Trax models are tensors  multidimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations – numpy
. You should take a look at the numpy guide if you don’t know how to operate on tensors: Trax also uses the numpy API for that.
In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath
package thanks to its backends – JAX and TensorFlow numpy.
[4]:
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflownumpy'.
matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix =\n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix =
[[1 2 3]
[4 5 6]
[7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]
Gradients can be calculated using trax.fastmath.grad
.
[5]:
def f(x):
return 2.0 * x * x
grad_f = trax.fastmath.grad(f)
print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
print(f'grad(2x^2) at 2 = {grad_f(2.0)}')
grad(2x^2) at 1 = 4.0
grad(2x^2) at 2 = 8.0
Layers¶
Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding
:
class Embedding(base.Layer):
"""Trainable layer that maps discrete tokens/IDs to vectors."""
def __init__(self,
vocab_size,
d_feature,
kernel_initializer=init.RandomNormalInitializer(1.0)):
"""Returns an embedding layer with given vocabulary size and vector size.
Args:
vocab_size: Size of the input vocabulary. The layer will assign a unique
vector to each id in `range(vocab_size)`.
d_feature: Dimensionality/depth of the output vectors.
kernel_initializer: Function that creates (random) initial vectors for
the embedding.
"""
super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
self._d_feature = d_feature # feature dimensionality
self._vocab_size = vocab_size
self._kernel_initializer = kernel_initializer
def forward(self, x):
"""Returns embedding vectors corresponding to input token IDs.
Args:
x: Tensor of token IDs.
Returns:
Tensor of embedding vectors.
"""
return jnp.take(self.weights, x, axis=0, mode='clip')
def init_weights_and_state(self, input_signature):
"""Randomly initializes this layer's weights."""
del input_signature
shape_w = (self._vocab_size, self._d_feature)
w = self._kernel_initializer(shape_w, self.rng)
self.weights = w
Layers with trainable weights like Embedding
need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.
[6]:
from trax import layers as tl
# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')
# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))
# Run the layer  y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]
shape of y = (15, 32)
Models¶
Models in Trax are built from layers most often using the Serial
and Branch
combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/
, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification
model.
[7]:
model = tl.Serial(
tl.Embedding(vocab_size=8192, d_feature=256),
tl.Mean(axis=1), # Average on axis 1 (length of sentence).
tl.Dense(2), # Classify 2 classes.
)
# You can print model structure.
print(model)
Serial[
Embedding_8192_256
Mean
Dense_2
]
Data¶
To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream)
and get a tuple, e.g., (inputs, targets)
. Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt')
.
[8]:
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream)) # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudolove affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)
Using the trax.data
module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial
and they are functions that you apply to streams to create processed streams.
[9]:
data_pipeline = trax.data.Serial(
trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
trax.data.Shuffle(),
trax.data.FilterByLength(max_length=2048, length_keys=[0]),
trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],
batch_sizes=[512, 128, 32, 8, 1],
length_keys=[0]),
trax.data.AddLossWeights()
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes.
shapes = [(8, 2048), (8,), (8,)]
Supervised training¶
When you have the model and the data, use trax.supervised.training
to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.
[10]:
from trax.supervised import training
# Training task.
train_task = training.TrainTask(
labeled_data=train_batches_stream,
loss_layer=tl.WeightedCategoryCrossEntropy(),
optimizer=trax.optimizers.Adam(0.01),
n_steps_per_checkpoint=500,
)
# Evaluaton task.
eval_task = training.EvalTask(
labeled_data=eval_batches_stream,
metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
n_eval_batches=20 # For less variance in eval numbers.
)
# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm rf {output_dir}
training_loop = training.Loop(model,
train_task,
eval_tasks=[eval_task],
output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(2000)
Step 1: Total number of trainable weights: 2097666
Step 1: Ran 1 train steps in 1.15 secs
Step 1: train WeightedCategoryCrossEntropy  0.69192106
Step 1: eval WeightedCategoryCrossEntropy  0.69349981
Step 1: eval WeightedCategoryAccuracy  0.50312500
Step 500: Ran 499 train steps in 10.62 secs
Step 500: train WeightedCategoryCrossEntropy  0.50712883
Step 500: eval WeightedCategoryCrossEntropy  0.42969493
Step 500: eval WeightedCategoryAccuracy  0.81406250
Step 1000: Ran 500 train steps in 8.89 secs
Step 1000: train WeightedCategoryCrossEntropy  0.35916388
Step 1000: eval WeightedCategoryCrossEntropy  0.41775789
Step 1000: eval WeightedCategoryAccuracy  0.79531250
Step 1500: Ran 500 train steps in 9.13 secs
Step 1500: train WeightedCategoryCrossEntropy  0.35241464
Step 1500: eval WeightedCategoryCrossEntropy  0.35194683
Step 1500: eval WeightedCategoryAccuracy  0.85117188
Step 2000: Ran 500 train steps in 8.54 secs
Step 2000: train WeightedCategoryCrossEntropy  0.29129386
Step 2000: eval WeightedCategoryCrossEntropy  0.37591279
Step 2000: eval WeightedCategoryAccuracy  0.84062500
After training the model, run it like any layer to get results.
[11]:
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :]) # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: There are a few aspects to Park's movies, and in particular Wallace & Gromit, that I would say make them so great. The first is subtlety and observation, the flagship of which is the character of Gromit. He doesn't speak, he doesn't make any noise, all he has are his eyes, brow, and body posture, and with these he commands the film. Park manages to give us everything we need from this silent character through his expression. The comedy and the emotion is conveyed through the subtlest of movements and it works superbly well.<br /><br />Watching the movie you have to be aware of the entire screen. Normally you'll be guided to things in the movies, the screen won't be cluttered too much, there won't be many things to take your eyes away from the main clue or action. Park seems to need to look the other way with his movies. He throws extra content at his audience, there's action in the background, to the side of the screen, even off screen, and there's just about always something in the foreground to catch your eye. His movies are about multiple viewing and discovery, they're layered with jokes and ancillary action.<br /><br />Throughout this film there are layers of things happening on screen, jokes in the foreground maybe on a jar label and background shadows that give away action. You can imagine that for Park the movies has always been an event, and the movies he loves are ones which he wants to watch again and again. This is what shows in his movies, and in through his most beloved characters.<br /><br />Then there are the bizarre and wacky inventions which Wallace make, something which is reflected in the storyline and the twists and turns of the plot, everything is bizarre and off the wall, yet it seems so perfectly normal in this world. You can imagine that inside Park is the mind of Wallace.<br /><br />There's also one more thing that make these movies so unique, and that's the modelling and precise hand animation. I must admit I was concerned when I knew Dreamworks was involved in the making of this movie, and I thought that they would bring their computer animation experience to the forefront. What I was scared of was Wallace & Gromit becoming CGI entities, or at the smallest, CGI being used to clean up the feel that the modelling brought to the movie.<br /><br />Not so. You can still see thumbprints and toolmarks on the characters, and far from distracting from the movie, this just adds so much real feeling to it and a feeling of physical depth to the characters and the scene on screen.<br /><br />So what of the movie? Well I must say that the plot twist was something I had thought about well before the film was in the cinema and it came as no surprise, but that did not affect my enjoyment one little bit. Actually watching the twist unfold and the comic timing of the discovery and reactions was everything, and it had me just as sucked in as if it was a thriller, yet all the time I was laughing.<br /><br />Watching the movie was fascinating in various ways. To see the animation completed, how wild the inventions are, how Wallace is going to get into trouble and Gromit get him out, where all the cross references are in the movie, and where all the jokes are! I must admit afterwards talking with my friends I couldn't believe how much I had missed.<br /><br />There's something different in this movie than with the others, there's a new level of adult humour in here, and I don't mean rude jokes (although there are a couple that are just so British you can't help laughing), I mean jokes that simply fly over kids heads but slap adults in the face. The kind you are used to seeing come out of somewhere like Pixar. This just adds even more appeal to the movie.<br /><br />Okay though, let me try and be a bit negative here. I didn't notice the voices in this movie, you know how you usually listen to the actors and see if you can recognise them? Well I was just too wrapped up in the movie to care or to notice who they were...okay, that's not negative. Let me try again. The main plot wasn't as strong and gripping as I'd expected, and I found myself being caught up in the side stories and the characters themselves...again...that's not a bad thing, the film was just so much rich entertainment.<br /><br />I honestly can't think of a bad thing to say about this movie, probably the worst thing I could say is that the title sequence at the end is quite repetitive...until the final title! Really, that's the worst I can say.<br /><br />The story is a lot of fun, well setup, well written, well executed. There's lot's of fantastic characters in here, not just Wallace & Gromit. There's so much happening on screen, so many references and jokes (check out the dresses of Lady Tottingham), cheese jokes everywhere, jokes for all the family. The characters are superbly absorbing and you'll find that you've taken to them before you realise. There's just so much in this movie for everyone.<br /><br />There's so much I could say and write about, but I know it will quickly turn into a backslapping exercise for Park and Aardman, it would also just turn into a series of "this bit was really funny" and "there's a bit when...", and what I would rather do is tell you that this is a superb movie, to go see it, and to experience the whole thing for yourselves. I will say though that the bunnies are excellent!<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[0.36765265 2.7904649 ]]
Trax Layers Intro¶
This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:
 Layers: the basic building blocks and how to combine them
 Inputs and Outputs: how data streams flow through layers
 Defining New Layer Classes (if combining existing layers isn’t enough)
 Testing and Debugging Layer Classes
General Setup
Execute the following few cells (once) before running any of the code samples in this notebook.
[ ]:
# Copyright 2018 Google LLC.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
[ ]:
# Import Trax
! pip install q U trax
! pip install q tensorflow
from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature
/bin/sh: pip: command not found
/bin/sh: pip: command not found
[ ]:
# Settings and utilities for handling inputs, outputs, and object properties.
np.set_printoptions(precision=3) # Reduce visual noise from extra digits.
def show_layer_properties(layer_obj, layer_name):
template = ('{}.n_in: {}\n'
'{}.n_out: {}\n'
'{}.sublayers: {}\n'
'{}.weights: {}\n')
print(template.format(layer_name, layer_obj.n_in,
layer_name, layer_obj.n_out,
layer_name, layer_obj.sublayers,
layer_name, layer_obj.weights))
1. Layers¶
The Layer class represents Trax’s basic building blocks:
class Layer:
"""Base class for composable layers in a deep learning network.
Layers are the basic building blocks for deep learning models. A Trax layer
computes a function from zero or more inputs to zero or more outputs,
optionally using trainable weights (common) and nonparameter state (not
common). ...
...
Layers compute functions.¶
A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.
The simplest layers, those with no weights or sublayers, can be used without initialization. You can think of them as (pure) mathematical functions that can be plugged into neural networks.
For ease of testing and interactive exploration, layer objects implement the __call__
method, so you can call them directly on input data:
y = my_layer(x)
Layers are also objects, so you can inspect their properties. For example:
print(f'Number of inputs expected by this layer: {my_layer.n_in}')
Example 1. tl.Relu \([n_{in} = 1, n_{out} = 1]\)
[ ]:
relu = tl.Relu()
x = np.array([[2, 1, 0, 1, 2],
[20, 10, 0, 10, 20]])
y = relu(x)
# Show input, output, and two layer properties.
print(f'x:\n{x}\n\n'
f'relu(x):\n{y}\n\n'
f'Number of inputs expected by this layer: {relu.n_in}\n'
f'Number of outputs promised by this layer: {relu.n_out}')
x:
[[ 2 1 0 1 2]
[20 10 0 10 20]]
relu(x):
[[ 0 0 0 1 2]
[ 0 0 0 10 20]]
Number of inputs expected by this layer: 1
Number of outputs promised by this layer: 1
Example 2. tl.Concatenate \([n_{in} = 2, n_{out} = 1]\)
[ ]:
concat = tl.Concatenate()
x0 = np.array([[1, 2, 3],
[4, 5, 6]])
x1 = np.array([[10, 20, 30],
[40, 50, 60]])
y = concat([x0, x1])
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'concat([x1, x2]):\n{y}\n\n'
f'Number of inputs expected by this layer: {concat.n_in}\n'
f'Number of outputs promised by this layer: {concat.n_out}')
x0:
[[1 2 3]
[4 5 6]]
x1:
[[10 20 30]
[40 50 60]]
concat([x1, x2]):
[[ 1 2 3 10 20 30]
[ 4 5 6 40 50 60]]
Number of inputs expected by this layer: 2
Number of outputs promised by this layer: 1
Layers are configurable.¶
Many layer types have creationtime parameters for flexibility. The Concatenate
layer type, for instance, has two optional parameters:
axis
: index of axis along which to concatenate the tensors; default value of 1 means to use the last axis.n_items
: number of tensors to join into one by concatenation; default value is 2.
The following example shows Concatenate
configured for 3 input tensors, and concatenation along the initial \((0^{th})\) axis.
Example 3. tl.Concatenate(n_items=3, axis=0)
[ ]:
concat3 = tl.Concatenate(n_items=3, axis=0)
x0 = np.array([[1, 2, 3],
[4, 5, 6]])
x1 = np.array([[10, 20, 30],
[40, 50, 60]])
x2 = np.array([[100, 200, 300],
[400, 500, 600]])
y = concat3([x0, x1, x2])
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'x2:\n{x2}\n\n'
f'concat3([x0, x1, x2]):\n{y}')
x0:
[[1 2 3]
[4 5 6]]
x1:
[[10 20 30]
[40 50 60]]
x2:
[[100 200 300]
[400 500 600]]
concat3([x0, x1, x2]):
[[ 1 2 3]
[ 4 5 6]
[ 10 20 30]
[ 40 50 60]
[100 200 300]
[400 500 600]]
Layers are trainable.¶
Many layer types include weights that affect the computation of outputs from inputs, and they use backprogagated gradients to update those weights.
🚧🚧 A very small subset of layer types, such as ``BatchNorm``, also include modifiable weights (called ``state``) that are updated based on forwardpass inputs/computation rather than backpropagated gradients.
Initialization
Trainable layers must be initialized before use. Trax can take care of this as part of the overall training process. In other settings (e.g., in tests or interactively in a Colab notebook), you need to initialize the outermost/topmost layer explicitly. For this, use init
:
def init(self, input_signature, rng=None, use_cache=False):
"""Initializes weights/state of this layer and its sublayers recursively.
Initialization creates layer weights and state, for layers that use them.
It derives the necessary array shapes and data types from the layer's input
signature, which is itself just shape and data type information.
For layers without weights or state, this method safely does nothing.
This method is designed to create weights/state only once for each layer
instance, even if the same layer instance occurs in multiple places in the
network. This enables weight sharing to be implemented as layer sharing.
Args:
input_signature: `ShapeDtype` instance (if this layer takes one input)
or list/tuple of `ShapeDtype` instances.
rng: Singleuse random number generator (JAX PRNG key), or `None`;
if `None`, use a default computed from an integer 0 seed.
use_cache: If `True`, and if this layer instance has already been
initialized elsewhere in the network, then return special marker
values  tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
Else return this layer's newly initialized weights and state.
Returns:
A `(weights, state)` tuple.
"""
Input signatures can be built from scratch using ShapeDType
objects, or can be derived from data via the signature
function (in module shapes
):
def signature(obj):
"""Returns a `ShapeDtype` signature for the given `obj`.
A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
instances. Note that this function is permissive with respect to its inputs
(accepts lists or tuples or dicts, and underlying objects can be any type
as long as they have shape and dtype attributes) and returns the corresponding
nested structure of `ShapeDtype`.
Args:
obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
of such objects.
Returns:
A corresponding nested structure of `ShapeDtype` instances.
"""
Example 4. tl.LayerNorm \([n_{in} = 1, n_{out} = 1]\)
[ ]:
layer_norm = tl.LayerNorm()
x = np.array([[2, 1, 0, 1, 2],
[1, 2, 3, 4, 5],
[10, 20, 30, 40, 50]]).astype(np.float32)
layer_norm.init(shapes.signature(x))
y = layer_norm(x)
print(f'x:\n{x}\n\n'
f'layer_norm(x):\n{y}\n')
print(f'layer_norm.weights:\n{layer_norm.weights}')
x:
[[2. 1. 0. 1. 2.]
[ 1. 2. 3. 4. 5.]
[10. 20. 30. 40. 50.]]
layer_norm(x):
[[1.414 0.707 0. 0.707 1.414]
[1.414 0.707 0. 0.707 1.414]
[1.414 0.707 0. 0.707 1.414]]
layer_norm.weights:
(DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))
Layers combine into layers.¶
The Trax library authors encourage users to build networks and network components as combinations of existing layers, by means of a small set of combinator layers. A combinator makes a list of layers behave as a single layer – by combining the sublayer computations yet looking from the outside like any other layer. The combined layer, like other layers, can:
 compute outputs from inputs,
 update parameters from gradients, and
 combine with yet more layers.
Combine with ``Serial``
The most common way to combine layers is with the Serial
combinator:
class Serial(base.Layer):
"""Combinator that applies layers serially (by function composition).
This combinator is commonly used to construct deep networks, e.g., like this::
mlp = tl.Serial(
tl.Dense(128),
tl.Relu(),
tl.Dense(10),
)
A Serial combinator uses stack semantics to manage data for its sublayers.
Each sublayer sees only the inputs it needs and returns only the outputs it
has generated. The sublayers interact via the data stack. For instance, a
sublayer k, following sublayer j, gets called with the data stack in the
state left after layer j has applied. The Serial combinator then:
 takes n_in items off the top of the stack (n_in = k.n_in) and calls
layer k, passing those items as arguments; and
 takes layer k's n_out return values (n_out = k.n_out) and pushes
them onto the data stack.
A Serial instance with no sublayers acts as a specialcase (but useful)
1input 1output noop.
"""
If one layer has the same number of outputs as the next layer has inputs (which is the usual case), the successive layers behave like function composition:
# h(.) = g(f(.))
layer_h = Serial(
layer_f,
layer_g,
)
Note how, inside Serial
, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed.
Example 5. y = layer_norm(relu(x)) \([n_{in} = 1, n_{out} = 1]\)
[ ]:
layer_block = tl.Serial(
tl.Relu(),
tl.LayerNorm(),
)
x = np.array([[2, 1, 0, 1, 2],
[20, 10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)
print(f'x:\n{x}\n\n'
f'layer_block(x):\n{y}')
x:
[[ 2. 1. 0. 1. 2.]
[20. 10. 0. 10. 20.]]
layer_block(x):
[[0.75 0.75 0.75 0.5 1.75]
[0.75 0.75 0.75 0.5 1.75]]
And we can inspect the block as a whole, as if it were just another layer:
Example 5’. Inspecting a Serial
layer.
[ ]:
print(f'layer_block: {layer_block}\n\n'
f'layer_block.weights: {layer_block.weights}')
layer_block: Serial[
Relu
LayerNorm
]
layer_block.weights: ((), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32)))
Combine with ``Branch``
The Branch
combinator arranges layers into parallel computational channels:
def Branch(*layers, name='Branch'):
"""Combinator that applies a list of layers in parallel to copies of inputs.
Each layer in the input list is applied to as many inputs from the stack
as it needs, and their outputs are successively combined on stack.
For example, suppose one has three layers:
 F: 1 input, 1 output
 G: 3 inputs, 1 output
 H: 2 inputs, 2 outputs (h1, h2)
Then Branch(F, G, H) will take 3 inputs and give 4 outputs:
 inputs: a, b, c
 outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)
As an important special case, a None argument to Branch acts as if it takes
one argument, which it leaves unchanged. (It acts as a onearg noop.)
Args:
*layers: List of layers.
name: Descriptive name for this layer.
Returns:
A branch layer built from the given sublayers.
"""
Residual blocks, for example, are implemented using Branch
:
def Residual(*layers, shortcut=None):
"""Wraps a series of layers with a residual connection.
Args:
*layers: One or more layers, to be applied in series.
shortcut: If None (the usual case), the Residual layer computes the
elementwise sum of the stacktop input with the output of the layer
series. If specified, the `shortcut` layer applies to a copy of the
inputs and (elementwise) adds its output to the output from the main
layer series.
Returns:
A layer representing a residual connection paired with a layer series.
"""
layers = _ensure_flat(layers)
layer = layers[0] if len(layers) == 1 else Serial(layers)
return Serial(
Branch(shortcut, layer),
Add(),
)
Here’s a simple code example to highlight the mechanics.
Example 6. Branch
[ ]:
relu = tl.Relu()
times_100 = tl.Fn("Times100", lambda x: x * 100.0)
branch_relu_t100 = tl.Branch(relu, times_100)
x = np.array([[2, 1, 0, 1, 2],
[20, 10, 0, 10, 20]])
branch_relu_t100.init(shapes.signature(x))
y0, y1 = branch_relu_t100(x)
print(f'x:\n{x}\n\n'
f'y0:\n{y0}\n\n'
f'y1:\n{y1}')
x:
[[ 2 1 0 1 2]
[20 10 0 10 20]]
y0:
[[ 0 0 0 1 2]
[ 0 0 0 10 20]]
y1:
[[ 200. 100. 0. 100. 200.]
[2000. 1000. 0. 1000. 2000.]]
2. Inputs and Outputs¶
Trax allows layers to have multiple input streams and output streams. When designing a network, you have the flexibility to use layers that:
 process a single data stream (\(n_{in} = n_{out} = 1\)),
 process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, … $),
 split or inject data streams (\(n_{in} < n_{out}\)), or
 merge or remove data streams (\(n_{in} > n_{out}\)).
We saw in section 1 the example of Residual
, which involves both a split and a merge:
...
return Serial(
Branch(shortcut, layer),
Add(),
)
In other words, layer by layer:
Branch(shortcut, layers)
: makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a noop), and processes the other copy via the given layers (applied in series). [\(n_{in} = 1\), \(n_{out} = 2\)]Add()
: combines the two streams back into one by adding two tensors elementwise. [\(n_{in} = 2\), \(n_{out} = 1\)]
Data Stack¶
Trax supports flexible data flows through a network via a data stack, which is managed by the Serial
combinator:
class Serial(base.Layer):
"""Combinator that applies layers serially (by function composition).
...
A Serial combinator uses stack semantics to manage data for its sublayers.
Each sublayer sees only the inputs it needs and returns only the outputs it
has generated. The sublayers interact via the data stack. For instance, a
sublayer k, following sublayer j, gets called with the data stack in the
state left after layer j has applied. The Serial combinator then:
 takes n_in items off the top of the stack (n_in = k.n_in) and calls
layer k, passing those items as arguments; and
 takes layer k's n_out return values (n_out = k.n_out) and pushes
them onto the data stack.
...
"""
Simple Case 1 – Each layer takes one input and has one output.
This is in effect a single data stream pipeline, and the successive layers behave like function composition:
# s(.) = h(g(f(.)))
layer_s = Serial(
layer_f,
layer_g,
layer_h,
)
Note how, inside Serial
, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed and the order of operations matches the textual order of layers.
Simple Case 2 – Each layer consumes all outputs of the preceding layer.
This is still a single pipeline, but data streams internal to it can split and merge. The Residual
example above illustrates this kind.
General Case – Successive layers interact via the data stack.
As described in the Serial
class docstring, each layer gets its inputs from the data stack after the preceding layer has put its outputs onto the stack. This covers the simple cases above, but also allows for more flexible data interactions between nonadjacent layers. The following example is schematic:
x, y_target = get_batch_of_labeled_data()
model_plus_eval = Serial(
my_fancy_deep_model(), # Takes one arg (x) and has one output (y_hat)
my_eval(), # Takes two args (y_hat, y_target) and has one output (score)
)
eval_score = model_plus_eval((x, y_target))
Here is the corresponding progression of stack states:
 At start: –empty–
 After
get_batch_of_labeled_data()
: x, y_target  After
my_fancy_deep_model()
: y_hat, y_target  After
my_eval()
: score
Note in particular how the application of the model (between stack states 1 and 2) only uses and affects the top element on the stack: x
–> y_hat
. The rest of the data stack (y_target
) comes in use only later, for the eval function.
3. Defining New Layer Classes¶
If you need a layer type that is not easily defined as a combination of existing layer types, you can define your own layer classes in a couple different ways.
With the Fn
layercreating function.¶
Many layer types needed in deep learning compute pure functions from inputs to outputs, using neither weights nor randomness. You can use Trax’s Fn
function to define your own pure layer types:
def Fn(name, f, n_out=1): # pylint: disable=invalidname
"""Returns a layer with no weights that applies the function `f`.
`f` can take and return any number of arguments, and takes only positional
arguments  no default or keyword arguments. It often uses JAXnumpy (`jnp`).
The following, for example, would create a layer that takes two inputs and
returns two outputs  elementwise sums and maxima:
`Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
The layer's number of inputs (`n_in`) is automatically set to number of
positional arguments in `f`, but you must explicitly set the number of
outputs (`n_out`) whenever it's not the default value 1.
Args:
name: Classlike name for the resulting layer; for use in debugging.
f: Pure function from input tensors to output tensors, where each input
tensor is a separate positional arg, e.g., `f(x0, x1) > x0 + x1`.
Output tensors must be packaged as specified in the `Layer` class
docstring.
n_out: Number of outputs promised by the layer; default value 1.
Returns:
Layer executing the function `f`.
"""
Example 7. Use Fn
to define a new layer type:
[ ]:
# Define new layer type.
def Gcd():
"""Returns a layer to compute the greatest common divisor, elementwise."""
return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))
# Use it.
gcd = Gcd()
x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
y = gcd((x0, x1))
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'gcd((x0, x1)):\n{y}')
x0:
[ 1 2 3 4 5 6 7 8 9 10]
x1:
[11 12 13 14 15 16 17 18 19 20]
gcd((x0, x1)):
[ 1 2 1 2 5 2 1 2 1 10]
The Fn
function infers n_in
(number of inputs) as the length of f
’s arg list. Fn
does not infer n_out
(number out outputs) though. If your f
has more than one output, you need to give an explicit value using the n_out
keyword arg.
Example 8. Fn
with multiple outputs:
[ ]:
# Define new layer type.
def SumAndMax():
"""Returns a layer to compute sums and maxima of two input tensors."""
return tl.Fn('SumAndMax',
lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
n_out=2)
# Use it.
sum_and_max = SumAndMax()
x0 = np.array([1, 2, 3, 4, 5])
x1 = np.array([10, 20, 30, 40, 50])
y0, y1 = sum_and_max([x0, x1])
print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'y0:\n{y0}\n\n'
f'y1:\n{y1}')
x0:
[1 2 3 4 5]
x1:
[ 10 20 30 40 50]
y0:
[ 11 18 33 36 55]
y1:
[10 2 30 4 50]
Example 9. Use Fn
to define a configurable layer:
[ ]:
# Function defined in trax/layers/core.py:
def Flatten(n_axes_to_keep=1):
"""Returns a layer that combines one or more trailing axes of a tensor.
Flattening keeps all the values of the input tensor, but reshapes it by
collapsing one or more trailing axes into a single axis. For example, a
`Flatten(n_axes_to_keep=2)` layer would map a tensor with shape
`(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.
Args:
n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;
collapse only the axes after these.
"""
layer_name = f'Flatten_keep{n_axes_to_keep}'
def f(x):
in_rank = len(x.shape)
if in_rank <= n_axes_to_keep:
raise ValueError(f'Input rank ({in_rank}) must exceed the number of '
f'axes to keep ({n_axes_to_keep}) after flattening.')
return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (1,)))
return tl.Fn(layer_name, f)
flatten_keep_1_axis = Flatten(n_axes_to_keep=1)
flatten_keep_2_axes = Flatten(n_axes_to_keep=2)
x = np.array([[[1, 2, 3],
[10, 20, 30],
[100, 200, 300]],
[[4, 5, 6],
[40, 50, 60],
[400, 500, 600]]])
y1 = flatten_keep_1_axis(x)
y2 = flatten_keep_2_axes(x)
print(f'x:\n{x}\n\n'
f'flatten_keep_1_axis(x):\n{y1}\n\n'
f'flatten_keep_2_axes(x):\n{y2}')
x:
[[[ 1 2 3]
[ 10 20 30]
[100 200 300]]
[[ 4 5 6]
[ 40 50 60]
[400 500 600]]]
flatten_keep_1_axis(x):
[[ 1 2 3 10 20 30 100 200 300]
[ 4 5 6 40 50 60 400 500 600]]
flatten_keep_2_axes(x):
[[[ 1 2 3]
[ 10 20 30]
[100 200 300]]
[[ 4 5 6]
[ 40 50 60]
[400 500 600]]]
By defining a Layer
subclass¶
If you need a layer type that uses trainable weights (or state), you can extend the base Layer
class:
class Layer:
"""Base class for composable layers in a deep learning network.
...
Authors of new layer subclasses typically override at most two methods of
the base `Layer` class:
`forward(inputs)`:
Computes this layer's output as part of a forward pass through the model.
`init_weights_and_state(self, input_signature)`:
Initializes weights and state for inputs with the given signature.
...
The forward
method uses weights stored in the layer object (self.weights
) to compute outputs from inputs. For example, here is the definition of forward
for Trax’s Dense
layer:
def forward(self, x):
"""Executes this layer as part of a forward pass through the model.
Args:
x: Tensor of same shape and dtype as the input signature used to
initialize this layer.
Returns:
Tensor of same shape and dtype as the input, except the final dimension
is the layer's `n_units` value.
"""
if self._use_bias:
if not isinstance(self.weights, (tuple, list)):
raise ValueError(f'Weights should be a (w, b) tuple or list; '
f'instead got: {self.weights}')
w, b = self.weights
return jnp.dot(x, w) + b # Affine map.
else:
w = self.weights
return jnp.dot(x, w) # Linear map.
Layer weights must be initialized before the layer can be used; the init_weights_and_state
method specifies how. Continuing the Dense
example, here is the corresponding initialization code:
def init_weights_and_state(self, input_signature):
"""Randomly initializes this layer's weights.
Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the
default case), or a `w` tensor for layers created with `use_bias=False`.
Args:
input_signature: `ShapeDtype` instance characterizing the input this layer
should compute on.
"""
shape_w = (input_signature.shape[1], self._n_units)
shape_b = (self._n_units,)
rng_w, rng_b = fastmath.random.split(self.rng, 2)
w = self._kernel_initializer(shape_w, rng_w)
if self._use_bias:
b = self._bias_initializer(shape_b, rng_b)
self.weights = (w, b)
else:
self.weights = w
By defining a Combinator
subclass¶
TBD
4. Testing and Debugging Layer Classes¶
TBD
Using Trax with TensorFlow NumPy and Keras¶
This notebook (run it in colab) shows how you can run Trax directly with TensorFlow NumPy. You will also see how to use Trax layers and models inside Keras so you can use Trax in production, e.g., with TensorFlow.js or TensorFlow Serving.
 Trax with TensorFlow NumPy: use Trax with TensorFlow NumPy without any code changes
 Convert Trax to Keras: how to get a Keras layer for your Trax model and use it
 Exporting Trax Models for Deployment: how to export Trax models to TensorFlow SavedModel
1. Trax with TensorFlow NumPy¶
In Trax, all computations rely on accelerated math operations happening in the fastmath
module. This module can use different backends for acceleration. One of them is TensorFlow NumPy which uses TensorFlow 2 to accelerate the computations.
The backend can be set using a call to trax.fastmath.set_backend
as you’ll see below. Currently available backends are jax
(default), tensorflownumpy
and numpy
(for debugging). The tensorflownumpy
backend uses TensorFlow Numpy for executing fastmath
functions on TensorFlow, while the jax
backend calls JAX which lowers to TensorFlow XLA.
You may see that tensorflownumpy
and jax
backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.
Let’s train the sentiment analysis model from the Trax intro using TensorFlow NumPy to see how it works.
General Setup
Execute the following few cells (once) before running any of the code samples.
[1]:
#@title
# Copyright 2020 Google LLC.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[2]:
# Install and import Trax
!pip install q U git+https://github.com/google/trax@master
import os
import numpy as np
import trax
Here is how you can set the fastmath backend to tensorflownumpy
and verify that it’s been set.
[3]:
# Use the tensorflownumpy backend.
trax.fastmath.set_backend('tensorflownumpy')
print(trax.fastmath.backend_name())
tensorflownumpy
[4]:
# Create data streams.
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
data_pipeline = trax.data.Serial(
trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
trax.data.Shuffle(),
trax.data.FilterByLength(max_length=2048, length_keys=[0]),
trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],
batch_sizes=[512, 128, 32, 8, 1],
length_keys=[0]),
trax.data.AddLossWeights()
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
# Print example shapes.
example_batch = next(train_batches_stream)
print(f'batch shapes = {[x.shape for x in example_batch]}')
batch shapes = [(8, 2048), (8,), (8,)]
[5]:
# Create the model.
from trax import layers as tl
model = tl.Serial(
tl.Embedding(vocab_size=8192, d_feature=256),
tl.Mean(axis=1), # Average on axis 1 (length of sentence).
tl.Dense(2), # Classify 2 classes.
)
# You can print model structure.
print(model)
Serial[
Embedding_8192_256
Mean
Dense_2
]
[6]:
# Train the model.
from trax.supervised import training
# Training task.
train_task = training.TrainTask(
labeled_data=train_batches_stream,
loss_layer=tl.WeightedCategoryCrossEntropy(),
optimizer=trax.optimizers.Adam(0.01),
n_steps_per_checkpoint=500,
)
# Evaluaton task.
eval_task = training.EvalTask(
labeled_data=eval_batches_stream,
metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
n_eval_batches=20 # For less variance in eval numbers.
)
# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
training_loop = training.Loop(model,
train_task,
eval_tasks=[eval_task],
output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(2000)
Step 1: Total number of trainable weights: 2097666
Step 1: Ran 1 train steps in 1.01 secs
Step 1: train WeightedCategoryCrossEntropy  0.69292086
Step 1: eval WeightedCategoryCrossEntropy  0.68457415
Step 1: eval WeightedCategoryAccuracy  0.56406250
Step 500: Ran 499 train steps in 19.92 secs
Step 500: train WeightedCategoryCrossEntropy  0.50587755
Step 500: eval WeightedCategoryCrossEntropy  0.46716719
Step 500: eval WeightedCategoryAccuracy  0.80625000
Step 1000: Ran 500 train steps in 17.50 secs
Step 1000: train WeightedCategoryCrossEntropy  0.36375266
Step 1000: eval WeightedCategoryCrossEntropy  0.44373559
Step 1000: eval WeightedCategoryAccuracy  0.80000000
Step 1500: Ran 500 train steps in 18.40 secs
Step 1500: train WeightedCategoryCrossEntropy  0.34449804
Step 1500: eval WeightedCategoryCrossEntropy  0.34941847
Step 1500: eval WeightedCategoryAccuracy  0.84687500
Step 2000: Ran 500 train steps in 17.18 secs
Step 2000: train WeightedCategoryCrossEntropy  0.28685242
Step 2000: eval WeightedCategoryCrossEntropy  0.50030373
Step 2000: eval WeightedCategoryAccuracy  0.77539062
[7]:
# Run on an example.
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_activations = model(example_input[None, :]) # Add batch dimension.
print(f'Model returned sentiment activations: {np.asarray(sentiment_activations)}')
example input_str: The movie features another exceptional collaboration between director William Wyler and cinematographer Gregg Toland, the first after Toland worked on Citizen Kane. But the talent of both these men was focused on achieving a perfectly crafted movie, understood in the good old American sense as a great story. The technical aspects of the movie are covered so as the viewer gets absorbed into the action that takes place on the screen without submitting to the power of the image. Technique is seen as a vehicle of representation unlike in Citizen Kane where Welles' baroque style almost drew the attention from the story to the way the story was told. One of my favorite moves with deep focus in this film is the drama conveyed by the returning home welcoming of Homer and Al. If Homer's girl, Wilma comes towards him perfectly in focus, Al goes over to his wife also perfectly in focus. This is a brilliant move because it shows only through the use of the image the nature of these relationships as we will see them throughout the movie: Wilma loves Homer and she accepts him as he is, Al's wife loves him also but she feels unprepared to fully welcome him home. Also later in the film we find out that their marriage has not always been a bed of roses.<br /><br />Wyler is a director whose force lies in being true to his work without feeling the need to boast. He wanted to show his audience how hard it was for the American soldiers returning from the war to fit into a society that either didn't understand them or treated them with contempt. With a perfect cast and great dialogue Goldwin and Wyler produced a movie that will forever be the template for any other returning home movie. The three hours which coincide with the "rough cut" because the test audience back then never felt for a moment that the action was slow and indeed every scene from the film seems perfectly justified. The whole thing is constructed beautifully, every character gets a fair amount of exposure, nothing is left to chance and it is quite pitiful that Hollywood nowadays never manages to bring so much character conflict to the screen. TBYOOL explores the depth of the American way of life, of the American family and society to an extent that makes other movies look like "the children's hour".<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment activations: [[1.6396211 1.6328843]]
2. Convert Trax to Keras¶
Thanks to TensorFlow NumPy you can convert the model you just trained into a Keras layer using trax.AsKeras
. This allows you to:
 use Trax layers inside Keras models
 run Trax models with existing Keras input pipelines
 export Trax models to TensorFlow SavedModel
When creating a Keras layer from a Trax one, the Keras layer weights will get initialized to the ones the Trax layer had at the moment of creation. In this way, you can create Keras layers from pretrained Trax models and save them as SavedModel as shown below.
[8]:
# Convert the model into a Keras layer, use the weights from model.
keras_layer = trax.AsKeras(model)
print(keras_layer)
# Run the Keras layer to verify it returns the same result.
sentiment_activations = keras_layer(example_input[None, :])
print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')
<trax.trax2keras.AsKeras object at 0x7efff5a47a90>
Keras returned sentiment activations: [[1.6396211 1.6328843]]
[9]:
import tensorflow as tf
# Create a full Keras model using the layer from Trax.
inputs = tf.keras.Input(shape=(None,), dtype='int32')
hidden = keras_layer(inputs)
# You can add other Keras layers here operating on hidden.
outputs = hidden
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
print(keras_model)
# Run the Keras model to verify it returns the same result.
sentiment_activations = keras_model(example_input[None, :])
print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')
Keras returned sentiment activations: [[1.6396211 1.6328843]]
3. Exporting Trax Models for Deployment¶
You can export the Keras model to disk as TensorFlow SavedModel. It’s as simple as calling keras_model.save
and allows you to use models with TF tools TensorFlow.js, TensorFlow Serving and TensorFlow Lite.
[10]:
# Save the Keras model to output_dir.
model_file = os.path.join(output_dir, "model_checkpoint")
keras_model.save(model_file)
# Load the model from SavedModel.
loaded_model = tf.keras.models.load_model(model_file)
# Run the loaded model to verify it returns the same result.
sentiment_activations = loaded_model(example_input[None, :])
print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')
Keras returned sentiment activations: [[1.6396211 1.6328843]]
Trax API¶
trax¶
trax.fastmath¶
ops¶
Trax accelerated math operations for fast computing on GPUs and TPUs.
Import these operations directly from fastmath and import fastmath.numpy as np:
from trax import fastmath
from trax.fastmath import numpy as np
x = np.array([1.0, 2.0]) # Use like numpy.
y = np.exp(x) # Common numpy ops are available and accelerated.
z = fastmath.logsumexp(y) # Special operations available from fastmath.
Trax uses either TensorFlow 2 or JAX as backend for accelerating operations. You can select which one to use (e.g., for debugging) with use_backend.

class
trax.fastmath.ops.
Backend
¶ Bases:
enum.Enum
An enumeration.

JAX
= 'jax'¶

TFNP
= 'tensorflownumpy'¶

NUMPY
= 'numpy'¶


class
trax.fastmath.ops.
NumpyBackend
¶ Bases:
object
Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.

trax.fastmath.ops.
numpy
¶ Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.

class
trax.fastmath.ops.
RandomBackend
¶ Bases:
object
Backend providing random functions.

get_prng
(seed)¶

split
(prng, num=2)¶

fold_in
(rng, data)¶

uniform
(*args, **kwargs)¶

randint
(*args, **kwargs)¶

normal
(*args, **kwargs)¶

bernoulli
(*args, **kwargs)¶


trax.fastmath.ops.
logsumexp
(*args, **kwargs)¶ Computes the log of the sum of exponentials of input elements.

trax.fastmath.ops.
expit
(*args, **kwargs)¶ Computes the expit (sigmoid) function.

trax.fastmath.ops.
sigmoid
(*args, **kwargs)¶ Computes the sigmoid (expit) function.

trax.fastmath.ops.
erf
(*args, **kwargs)¶ Computes the erf function.

trax.fastmath.ops.
conv
(*args, **kwargs)¶ Computes a generalized convolution.

trax.fastmath.ops.
avg_pool
(*args, **kwargs)¶ Average pooling.

trax.fastmath.ops.
max_pool
(*args, **kwargs)¶ Max pooling.

trax.fastmath.ops.
sum_pool
(*args, **kwargs)¶ Sum pooling.

trax.fastmath.ops.
top_k
(*args, **kwargs)¶ Top k.

trax.fastmath.ops.
sort_key_val
(*args, **kwargs)¶ Sorts keys along dimension and applies same permutation to values.

trax.fastmath.ops.
scan
(*args, **kwargs)¶ Scan to make recurrent functions run faster on accelerators.

trax.fastmath.ops.
map
(*args, **kwargs)¶ Map a function over leading array axes.

trax.fastmath.ops.
fori_loop
(lower, upper, body_fn, init_val)¶ Loop from lower to upper running body_fn starting from init_val.
The semantics of fori_loop is as follows:
def fori_loop(lower, upper, body_fn, init_val): val = init_val for i in range(lower, upper): val = body_fn(i, val) return val
Parameters:  lower – an integer representing the loop index lower bound (inclusive)
 upper – an integer representing the loop index upper bound (exclusive)
 body_fn – function of type (int, a) > a.
 init_val – initial loop carry value of type a.
Returns: Loop value from the final iteration.

trax.fastmath.ops.
remat
(*args, **kwargs)¶ Recompute everything in the backward pass to same memory.

trax.fastmath.ops.
cond
(*args, **kwargs)¶ Conditional computation to run on accelerators.

trax.fastmath.ops.
lt
(*args, **kwargs)¶ Lessthan function for backends that do not override <.

trax.fastmath.ops.
index_update
(*args, **kwargs)¶

trax.fastmath.ops.
index_add
(*args, **kwargs)¶

trax.fastmath.ops.
index_min
(*args, **kwargs)¶

trax.fastmath.ops.
index_max
(*args, **kwargs)¶

trax.fastmath.ops.
dynamic_slice
(*args, **kwargs)¶

trax.fastmath.ops.
dynamic_slice_in_dim
(*args, **kwargs)¶

trax.fastmath.ops.
dynamic_update_slice
(*args, **kwargs)¶

trax.fastmath.ops.
dynamic_update_slice_in_dim
(*args, **kwargs)¶

trax.fastmath.ops.
stop_gradient
(*args, **kwargs)¶ Identity on the forward pass but 0 (no gradient) on the backward pass.

trax.fastmath.ops.
jit
(*args, **kwargs)¶ JustInTime compiles the given function for use on accelerators.

trax.fastmath.ops.
disable_jit
()¶ Disables JITcompilation; helpful for debugging.

trax.fastmath.ops.
vmap
(*args, **kwargs)¶ Vectorizes the specified function (returns a function).

trax.fastmath.ops.
grad
(*args, **kwargs)¶ Computes the gradient of the specified function (returns a function).

trax.fastmath.ops.
value_and_grad
(*args, **kwargs)¶ Computes the gradient of the specified function together with the value.

trax.fastmath.ops.
vjp
(*args, **kwargs)¶ Computes the vectorJacobian product for the specified function.

trax.fastmath.ops.
custom_grad
(*args, **kwargs)¶ Set a custom gradient computation (override the default) for a function.

trax.fastmath.ops.
custom_vjp
(f, f_fwd, f_bwd, nondiff_argnums=())¶ Set a custom vjp computation (override the default) for a function.

trax.fastmath.ops.
pmap
(*args, **kwargs)¶ Parallelmap to apply a function on multiple accelerators in parallel.

trax.fastmath.ops.
psum
(*args, **kwargs)¶ Parallelsum to use within a pmap’d function for aggregation.

trax.fastmath.ops.
abstract_eval
(*args, **kwargs)¶ Evaluates function just on signatures of parameters, return signatures.

trax.fastmath.ops.
dataset_as_numpy
(*args, **kwargs)¶ Convert a tf.data.Dataset to a stream of numpy arrays.

trax.fastmath.ops.
global_device_count
(*args, **kwargs)¶ Return the number of accelerators (GPUs or TPUs) in all hosts.

trax.fastmath.ops.
local_device_count
(*args, **kwargs)¶ Return the number of accelerators (GPUs or TPUs) available on this host.

trax.fastmath.ops.
set_backend
(name)¶ Sets the default backend to use in Trax.

trax.fastmath.ops.
backend
(name='jax')¶ Returns the backend used to provide fastmath ops (‘tf’ or ‘jax’).

trax.fastmath.ops.
use_backend
(name)¶ Call fastmath functions with a specified backend.

trax.fastmath.ops.
backend_name
()¶ Returns the name of the backend currently in use (‘tf’ or ‘jax’).

trax.fastmath.ops.
is_backend
(backend_)¶
trax.layers¶
acceleration¶
Modifications to data and computation to use accelerators (better).

class
trax.layers.acceleration.
Accelerate
(layer, n_devices=None)¶ Bases:
trax.layers.base.Layer
Accelerates a layer, running in dataparallel way on multiple devices.
By default it uses all available accelerators, splits the input on the first (batch) axis, and runs each part on the corresponding accelerator. If only one accelerator is available, this layer JITcompiles the underlying layer and in this way makes it run faster.
The output is guaranteed to be the same as the output of the original layer if the batch dimension is divisible by the number of devices. If it is not, then 0padding is added to make it divisible and the output may be affected if it relies on layers like batch normalization.
This layer does not require calling
init
if the underlying layer has already been initialized, so it can be used as follows:layer = tl.Serial(...) layer.init(...) fast_layer = tl.Accelerate(layer) y = fast_layer(x) # Split x on batch and run dataparallel
In case the weights of this layer need to be set using the weights of the sublayer, use the
replicate_weights
function:# Instead of layer.weights = new_weights: fast_layer.replicate_weights(new_weights)

__init__
(layer, n_devices=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

pure_fn
(x, weights, state, rng, use_cache=False)¶ Calls
self.sublayer.pure_fn
in an accelerated way.

init
(input_signature)¶ Calls
self.sublayer.init
and replicates its values onto devices.

replicate_weights
(weights)¶ Sets the weights of the sublayer and replicates them for this layer.

replicate_state
(state)¶ Sets the state of the sublayer and replicates it for this layer.

weights
¶ Returns this layer’s weights.
Depending on the layer, the weights can be in the form of:
 an empty tuple
 a tensor (ndarray)
 a nested structure of tuples and tensors
If the layer has sublayers, the weights by convention will be a tuple of length len(sublayers) containing the weights of sublayers. Note that in this case self._weights only marks which ones are shared.

state
¶ Returns a tuple containing this layer’s state; may be empty.
If the layer has sublayers, the state by convention will be a tuple of length len(sublayers) containing sublayer states. Note that in this case self._state only marks which ones are shared.


trax.layers.acceleration.
mean_or_pmean
(n_devices, x, axis=None)¶ Computes the mean of a distributed value
x
.Parameters:  n_devices – Number of devices.
 x – Distributed array.
 axis – Axis along which to compute means; can only be
0
orNone
.
Returns: A local array.

trax.layers.acceleration.
jit_forward
(forward, n_devices, do_mean=True)¶ Returns a JITcompiled forward function running on
n_devices
.

trax.layers.acceleration.
reshape_by_device
(x, n_devices, pure_np=False)¶ Reshapes possibly nested
x
into a shape(n_devices, ...)
.

trax.layers.acceleration.
for_n_devices
(x, n_devices)¶ Replicates/broadcasts
x
forn_devices
.

trax.layers.acceleration.
on_cpu
(x)¶ Puts
x
in CPU memory in JAX.

trax.layers.acceleration.
on_accelerator
(x)¶ Puts
x
in (single) accelerator memory in JAX.
activation_fns¶
Layers that compute activation functions.
An activation layer computes elementwise a nonlinear function of the preceding layer’s output. Historically, an activation function was considered part of each node in each layer of the neural network. Trax follows the common current practice of separating the activation function as its own layer, which enables easier experimentation across different activation functions.

trax.layers.activation_fns.
Relu
()¶ Returns a layer that computes the Rectified Linear Unit (ReLU) function.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]

trax.layers.activation_fns.
ParametricRelu
(a=1.0)¶ Returns a layer that computes a ReLU function with the given slope.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ ax & \text{otherwise}. \end{array} \right.\end{split}\]Parameters: a – Slope of line for positive inputs.

trax.layers.activation_fns.
LeakyRelu
(a=0.01)¶ Returns a ReLUlike layer with linear nonzero outputs for negative inputs.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} ax & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]Parameters: a – Slope of line for negative inputs.

trax.layers.activation_fns.
Elu
(a=1.0)¶ Returns a ReLUlike layer with exponential outputs for negative inputs.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} a \cdot (e^x  1) & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\](Asymptotically, \(f(x)\rightarrow a\) as \(x\rightarrow  \infty\).)
Parameters: a – Coefficient multiplying the exponential, for negative inputs.

trax.layers.activation_fns.
Selu
(alpha=1.6732632423543772, lmbda=1.0507009873554805)¶ Returns an Elulike layer with an additional scaling/slope parameter.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} \lambda \cdot \alpha \cdot (e^x  1) & \text{if}\ x \leq 0, \\ \lambda \cdot x & \text{otherwise}. \end{array} \right.\end{split}\]Parameters:  alpha – Coefficient multiplying the exponential, for negative inputs.
 lmbda – Coefficient scaling the whole function.

trax.layers.activation_fns.
Gelu
()¶ Returns a layer that computes the Gaussian Error Linear Unit function.
\[f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}}))\]

trax.layers.activation_fns.
FastGelu
()¶ Returns a layer that computes a fast approximation to Gelu.
\[f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3))\]where \(a = 0.7978845608\) and \(b = 0.044715\).

trax.layers.activation_fns.
Sigmoid
()¶ Returns a layer that computes the sigmoid function.
\[f(x) = \frac{1}{1 + e^{x}}\]

trax.layers.activation_fns.
Tanh
()¶ Returns a layer that computes the hyperbolic tangent function.
\[f(x) = \frac{e^x  e^{x}}{e^x + e^{x}}\]

trax.layers.activation_fns.
HardSigmoid
()¶ Returns a layer that computes a linear approximation to Sigmoid.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{if}\ 0 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]

trax.layers.activation_fns.
HardTanh
()¶ Returns a layer that computes a linear approximation to Tanh.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 1 & \text{if}\ x \leq 1, \\ x & \text{if}\ 1 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]

trax.layers.activation_fns.
Softplus
()¶ Returns a layer that computes the softplus function.
\[f(x) = \ln(e^x + 1)\]

trax.layers.activation_fns.
Exp
()¶ Returns a layer that computes the elementwise exponential of a tensor.

trax.layers.activation_fns.
Log
()¶ Returns a layer that computes the elementwise logarithm of a tensor.

trax.layers.activation_fns.
Swish
()¶ Returns a layer that computes the Swish function.
\[f(x) = x \cdot \text{sigmoid}(x)\]

trax.layers.activation_fns.
Glu
()¶ Returns a layer that computes the Gated Linear Unit function.
\[f(x) = a \cdot \text{sigmoid}(b)\]where a and b are formed by splitting input in half along axis

class
trax.layers.activation_fns.
ThresholdedLinearUnit
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Bases:
trax.layers.base.Layer
Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf .

init_weights_and_state
(input_signature)¶ Initializes this layer’s single weight to zero.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.
Parameters: inputs – Tensor. Returns: Tensor of same shape and dtype as the input.

attention¶
Attentionrelated layers, as used in Transformer(like) models.
Attention is a trainable mechanism for mapping between collections of vectors:
Whereas classic neural networks assemble nodes of numbers with weighted connections:
 node activations: floating point values (one float per node)
 internode connections: trainable weights (one float per connection),
attention lets one assemble nodes of vectors and use further vectors to calculate connection strengths:
 node activations: floating point vectors, and
 internode connections: computed using trainable vectors.
Computing connection strengths involves several concepts – queries, keys, values, masks, attention heads – that factor heavily into the API below.
NOTE: Attention, positional encoding, and shift layers in this module include
mode
dependent behavior. The possible modes are:
'train'
: in training – dropouts and position shifts active'eval'
: in evals – dropouts inactive, position shifts active'predict'
: in prediction – dropouts and position shifts inactive

trax.layers.attention.
Attention
(d_feature, n_heads=1, dropout=0.0, mode='train')¶ Returns a layer that maps (vectors, mask) to (new_vectors, mask).
This layer type represents one pass of multihead selfattention, from vector set to vector set, using masks to represent outofbound (e.g., padding) positions. It:
 makes three copies of incoming activations and maps these to multihead query (Q) vectors, key (K) vectors, and value (V) vectors, respectively;
 for each head, computes the scaled dot product of each QK pair;
 applies mask to screen out positions that come from padding tokens (indicated by 0 value);
 [in
'train'
mode] applies dropout to QK dot products;  for each head, computes QK attention strengths using a perquery softmax of the QK dot products;
 for each head, for each query position, combines V vectors according to the QK attention strengths; and
 concatenates and fuses resulting perhead vectors into outgoing activations matching original input activation shapes.
Parameters:  d_feature – Last/innermost dimension of activations in the input to and output from this layer.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.

trax.layers.attention.
AttentionQKV
(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None)¶ Returns a layer that maps (AQ, AK, AV, mask) to (newA, mask).
Unlike
Attention
above,AttentionQKV
allows the incoming activations (AQ, AK, and AV) to come from different sources. This is used, for instance, in encoderdecoder attention (Qrelated activations AQ from the decoder, K and Vrelated activations – AK and AV – from the encoder). Otherwise, see theAttention
description for further context/details.Parameters:  d_feature – Last/innermost dimension of activations in the input to and output from this layer.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.  cache_KV_in_predict – Whether to cache K/V arrays in
'predict'
mode.  q_sparsity – Sparsity with which to process queries. If
None
,Dense
is used; if'noop'
, no processing is used.  result_sparsity – Sparsity with which to process result of the attention.
If
None
,Dense
is used; if'noop'
, no processing is used.

class
trax.layers.attention.
PureAttention
(n_heads=1, dropout=0.0, mode='train')¶ Bases:
trax.layers.base.Layer
Returns a layer that maps (Q, K, V, mask) to (activations, mask).
This layer type performs the inner workings of one pass of multihead selfattention. It:
 subdivides incoming Q/K/V activations into multihead versions;
 for each head, computes the scaled dot product of each QK pair;
 applies mask to screen out positions that come from padding tokens (indicated by 0 value);
 [in
'train'
mode] applies dropout to QK dot products;  for each head, computes QK attention strengths using a perquery softmax of the QK dot products;
 for each head, for each query position, combines V vectors according to the QK attention strengths; and
 concatenates and fuses resulting perhead vectors into outgoing activations matching original input activation shapes.

__init__
(n_heads=1, dropout=0.0, mode='train')¶ Returns a new
PureAttention
instance.Parameters:  n_heads – Number of attention heads.
 dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.

forward
(inputs)¶ Returns attentioncomputed activations and unmodified mask.
Parameters: inputs – A (Q, K, V, mask) tuple, whose query, key, and value activations have not yet been subdivided into heads.

class
trax.layers.attention.
DotProductAttention
(dropout=0.0, mode='train')¶ Bases:
trax.layers.base.Layer
Returns a layer that computes perhead attention (via scaled dotproduct).
This layer computes the core of the attention mechanism. Given perhead queries (Q), keys (K), values (V), and mask, it:
 computes the scaled dot product of each QK pair;
 applies mask to screen out positions that come from padding tokens (indicated by 0 value);
 [if created in
'train'
mode] applies dropout to QK dot products;  computes QK attention strengths using a perquery softmax of the QK dot products; and
 for each query position, combines V vectors according to the QK attention strengths.

__init__
(dropout=0.0, mode='train')¶ Creates a
DotProductAttention
instance in a specific mode.Parameters:  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
,'predict'
or'viz'
.
 dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in

forward
(inputs)¶ Returns attentioncomputed perhead activations and unchanged mask.
Parameters: inputs – A (Q, K, V, mask) tuple, whose query, key, and value activations have been subdivided into heads.

trax.layers.attention.
SplitIntoHeads
(n_heads, merged_batch_and_head=True)¶ Returns a layer that reshapes an array for multihead computation.

trax.layers.attention.
MergeHeads
(n_heads, merged_batch_and_head=True)¶ Returns a layer that rejoins heads, after multihead computation.

trax.layers.attention.
ConfigurableAttention
(q_layer, k_layer, v_layer, final_layer, qkv_attention_layer, n_heads=1)¶ Returns a configured multihead selfattention layer.
A
ConfigurableAttention
layer acts similarly toAttention
layers, but with configurable components. It makes three copies of incoming activations and uses
q_layer
,k_layer
, andv_layer
to map activations to multihead query (Q) vectors, key (K) vectors, and value (V) vectors, respectively;  uses
qkv_attention_layer
to compute perhead attention, similar toDotProductAttention
orDotProductCausalAttention
;  concatenates and fuses resulting perhead vectors into activations matching original input activation shapes; and
 applies a final layer,
final_layer
, mapping activations to activations (with shape matching the original input activations).
Parameters:  q_layer – Layer that maps input activations to perhead query activations.
 k_layer – Layer that maps input activations to perhead key activations.
 v_layer – Layer that maps input activations to perhead value activations.
 final_layer – After main multihead computation and rejoining of heads, layer that maps activations to activations (with shape matching the original input activations).
 qkv_attention_layer – Layer the does the core multihead selfattention computation.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.
 makes three copies of incoming activations and uses

trax.layers.attention.
CausalAttention
(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, use_dconv=False, mode='train')¶ Returns a layer that maps activations to activations, with causal masking.
Like
Attention
, this layer type represents one pass of multihead selfattention, but with causal masking rather than paddingbased masking.Parameters:  d_feature – Last/innermost dimension of activations in the input to and output from this layer.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  max_inference_length – Maximum sequence length allowed in nontraining modes.
 use_dconv – if True, use depthwise convolutions on top of dense layers for Q, K and V.
 mode – One of
'train'
,'eval'
, or'predict'
.

class
trax.layers.attention.
DotProductCausalAttention
(dropout=0.0, max_inference_length=2048, mode='train')¶ Bases:
trax.layers.base.Layer
Layer that computes attention strengths by masking out the “future”.
Causal attention uses masking to prevent a given sequence position from attending to positions greater than / following it. This is used, for example, when training autoregressive sequence models, or when decoding a sequence symbol by symbol.
This layer performs the core perhead attention calculation. The layer assumes that any splitting into attention heads precedes it, and that any merging of attention heads will follow it.

__init__
(dropout=0.0, max_inference_length=2048, mode='train')¶ Creates a
DotProductCausalAttention
instance.Parameters:  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  max_inference_length – Maximum sequence length allowed in nontraining modes.
 mode – One of
'train'
,'eval'
, or'predict'
.
 dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in

monkey_patched_mask
()¶

forward
(inputs)¶ Returns attentioncomputed activations.
Parameters: inputs – A (queries, keys, values) tuple.

init_weights_and_state
(input_signature)¶ Initializes this layer for fast inference, if in
'predict'
mode.


trax.layers.attention.
ShiftRight
(n_positions=1, mode='train')¶ Returns a layer that can insert padding to shift the input sequence.
Parameters:  n_positions – Number of positions to shift the input sequence rightward;
initial positions freed by the shift get padded with zeros. Applies
only if layer is created in a non
'eval'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.
 n_positions – Number of positions to shift the input sequence rightward;
initial positions freed by the shift get padded with zeros. Applies
only if layer is created in a non

trax.layers.attention.
PaddingMask
(pad=0)¶ Returns a layer that maps integer sequences to padding masks.
The layer expects as input a batch of integer sequences. The layer output is an ND array that marks for each sequence position whether the integer (e.g., a token ID) in that position represents padding – value
pad
– versus text/content – all other values. The padding mask shape is (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast to cover any number of attention heads and axis 2 will broadcast to cover decoder sequence positions.Parameters: pad – Integer that represents padding rather than a token/content ID.

trax.layers.attention.
EncoderDecoderMask
()¶ Returns a layer that creates a mask for encoderdecoder cross attention.
The layer expects two inputs:
 decoder_input: batch of integer (e.g., token ID) sequences
 mask: padding mask from the encoder
The layer output is a mask that marks for each sequence position (for both encoder and decoder) whether that position can be attended to or not. The encoderdecoder mask shape is (batch_size, 1, decoder_sequence_length, encoder_sequence_length), such that axis 1 will automatically broadcast to cover any number of attention heads.

class
trax.layers.attention.
PositionalEncoding
(max_len=2048, dropout=0.0, dropout_broadcast_dims=(2, ), use_bfloat16=False, start_from_zero_prob=1.0, max_offset_to_add=0, d_feature=None, mode='train')¶ Bases:
trax.layers.base.Layer
Implements bare positional encoding.
Positional encoding includes a kind of dropout, if the layer is created in
'train'
mode with a nonzerodropout
value. For such a layer, on each forward pass a subset of sequence positions selected at random will not receive positional marking.
__init__
(max_len=2048, dropout=0.0, dropout_broadcast_dims=(2, ), use_bfloat16=False, start_from_zero_prob=1.0, max_offset_to_add=0, d_feature=None, mode='train')¶ Creates a
PositionalEncoding
instance in a given mode.Parameters:  max_len – Maximum input sequence length.
 dropout – Probability of not adding positional encoding to a sequence
position. Applies only if layer is created in
'train'
mode.  dropout_broadcast_dims – Axes along which dropout mask values are broadcast rather than individually set at random.
 use_bfloat16 – If
True
, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.  start_from_zero_prob – how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize).
 max_offset_to_add – maximum offset to add to the positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples.
 d_feature – int or None; have this dimension for embeddings + shared FF if not None.
 mode – One of
'train'
,'eval'
, or'predict'
.

forward
(inputs)¶ Returns the input activations, with added positional information.

init_weights_and_state
(input_signature)¶ Randomly initializes the positional encoding vectors.
Parameters: input_signature – ShapeDtype
instance characterizing the input this layer should compute on.

base¶
The key layer abstraction (Layer class) and supporting machinery.

class
trax.layers.base.
Layer
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Bases:
object
Base class for composable layers in a deep learning network.
Layers are the basic building blocks for deep learning models. A layer computes a function from zero or more inputs to zero or more outputs, optionally using trainable weights (common) and nonparameter state (not common).
Layer subclasses typically override at most two methods of the base Layer class:
 forward(inputs):
 Computes the layer’s output as part of a forward pass through the model.
 init_weights_and_state(self, input_signature):
 Initializes the layer’s weights and state to handle input with the given signature (number, shapes and dtypes of input arguments).
A small number of layer types are combinators – they organize the computation of their sublayers, e.g., applying their sublayers in series or in parallel.
All layers have the following properties, with default values implemented in the base Layer class:
 n_in: int (default 1)
 n_out: int (default 1)
 weights: tuple (default empty – the layer has no weights)
 state: tuple (default empty – the layer has no nonparameter state)
 sublayers: tuple (default empty – the layer has no sublayers)
The inputs to a layer are tensors, packaged according to how many there are:
 n_in = 0: an empty tuple
 n_in = 1: one tensor (NOT wrapped in a tuple)
 n_in > 1: a tuple of tensors
(The special treatment of the singleinput case is meant to simplify the work of layer writers; this design choice may be revisited in the future.)
The outputs from a layer are also tensors, packaged the same as layer inputs:
 n_out = 0: an empty tuple
 n_out = 1: the tensor (NOT wrapped in a tuple)
 n_out > 1: a tuple of tensors
The Trax runtime maintains a data stack with which layer calls are composed. For more complex data network architectures, possibly involving multiple data flows, one can view each layer as a function from stack state to stack state, where the function’s inputs are a slice from the stack, and the function’s outputs are spliced back into the stack.

__init__
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

__call__
(x, weights=None, state=None, rng=None)¶ Makes layers callable; for use in tests or interactive settings.
This convenience method helps library users play with, test, or otherwise probe the behavior of layers outside of a full training environment. It presents the layer as callable function from inputs to outputs, with the option of manually specifying weights and nonparameter state per individual call. For convenience, weights and nonparameter state are cached per layer instance, starting from default values of EMPTY_WEIGHTS and EMPTY_STATE, and acquiring nonempty values either by initialization or from values explicitly provided via the weights and state keyword arguments, in which case the old weights will be preserved, and the state will be updated.
Parameters:  x – Zero or more input tensors, packaged as described in the Layer class docstring.
 weights – Weights or None; if None, use self’s cached weights value.
 state – State or None; if None, use self’s cached state value.
 rng – Singleuse random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.
Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng)¶ Custom backward pass to propagate gradients in a custom way.
Parameters:  inputs – Input tensors; can be a (possibly nested) tuple.
 output – The result of running this layer on inputs.
 grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
 weights – This layer’s weights.
 state – This layer’s state prior to the current forward pass.
 new_state – This layer’s state after the current forward pass.
 rng – Singleuse random number generator (JAX PRNG key).
Returns: The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.

init
(input_signature, rng=None, use_cache=False)¶ Initializes weights/state of this layer and its sublayers recursively.
Initialization creates layer weights and state, for layers that use them. It derives the necessary array shapes and data types from the layer’s input signature, which is itself just shape and data type information.
For layers without weights or state, this method safely does nothing.
This method is designed to create weights/state only once for each layer instance, even if the same layer instance occurs in multiple places in the network. This enables weight sharing to be implemented as layer sharing.
Parameters:  input_signature – ShapeDtype instance (if this layer takes one input) or list/tuple of ShapeDtype instances.
 rng – Singleuse random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.
 use_cache – If True, and if this layer instance has already been initialized elsewhere in the network, then return special marker values – tuple (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE). Else return this layer’s newly initialized weights and state.
Returns: A (weights, state) tuple.

init_from_file
(file_name, weights_only=False, input_signature=None)¶ Initializes this layer and its sublayers from a pickled checkpoint.
In the common case (weights_only=False), the file must be a gziped pickled dictionary containing items with keys ‘flat_weights’, `’flat_state’ and ‘input_signature’, which are used to initialize this layer. If input_signature is specified, it’s used instead of the one in the file. If weights_only is True, the dictionary does not need to have the ‘flat_state’ item and the state it not restored either.
Parameters:  file_name – Name/path of the pickled weights/state file.
 weights_only – If True, initialize only the layer’s weights. Else initialize both weights and state.
 input_signature – Input signature to be used instead of the one from file.
Returns: A (weights, state) tuple.

save_to_file
(file_name, weights_only=False, input_signature=None)¶ Saves this layer and its sublayers to a pickled checkpoint.
Parameters:  file_name – Name/path of the pickled weights/state file.
 weights_only – If True, save only the layer’s weights. Else save both weights and state.
 input_signature – Input signature to be used.

name
¶ Returns the name of this layer.

n_in
¶ Returns how many tensors this layer expects as input.

n_out
¶ Returns how many tensors this layer promises as output.

sublayers
¶ Returns a tuple containing this layer’s sublayers; may be empty.

weights
¶ Returns this layer’s weights.
Depending on the layer, the weights can be in the form of:
 an empty tuple
 a tensor (ndarray)
 a nested structure of tuples and tensors
If the layer has sublayers, the weights by convention will be a tuple of length len(sublayers) containing the weights of sublayers. Note that in this case self._weights only marks which ones are shared.

state
¶ Returns a tuple containing this layer’s state; may be empty.
If the layer has sublayers, the state by convention will be a tuple of length len(sublayers) containing sublayer states. Note that in this case self._state only marks which ones are shared.

weights_and_state_signature
(input_signature, unsafe=False)¶ Return a pair containing the signatures of weights and state.

rng
¶ Returns this layer’s current singleuse random number generator.
Code that wants to base random samples on this generator must explicitly split off new generators from it. (See, for example, the rng setter code below.)

pure_fn
(x, weights, state, rng, use_cache=False)¶ Applies this layer as a pure function with no optional args.
This method exposes the layer’s computation as a pure function. This is especially useful for JIT compilation. Do not override, use forward instead.
Parameters:  x – Zero or more input tensors, packaged as described in the Layer class docstring.
 weights – A tuple or list of trainable weights, with one element for this layer if this layer has no sublayers, or one for each sublayer if this layer has sublayers. If a layer (or sublayer) has no trainable weights, the corresponding weights element is an empty tuple.
 state – Layerspecific nonparameter state that can update between batches.
 rng – Singleuse random number generator (JAX PRNG key).
 use_cache – if True, cache weights and state in the layer object; used to implement layer sharing in combinators.
Returns: A tuple of (tensors, state). The tensors match the number (n_out) promised by this layer, and are packaged as described in the Layer class docstring.

output_signature
(input_signature)¶ Returns output signature this layer would give for input_signature.

class
trax.layers.base.
PureLayer
(forward_fn, n_in=1, n_out=1, name='PureLayer')¶ Bases:
trax.layers.base.Layer
Pure function from inputs to outputs, packaged as neural network layer.
The PureLayer class represents the simplest kinds of layers: layers with no trainable weights and no randomness, hence pure functions from inputs to outputs.

__init__
(forward_fn, n_in=1, n_out=1, name='PureLayer')¶ Creates an unconnected PureLayer instance.
Parameters:  forward_fn – Pure function from input tensors to output tensors, where inputs and outputs are packaged as specified for forward.
 n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use only in debugging.

forward
(inputs)¶ Overrides Layer.forward.
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.


trax.layers.base.
Fn
(name, f, n_out=1)¶ Returns a layer with no weights that applies the function f.
f can take and return any number of arguments, and takes only positional arguments – no default or keyword arguments. It often uses JAXnumpy (jnp). The following, for example, would create a layer that takes two inputs and returns two outputs – elementwise sums and maxima:
Fn(‘SumAndMax’, lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)The layer’s number of inputs (n_in) is automatically set to number of positional arguments in f, but you must explicitly set the number of outputs (n_out) whenever it’s not the default value 1.
Parameters:  name – Classlike name for the resulting layer; for use in debugging.
 f – Pure function from input tensors to output tensors, where each input tensor is a separate positional arg, e.g., f(x0, x1) –> x0 + x1. Output tensors must be packaged as specified in the Layer class docstring.
 n_out – Number of outputs promised by the layer; default value 1.
Returns: Layer executing the function f.

exception
trax.layers.base.
LayerError
(layer_name, function_name, caller, input_signature, traceback_string)¶ Bases:
Exception
Exception raised in the layer stack.

__init__
(layer_name, function_name, caller, input_signature, traceback_string)¶ Initialize self. See help(type(self)) for accurate signature.

message
¶ Assembles current layer context into an error message.


trax.layers.base.
flatten_weights_and_state
(weights, state)¶ Flatten weights and state into lists, excluding empty and cached ones.

trax.layers.base.
unflatten_weights_and_state
(flat_weights, flat_state, weights_and_state_signature, weights_only=False)¶ Unflatten weights and state given their signatures.

trax.layers.base.
np_to_file
(list_of_nparrays, file_path, compresslevel)¶ Save numpy arrays to file_path with gzipping and failure protection.

trax.layers.base.
np_from_file
(file_path, compresslevel)¶ Load numpy arrays from file_path with gzipping.

trax.layers.base.
to_list
(outputs)¶ Converts layer outputs to a nested list, for easier equality testing.
Parameters: outputs – A tensor or tuple/list of tensors coming from the forward application of a layer. Each tensor is NumPy ndarraylike, which complicates simple equality testing (e.g., via assertEquals): such tensors require equality testing to use either all (all elements match) or any (at least one element matches), which is not directly supported in absltest. Returns: A nested list structure containing all the output values, but now directly testable using assertEquals.

trax.layers.base.
shard
(tensors, n_shards=None)¶ Shard tensors across n_shards.

trax.layers.base.
unshard_in_pmap
(tensors, n_shards)¶ Unshard tensors that were sharded into n_shards (call inside pmap).

trax.layers.base.
unshard
(tensors, n_shards=None)¶ Unshard tensors that were sharded into n_shards (outside of pmap).
combinators¶
Combinators for composing layers.

class
trax.layers.combinators.
Serial
(*sublayers, name=None, sublayers_to_print=None)¶ Bases:
trax.layers.base.Layer
Combinator that applies layers serially (by function composition).
This combinator is commonly used to construct deep networks, e.g., like this:
mlp = tl.Serial( tl.Dense(128), tl.Relu(), tl.Dense(10), )
A Serial combinator uses stack semantics to manage data for its sublayers. Each sublayer sees only the inputs it needs and returns only the outputs it has generated. The sublayers interact via the data stack. For instance, a sublayer k, following sublayer j, gets called with the data stack in the state left after layer j has applied. The Serial combinator then:
 takes n_in items off the top of the stack (n_in = k.n_in) and calls layer k, passing those items as arguments; and
 takes layer k’s n_out return values (n_out = k.n_out) and pushes them onto the data stack.
A Serial instance with no sublayers acts as a specialcase (but useful) 1input 1output noop.

__init__
(*sublayers, name=None, sublayers_to_print=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

class
trax.layers.combinators.
Parallel
(*sublayers, name=None)¶ Bases:
trax.layers.base.Layer
Combinator that applies a list of layers in parallel to its inputs.
Layers in the list apply to successive spans of inputs, where the spans are determined how many inputs each layer takes. The resulting output is the (flattened) concatenation of the respective layer outputs.
For example, suppose one has three layers:
 F: 1 input, 1 output
 G: 3 inputs, 1 output
 H: 2 inputs, 2 outputs (h1, h2)
Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:
 inputs: a, b, c, d, e, f
 outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f)
As an important special case, a None argument to Parallel acts as if it takes one argument, which it leaves unchanged. (It acts as a onearg noop.) For .. rubric:: Example
Parallel(None, F)
creates a layer that passes its first input unchanged and applies F to the following input(s).

__init__
(*sublayers, name=None)¶ The constructor.
Parameters:  *sublayers – A list of sublayers.
 name – Descriptive name for this layer.
Returns: A new layer in which each of the given sublayers applies to its corresponding span of elements in the dataflow stack.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

class
trax.layers.combinators.
Concatenate
(n_items=2, axis=1)¶ Bases:
trax.layers.base.Layer
Concatenates a number of tensors into a single tensor.
For example:
x = np.array([1, 2]) y = np.array([3, 4]) z = np.array([5, 6]) concat3 = tl.Concatenate(n_items=3) z = concat3((x, y, z)) # z = [1, 2, 3, 4, 5, 6]
Use the axis argument to specify on which axis to concatenate the tensors. By default it’s the last axis, axis=1, and n_items=2.

__init__
(n_items=2, axis=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.


class
trax.layers.combinators.
Split
(n_items=2, axis=1)¶ Bases:
trax.layers.base.Layer
Splits the input into n items along an axis.

__init__
(n_items=2, axis=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.


class
trax.layers.combinators.
Scan
(layer, axis=0, n_carry=1, remat=False, mode='train')¶ Bases:
trax.layers.base.Layer
Applies a layer progressively/cumulatively to an axisderived sequence.
Conceptually, this is a function from a list to a samelength list of partial (cumulative) results. For instance, a list of values ([1, 2, 3, 4, 5]) can transform to a list of cumulative sums ([1, 3, 6, 10, 15]). Functions for the same concept are called scan in Scala, scanl in Haskell, and accumulate* in Factor.
In more detail, we assume the layer takes a tuple of inputs of the following form:
(input1, …, inputN, carry1, …, carryM)and returns:
(output1, …, outputK, new_carry1, …, new_carryM)The scanned version applies the layer iteratively to a tensor treating values at the given axis as if they were a list. For example, to calculate all sums of prefixes of a tensor, we can do this:
def add(x, carry): def f(input, carry): res = input + carry return res, res # output and carry are the same return tl.Fn('add', f, n_out=2) Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6

__init__
(layer, axis=0, n_carry=1, remat=False, mode='train')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

state
¶ Returns a tuple containing this layer’s state.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.


class
trax.layers.combinators.
Cond
(cond, true, false=None, name=None)¶ Bases:
trax.layers.base.Layer
Applies layers conditionally.
For parameters cond, true, and false runs the equivalent of true(y) if cond(x) else false(y), where x is cond.n_in elements from front of the stack and y is the rest of the stack. Exactly one of true and false functions is executed, so it can be used to conditionally run long computations. The state of nonexecuted function is not updated. Note that different branches may be executed on different devices if cond returns different values on them. By default ‘false’ function is an identity.
cond must return exactly one element: a Boolean value. true and false must have the same n_in, and the same n_out.

__init__
(cond, true, false=None, name=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.
Parameters: xs – Tensors of as required by the branches of this conditional. Returns: Tensors resulting from running the chosen branch.


trax.layers.combinators.
Chunk
(layer, chunk_size, pass_unchunkable=True)¶ Executes layer using batch chunks of size chunk_size to save memory.

trax.layers.combinators.
Branch
(*layers, name='Branch')¶ Combinator that applies a list of layers in parallel to copies of inputs.
Each layer in the input list is applied to as many inputs from the stack as it needs, and their outputs are successively combined on stack.
For example, suppose one has three layers:
 F: 1 input, 1 output
 G: 3 inputs, 1 output
 H: 2 inputs, 2 outputs (h1, h2)
Then Branch(F, G, H) will take 3 inputs and give 4 outputs:
 inputs: a, b, c
 outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)
As an important special case, a None argument to Branch acts as if it takes one argument, which it leaves unchanged. (It acts as a onearg noop.)
Parameters:  *layers – List of layers.
 name – Descriptive name for this layer.
Returns: A branch layer built from the given sublayers.

trax.layers.combinators.
Residual
(*layers, shortcut=None)¶ Wraps a series of layers with a residual connection.
Parameters:  *layers – One or more layers, to be applied in series.
 shortcut – If None (the usual case), the Residual layer computes the elementwise sum of the stacktop input with the output of the layer series. If specified, the shortcut layer applies to a copy of the inputs and (elementwise) adds its output to the output from the main layer series.
Returns: A layer representing a residual connection paired with a layer series.

trax.layers.combinators.
Select
(indices, n_in=None, name=None)¶ Copies, reorders, or deletes stack elements according to indices.
Parameters:  indices – A list or tuple of 0based indices to select elements relative to the top of the stack.
 n_in – Number of input elements to pop from the stack, and replace with those specified by indices. If not specified, its value will be calculated as max(indices) + 1.
 name – Descriptive name for this layer.
Returns: Tensors, matching the number selected (n_out = len(indices)). Specifically:
 n_out = 0: an empty tuple
 n_out = 1: one tensor (NOT wrapped in a tuple)
 n_out > 1: a tuple of tensors, with n_out items

trax.layers.combinators.
Drop
()¶ Drops the top stack element.

trax.layers.combinators.
Dup
()¶ Duplicates (copies) the top element on the data stack.

trax.layers.combinators.
Swap
()¶ Swaps the top two stack elements.

trax.layers.combinators.
SerialWithSideOutputs
(layers, n_side_outputs=1)¶ Serial layer with side outputs.
This layer makes it easier to manage the stack when layers have side outputs.
In the simplest case of layers with n_in=1, n_out=2 and with n_side_outputs=1, this layer runs the following computation on x:
side_outputs = [] for i in range(len(layers)): x, side_output = layers[i](x) side_outputs.append(side_output) return [x] + side_outputs
In the general case of layers with variable n_in and n_out and n_side_outputs being a list of N integers, it does the following:
side_outputs = [] for i in range(N): res = layer[i](cur_stack) # remove n_in from stack cur_stack.append(res[:n_side_outputs[i]]) # put back some on stack side_outputs.extend(res[n_side_outputs:]) return cur_stack + side_outputs
Parameters:  layers – a list of layers to execute
 n_side_outputs – an int or a list of ints, how many outputs of each layer to put aside
Returns: A layer that performs the above computation.

trax.layers.combinators.
FlattenList
()¶ Flatten lists.

trax.layers.combinators.
Add
()¶ Adds two tensors.

trax.layers.combinators.
SubtractTop
()¶ Subtracts the first tensor from the second.

trax.layers.combinators.
Multiply
()¶ Multiplies two tensors.

trax.layers.combinators.
Gate
()¶ Returns a gating layer on a (memory, gate, candidate) tuple.
Final update is memory * gate + (1  gate) * candidate
This gating equation may also be referred to as Highway Network. Highway Networks: https://arxiv.org/abs/1505.00387

class
trax.layers.combinators.
Cache
(layer)¶ Bases:
trax.layers.base.Layer
Applies a layer on the first run and returns the outputs on next calls.

__init__
(layer)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

state
¶ Returns a tuple containing this layer’s state; may be empty.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.
Parameters: inputs – Tensors required by the sublayer. Returns: Tensors resulting from running the sublayer the first time.


class
trax.layers.combinators.
BatchLeadingAxes
(layer, n_last_axes_to_keep=1)¶ Bases:
trax.layers.base.Layer
Applies a layer after flattening all but n_last_axes_to_keep to batch.
This can be used to make layers accept an arbitrary number of leading axes (dimensions) as batch. For example, a Convolution layer may normally only operate on tensors of shape [B, W, H, C]. In this case, the layer
BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3)will operate on any tensor […, W, H, C] and treat the leading axes as batch.

__init__
(layer, n_last_axes_to_keep=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.


trax.layers.combinators.
Bidirectional
(forward_layer, axis=1, merge_layer=Concatenate_in2)¶ Bidirectional combinator for RNNs.
Parameters:  forward_layer – A layer, such as trax.layers.LSTM or trax.layers.GRU.
 axis – a time axis of the inputs. Default value is 1.
 merge_layer – A combinator used to combine outputs of the forward and backward RNNs. Default value is ‘trax.layers.Concatenate’.
Example
Bidirectional(RNN(n_units=8))
Returns: The Bidirectional combinator for RNNs.

trax.layers.combinators.
inputs_from_stack
(stack, n)¶ Returns n inputs from stack.

trax.layers.combinators.
outputs_onto_stack
(outputs, stack, n)¶ “Returns the new stack after removing n items and pushing outputs there.
convolution¶
Trax convolution layers.

class
trax.layers.convolution.
Conv
(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Bases:
trax.layers.base.Layer
Layer constructor function for a general convolution layer.

__init__
(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.convolution.
CausalConv
(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Bases:
trax.layers.convolution.Conv
Causal (masked) convolution for [batch x time x depth] sequences.
Maintains causality along time axis. Used in language modeling tasks.

__init__
(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.


trax.layers.convolution.
Conv1d
(filters, kernel_size, stride=1, padding='VALID', kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶

class
trax.layers.convolution.
CausalDepthwiseConv
(kernel_size=3, kernel_initializer=<function ScaledInitializer.<locals>.Init>, use_bfloat16=False)¶ Bases:
trax.layers.base.Layer
A causal depthwise convolution layer.

__init__
(kernel_size=3, kernel_initializer=<function ScaledInitializer.<locals>.Init>, use_bfloat16=False)¶ Returns a causal depthwise convolution layer.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.

core¶
Core layer types and key functions used by various layers.

class
trax.layers.core.
Dense
(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, use_bfloat16=False)¶ Bases:
trax.layers.base.Layer
A dense (a.k.a. fullyconnected, affine) layer.
Dense layers are the prototypical example of a trainable layer, i.e., a layer with trainable weights. Each node in a dense layer computes a weighted sum of all node values from the preceding layer and adds to that sum a nodespecific bias term. The full layer computation is expressed compactly in linear algebra as an affine map y = Wx + b, where W is a matrix and y, x, and b are vectors. The layer is trained, or “learns”, by updating the values in W and b.
Less commonly, a dense layer can omit the bias term and be a pure linear map: y = Wx.

__init__
(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, use_bfloat16=False)¶ Returns a dense (fully connected) layer of width n_units.
A dense layer maps collections of R^m vectors to R^n, where n (= n_units) is fixed at layer creation time, and m is set at layer initialization time.
Parameters:  n_units – Number of nodes in the layer, also known as the width of the layer.
 kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
 bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
 use_bias – If True, compute an affine map y = Wx + b; else compute a linear map y = Wx.
 use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer’s n_units value.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.
Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on.


class
trax.layers.core.
Embedding
(vocab_size, d_feature, use_bfloat16=False, kernel_initializer=<function ScaledInitializer.<locals>.Init>)¶ Bases:
trax.layers.base.Layer
Trainable layer that maps discrete tokens/IDs to vectors.
Embedding layers are commonly used to map discrete data, like words in NLP, into vectors. Here is a canonical example:
vocab_size = 5 word_ids = np.array([1, 2, 3, 4], dtype=np.int32) # word_ids < vocab_size embedding_layer = tl.Embedding(vocab_size, 32) embedding_layer.init(trax.shapes.signature(word_ids)) embedded = embedding_layer(word_ids) # embedded.shape = (4, 32)

__init__
(vocab_size, d_feature, use_bfloat16=False, kernel_initializer=<function ScaledInitializer.<locals>.Init>)¶ Returns an embedding layer with given vocabulary size and vector size.
The layer clips input values (token IDs) to the range [0, vocab_size). That is, negative token IDs all clip to 0 before being mapped to a vector, and token IDs with value vocab_size or greater all clip to vocab_size  1 before being mapped to a vector.
Parameters:  vocab_size – Size of the input vocabulary. The layer will assign a unique vector to each id in range(vocab_size).
 d_feature – Dimensionality/depth of the output vectors.
 use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.
 kernel_initializer – Function that creates (random) initial vectors for the embedding.

forward
(x)¶ Returns embedding vectors corresponding to input token IDs.
Parameters: x – Tensor of token IDs. Returns: Tensor of embedding vectors.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.


class
trax.layers.core.
Dropout
(rate=0.0, shared_axes=None, mode='train')¶ Bases:
trax.layers.base.Layer
A layer that stochastically ignores a subset of inputs each training step.
In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1  rate).
The parameter shared_axes allows to specify a list of axes on which the mask will be shared: we will use size 1 on those axes for dropout mask and broadcast it. Sharing reduces randomness, but can save memory.
This layer is active only during training (mode=’train’). In other circumstances it is a noop.
Originally introduced in the paper “Dropout: A Simple Way to Prevent Neural Networks from Overfitting” available under the following link: https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf

__init__
(rate=0.0, shared_axes=None, mode='train')¶ Creates a dropout layer with the given target drop rate.
Parameters:  rate – Stochastic rate (probability) for dropping an activation value from the preceding layer (setting it to zero).
 shared_axes – List of axes on which the mask is shared.
 mode – If ‘train’, this layer will perform dropout; else, it will pass all values through unaltered.

init_weights_and_state
(input_signature)¶ Sets layerspecific internal state.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of activations. Returns: Tensor of same shape and dtype as the input.


class
trax.layers.core.
Weights
(initializer, shape=(), use_bfloat16=False)¶ Bases:
trax.layers.base.Layer
Learnable weights as a layer.
It takes no input and returns a single tensor: weights.

__init__
(initializer, shape=(), use_bfloat16=False)¶ Returns a learnable tensor of shape shape.
Parameters:  initializer – Function taking shape and rng as arguments.
 shape – Shape of the learnable weights.
 use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor with previously specified shape and dtype.

init_weights_and_state
(input_signature)¶ Returns newly initialized weights for this layer.
Weights is a single w tensor with previously specified shape.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on. Unused.


trax.layers.core.
PrintShape
(n_in=1, msg='')¶ Prints the shapes of n_in inputs and returns then unchanged.

class
trax.layers.core.
SummaryImage
(name, n_in, num_summaries=5, recover_fn=None)¶ Bases:
trax.layers.base.Layer
A layer receiving a tensor, and adding it to TensorBoard as an image.
It takes an input and returns it unchanged. It stores this input as a state to be used as a metric in TensorBoard. It converts a tensor to a scalar by running a given aggregation function (mean by default). On TensorBoard, results for each device will be reported separately.

__init__
(name, n_in, num_summaries=5, recover_fn=None)¶ Takes a tensor and returns it.
Parameters:  name – Name of the metric to be reported.
 n_in – Number of inputs.
 num_summaries – Number of images to show.
 recover_fn – the function for converting a tensor to a dipslayable image.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor with previously specified shape and dtype.

init_weights_and_state
(input_signature)¶ Returns newly initialized weights for this layer.
Weights is a single w tensor with previously specified shape.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on. Unused.


class
trax.layers.core.
SummaryScalar
(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)¶ Bases:
trax.layers.base.Layer
A layer receiving a tensor, and adding it to TensorBoard as a scalar.
It takes an input and returns it unchanged. It stores this input as a state to be used as a metric in TensorBoard. It converts a tensor to a scalar by running a given aggregation function (mean by default). On TensorBoard, results for each device will be reported separately.

__init__
(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)¶ Takes a tensor and returns it.
Parameters:  name – Name of the metric to be reported.
 aggregation_fun – Aggregation function to be used.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor with previously specified shape and dtype.

init_weights_and_state
(input_signature)¶ Returns newly initialized weights for this layer.
Weights is a single w tensor with previously specified shape.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on. Unused.


class
trax.layers.core.
RandomUniform
(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)¶ Bases:
trax.layers.base.Layer
Layer returning a tensor with random values distributed uniformly.

__init__
(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)¶ Layer returning a tensor with random values distributed uniformly.
Parameters:  min_val – Lower end of uniform distribution.
 max_val – Upper end of uniform distribution.
 shape – Shape of the tensor to return. Values are sampled independently.
 dtype – Type of value to return.
 sync – Whether to synchronise rng across devices.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.
Parameters: xs – Unused tensors. Returns: Random uniform tensor of the shape and type specified in constructor.


class
trax.layers.core.
LocallyConnected1d
(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')¶ Bases:
trax.layers.base.Layer
Locallyconnected layer for 1D inputs.
The LocallyConnected1d layer applies a different set of filters to each patch of the input. This is similar to applying a convolution layer, except that locallyconnected layer uses a different set of weights for each patch.
The size of patch is determined by the kernel size. The stride is currently not modifiable and set to one. This means for the input of shape (…, L, D) the output shape for paddings ‘SAME’ and ‘WRAP’ will be (…, L, filters) and for padding ‘VALID’ (…, Lkernel_size+1, filters); where L is the number of “pixels” or “steps” in the input, D is the size of the embedding.
Note that, since the weights for different patches are not shared, the number of “pixels” or “steps” cannot change after calling init_weights_and_state. This is because each “pixel” is assigned its own set of weights.

__init__
(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')¶ Returns a locallyconnected convlike layer.
Parameters:  filters – Number of output filters in the convolution.
 kernel_size – A length of the convolution window. Must be an odd number.
 kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
 bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
 use_bias – If True, the layer uses a bias vector.
 padding – The type of padding to use; must be ‘VALID’, ‘SAME’, or ‘WRAP’.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer’s filters value, and the second to last dimension is shrinked if ‘VALID’ padding is used with kernel_size bigger than one.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.
Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on.


trax.layers.core.
Flatten
(n_axes_to_keep=1)¶ Returns a layer that combines one or more trailing axes of a tensor.
Flattening keeps all the values of the input tensor, but reshapes it by collapsing one or more trailing axes into a single axis. For example, a Flatten(n_axes_to_keep=2) layer would map a tensor with shape (2, 3, 5, 7, 11) to the same values with shape (2, 3, 385).
Parameters: n_axes_to_keep – Number of leading axes to leave unchanged when reshaping; collapse only the axes after these.

trax.layers.core.
LogSoftmax
(axis=1)¶ Returns a layer that applies log softmax along one tensor axis.
Note that the implementation actually computes x  LogSumExp(x), which is mathematically equal to LogSoftmax(x).
LogSoftmax acts on a group of values and normalizes them to look like a set of log probability values. (Probability values must be nonnegative, and as a set must sum to 1. A group of log probability values can be seen as the natural logarithm function applied to a set of probability values.)
Parameters: axis – Axis along which values are grouped for computing log softmax.

trax.layers.core.
LogSumExp
(axis=1)¶ Returns a layer that computes log(sum(exp(x))) along one tensor axis.
Parameters: axis – Axis along which values are grouped for computing logsumexp.

trax.layers.core.
Softmax
(axis=1)¶ Returns a layer that applies softmax along one tensor axis.
Softmax acts on a group of values and normalizes them to look like a set of probability values. (Probability values must be nonnegative, and as a set must sum to 1.)
Parameters: axis – Axis along which values are grouped for computing softmax.

trax.layers.core.
ToFloat
()¶ Returns a layer that changes the dtype of a tensor to float32.

trax.layers.core.
Mean
(axis=1, keepdims=False)¶ Returns a layer that computes mean values using one tensor axis.
Mean uses one tensor axis to form groups of values and replaces each group with the mean value of that group. The resulting values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.
Parameters:  axis – Axis along which values are grouped for computing a mean.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
Min
(axis=1, keepdims=False)¶ Returns a layer that applies min along one tensor axis.
Parameters:  axis – Axis along which values are grouped for computing minimum.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
Max
(axis=1, keepdims=False)¶ Returns a layer that applies max along one tensor axis.
Parameters:  axis – Axis along which values are grouped for computing maximum.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
Sum
(axis=None, keepdims=False)¶ Returns a layer that computes sums using one tensor axis.
Sum uses one tensor axis to form groups of values and replaces each group with the sum of that group. The resulting sum values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.
Parameters:  axis – Axis along which values are grouped for computing a sum; if None, compute sum over all elements in tensor.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
ThresholdToBinary
(threshold=0.5)¶ Returns a layer that thresholds inputs to yield outputs in {0, 1}.

trax.layers.core.
ArgMax
(axis=1)¶ Returns a layer that calculates argmax along the given axis.

trax.layers.core.
Negate
()¶ Returns a layer that computes the elementwise negation of a tensor.

trax.layers.core.
StopGradient
()¶ Returns an identity layer with a stop gradient.

trax.layers.core.
one_hot
(x, n_categories, dtype=<sphinx.ext.autodoc.importer._MockObject object>)¶ Makes a onehot array (n+1 dims) from an intcategorical array (n dims).

trax.layers.core.
log_softmax
(x, axis=1)¶ Transforms activation vectors to logprobability vectors.
Log probability vectors are derived by, in effect, applying softmax to raw activation vectors and then applying log elementwise. The actual implementation uses a mathematically valid simplification of this.
Parameters:  x – An ndarray with activation vectors along the given axis.
 axis – Axis along which values are grouped for computing log softmax.
Returns: An ndarray containing logprobability vectors derived from the raw activation vectors in x.

trax.layers.core.
log_gaussian_pdf
(x, mu, sigma)¶ Returns log N(x  mu, sigma).
Parameters:  x – <tbd>
 mu – <tbd>
 sigma – <tbd>

trax.layers.core.
log_gaussian_diag_pdf
(x, mu, diag_sigma)¶ Returns log N(x  mu, eye(diag_sigma)).
Parameters:  x – <tbd>
 mu – <tbd>
 diag_sigma – <tbd>

trax.layers.core.
multigaussian_loss
(preds, targets, ngauss=1)¶ Returns a mixture of gaussians loss.
Parameters:  preds – <tbd>
 targets – <tbd>
 ngauss – <tbd>

trax.layers.core.
logsoftmax_sample
(log_probs, temperature=1.0)¶ Returns a sample from a logsoftmax output, with temperature.
Parameters:  log_probs – Logarithms of probabilities (often coming from LogSoftmax)
 temperature – For scaling before sampling (1.0 = default, 0.0 = pick argmax)
initializers¶
Trax initializers.

trax.layers.initializers.
InitializerFromFile
(path)¶ Loads parameters from .npy file.

trax.layers.initializers.
RandomNormalInitializer
(stddev=0.01)¶ Returns an initializer for random normal coefficients.

trax.layers.initializers.
RandomUniformInitializer
(lim=1.0)¶ Returns an initializer for random uniform coefficients.

trax.layers.initializers.
ScaledInitializer
(out_dim, in_dim, scale, mode, distribution)¶ Returns an initializer that adjusts its scale based on weight shapes.

trax.layers.initializers.
GlorotNormalInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random Glorotscaled coefficients.

trax.layers.initializers.
GlorotUniformInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random uniform Glorotscaled coefficients.

trax.layers.initializers.
LeCunNormalInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random LeCunscaled coefficients.

trax.layers.initializers.
LeCunUniformInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random uniform LeCunscaled coefficients.

trax.layers.initializers.
KaimingNormalInitializer
(out_dim=1, in_dim=2, param=0.0)¶ Returns an initializer for random Kaimingscaled coefficients.

trax.layers.initializers.
KaimingUniformInitializer
(out_dim=1, in_dim=2, param=0.0)¶ Returns an initializer for random uniform Kaimingscaled coefficients.

trax.layers.initializers.
OrthogonalInitializer
(stddev=1.0)¶ Returns an orthogonal initializer.

trax.layers.initializers.
AtariConvInit
(kernel_shape, rng, dtype=<sphinx.ext.autodoc.importer._MockObject object>)¶ The standard init for Conv laters and Atari.
metrics¶
Layers for computing loss functions and evaluation metrics.
A metric layer computes a scalar value from two or three ndarray inputs:
 model outputs: Batch of predicted values (typically vectors).
 targets: Batch of target values (e.g., categories or vectors).
 weights: Float values that allow for uneven weighting of batch items, sequence positions, or vector components when computing an overall scalar value for the batch.
Most metric computations take into account the items that make up a batch. For each item in a batch, a raw metric value is computed by comparing (itemwise) the model output to the target value. These itemwise values are then combined into a single scalar for the batch by a function such as sum, average, or weightedaverage. For example:
 CategoryAccuracy: Treat model output as vectors whose components correspond to the possible categories; measure a vector as correct (value 1) if its largest component is the target category, else as incorrect (value 0). The accuracy for the batch is then the average across vectors of these 1’s and 0’s.
 CategoryCrossEntropy: Treat model output and target values as the source of two probability distributions; measure the crossentropy of the model’s predicted distribution relative to the (assumed true) target distribution. The scalar value for the batch is then the average of the itemwise crossentropy values.

trax.layers.metrics.
CategoryAccuracy
()¶ Returns a layer that computes category prediction accuracy.
The layer takes two inputs:
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (
axis=1
) picks the index (category) having the highest probablity.  A batch of target categories; each target is an integer in \(\{0, ..., N1\}\).
The predicted category from each vector is the index of the highestvalued vector component. The layer returns the accuracy of these predictions averaged over the batch.
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (

trax.layers.metrics.
WeightedCategoryAccuracy
()¶ Returns a layer that computes a weighted category prediction accuracy.
The layer takes three inputs:
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (
axis=1
) picks the index (category) having the highest probablity.  A batch of target categories; each target is an integer in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality.
 A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item’s target category).
The predicted category from each vector is the index of the highestvalued vector component. The layer returns a weighted average accuracy of these predictions.
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (

trax.layers.metrics.
CategoryCrossEntropy
(label_smoothing=None)¶ Returns a layer that computes crossentropy from activations and integers.
The layer takes two inputs:
 A batch of activation vectors. The components in a given vector should be presoftmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and crossentropy computations are combined inside the layer.
 A batch of target categories; each target is an integer in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality.
To compute crossentropy per batch item, the layer derives probability distributions:
 from model output (vectors): \(\ q = \text{softmax}(v)\)
 from target categories (integers): \(\ p = \text{one_hot}(n)\) or \(p = (1\varepsilon)\cdot\text{one_hot}(n) + \frac{\varepsilon}{N}\), where \(\varepsilon\) is the label smoothing factor.
(The conversion of integer category targets to onehot vectors amounts to assigning all the probability mass to the target category.) Crossentropy per batch item is computed between the resulting distributions:
\[\text{cross_entropy} =  \sum_{i=0}^{N1} p_i \log q_i\]The layer returns the average of these crossentropy values over all items in the batch.
Parameters: label_smoothing – Creates soft targets if provided. Must be between 0 and 1.

trax.layers.metrics.
WeightedCategoryCrossEntropy
(label_smoothing=None, cutoff=0.0)¶ Returns a layer like
CategoryCrossEntropy
, with weights as third input.The layer takes three inputs:
 A batch of activation vectors. The components in a given vector should be presoftmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and crossentropy computations are combined inside the layer.
 A batch of target categories; each target is an integer in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality.
 A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item’s target category).
The layer returns the weighted average of these crossentropy values over all items in the batch.
Parameters:  label_smoothing – Creates soft targets if provided. Must be between 0 and 1.
 cutoff – Prevent loss lower than this cutoff (0.0 meaning none by default).

trax.layers.metrics.
BinaryCrossEntropy
()¶ Returns a layer that computes crossentropy for binary classification.
The layer takes two inputs:
 A batch of activation values; each batch item \(x\) is a float in \((\infty, \infty)\).
 A batch of binary targets; each target \(t\) is an integer in \(\{0, 1\}\).
The layer maps each activation value into the range \((0, 1)\), interpreted as the modelpredicted probability that item’s category is 1:
\[q = \frac 1 {1 + e^{x}} \ \ \text{[modelpredicted probability]}\]and computes crossentropy (per batch item) by treating the target category as having probability 1:
\[\begin{split}\text{cross_entropy} = \left\{ \begin{array}{cl}  \log q & \text{if}\ t = 1, \\  \log (1  q) & \text{if}\ t = 0. \end{array} \right.\end{split}\]The layer returns the average of these crossentropy values over all items in the batch.

trax.layers.metrics.
MaskedSequenceAccuracy
()¶ Returns a layer that computes sequence prediction accuracy with masking.
This layer type is intended for variable length sequences, especially text, represented as a batch of fixedlength sequences via padding for unused positions.
The layer takes three inputs:
 A batch of sequences of activation vectors. The components in a given
vector should be mappable to a probability distribution in the following
loose sense: within a vector, a higher component value corresponds to a
higher probability, such that argmax within a vector (
axis=1
) picks the index having the highest probablity. In text modeling, the index represents a token id from a predetermined token vocabulary (or padding).  A batch of target integer sequences, with values in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality. In text modeling, these sequences typically represent token ids from a predetermined token vocabulary (or padding).
 A batch of weights/masks, which matches or can be broadcast to match the shape of the target ndarray. This arg is used to give weight 0 to padding positions, which masks those positions out of the calculation. Only the zero/nonzero distinction matters; all nonzero values are treated alike as signaling nonmasked (i.e., valid/inuse) positions.
The predicted integer value for each sequence position is the index of the highestvalued component of the position’s vector. A predicted integer sequence is judged correct if it matches the target integer sequence in all nonzeroweighted positions. The layer returns the accuracy of predicted sequences averaged over the batch.
 A batch of sequences of activation vectors. The components in a given
vector should be mappable to a probability distribution in the following
loose sense: within a vector, a higher component value corresponds to a
higher probability, such that argmax within a vector (

trax.layers.metrics.
Accuracy
(classifier=ArgMax)¶ Returns a layer that computes mean category prediction accuracy.
DEPRECATED; use
WeightedCategoryAccuracy
instead.Parameters: classifier – Layer that transforms activation vectors into category predictions.

trax.layers.metrics.
SequenceAccuracy
(classifier=ArgMax)¶ Returns a layer that computes mean sequence prediction accuracy.
DEPRECATED; use
MaskedSequenceAccuracy
instead.Parameters: classifier – Layer that transforms activation vectors into category predictions.

trax.layers.metrics.
CrossEntropyLoss
()¶ Returns a layer that outputs multiclass predictiontarget crossentropy.
DEPRECATED; refactor to use
WeightedCategoryCrossEntropy
orCategoryCrossEntropy
instead.(
CrossEntropyLoss
by itself does not compute crossentropy. In older code, this layer had to be preceded byLogSoftmax
, and the two layers together did the work of converting category information to probability distributions and computing the crossentropy between those distributions. All this is now done byWeightedCategoryCrossEntropy
.)

trax.layers.metrics.
CrossEntropyLossWithLogSoftmax
()¶ Mean predictiontarget crossentropy for multiclass classification.

trax.layers.metrics.
BinaryCrossEntropyLoss
()¶ Returns a layer that outputs binary predictiontarget crossentropy.
DEPRECATED; refactor to use
BinaryCrossEntropy
instead. (The newerBinaryCrossEntropy
does not use weights, so refactor accordingly. Unless and until clear motivating use cases arise, the library will not include a binary crossentropy function with weights.)

trax.layers.metrics.
L2Loss
()¶ Returns a layer that computes an L2like loss for one batch.
The layer takes three inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
 A batch of weights, which matches the shape of the model output.
The layer returns a weighted average of elementwise squared error terms \((y_i  t_i)^2\).

trax.layers.metrics.
SmoothL1Loss
()¶ Returns a layer that computes a weighted, smoothed L1 loss for one batch.
The layer takes three inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
 A batch of weights, which matches the shape of the model output.
The layer computes a “smooth” L1 loss (a.k.a. Huber loss), for model output float \(y_i\) and target float \(t_i\):
\[\begin{split}\text{output} = \left\{ \begin{array}{cl} \frac 1 2 (y_i  t_i)^2, & \text{if}\ y_i  t_i < 1, \\ y_i  t_i  \frac 1 2, & \text{otherwise}. \end{array} \right.\end{split}\]The layer returns a weighted average of these elementwise values.

trax.layers.metrics.
MacroAveragedFScore
(beta=1.0, initial_category_index=0)¶ Returns a layer that computes a macroaveraged Fscore.
The macroaveraged Fscore summarize how well the classifier’s k predictions align with the observed/gold instances of k. It additionally cares about all the classes equally regardless of their size.
Parameters:  beta – a parameter that determines the weight of recall in the Fscore.
 initial_category_index – an index of the initial category.
The layer takes two inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
The layer returns an macroaveraged Fscore across all the classes.

trax.layers.metrics.
WeightedFScore
(beta=1.0, initial_category_index=0)¶ Returns a layer that computes a weighted Fscore.
The weighted Fscore summarize how well the classifier’s k predictions align with the observed/gold instances of k. It additionally weights the summary by the number of observed/gold and predicted examples in each class.
Parameters:  beta – a parameter that determines the weight of recall in the Fscore.
 initial_category_index – an index of the initial category.
The layer takes two inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
The layer returns a weighted Fscore across all the classes.

trax.layers.metrics.
WeightedSum
()¶ Returns a layer that computes a weighted sum of the given values.

trax.layers.metrics.
CrossEntropySum
()¶ Sum of predictiontarget cross entropies for multiclass classification.

trax.layers.metrics.
BinaryCrossEntropySum
()¶ Sum of predictiontarget cross entropies for binary classification.
normalization¶
Trax normalization layers.

class
trax.layers.normalization.
BatchNorm
(axis=(0, 1, 2), epsilon=1e05, center=True, scale=True, momentum=0.999, mode='train')¶ Bases:
trax.layers.base.Layer
Layer that performs batch normalization.
In training, batch normalization keeps smoothed cumulative statistics across batches of input data and modifies each new batch so that its components are normally distributed. In eval or inference, a BatchNorm instance uses its stored mean and variance to approximately normalize each new batch of data.
See https://arxiv.org/abs/1502.03167 for original presentation and motivation of batch normalization).

__init__
(axis=(0, 1, 2), epsilon=1e05, center=True, scale=True, momentum=0.999, mode='train')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes batch normalization as part of a forward pass in the model.

init_weights_and_state
(input_signature)¶ Helper to initialize batch norm weights and state.


class
trax.layers.normalization.
LayerNorm
(center=True, epsilon=1e06)¶ Bases:
trax.layers.base.Layer
Layer normalization.

__init__
(center=True, epsilon=1e06)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.normalization.
FilterResponseNorm
(mode=None, learn_epsilon=False, init_epsilon=1e06, init_learnt_epsilon=0.0001)¶ Bases:
trax.layers.base.Layer
Filter Response Normalization layer without Threshold Linear Unit.
c.f. https://arxiv.org/pdf/1911.09737.pdf

__init__
(mode=None, learn_epsilon=False, init_epsilon=1e06, init_learnt_epsilon=0.0001)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

pooling¶
Trax pooling layers.

trax.layers.pooling.
MaxPool
(pool_size=(2, 2), strides=None, padding='VALID')¶ Reduces each multidimensional window to the max of the window’s values.
Windows, as specified by pool_size and strides, involve all axes of an ndimensional array except the first and last: \((d_1, ..., d_{n2})\) from shape \((d_0, d_1, ..., d_{n2}, d_{n1})\).
Parameters:  pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)dimensional arrays, then pool_size must be a tuple of length \(n2\).
 strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
 padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the selection of max values.
Returns: Ndimensional array in which each valid (or paddedvalid) window position in the input is reduced to / replaced by the max value from that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.
SumPool
(pool_size=(2, 2), strides=None, padding='VALID')¶ Reduces each multidimensional window to the sum of the window’s values.
Windows, as specified by pool_size and strides, involve all axes of an ndimensional array except the first and last: \((d_1, ..., d_{n2})\) from shape \((d_0, d_1, ..., d_{n2}, d_{n1})\).
Parameters:  pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)dimensional arrays, then pool_size must be a tuple of length \(n2\).
 strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
 padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the computation of sums.
Returns: Ndimensional array in which each valid (or paddedvalid) window position in the input is reduced to / replaced by the sum of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.
AvgPool
(pool_size=(2, 2), strides=None, padding='VALID')¶ Reduces each multidimensional window to the mean of the window’s values.
Windows, as specified by pool_size and strides, involve all axes of an ndimensional array except the first and last: \((d_1, ..., d_{n2})\) from shape \((d_0, d_1, ..., d_{n2}, d_{n1})\).
Parameters:  pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)dimensional arrays, then pool_size must be a tuple of length \(n2\).
 strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
 padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed but is not counted in the computation of averages.
Returns: Ndimensional array in which each valid (or paddedvalid) window position in the input is reduced to / replaced by the mean of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.
reversible¶
Layers that can run in reverse to compute inputs from outputs.
Reversible layers reduce the memory required for backpropagationbased training, especially for deep networks. In a series of reversible layers, input activations from a forward pass don’t need to be stored: they can be reconstructed on the backward pass, layer by layer, from outputs to inputs.
See, e.g., [The Reversible Residual Network: Backpropagation Without Storing Activations](https://arxiv.org/abs/1707.04585) and [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451).

class
trax.layers.reversible.
ReversibleLayer
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Bases:
trax.layers.base.Layer
Reversible Layer.

reverse
(output, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.

reverse_and_grad
(output, grad, weights=(), state=(), new_state=(), rng=None)¶ Backward pass: computes the inverse of a layer and propagates gradients.
While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.
Parameters:  output – Output activations; can be a (possibly nested) tuple.
 grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
 weights – layer weights
 state – start state
 new_state – updated state computed by the forward pass
 rng – Singleuse random number generator (JAX PRNG key).
Returns: A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng)¶ Custom backward pass to propagate gradients in a custom way.
Parameters:  inputs – Input tensors; can be a (possibly nested) tuple.
 output – The result of running this layer on inputs.
 grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
 weights – This layer’s weights.
 state – This layer’s state prior to the current forward pass.
 new_state – This layer’s state after the current forward pass.
 rng – Singleuse random number generator (JAX PRNG key).
Returns: The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.


class
trax.layers.reversible.
ReversibleConcatenatePair
¶ Bases:
trax.layers.reversible.ReversibleLayer
Maps (x, y) > ([x, y], [x, y]); [x, y] is concatenation on last axis.

__init__
()¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


class
trax.layers.reversible.
ReversibleSelect
(indices, n_in=None, name=None)¶ Bases:
trax.layers.reversible.ReversibleLayer
Reversible version of the Select combinator.

__init__
(indices, n_in=None, name=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


trax.layers.reversible.
ReversibleSwap
()¶

class
trax.layers.reversible.
ReversibleReshape
(shape1, shape2, n_in=1)¶ Bases:
trax.layers.reversible.ReversibleLayer
Reversible reshaping layer.

__init__
(shape1, shape2, n_in=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


class
trax.layers.reversible.
ReversiblePrintShape
(n_in=1, msg='')¶ Bases:
trax.layers.reversible.ReversibleLayer
Reversible PrintShape for debugging reversible serial layers.

__init__
(n_in=1, msg='')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


class
trax.layers.reversible.
ReversibleSerial
(*layers)¶ Bases:
trax.layers.reversible.ReversibleLayer
,trax.layers.combinators.Serial
A reversible version of tl.Serial (requires reversible sublayers).

__init__
(*layers)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

reverse
(output, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.

reverse_and_grad
(output, grad, weights=(), state=(), new_state=(), rng=None)¶ Backward pass: computes the inverse of a layer and propagates gradients.
While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.
Parameters:  output – Output activations; can be a (possibly nested) tuple.
 grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
 weights – layer weights
 state – start state
 new_state – updated state computed by the forward pass
 rng – Singleuse random number generator (JAX PRNG key).
Returns: A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.


class
trax.layers.reversible.
ReversibleHalfResidual
(*residual_layers, attention_layer=None, name=None)¶ Bases:
trax.layers.reversible.ReversibleLayer
Half of a RevNetstyle residual that optionally performs attention.
When attention_layer is None, this layer has the signature
[accumulator, *context] > [accumulator + f(context), *context]
The attention_layer must be an instance of EfficientAttentionBase or one of its subclasses (see efficient_attention.py), or None.
Attention is specialcased for the following two reasons:
 LSH attention needs to save bucket assignments from the forward pass to the backward pass, for training stability. This requires specialcasing it.
 We can call attention_layer.forward_and_or_backward to compute its output (needed for inverting a reversible residual layer) while simultaneously performing the backward pass. Sharing computation between these two operations improves training speed.

__init__
(*residual_layers, attention_layer=None, name=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(output, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.

reverse_and_grad
(output, ct, weights=(), state=(), new_state=(), rng=None)¶ Backward pass: computes the inverse of a layer and propagates gradients.
While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.
Parameters:  output – Output activations; can be a (possibly nested) tuple.
 grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
 weights – layer weights
 state – start state
 new_state – updated state computed by the forward pass
 rng – Singleuse random number generator (JAX PRNG key).
Returns: A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
rnn¶
Implementations of common recurrent neural network cells (RNNs).

class
trax.layers.rnn.
LSTMCell
(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
LSTM Cell.
For a nice overview of the motivation and (i, o, f) gates, see this tutorial: https://colah.github.io/posts/201508UnderstandingLSTMs/
See this paper for a description and detailed study of all gate types: https://arxiv.org/pdf/1503.04069.pdf

__init__
(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


trax.layers.rnn.
MakeZeroState
(depth_multiplier=1)¶ Makes zeros of shape like x but removing the length (axis 1).

trax.layers.rnn.
LSTM
(n_units, mode='train', return_state=False, initial_state=False)¶ LSTM running on axis 1.
Parameters:  n_units – n_units for the LSTMCell.
 mode – if ‘predict’ then we save the previous state for onebyone inference.
 return_state – Boolean. Whether to return the latest status in addition to the output. Default: False.
 initial_state – Boolean. If the state RNN (c, h) is to be obtained from the stack. Default: False.
Returns: A LSTM layer.

class
trax.layers.rnn.
GRUCell
(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
Builds a traditional GRU cell with dense internal transformations.
Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555

__init__
(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


trax.layers.rnn.
GRU
(n_units, mode='train')¶ GRU running on axis 1.

trax.layers.rnn.
ConvGRUCell
(n_units, kernel_size=(3, 3))¶ Builds a convolutional GRU.
Paper: https://arxiv.org/abs/1511.06432.
Parameters:  n_units – Number of hidden units
 kernel_size – Kernel size for convolution
Returns: A Stax model representing a GRU cell with convolution transforms.

trax.layers.rnn.
GeneralGRUCell
(candidate_transform, memory_transform_fn=None, gate_nonlinearity=<function Sigmoid>, candidate_nonlinearity=<function Tanh>, dropout_rate_c=0.1, sigmoid_bias=0.5)¶ Parametrized Gated Recurrent Unit (GRU) cell construction.
GRU update equations for update gate, reset gate, candidate memory, and new state:
\[\begin{split}u_t &= \sigma(U' \times s_{t1} + B') \\ r_t &= \sigma(U'' \times s_{t1} + B'') \\ c_t &= \tanh(U \times (r_t \odot s_{t1}) + B) \\ s_t &= u_t \odot s_{t1} + (1  u_t) \odot c_t\end{split}\]See combinators.Gate for details on the gating function.
Parameters:  candidate_transform – Transform to apply inside the Candidate branch. Applied before nonlinearities.
 memory_transform_fn – Optional transformation on the memory before gating.
 gate_nonlinearity – Function to use as gate activation; allows trying alternatives to Sigmoid, such as HardSigmoid.
 candidate_nonlinearity – Nonlinearity to apply after candidate branch; allows trying alternatives to traditional Tanh, such as HardTanh.
 dropout_rate_c – Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch.
 sigmoid_bias – Constant to add before sigmoid gates. Generally want to start off with a positive bias.
Returns: A model representing a GRU cell with specified transforms.

trax.layers.rnn.
InnerSRUCell
()¶ The inner (nonparallel) computation of an SRU.

trax.layers.rnn.
ScanSRUCell
(mode, monkey_patched_mask=None)¶ The inner (nonparallel) computation of an SRU.

trax.layers.rnn.
SRU
(n_units, activation=None, mode='train')¶ SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.
As defined in the paper:
\[\begin{split}y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t1} + (1  f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1  r_t) \times x_t\end{split}\]We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It’s best to use at least 2, they say in the paper, except inside a Transformer.
Parameters:  n_units – output depth of the SRU layer.
 activation – Optional activation function.
 mode – if ‘predict’ then we save the previous state for onebyone inference
Returns: The SRU layer.
research.efficient_attention¶
Attention Layers optimized for efficiency (secondpass implementation).
The approach taken in the first round of efficient attention implementations revealed several limitations, which this code attempts to address:
 Simultaneously instantiating queries, keys, and values for all heads can exceed the memory budget. Transformers are typically tuned such that n_heads * d_attention_key == d_model. Since attention involves queries, keys, AND values, the memory to store them can be ~3x the memory needed to store the input activations. Once the O(n^2) dotproduct bottleneck is removed – as is the case in all of our efficient attention implementations – this becomes the next critical bottleneck for scaling up Transformer models.
 Attention masking is implemented by associating an integer (typically, the sequence position) with each query and key vector, and defining a function to compute attention masks from this information. The standard attention API (attention.py) is unscalable because it instantiates O(n^2)size attention masks, and the previous efficient implementations (efficient_attention.py) only supported causal masking.

trax.layers.research.efficient_attention.
length_normalized
(x, epsilon=1e06)¶

trax.layers.research.efficient_attention.
hash_vecs
(vecs, n_buckets_in, n_hashes, rng)¶ Hash vectors into buckets.
Parameters:  vecs – vectors to hash, a tensor of shape [batch_size, depth]
 n_buckets_in – an int or a list of ints, number of hash buckets; if it is a list, we do hierarchical hashing as specified by the list
 n_hashes – number of hashes
 rng – random generator to use for hashing
Returns: A pair (buckets, n_buckets) where buckets is a tensor of shape [n_hashes, batch_size] of integers – the hash bucket IDs, and n_buckets is an int, the total number of hash buckets, equal to the product of all items in n_buckets_in.

trax.layers.research.efficient_attention.
look_adjacent
(x, n_chunks_before, n_chunks_after)¶ Used to implement attention between consecutive chunks.
Parameters:  x – array of shape [n_chunks, chunk_len, …]
 n_chunks_before – Number of previous chunks to attend to.
 n_chunks_after – Number of subsequent chunks to attend to.
Returns: array of shape [n_chunks, N * chunk_len, …], where N = (1 + n_chunks_before + n_chunks_after).

trax.layers.research.efficient_attention.
mask_self_attention
(dots, q_info, kv_info, causal=True, exclude_self=True, masked=False)¶ Performs masking for selfattention.

trax.layers.research.efficient_attention.
attend
(q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None)¶ Dotproduct attention, with optional chunking and/or masking.
Parameters:  q – Query vectors, shape [q_len, d_qk]
 k – Key vectors, shape [kv_len, d_qk]; or None
 v – Value vectors, shape [kv_len, d_v]
 q_chunk_len – Set to nonzero to enable chunking for query vectors
 kv_chunk_len – Set to nonzero to enable chunking for key/value vectors
 n_chunks_before – Number of adjacent previous chunks to attend to
 n_chunks_after – Number of adjacent subsequent chunks to attend to
 mask_fn – TODO(kitaev) doc
 q_info – Queryassociated metadata for masking
 kv_info – Keyassociated metadata for masking
 dropout – Dropout rate
 rng – RNG for dropout
Returns: A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention).

trax.layers.research.efficient_attention.
apply_broadcasted_dropout
(vecs, dropout_rate, rng)¶ Apply dropout, broadcasted across all but the last dimension of vecs.

trax.layers.research.efficient_attention.
permute_via_gather
(val, permutation, inverse_permutation, axis=0)¶ Permutation helper for LSH attention.

trax.layers.research.efficient_attention.
permute_via_sort
(val, keys, inverse_keys, axis=0)¶ Permutation helper for LSH attention.

class
trax.layers.research.efficient_attention.
EfficientAttentionBase
(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
Base class for efficient attention.
This is a base class that implements memoryefficient batching for both the forward and backward passes. Subclasses should override create_weights_unbatched, create_state_unbatched, forward_unbatched, and optionally incremental_forward_unbatched to define the actual attention mechanism.

__init__
(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)¶ Constructs an EfficientAttentionBase instance.
Parameters:  n_heads – Number of attention heads.
 n_in – Number of inputs to the layer (default 1).
 n_parallel_heads –
Number of attention heads to compute in parallel.
 If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
 If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
 If n_parallel_heads is a multiple of n_heads, attention is computed for subbatches of (n_parallel_heads // n_heads) examples at a time.
 If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
 incremental – If True, enable fast inference for selfattention types. Note that this flag should not be set when doing encoderdecoder attention, but only when doing selfattention.
 predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
 predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
 use_python_loop – Set to True to use a Python loop when iterating over subbatches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
 use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_weights_unbatched
(input_signature, rng)¶

create_state_unbatched
(input_signature, rng)¶

forward_unbatched
(*inputs, weights, state)¶ Perform attention for a single batch element and head.
Subclasses should override this method.
Parameters:  *inputs – Inputs for a single example (subclasses may use different inputs)
 weights – Weights for a single attention head
 state – State for a single example & attention head pair.
Returns: A tuple (output, new_state) – output and new state for a single example and attention head.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
SelfAttention
(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
Memoryefficient selfattention (second attempt).

__init__
(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Construct a selfattention layer.
Parameters:  n_heads – int: Number of attention heads
 d_qk – int: Depth of query ond key vectors
 d_v – int: Depth of value vectors
 share_qk – bool: Set to True to share query and key projection weights
 causal – bool: Set to True to mask out attention to future items
 masked – bool: Set to True to accept an additional mask argument, that allows masking out attention to padding tokens.
 chunk_len (optional) – Number of tokens per chunk. Setting this option will enable chunked attention.
 n_chunks_before – Number of previous chunks to attend to, when using chunked attention.
 n_chunks_after – Number of subsequent chunks to attend to, when using chunked attention. Don’t use this option for causal attention, because attention to future tokens will be masked out anyway. However, note that crosschunk attention “wraps around” in both directions, so this option is never a strict noop.
 bias – bool: Set to True to add bias vectors when computing query/key/value
 mode – ‘train’, ‘eval’, or ‘predict’
 predict_mem_len – int: Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten. When chunking is enabled, the default is to store chunk_len * (1 + n_chunks_before) elements.
 predict_drop_len – int: Number of input elements to drop once the fast inference input cache fills up. When chunking is enabled, the default is to drop exactly chunk_len elements.
 attention_dropout – Dropout probability for attention mask.
 output_dropout – Dropout probability for the layer output.
 n_parallel_heads –
Number of attention heads to compute in parallel.
 If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
 If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
 If n_parallel_heads is a multiple of n_heads, attention is computed for subbatches of (n_parallel_heads // n_heads) examples at a time.
 If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
 use_python_loop – Set to True to use a Python loop when iterating over subbatches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
 use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_weights_unbatched
(input_signature, rng)¶

create_state_unbatched
(input_signature, rng)¶

forward_unbatched
(x, mask=None, *, weights, state, rng, update_state)¶ Perform attention for a single batch element and head.
Parameters:  x – Inputs for a single example (subclasses may use different inputs)
 mask – Mask for the inputs.
 weights – Weights for a single attention head
 state – State for a single example & attention head pair.
 rng – PRNG key for the layer (shared across all examples and heads)
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state) – output and new state for a single example and attention head.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
LSHSelfAttention
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
LSH selfattention (second implementation).

__init__
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Construct an LSH selfattention layer.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_weights_unbatched
(input_signature, rng)¶

create_state_unbatched
(input_signature, rng)¶

hash_vectors
(vecs, rng, mask=None)¶

forward_unbatched
(x, mask=None, *, weights, state, rng, update_state)¶

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
PureLSHSelfAttention
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
LSH selfattention without weights.

__init__
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Construct an LSH selfattention layer.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_state_unbatched
(input_signature, rng)¶

hash_vectors
(vecs, rng, mask=None)¶

forward_unbatched
(qk, v, mask=None, *, state, rng, update_state)¶

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer tuple (qk, v, mask)
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
MixedLSHSelfAttention
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, std_length=None, mode='train', output_dropout=0.0, attention_dropout=0.0, force_no_dropout=False, **pure_lsh_implementation_kwargs)¶ Bases:
trax.layers.base.Layer
LSH attention mixed with standard attention used until std_length.

__init__
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, std_length=None, mode='train', output_dropout=0.0, attention_dropout=0.0, force_no_dropout=False, **pure_lsh_implementation_kwargs)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.

forward_and_or_backward
(inputs, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.


class
trax.layers.research.efficient_attention.
PureLSHSelfAttentionWrapper
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, output_dropout=0.0, attention_dropout=0.0, pure_lsh_implementation=None, bias=True, mode='train', num_weights=3, sparsity=16, weights_format='model', rotary_position_emb=False, **pure_lsh_implementation_kwargs)¶ Bases:
trax.layers.combinators.Serial
Pure LSH serial.

__init__
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, output_dropout=0.0, attention_dropout=0.0, pure_lsh_implementation=None, bias=True, mode='train', num_weights=3, sparsity=16, weights_format='model', rotary_position_emb=False, **pure_lsh_implementation_kwargs)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).  output is not None iff compute_output is True  new_state is not None iff update_state is True  inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
EncDecAttention
(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.research.efficient_attention.EfficientAttentionBase
Memoryefficient encoderdecoder attention.

__init__
(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Constructs an EfficientAttentionBase instance.
Parameters:  n_heads – Number of attention heads.
 n_in – Number of inputs to the layer (default 1).
 n_parallel_heads –
Number of attention heads to compute in parallel.
 If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
 If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
 If n_parallel_heads is a multiple of n_heads, attention is computed for subbatches of (n_parallel_heads // n_heads) examples at a time.
 If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
 incremental – If True, enable fast inference for selfattention types. Note that this flag should not be set when doing encoderdecoder attention, but only when doing selfattention.
 predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
 predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
 use_python_loop – Set to True to use a Python loop when iterating over subbatches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
 use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

create_weights_unbatched
(input_signature, rng)¶

forward_unbatched
(q_antecedent, kv_antecedent, mask=None, *, weights, state, rng, update_state)¶ Perform attention for a single batch element and head.
Subclasses should override this method.
Parameters:  *inputs – Inputs for a single example (subclasses may use different inputs)
 weights – Weights for a single attention head
 state – State for a single example & attention head pair.
Returns: A tuple (output, new_state) – output and new state for a single example and attention head.


class
trax.layers.research.efficient_attention.
LSHFF
(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
Feedforward block with LSH.
The original (nonLSH) feedforward block is a triple Dense(d_ff)ReluDense that takes an input, makes it of size d_ff (usually larger than it was) and then brings it back to the original size after Relu. It is commonly used in Transformer models where it often accounts for most of the trainable weights.
The original block can be slow in decoding due to the need to fetch a lot of weights from memory. The LSH block aims to exploit this sparsity. So in the first Dense(d_ff) layer, instead of making a full matrix multiplication, this block only multiplies by the parts of the weights matrix that have the highest chance to give non0 after Relu. This is determined by taking a number of localitysensitive hashes and masking to only include weights that have one hash identical to the multiplied element.

__init__
(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Returns a LSH feedforward block.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.

research.position_encodings¶
Experimenting with position encodings.

class
trax.layers.research.position_encodings.
AxialPositionalEncoding
(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')¶ Bases:
trax.layers.base.Layer
Axial positional encoding.

__init__
(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.research.position_encodings.
SinCosPositionalEncoding
(add_offset=2048, dropout=0.0, dropout_broadcast_dims=(2, ), start_from_zero_one_in=2, mode='train')¶ Bases:
trax.layers.base.Layer
Implements the sincos positional encoding.

__init__
(add_offset=2048, dropout=0.0, dropout_broadcast_dims=(2, ), start_from_zero_one_in=2, mode='train')¶ Creates a SinCosPositionalEncoding instance.
Parameters:  add_offset – Maximumnumber to add to positions during training.
 dropout – Probability of not adding positional encoding to a sequence position.
 dropout_broadcast_dims – Axes along which dropout mask values are broadcast rather than individually set at random.
 start_from_zero_one_in – how often to start from 0 during training, every one in that many times (e.g., if 4, then it’s 25% of the time).
 mode – One of ‘train’, ‘eval’, or ‘predict’.

forward
(inputs)¶ Returns the input activations, with added positional information.

init_weights_and_state
(input_signature)¶ Randomly initializes the positional encoding vectors.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on.


class
trax.layers.research.position_encodings.
FixedBasePositionalEncoding
(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=2, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
Implements fixedbase positional encoding.

__init__
(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=2, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


trax.layers.research.position_encodings.
threefry_2x32_prf
(key, x: <sphinx.ext.autodoc.importer._MockObject object at 0x7f0e110489d0>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f0e11048a50>¶ Apply the threefry PRF to an array of inputs.
This function is vectorized over x. For threefry_2x32: K = X = uint32[2]
Parameters:  key – uint32[2] the key of the PRF
 x – uint32[…, 2] the inputs
Returns: uint32[…, 2] the outputs
Return type: y

trax.layers.research.position_encodings.
threefry_2x32_prange
(key, lo: int = 0, hi: int = 2)¶ Splits a key into a stream of random keys.
This uses the littleendian counter mode.
Parameters:  key – uint32[2] the key to split
 lo – the range to start extracting from
 hi – the range to stop extracting from
Returns: uint32[hi  lo, 2] the split keys
Return type: keys

class
trax.layers.research.position_encodings.
InfinitePositionalEncoding
(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')¶ Bases:
trax.layers.base.Layer
Infinite positional encoding.

__init__
(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')¶ Initializes the encoding.
The encoding tries to roughly evenly traverse the latent space. The recurrence time is dependent on how many bits per dimension you use.
There are two parameters to control randomization:  randomizing the origin every 1/drift steps by letting it drift  randomizing the origin per call
Parameters:  drift – variance in position difference per unit of difference
 affine – whether to randomize the origin every call
 transform – learnable transform after encoding (any/diag/none)
 time_bin_length – Add features AxialPositionalEncoding learns if TimeBinCausalAttention is the first layer. bin_length should match TBCA.bin_length If you set transform=’diag’, this flag increases your model capacity to close to transform=’any’, though it will still train slower.
 mode – if ‘predict’, allow evaluating one token at a time

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.research.position_encodings.
TimeBinPositionalEncoding
(time_bin_length, mode='train')¶ Bases:
trax.layers.base.Layer
Just the engineered features from InfinitePositionalEncoding.

num_features
= 3¶

__init__
(time_bin_length, mode='train')¶ Initializes the encoding.
Parameters:  time_bin_length – TimeBinCausalAttention.bin_length of the first layer.
 mode – if ‘predict’, allow evaluating one token at a time

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

trax.models¶
atari_cnn¶
Simple net for playing Atari games using PPO.

trax.models.atari_cnn.
AtariCnn
(n_frames=4, hidden_sizes=(32, 32), output_size=128, mode='train')¶ An Atari CNN.

trax.models.atari_cnn.
AtariCnnBody
(n_frames=4, hidden_sizes=(32, 64, 64), output_size=512, mode='train', kernel_initializer=None, padding='VALID')¶ An Atari CNN.

trax.models.atari_cnn.
FrameStackMLP
(n_frames=4, hidden_sizes=(64, ), output_size=64, mode='train')¶ MLP operating on a fixed number of last frames.
mlp¶
mlp – functions that assemble “multilayer perceptron” networks.

trax.models.mlp.
MLP
(layer_widths=(128, 64), activation_fn=<function Relu>, out_activation=False, flatten=True, mode='train')¶ A “multilayer perceptron” (MLP) network.
This is a classic fully connected feedforward network, with one or more layers and a (nonlinear) activation function between each layer. For historical reasons, such networks are often called multilayer perceptrons; but they are more accurately described as multilayer networks, where each layer + activation function is a perceptronlike unit (see, e.g., [https://en.wikipedia.org/wiki/Multilayer_perceptron#Terminology]).
Parameters:  layer_widths – Tuple of ints telling the number of layers and the width of each layer. For example, setting layer_widths=(128, 64, 32) would yield 3 layers with successive widths of 128, 64, and 32.
 activation_fn – Type of activation function between pairs of fully connected layers; must be an activationtype subclass of Layer.
 out_activation – If True, include a copy of the activation function as the last layer in the network.
 flatten – If True, insert a layer at the head of the network to flatten the input tensor into a matrix of shape (batch_size. 1).
 mode – Ignored.
Returns: An assembled MLP network with the specified layers. This network can either be initialized and trained as a full model, or can be used as a building block in a larger network.
neural_gpu¶
Implementation of the improved Neural GPU (NGPU).

trax.models.neural_gpu.
SaturationCost
(x, limit=0.9)¶

trax.models.neural_gpu.
DiagonalGate
()¶ Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.

trax.models.neural_gpu.
ConvDiagonalGRU
(units, kernel_size=(3, 3))¶ Build convolutional GRU with diagonal gating as in ImprovedNGPU.

trax.models.neural_gpu.
NeuralGPU
(d_feature=96, steps=16, vocab_size=2, mode='train')¶ Implementation of Neural GPU: https://arxiv.org/abs/1702.08727.
Parameters:  d_feature – Number of memory channels (dimensionality of feature embedding).
 steps – Number of times depthwise recurrence steps.
 vocab_size – Vocabulary size.
 mode – Whether we are training or evaluating or doing inference.
Returns: A NeuralGPU Stax model.
resnet¶
ResNet.

trax.models.resnet.
ConvBlock
(kernel_size, filters, strides, norm, non_linearity, mode='train')¶ ResNet convolutional striding block.

trax.models.resnet.
IdentityBlock
(kernel_size, filters, norm, non_linearity, mode='train')¶ ResNet identical size block.

trax.models.resnet.
Resnet50
(d_hidden=64, n_output_classes=1001, mode='train', norm=<sphinx.ext.autodoc.importer._MockObject object>, non_linearity=<function Relu>)¶ ResNet.
Parameters:  d_hidden – Dimensionality of the first hidden layer (multiplied later).
 n_output_classes – Number of distinct output classes.
 mode – Whether we are training or evaluating or doing inference.
 norm – Layer used for normalization, Ex: BatchNorm or FilterResponseNorm.
 non_linearity – Layer used as a nonlinearity, Ex: If norm is BatchNorm then this is a Relu, otherwise for FilterResponseNorm this should be ThresholdedLinearUnit.
Returns: The list of layers comprising a ResNet model with the given parameters.

trax.models.resnet.
WideResnetBlock
(channels, strides=(1, 1), bn_momentum=0.9, mode='train')¶ WideResnet convolutional block.

trax.models.resnet.
WideResnetGroup
(n, channels, strides=(1, 1), bn_momentum=0.9, mode='train')¶

trax.models.resnet.
WideResnet
(n_blocks=3, widen_factor=1, n_output_classes=10, bn_momentum=0.9, mode='train')¶ WideResnet from https://arxiv.org/pdf/1605.07146.pdf.
Parameters:  n_blocks – int, number of blocks in a group. total layers = 6n + 4.
 widen_factor – int, widening factor of each group. k=1 is vanilla resnet.
 n_output_classes – int, number of distinct output classes.
 bn_momentum – float, momentum in BatchNorm.
 mode – Whether we are training or evaluating or doing inference.
Returns: The list of layers comprising a WideResnet model with the given parameters.
rl¶
Policy networks.

trax.models.rl.
Policy
(policy_distribution, body=None, normalizer=None, head_init_range=None, batch_axes=None, mode='train')¶ Attaches a policy head to a model body.

trax.models.rl.
Value
(body=None, normalizer=None, inject_actions=False, inject_actions_n_layers=1, inject_actions_dim=64, batch_axes=None, mode='train', is_discrete=False, vocab_size=2, multiplicative_action_injection=False, head_init_range=None)¶ Attaches a value head to a model body.

trax.models.rl.
PolicyAndValue
(policy_distribution, body=None, policy_top=<function Policy>, value_top=<function Value>, normalizer=None, joint=True, mode='train')¶ Attaches policy and value heads to a model body.

trax.models.rl.
Quality
(body=None, normalizer=None, batch_axes=None, mode='train', n_actions=2, head_init_range=None)¶ The network takes as input an observation and outputs values of actions.
rnn¶
RNNs (recursive neural networks).

trax.models.rnn.
RNNLM
(vocab_size, d_model=512, n_layers=2, rnn_cell=<sphinx.ext.autodoc.importer._MockObject object>, rnn_cell_d_state_multiplier=2, dropout=0.1, mode='train')¶ Returns an RNN language model.
This model performs autoregressive language modeling:
 input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in range(vocab_size), and 0 values mark padding positions.
 output: rank 3 tensor representing a batch of logprobability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, vocab_size).
Parameters:  vocab_size – Input vocabulary size – each element of the input tensor should be an integer in range(vocab_size). These integers typically represent token IDs from a vocabularybased tokenizer.
 d_model – Embedding depth throughout the model.
 n_layers – Number of RNN layers.
 rnn_cell – Type of RNN cell; must be a subclass of Layer.
 rnn_cell_d_state_multiplier – Multiplier for feature depth of RNN cell state.
 dropout – Stochastic rate (probability) for dropping an activation value when applying dropout.
 mode – If ‘predict’, use fast inference; if ‘train’ apply dropout.
Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set.

trax.models.rnn.
GRULM
(vocab_size=256, d_model=512, n_layers=1, mode='train')¶ Returns a GRU (gated recurrent unit) language model.
This model performs autoregressive language modeling:
 input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in range(vocab_size), and 0 values mark padding positions.
 output: rank 3 tensor representing a batch of logprobability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, vocab_size).
Parameters:  vocab_size – Input vocabulary size – each element of the input tensor should be an integer in range(vocab_size). These integers typically represent token IDs from a vocabularybased tokenizer.
 d_model – Embedding depth throughout the model.
 n_layers – Number of GRU layers.
 mode – If ‘predict’, use fast inference (and omit the right shift).
Returns: A GRU language model as a layer that maps from a tensor of tokens to activations over a vocab set.

trax.models.rnn.
LSTMSeq2SeqAttn
(input_vocab_size=256, target_vocab_size=256, d_model=512, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=1, attention_dropout=0.0, mode='train')¶ Returns an LSTM sequencetosequence model with attention.
This model is an encoderdecoder that performs tokenized stringtostring (“source”to“target”) transduction:
inputs (2):
 source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in range(input_vocab_size), and 0 values mark padding positions.
 target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in range(output_vocab_size), and 0 values mark padding positions.
output: rank 3 tensor representing a batch of logprobability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, vocab_size).
An example use would be to translate (tokenized) sentences from English to German.
The model works as follows:
 Input encoder runs on the input tokens and creates activations that are used as both keys and values in attention.
 Preattention decoder runs on the targets and creates activations that are used as queries in attention.
 Attention runs on the queries, keys and values masking out input padding.
 Decoder runs on the result, followed by a crossentropy loss.
Parameters:  input_vocab_size – Input vocabulary size – each element of the input tensor should be an integer in range(vocab_size). These integers typically represent token IDs from a vocabularybased tokenizer.
 target_vocab_size – Target vocabulary size.
 d_model – Final dimension of tensors at most points in the model, including the initial embedding output.
 n_encoder_layers – Number of LSTM layers in the encoder.
 n_decoder_layers – Number of LSTM layers in the decoder after attention.
 n_attention_heads – Number of attention heads.
 attention_dropout – Stochastic rate (probability) for dropping an activation value when applying dropout within an attention block.
 mode – If ‘predict’, use fast inference. If ‘train’, each attention block will include dropout; else, it will pass all values through unaltered.
Returns: An LSTM sequencetosequence model as a layer that maps from a sourcetarget tokenized text pair to activations over a vocab set.
transformer¶
Transformer models: encoder, decoder, language model, and encoderdecoder.
The “Transformer” name and network architecture were introduced in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762).

trax.models.transformer.
TransformerEncoder
(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=<function Relu>)¶ Returns a Transformer encoder suitable for Nway classification.
This model maps tokenized text to Nway (
n_classes
) activations: input: Array representing a batch of text strings via token IDs plus
padding markers; shape is (batch_size, sequence_length), where
sequence_length <=
max_len
. Array elements are integers inrange(vocab_size)
, and 0 values mark padding positions.  output: Array representing a batch of raw (nonnormalized) activations
over
n_classes
categories; shape is (batch_size,n_classes
).
Parameters:  vocab_size – Input vocabulary size – each element of the input array
should be an integer in
range(vocab_size)
. These integers typically represent token IDs from a vocabularybased tokenizer.  n_classes – Last/innermost dimension of output arrays, suitable for Nway classification.
 d_model – Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output.
 d_ff – Last/innermost dimension of special (typically wider)
Dense
layer in the feedforward part of each encoder block.  n_layers – Number of encoder blocks. Each block includes attention, dropout,
residual, layernorm, feedforward (
Dense
), and activation layers.  n_heads – Number of attention heads.
 max_len – Maximum symbol length for positional encoding.
 dropout – Stochastic rate (probability) for dropping an activation value when applying dropout within encoder blocks. The same rate is also used for attention dropout in encoder blocks.
 dropout_shared_axes – Tensor axes on which to share a dropout mask.
Sharing along batch and sequence axes (
dropout_shared_axes=(0,1)
) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions.  mode – If
'train'
, each encoder block will include dropout; else, it will pass all values through unaltered.  ff_activation – Type of activation function at the end of each encoder
block; must be an activationtype subclass of
Layer
.
Returns: A Transformer model that maps strings (conveyed by token IDs) to raw (nonnormalized) activations over a range of output classes.
 input: Array representing a batch of text strings via token IDs plus
padding markers; shape is (batch_size, sequence_length), where
sequence_length <=

trax.models.transformer.
TransformerDecoder
(vocab_size=None, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=<function Relu>)¶ Returns a Transformer decoder.
This model maps sequential inputs to sequential outputs:
 input if
vocab_size
is specified: array representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers inrange(vocab_size)
, and 0 values mark padding positions.  input if
vocab_size
isNone
: 3D array representing a batch of sequences of activation vectors; shape is (batch_size, sequence_length,d_model
).  output: 3D array with shape (batch_size, sequence_length,
d_model
).
The model uses causal attention and does not shift the input to the right. Thus, the output for position t is based on inputs up to and including position t.
Parameters:  vocab_size – If specified, gives the input vocabulary size – each element
of the input tensor should be an integer in
range(vocab_size)
. IfNone
, indicates that the model expects as input sequences of floating point vectors, each withd_model
components.  d_model – Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output.
 d_ff – Last/innermost dimension of special (typically wider)
Dense
layer in the feedforward part of each encoder block.  n_layers – Number of decoder blocks. Each block includes attention, dropout,
residual, layernorm, feedforward (
Dense
), and activation layers.  n_heads – Number of attention heads.
 max_len – Maximum symbol length for positional encoding.
 dropout – Stochastic rate (probability) for dropping an activation value when applying dropout within decoder blocks. The same rate is also used for attention dropout in decoder blocks.
 dropout_shared_axes – Tensor axes on which to share a dropout mask.
Sharing along batch and sequence axes (
dropout_shared_axes=(0,1)
) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions.  mode – If
'train'
, each encoder block will include dropout; else, it will pass all values through unaltered.  ff_activation – Type of activation function at the end of each encoder
block; must be an activationtype subclass of
Layer
.
Returns: a Transformer model that maps strings (conveyed by token IDs) to sequences of activation vectors.
If
vocab_size
isNone
: a Transformer model that maps sequences of activation vectors to sequences of activation vectors.Return type: If
vocab_size
is defined input if

trax.models.transformer.
TransformerLM
(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=<function Relu>)¶ Returns a Transformer language model.
This model performs autoregressive language modeling:
 input: Array representing a batch of text strings via token IDs
plus padding markers; shape is (batch_size, sequence_length). Array
elements are integers in
range(vocab_size)
, and 0 values mark padding positions.  output: 3D array of raw activations with last/innermost dimension of
vocab_size
, suitable for decoding into a batch of token strings; shape is (batch_size, sequence_length,vocab_size
).
This model uses only the decoder part of the overall Transformer.
Parameters:  vocab_size – Input vocabulary size – each element of the input array
should be an integer in
range(vocab_size)
. These integers typically represent token IDs from a vocabularybased tokenizer.  d_model – Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output.
 d_ff – Last/innermost dimension of special (typically wider)
Dense
layer in the feedforward part of each encoder block.  n_layers – Number of decoder blocks. Each block includes attention, dropout,
residual, layernorm, feedforward (
Dense
), and activation layers.  n_heads – Number of attention heads.
 max_len – Maximum symbol length for positional encoding.
 dropout – Stochastic rate (probability) for dropping an activation value when applying dropout within decoder blocks. The same rate is also used for attention dropout in decoder blocks.
 dropout_shared_axes – Tensor axes on which to share a dropout mask.
Sharing along batch and sequence axes (
dropout_shared_axes=(0,1)
) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions.  mode – If
'predict'
, use fast inference. If'train'
, each decoder block will include dropout; else, it will pass all values through unaltered.  ff_activation – Type of activation function at the end of each encoder
block; must be an activationtype subclass of
Layer
.
Returns: A Transformer language model that maps strings (represented as token ID sequences) to sequences of raw (nonnormalized) activation vectors; each vector in the sequence can be mapped (e.g., by argmax) to a token ID.
 input: Array representing a batch of text strings via token IDs
plus padding markers; shape is (batch_size, sequence_length). Array
elements are integers in

trax.models.transformer.
Transformer
(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=<function Relu>)¶ Returns a full Transformer model.
This model is an encoderdecoder that performs tokenized stringtostring (“source”to“target”) transduction:
inputs (2):
 source: Array representing a batch of text strings via token
IDs plus padding markers; shape is (batch_size, sequence_length),
where sequence_length <=
max_len
. Array elements are integers inrange(input_vocab_size)
, and 0 values mark padding positions.  target: Array representing a batch of text strings via token
IDs plus padding markers; shape is (batch_size, sequence_length),
where sequence_length <=
max_len
. Array elements are integers inrange(output_vocab_size)
, and 0 values mark padding positions.
 source: Array representing a batch of text strings via token
IDs plus padding markers; shape is (batch_size, sequence_length),
where sequence_length <=
output: 3D array of raw activations with last/innermost dimension of
output_vocab_size
, suitable for decoding into a batch of token strings; shape is (batch_size, sequence_length,vocab_size
).
An example use would be to translate (tokenized) sentences from English to German.
Parameters:  input_vocab_size – Input vocabulary size – each element of the input tensor
should be an integer in
range(vocab_size)
. These integers typically represent token IDs from a vocabularybased tokenizer.  output_vocab_size – If specified, gives the vocabulary size for the targets;
if
None
, then input and target integers (token IDs) are assumed to come from the same vocabulary.  d_model – Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output.
 d_ff – Last/innermost dimension of special (typically wider)
Dense
layer in the feedforward part of each encoder block.  n_encoder_layers – Number of encoder blocks.
 n_decoder_layers – Number of decoder blocks.
 n_heads – Number of attention heads.
 max_len – Maximum symbol length for positional encoding.
 dropout – Stochastic rate (probability) for dropping an activation value when applying dropout within encoder/decoder blocks. The same rate is also used for attention dropout in encoder/decoder blocks.
 dropout_shared_axes – Tensor axes on which to share a dropout mask.
Sharing along batch and sequence axes (
dropout_shared_axes=(0,1)
) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions.  mode – If
'predict'
, use fast inference. If'train'
, each encoder/decoder block will include dropout; else, it will pass all values through unaltered.  ff_activation – Type of activation function at the end of each
encoder/decoder block; must be an activationtype subclass of
Layer
.
Returns: A Transformer model as a layer that maps from a sourcetarget tokenized text pair to activations over a vocab set.
reformer.reformer¶
Reformer Models.

trax.models.reformer.reformer.
DecoderBlock
(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size, n_attention_layers=1, n_feedforward_layers=1, center_layernorm=True, use_bfloat16=False, mode='train')¶ Reversible transformer decoder layer.
Parameters:  d_model – int: depth of embedding
 d_ff – int: depth of feedforward layer
 d_attention_key – int: depth of key vector for each attention head
 d_attention_value – int: depth of value vector for each attention head
 n_heads – int: number of attention heads
 attention_type – subclass of tl.BaseCausalAttention: attention class to use
 dropout – float: dropout rate (how much to drop out)
 ff_activation – the nonlinearity in feedforward layer
 ff_dropout – the dropout rate in feedforward layer
 ff_use_sru – int; if > 0, we use this many SRU layers instead of feedforward
 ff_chunk_size – int; if > 0, chunk feedforward into thissized chunks
 ff_sparsity – int, if > 0 use sparse feedforward block with this sparsity
 attention_chunk_size – int, if > 0 run attention chunked at this size
 n_attention_layers – how many residual causal attention layers should we have before the feedforward block (default: 1, the standard block)
 n_feedforward_layers – how many FFNN layers should we have (default 1).
 center_layernorm – whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization.
 use_bfloat16 – whether to use bfloat16 for weights (default: False).
 mode – str: ‘train’ or ‘eval’
Returns: the layer.

trax.models.reformer.reformer.
ReformerLM
(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, attention_type=<sphinx.ext.autodoc.importer._MockObject object>, pos_type=None, pos_axial_shape=(), pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, ff_activation=<function FastGelu>, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, mode='train')¶ Reversible transformer language model (only uses a decoder, no encoder).
Parameters:  vocab_size – int: vocab size
 d_model – int: depth of each half of the twopart features
 d_ff – int: depth of feedforward layer
 d_attention_key – int: depth of key vector for each attention head
 d_attention_value – int: depth of value vector for each attention head
 n_layers – int: number of decoder layers
 n_heads – int: number of attention heads
 dropout – float: dropout rate (how much to drop out)
 max_len – int: maximum symbol length for positional encoding
 attention_type – class: attention class to use, such as SelfAttention.
 pos_type – string, the type of positional embeddings to use.
 pos_axial_shape – tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled.
 pos_d_axial_embs – tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model.
 pos_start_from_zero_prob – how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize).
 pos_max_offset_to_add – maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples.
 ff_activation – the nonlinearity in feedforward layer
 ff_use_sru – int; if > 0, we use this many SRU layers instead of feedforward
 ff_chunk_size – int; if > 0, chunk feedforward into thissized chunks
 ff_sparsity – int, if > 0 use sparse feedforward block with this sparsity
 loss_sparsity_type – str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used.
 loss_sparsity – int, the sparsity for loss layer (if used)
 loss_d_lowrank – int, the dimensions for intermediate layer (if used)
 loss_sparsity_prob – float, the probability for sparse version of loss to be used. If None, only sparse version is used.
 attention_chunk_size – int, if > 0 run attention chunked at this size
 mode – str: ‘train’, ‘eval’, or ‘predict’
Returns: the layer.

trax.models.reformer.reformer.
ReformerShortenLM
(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, attention_type=<sphinx.ext.autodoc.importer._MockObject object>, pos_type=None, pos_axial_shape=(), pos_d_axial_embs=None, ff_activation=<function FastGelu>, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, mode='train')¶ Reversible transformer language model with shortening.
When shorten_factor is F and processing an input of shape [batch, length], we embed the (shiftedright) input and then group each F elements (on length) into a single vector – so that in the end we process a tensor of shape
[batch, length // F, d_model]
almost until the end – at the end it’s unshortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate.
Parameters:  vocab_size – int: vocab size
 shorten_factor – by how much to shorten, see above
 d_embedding – the depth of the embedding layer and final logits
 d_model – int: depth of each half of the twopart features
 d_ff – int: depth of feedforward layer
 d_attention_key – int: depth of key vector for each attention head
 d_attention_value – int: depth of value vector for each attention head
 n_layers – int: number of decoder layers
 n_heads – int: number of attention heads
 dropout – float: dropout rate (how much to drop out)
 max_len – int: maximum symbol length for positional encoding
 attention_type – class: attention class to use, such as SelfAttention.
 pos_type – string, the type of positional embeddings to use.
 pos_axial_shape – tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled.
 pos_d_axial_embs – tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, values must sum to d_embedding.
 ff_activation – the nonlinearity in feedforward layer
 ff_use_sru – int; if > 0, we use this many SRU layers instead of feedforward
 ff_chunk_size – int; if > 0, chunk feedforward into thissized chunks
 ff_sparsity – int, if > 0 use sparse feedforward block with this sparsity
 attention_chunk_size – int, if > 0 run attention chunked at this size
 mode – str: ‘train’ or ‘eval’
Returns: the layer.

trax.models.reformer.reformer.
EncoderBlock
(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, center_layernorm=True, use_bfloat16=False, use_two_swaps_per_block=True, mode='train')¶ Returns a list of layers that implements a Reformer encoder block.
The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input.
Parameters:  d_model – int: depth of embedding
 d_ff – int: depth of feedforward layer
 n_heads – int: number of attention heads
 attention_type – subclass of tl.BaseCausalAttention: attention class to use
 dropout – float: dropout rate (how much to drop out)
 ff_activation – the nonlinearity in feedforward layer
 ff_dropout – the dropout rate in feedforward layer
 ff_use_sru – int; if > 0, we use this many SRU layers instead of feedforward
 ff_chunk_size – int; if > 0, chunk feedforward into thissized chunks
 ff_sparsity – int, if > 0 use sparse feedforward block with this sparsity
 attention_chunk_size – int, if > 0 run attention chunked at this size
 center_layernorm – whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization.
 use_bfloat16 – whether to use bfloat16 for weights (default: False)
 use_two_swaps_per_block – bool, if True use two reversible swaps in Encoder block, otherwise use only one swap.
 mode – str: ‘train’ or ‘eval’
Returns: A list of layers that maps (activations, mask) to (activations, mask).

trax.models.reformer.reformer.
EncoderDecoderBlock
(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0)¶ Reversible transformer decoder layer.
Parameters:  d_model – int: depth of embedding
 d_ff – int: depth of feedforward layer
 n_heads – int: number of attention heads
 dropout – float: dropout rate (how much to drop out)
 ff_activation – the nonlinearity in feedforward layer
 ff_dropout – float: (optional) separate dropout rate for feedforward layer
 mode – str: ‘train’ or ‘eval’
 ff_use_sru – int; if > 0, we use this many SRU layers instead of feedforward
 ff_chunk_size – int; if > 0, chunk feedforward into thissized chunks
 ff_sparsity – int, if > 0 use sparse feedforward block with this sparsity
Returns: the layer.

trax.models.reformer.reformer.
Reformer
(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=<function Relu>, ff_dropout=None, mode='train', pos_type=None, pos_axial_shape=None, pos_d_axial_embs=None, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0)¶ Reversible transformer encoderdecoder model.
This model expects an input pair: target, source.
At the moment, this model supports dotproduct attention only. For the attention types in the Reformer paper, see ReformerLM.
Parameters:  input_vocab_size – int: vocab size of the source.
 output_vocab_size – int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab.
 d_model – int: depth of embedding
 d_ff – int: depth of feedforward layer
 n_encoder_layers – int: number of encoder layers
 n_decoder_layers – int: number of decoder layers
 n_heads – int: number of attention heads
 dropout – float: dropout rate (how much to drop out)
 max_len – int: maximum symbol length for positional encoding
 ff_activation – the nonlinearity in feedforward layer
 ff_dropout – float: (optional) separate dropout rate at feedforward nonlinearity. This is called relu_dropout in T2T.
 mode – str: ‘train’ or ‘eval’
 pos_type – string, the type of positional embeddings to use.
 pos_axial_shape – tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled.
 pos_d_axial_embs – tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model.
 ff_use_sru – int; if > 0, we use this many SRU layers instead of feedforward
 ff_chunk_size – int; if > 0, chunk feedforward into thissized chunks
 ff_sparsity – int, if > 0 use sparse feedforward block with this sparsity
Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set.
research.bert¶
BERT.

class
trax.models.research.bert.
AddBias
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Bases:
trax.layers.base.Layer

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


trax.models.research.bert.
BERTClassifierHead
(n_classes)¶

trax.models.research.bert.
BERTRegressionHead
()¶

trax.models.research.bert.
BERTMLMHead
(vocab_size=30522)¶

trax.models.research.bert.
BERTPretrainingLoss
()¶

trax.models.research.bert.
BERTPretrainingHead
(n_classes)¶

trax.models.research.bert.
BERT
(d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval')¶ BERT (default hparams are for bertbaseuncased).

class
trax.models.research.bert.
PretrainedBERT
(*sublayers, init_checkpoint=None)¶ Bases:
trax.layers.combinators.Serial
Wrapper that always initializes weights from a pretrained checkpoint.

__init__
(*sublayers, init_checkpoint=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

classmethod
download_model
(model_name)¶ Returns model dir path with model filename.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

research.skipping_transformer¶
trax.data¶
inputs¶
Data sources and input processing.
Trax authors recommend constructing input pipelines using layerlike functions and combinators. For example, following is an input pipeline for training sentiment analysis tasks on the IMDB dataset:
from trax import data
inputs = data.Serial(
data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
data.Shuffle(),
data.FilterByLength(max_length=2048, length_keys=[0]),
data.BucketByLength(boundaries=[ 32, 128, 512, 2048],
batch_sizes=[128, 32, 8, 2, 1],
length_keys=[0]),
data.AddLossWeights()
)
Each of these functions creates a Python generator of tuples of data arrays. For example:
data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
creates a generator of examples (tuples of NumPy ndarray
objects)
from the TFDS imdb_reviews dataset, see here:
https://www.tensorflow.org/datasets/catalog/imdb_reviews
As you can see on the website above, this dataset has ‘text’ and ‘label’ fields and we create tuples containing the text and the label from the training split by specifying keys=(‘text’, ‘label’), train=True.
Other functions, like Tokenize
and Shuffle
, take a generator and output
another generator, in this way converting tuples into other tuples or mixing
the training stream. For example, Tokenize(..., keys=[0])
tokenizes the
first element of a tuple – converting it from text to a NumPy integer array.
And Shuffle
randomizes the order of examples.
Note that all elements in the data pipeline are just functions on generators, so you can use Python’s map and filter and other native functions too. For example, you can create an input pipeline for a language model reading lines from my_file.txt as follows:
inputs = data.Serial(
lambda _: open('my_file.txt'),
lambda g: map(lambda line: line.strip(), g),
data.Tokenize(vocab_file='en_8k.subword'),
lambda g: filter(lambda x: x.shape[0] < 513, g), # At most 512 tokens.
data.Shuffle(),
lambda g: map(lambda x: (x, x)), # Language models have inputs = targets.
data.BucketByLength(boundaries=[ 32, 64, 128, 256, 512],
batch_sizes=[ 32, 16, 8, 4, 2, 1]),
data.AddLossWeights(id_to_mask=0)
)

trax.data.inputs.
Serial
(*fns)¶ Combines generator functions into one that runs them serially.

trax.data.inputs.
Parallel
(fns=None, counters=None, reweight_by_minimum=False, gradually_reweight=False, use_remainders=False)¶ Combines generator functions into one that runs them in parallel.
Parameters:  fns – a sequence of datasets which are combined in parallel.
 counters – a sequence of ints with same length as fns, please see comments on its use below.
 reweight_by_minimum – if set to True, then we reweight every counter by the minimal counter. E.g. counters (10000, 100000) are translated to (1, 10) and hence for every 10 examples from the second dataset we are getting 1 example from the first dataset. Without reweighting first we would see 20 examples from the first and second dataset and then 90 thousand eamples only from the first dataset.
 gradually_reweight – if set to True, then we loop through the generators using a recursive rule defined in emit_examples. First we sort generators by the counters. If we have datasets with counters 1, 20, 40 (after sorting) then we yield examples (a(b c^2)^20)^*, where examples of type a come from the first dataset, of type b from the second and of type c from the third. The exponents are obtained through divisions of subsequent counters.
 use_remainders – if set to True as weell as gradually_reweight is set to True and counters are 1, 20, 45 then after dealing with all examples in the format (a(b c^2)^20)^*, the generator yields the remaining 5 examples from the dataset with counter 45.
Returns: the generator yields samples according to given; if counters are not given then samples are genereted uniformly.
Return type: parallel_generator
Example 1:
gen = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3))defines a generator that yields 33% examples from dataset1, 16% examples from dataset2 and 50% examples from dataset3.
Example 2:
gen = data.Parallel([dataset1, dataset2, dataset3], counters=(20, 50, 30))defines a generator that yields 20% examples from dataset1, 50% examples from dataset2 and 30% examples from dataset3.

trax.data.inputs.
Shuffle
(queue_size=1024)¶ Returns a shuffle function with the given queue size.

trax.data.inputs.
Batch
(batch_size)¶ Returns a batching function with given batch size.

trax.data.inputs.
Dup
()¶ Duplicates (copies) the top element (inputs).
The generator stream is augmented in the following way:
 If the stream consists of a single element (inputs, ), the inputs simply get copied to (inputs, inputs).
 If the stream consists of multiple elements, for example (inputs, weights), the rest of elements get moved toward the right side (inputs, inputs, weights).
Returns: the duplicating function.

trax.data.inputs.
FilterEmptyExamples
(axes=None, debug=False)¶ Filters empty examples.
Filters any example that has an array of size (0,) (if axes=None). Alternatively, checks only axes provided in `axes’ list. Contrary to FilterByLength used with several elements with length_axis, here the example would be filtered if ANY of the dimensions listed in `axes’ contains an empty array.
Parameters:  axes – list of indices to check, if None, all of them.
 debug – If true, emits a log everytime we filter out an empty example.
Returns: Function filtering empty examples.

trax.data.inputs.
FilterByLength
(max_length, min_length=0, length_keys=None, length_axis=0)¶ Returns a function that filters out examples by length.
Parameters:  max_length – int. If not None, indicates maximum length.
 min_length – int. If not None, indicates minimum length.
 length_keys – (list) which example keys to take into account.
 length_axis – which shape axis to take into account.
Returns: a function that filters out examples by length.

trax.data.inputs.
TruncateToLength
(len_map=None)¶ Returns a stream function that resizes items as specified by
len_map
.Parameters: len_map – Dictionary that specifies maximum shapes for potentially multiple features per stream item. For example, given a stream of tokenized string pairs, one could enforce a maximum length of 256 tokens for each string by using len_map={0: (256,), 1: (256,)}
.

trax.data.inputs.
PadToLength
(len_map=None, pad_value=0, multiple=False)¶ Pads the values to lengths given in `len_map’.
len_map contains a dictionary of example keys to dimension sizes.
Parameters:  len_map – dict of int to int, we pad examples to lengths given by the values of the dict. If multiple is True, the dimensions are padded to multiple of this value.
 pad_value – dict of int to int. The value gets applied to constant_values on numpy.pad per given dimension.
 multiple – boolean. If False, pads to the value of len_map. If True, pads to closest multiple of value of len_map.
Returns: Function to pad examples to given lengths.

trax.data.inputs.
BucketByLength
(boundaries, batch_sizes, length_keys=None, length_axis=0, strict_pad_on_len=False)¶ Returns a function for bucketing inputs, see bucket_by_length.

trax.data.inputs.
MLM
(vocab_size=None, max_length=None, noise_density=0.15, mean_noise_span_length=3.0)¶ Pipeline that just does MLM.

trax.data.inputs.
PrefixLM
(input_length=128, output_length=512)¶ Chunks examples so as to make inputs/outputs of specified lenghts.

trax.data.inputs.
ConcatenateToLMInput
(pad_to_length=None)¶ Prepares the input needed for training of Language Models.
Each example needs to contain two elements (input and target). Input is concatenated to target and, if pad_to_length is given, padded to length provided. The loss_weights indicates only the target, without input nor padding.
Parameters: pad_to_length – int, total length of padding of input and target arrays. Returns: Function to return input for a LM.

trax.data.inputs.
CastTo
(dtype=<sphinx.ext.autodoc.importer._MockObject object>, indices=(0, 1), debug=False)¶ Casts the given indices to the given dtype.

trax.data.inputs.
AppendValue
(val=None)¶ Appends values provided in ‘val` to inputs.
val are keyed by example keys, its values contain appended tensors.
Parameters: val – dict of int to tensors. Specific keys get the tensors specified in values appended. Returns: Funtion to append tensors to examples.

trax.data.inputs.
AddLossWeights
(id_to_mask=None)¶ Returns a function to add loss weights; see add_loss_weights.

trax.data.inputs.
UnBatch
()¶ Returns a function which unbatches.

trax.data.inputs.
Prefetch
(n_prefetch=2)¶ Prefetches a number of examples from generator in a separate process.

trax.data.inputs.
UniformlySeek
(name=None, host_id=None, n_hosts=None, dataset_size=None)¶ Sets each host at (dataset_size/n_hosts)th of the dataset.

trax.data.inputs.
CountAndSkip
(name)¶ Returns a function that counts and skips examples (see above).

trax.data.inputs.
Log
(n_steps_per_example=1, only_shapes=True)¶ Creates a logging component of the input pipeline.

trax.data.inputs.
shuffle
(samples, queue_size)¶ Shuffles a sample stream using a randomout nextin queue of given size.
Parameters:  samples – Stream of samples for eventual use as training data or eval data.
 queue_size – Minimum number of samples within which the streamed shuffling takes place.
Yields: Shuffled stream of samples, ready for further processing, e.g., grouping into batches.

trax.data.inputs.
batch
(generator, batch_size)¶ Batch and pad generator as in tf.data.Dataset.padded_batch.

trax.data.inputs.
pad_to_max_dims
(tensors, boundary=None, strict_pad_on_len=False)¶ Pad a tuple of tensors to a joint dimension and return their batch.
For example, a pair of tensors of shape (2, 10) and (3, 9) will be padded to (3, 10) both and the returned tensor will have shape (2, 3, 10).
When boundary is specified, we try to pad all unknown dimensions to boundary if possible, which can help reduce the number of different shapes occurring in the tensors and speed up XLA compilation. So, for example, a pair of tensors of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12).
One special case occurs when boundary is much higher than the padding length that we’d use without boundary. For example, tensors (2, 10) and (3, 9) with boundary=12 could end up padded to (12, 12), but this is very wasteful in the first dimension. In that case, we will use the closest powerof2 instead of the boundary, so the we will end up padding to (4, 12) instead of (12, 12).
Parameters:  tensors – a tuple or list of tensors to pad
 boundary – int or None; if given, expand the padded dimensions to this size
 strict_pad_on_len – bool; if true we pad on the length dimension, dim[0] strictly as a multiple of boundary.
Returns: a tensor, the tensors padded together

trax.data.inputs.
bucket_by_length
(generator, length_fn, boundaries, batch_sizes, strict_pad_on_len=False)¶ Bucket by length, like tf.data.experimental.bucket_by_sequence_length.
This function draws examples from the provided generator and puts an example into a bucket depending on l = length_fn(example). Which bucket is used depends on between which boundaries is l. When a bucket reaches its batch size, as specified by batch_sizes, generates a batch of padded examples from this bucket.
Parameters:  generator – python generator to draw data from.
 length_fn – a function taking the example and returning the length.
 boundaries – a list of bucket boundaries.
 batch_sizes – a list of batch sizes.
 strict_pad_on_len – bool; if true we pad on the length dimension, dim[0] strictly as a multiple of boundary.
Yields: An input batch, which comes from one of the buckets.

trax.data.inputs.
add_loss_weights
(generator, id_to_mask=None)¶ Add weights to inputs without weights and masks by id if requested.
The generator stream is augmented in the following way:
 If the stream consists of pairs (inputs, targets), a loss mask is added that is creates as a tensor of ones of the same shape as targets.
 If id_to_mask is not None, and the stream (after the previous point) has triples (inputs, targets, weights), the weights are multiplied by a 0/1 mask that is 0 iff targets is equal to id_to_mask (1 otherwise).
Parameters:  generator – Stream of tuples.
 id_to_mask – If not None, intvalued id that represents padding, as opposed to true target IDs.
Yields: Examples from the augmented stream.

trax.data.inputs.
generate_random_noise_mask
(noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None)¶ Returns a function that generates a random noise mask.

trax.data.inputs.
consume_noise_mask
(vocab_size=32100)¶ Consumes (tokens, noise mask) and returns (inputs, targets).

trax.data.inputs.
generate_sequential_chunks
(max_length=None)¶ Returns a function that generates chunks of atmost max_length length.

trax.data.inputs.
addition_input_stream
(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, min_length=<sphinx.ext.autodoc.importer._MockObject object>, max_length=<sphinx.ext.autodoc.importer._MockObject object>, pad_to_multiple=32, encdec=False)¶ Data stream for the add problem: <S>x+y<S>(x+y).
Parameters:  vocab_size – how many symbols to use.
 batch_size – how large are the batches.
 min_length – minimal length of w.
 max_length – maximal length of w.
 pad_to_multiple – int, pad length to be multiple of this number.
 encdec – bool, if True return encoderdecoder style inputs (default: False)
Returns: python generator of tuples of data examples

trax.data.inputs.
random_spans_noise_mask
(length, noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None, example=None)¶ Computes span corruption masks given input parameters.

trax.data.inputs.
lower_endian_to_number
(l, base)¶ Helper function: convert a list of digits in the given base to a number.

trax.data.inputs.
number_to_lower_endian
(n, base)¶ Helper function: convert a number to a list of digits in the given base.

trax.data.inputs.
random_number_lower_endian
(length, base)¶ Helper function: generate a random number as a lowerendian digits list.

trax.data.inputs.
count_and_skip
(generator, name)¶ Count the number of items in the generator, skip already counted ones.
This function counts the number of processed examples and puts it into the global variable counters. This variable can be saved and restored, and if restored, this function will skip examples until the restored counter is reached. When the data generator is deterministic, this allows to restore the data reading process from a checkpoint.
Parameters:  generator – generator for examples in the dataset.
 name – string, a unique id that we use to count the examples
Yields: The examples from generator but first skip the number specified in the global variable counters[name] and next increment this variable every time a new example appears.

trax.data.inputs.
save_data_counters
(output_dir, host_id=None)¶ Checkpoint data counters.

trax.data.inputs.
load_data_counters
(output_dir, host_id=None)¶ Checkpoint data counters.

class
trax.data.inputs.
Inputs
(train_stream, eval_stream=None, train_eval_stream=None)¶ Bases:
object
Inputs bundle.
Inputs bundle holds input streams and shapes for a training run. It contains streamcreating functions that return python generators of (input_batch, target_batch) tuples.
 train_stream: training data that will be used for training
 may include all the augmentation or selection the training wants the shape of examples is [batch_fn.batch_size, …]
 train_eval_stream: training data used for evaluation
 examples from training data but usually without augmentation the shape of examples is [batch_fn.eval_batch_size, …]
 eval_stream: evaluation data stream
 examples from evaluation data, usually without augmentation the shape of examples is [batch_fn.eval_batch_size, …]
 input_shape: the shape of inputs
 the […] above, without batch size
 input_dtype: the data type of inputs
 target_shape: the shape of targets
 the […] above, without batch size
 target_dtype: the data type of targets

__init__
(train_stream, eval_stream=None, train_eval_stream=None)¶ Initialize a new set of inputs.
Parameters:  train_stream – a function taking n_devices (an int) and returning a python generator of training batches.
 eval_stream – a function taking n_devices (an int) and returning a python generator of validation batches; if None, then the training generator will be used for evaluation.
 train_eval_stream – a function taking n_devices (an int) and returning a python generator of batches from the training set used for evaluation (if None, use train_stream).

train_stream
(n_devices)¶

eval_stream
(n_devices)¶

train_eval_stream
(n_devices)¶

input_shape
¶ Example input shape, without batch dimension.

target_shape
¶ Example target shape, without batch dimension.

input_dtype
¶ Dtype of the input.

target_dtype
¶ Dtype of the target.

example_shape_dtype
¶ Shape and Dtype of an example batch.

trax.data.inputs.
make_inputs
(train_stream=<sphinx.ext.autodoc.importer._MockObject object>, eval_stream=None)¶ Create Inputs from two streams; mostly for use in gin configs.

trax.data.inputs.
make_additional_stream
(stream=<sphinx.ext.autodoc.importer._MockObject object>)¶ Create a stream mostly for use in gin configs for additional tasks.

trax.data.inputs.
make_parallel_stream
(streams=<sphinx.ext.autodoc.importer._MockObject object>, counters=None)¶ Create a parallel stream for use in gin configs for additional tasks.

trax.data.inputs.
batcher
(data_streams=<sphinx.ext.autodoc.importer._MockObject object>, variable_shapes=True, batch_size_per_device=32, batch_size=None, eval_batch_size=32, bucket_length=32, buckets=None, buckets_include_inputs_in_length=False, batch_shuffle_size=None, max_eval_length=None, id_to_mask=None, strict_pad_on_len=False)¶ Batcher: create trax Inputs from singleexample datastreams.

trax.data.inputs.
batch_fn
(dataset, training, n_devices, variable_shapes, batch_size_per_device=32, batch_size=None, eval_batch_size=32, bucket_length=32, buckets=None, buckets_include_inputs_in_length=False, batch_shuffle_size=None, max_eval_length=None, id_to_mask=None, strict_pad_on_len=False)¶ Batching function.

trax.data.inputs.
random_inputs
(input_shape=<sphinx.ext.autodoc.importer._MockObject object>, input_dtype=<sphinx.ext.autodoc.importer._MockObject object>, input_range=(0, 255), output_shape=<sphinx.ext.autodoc.importer._MockObject object>, output_dtype=<sphinx.ext.autodoc.importer._MockObject object>, output_range=(0, 9))¶ Make random Inputs for debugging.
Parameters:  input_shape – the shape of inputs (including batch dimension).
 input_dtype – the type of the inputs (int32 by default).
 input_range – the range of inputs (defaults to (0, 255)).
 output_shape – the shape of outputs (including batch dimension).
 output_dtype – the type of the outputs (int32 by default).
 output_range – the range of outputs (defaults to (0, 9)).
Returns: trax.inputs.Inputs

trax.data.inputs.
sequence_copy_inputs
(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, train_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_min_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_max_length=<sphinx.ext.autodoc.importer._MockObject object>, reverse=False, pad_to_multiple=32)¶ Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size1]*.
Parameters:  vocab_size – how many symbols to use.
 batch_size – how large are the batches.
 train_length – maximum length of w for training.
 eval_min_length – minimum length of w for eval.
 eval_max_length – maximum length of w for eval.
 reverse – bool (optional, false by default): reverse the second sequence.
 pad_to_multiple – int, pad length to be multiple of this number.
Returns: trax.inputs.Inputs

trax.data.inputs.
simple_sequence_copy_inputs
(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, train_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_min_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_max_length=<sphinx.ext.autodoc.importer._MockObject object>, pad_to_multiple=32)¶ Inputs for the sequence copy problem: w for w in [1..vocab_size1]*.
Parameters:  vocab_size – how many symbols to use.
 batch_size – how large are the batches.
 train_length – maximum length of w for training.
 eval_min_length – minimum length of w for eval.
 eval_max_length – maximum length of w for eval.
 pad_to_multiple – int, pad length to be multiple of this number.
Returns: trax.inputs.Inputs

trax.data.inputs.
addition_inputs
(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, train_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_min_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_max_length=<sphinx.ext.autodoc.importer._MockObject object>, pad_to_multiple=32, encdec=False)¶ Inputs for the add problem: <S>x+y<S>(x+y).
Parameters:  vocab_size – how many symbols to use.
 batch_size – how large are the batches.
 train_length – maximal length of w for training.
 eval_min_length – minimal length of w for eval.
 eval_max_length – maximal length of w for eval.
 pad_to_multiple – int, pad length to be multiple of this number.
 encdec – bool, if True return encoderdecoder style inputs (default: False)
Returns: trax.inputs.Inputs

trax.data.inputs.
sine_inputs
(batch_size=<sphinx.ext.autodoc.importer._MockObject object>, length=<sphinx.ext.autodoc.importer._MockObject object>, max_phase=6.283185307179586, min_period=0.1, max_period=10.0)¶ Sinusoids of random period and phase.
Parameters:  batch_size (int) – Number of examples in a batch.
 length (int) – Length of each sequence.
 max_phase (float) – Maximum phase of the sinusoids.
 min_period (float) – Minimum period of the sinusoids.
 max_period (float) – Maximum period of the sinusoids.
Returns: trax.inputs.Inputs
tf_inputs¶
TensorFlow data sources and associated prepocessing functions.

trax.data.tf_inputs.
t5_data
()¶ Get the T5 data module if available.

trax.data.tf_inputs.
no_preprocess
(dataset, training)¶

trax.data.tf_inputs.
t2t_problems
()¶

trax.data.tf_inputs.
data_streams
(dataset_name, data_dir=None, preprocess_fn=<function no_preprocess>, bare_preprocess_fn=None, shuffle_buffer_size=1024, eval_holdout_size=0, input_name=None, target_name=None)¶ Creates (train, eval) data sources from
dataset_name
.Parameters:  dataset_name – Name of dataset belonging to TFDS or T2T. T2T dataset names
must start with
't2t_'
.  data_dir – Directory where the data is located.
 preprocess_fn – Function to use for preprocessing after appending targets to inputs.
 bare_preprocess_fn – Function to use for preprocessing before appending targets to inputs.
 shuffle_buffer_size – Size of the shuffle buffer.
 eval_holdout_size – If greater than 0, specifies a fraction of training data to siphon off and use as eval data, in place of an separate eval split.
 input_name – Name of the inputs from the dictionary.
 target_name – Name of the outputs either from the dictionary or as a result of postprocessing.
Returns: A pair of functions, (f, g) for use as data sources; call f() to get an iterator of training data samples, and call g() to get an iterator of eval data samples.
 dataset_name – Name of dataset belonging to TFDS or T2T. T2T dataset names
must start with

trax.data.tf_inputs.
dataset_to_stream
(dataset, input_name)¶ Takes a tf.Dataset and creates a numpy stream of ready batches.

trax.data.tf_inputs.
TFDS
(dataset_name, data_dir=None, tfds_preprocess_fn=None, keys=None, train=True, use_alt_eval=False, shuffle_train=True, host_id=None, n_hosts=None, eval_holdout_size=0)¶ Creates a data source from TensorFlow dataset
dataset_name
.Parameters:  dataset_name – Name of the dataset, as registered in TensorFlow datasets
(e.g.,
'glue/mnli'
).  data_dir – Directory where the data is located.
 tfds_preprocess_fn – If specified, function that applies to items in raw dataset (before selecting specific features).
 keys – Tuple of datasetspecific strings that select features from the dataset.
 train – If True, select the training split from the dataset; else select an eval split.
 use_alt_eval – If True, and if
train
is False, select the dataset’s alternate eval split if it has one (or fall back to the dataset’s only eval split). This currently affects only the glue/mnli dataset.  shuffle_train – If True, have TensorFlow preshuffle the training data; else receive training data in deterministic sequence.
 host_id – Integer id used for tracking data subsplits, in cases where
n_hosts
> 1.  n_hosts – If greater than 1, prepare data subsplits for the given number of hosts.
 eval_holdout_size – If greater than 0, specifies a fraction of training data to siphon off and use as eval data, in place of an separate eval split.
Returns: A function f for use as a training or eval data source; call f() to get an iterator of data samples.
 dataset_name – Name of the dataset, as registered in TensorFlow datasets
(e.g.,

trax.data.tf_inputs.
tokenize
(stream, keys=None, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)¶ Tokenize examples from the stream.
This function assumes that stream generates either strings or tuples/dicts containing strings at some keys. This function maps these strings to numpy arrays of integers – the tokenized version of each string.
Parameters:  stream – A python generator yielding strings, tuples or dicts.
 keys – which keys of the tuple/dict to tokenize (by default: all)
 vocab_type – Type of vocabulary, one of: ‘subword’, ‘sentencepiece’, ‘char’.
 vocab_file – Name of the vocabulary file.
 vocab_dir – Directory which contains the vocabulary file.
 n_reserved_ids – An int, offset added so 0, …, n_reserved_ids1 are unused; This is common for example when reserving the 0 for padding and 1 for EOS, but it’s only needed if these symbols are not already included (and thus reserved) in the vocab_file.
Yields: Examples from stream with strings at keys replaced by np.arrays of integers – the tokenized version of these strings.

trax.data.tf_inputs.
Tokenize
(keys=None, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)¶ Returns a function that maps text to integer arrays; see tokenize.

trax.data.tf_inputs.
detokenize
(x, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)¶ Maps integer arrays to text; the opposite of tokenize.
In many cases (all char and subwordtype vocabularies and most sentencepiece ones) the tokenization is invertible, so detokenize(tokenize(x)) = x. In some more rare cases this can remove some spacing, but it is still often useful to run detokenize to get a readable version for a tokenized string.
Parameters:  x – a list or numpy array of integers.
 vocab_type – Type of vocabulary, one of: ‘subword’, ‘sentencepiece’, ‘char’.
 vocab_file – Name of the vocabulary file.
 vocab_dir – Directory which contains the vocabulary file.
 n_reserved_ids – An int, offset added so 0, …, n_reserved_ids1 are unused; This is common for example when reserving the 0 for padding and 1 for EOS, but it’s only needed if these symbols are not already included (and thus reserved) in the vocab_file.
Returns: A string corresponding to the detokenized version of x.

trax.data.tf_inputs.
ConvertToUnicode
(keys=None)¶ Converts to Unicode UTF8 elements of an example.
Useful for when TFDS outputs byte arrays. All of the errors of the conversion are ignored.
Parameters: keys – tuple/list of example dimensions to convert. Returns: Function converting chosen elements of an example to UTF8.

trax.data.tf_inputs.
vocab_size
(vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)¶ Returns the size of the vocabulary (number of symbols used).
This function can be used to set the size of the final layers of a model that needs to predict symbols from a given vocabulary. More precisely, if this function returns N then the last layer size should be set to at least N (it can be more). Note that this function does take reserved IDs into account.
Parameters:  vocab_type – Type of vocabulary, one of: ‘subword’, ‘sentencepiece’, ‘char’.
 vocab_file – Name of the vocabulary file.
 vocab_dir – Directory which contains the vocabulary file.
 n_reserved_ids – An int, offset added so 0, …, n_reserved_ids1 are unused.
Returns: An integer, the number of symbols used (including reserved IDs).

trax.data.tf_inputs.
cifar10_no_augmentation_preprocess
(dataset, training)¶

trax.data.tf_inputs.
cifar10_augmentation_preprocess
(dataset, training)¶ Preprocessing for cifar10 with augmentation (see below).

trax.data.tf_inputs.
cifar10_augmentation_flatten_preprocess
(dataset, training, predict_image_train_weight=0.01)¶ Preprocessing for cifar10 that flattens it and appends targets.

trax.data.tf_inputs.
downsampled_imagenet_flatten_bare_preprocess
(dataset, training)¶ Preprocessing for downsampled_imagenet.
Parameters:  dataset – the dataset.
 training – unused option.
Returns: Flattened dataset.
Preprocessing for downsampled_imagenet 32x32 and 64x64 generation from http://arxiv.org/abs/1601.06759 (page 8).

trax.data.tf_inputs.
concat_preprocess
(dataset, training, pad_symbol=0)¶ Preprocessing function that concatenates input and target for LM.

trax.data.tf_inputs.
squeeze_targets_preprocess
(dataset, training)¶ Preprocessing function that squeezes last axis of targets.

trax.data.tf_inputs.
lm1b_preprocess
(dataset, training, max_target_length=1, max_eval_target_length=1)¶ Preprocessing for LM1B: filter out targets exceeding maximum length.

trax.data.tf_inputs.
wmt_preprocess
(dataset, training, max_length=1, max_eval_length=1)¶ Preprocessing for LM1B: filter out targets exceeding maximum length.

trax.data.tf_inputs.
wmt_concat_preprocess
(dataset, training, max_length=1, max_eval_length=1)¶ Preprocessing for WMT: filter exceeding maximum length and concatenate.

trax.data.tf_inputs.
lm_token_preprocessing
(dataset, training)¶ Concatenates inputs, 0, targets, with masking only for targets.

trax.data.tf_inputs.
bair_robot_pushing_hparams
(hparams=None, video_num_input_frames=1, video_num_target_frames=15)¶

trax.data.tf_inputs.
bair_robot_pushing_preprocess
(dataset, training)¶ Preprocessing function that concatenates input and target frames.

trax.data.tf_inputs.
sentencepiece_tokenize
(stream, spm_path=None, extra_ids=0)¶ Sentencepiece tokenization.

trax.data.tf_inputs.
SentencePieceTokenize
(spm_path=None, extra_ids=0)¶ Returns a function that maps text to integer arrays.

trax.data.tf_inputs.
c4_preprocess
(dataset, training, max_target_length=1, tokenization=None, spm_path=None)¶ Preprocessing function for C4 dataset.

trax.data.tf_inputs.
c4_bare_preprocess_fn
(dataset, training=True, spm_path=None, copy_pretokenized=True, sequence_length=None)¶ Returns a dataset that contains ‘inputs’ and ‘targets’ from C4.

trax.data.tf_inputs.
filter_dataset_on_len
(dataset, training, len_map=None, filter_on_eval=False)¶ Filters a dataset of lengths given in len_map.
Parameters:  dataset – tf.data.Dataset the dataset to filter.
 training – bool, true if we are in training mode.
 len_map –
optional dict of str to (int, int). We filter examples where a feature’s size is beyond the specified bounds. Ex: {‘inputs’: (1, 512), ‘targets’: (64, 128)} will keep only those examples
where 1 <= len(inputs) <= 512 and 64 <= len(targets) <= 128.  filter_on_eval – bool if true, we will filter in eval mode also.
Returns: a filtered tf.data.Dataset.

trax.data.tf_inputs.
truncate_dataset_on_len
(dataset, training, len_map=None, truncate_on_eval=False)¶ Truncates features in an example to lengths given in len_map.
Parameters:  dataset – tf.data.Dataset the dataset to filter.
 training – bool, true if we are in training mode.
 len_map –
optional dict of str to int, we truncate examples where a feature’s size is beyond the max. Ex: {‘inputs’: 512, ‘targets’: 64} will truncate
examples to be within those bounds.  truncate_on_eval – bool if true, we will truncate in eval mode also.
Returns: a filtered tf.data.Dataset.

trax.data.tf_inputs.
pad_dataset_to_length
(dataset, training, len_map=None)¶ Pad features less than specified length to specified length.

trax.data.tf_inputs.
add_eos_to_output_features
(dataset, training, output_features='targets', eos=1)¶ Adds EOS to all features in output_features.

trax.data.tf_inputs.
generic_text_dataset_preprocess_fn
(dataset, training=True, text_preprocess_fns=None, token_preprocess_fns=None, spm_path=None, copy_pretokenized=False, debug_print_examples=False, debug_print_examples_rate=0.01)¶ Preprocesses, tokenizes and postprocesses a tf.data.Dataset.
Parameters:  dataset – tf.data.Dataset to process.
 training – boolean, set to True if training, False otherwise.
 text_preprocess_fns – None or list of callables: tf.data.Dataset, bool > tf.data.Dataset this operates before tokenization. Typically used to select which fields we want to learn over or change something into “text to text” form.
 token_preprocess_fns – None or list of callables: tf.data.Dataset, bool > tf.data.Dataset, this operates after tokenization. Since this can view the tokenized fields, this can be used to filter on length etc.
 spm_path – None or str, path to a sentencepiece model to use for tokenization by default uses the 32k vocabulary from T5.
 copy_pretokenized – bool, if True retains the original fields after tokenization.
 debug_print_examples – bool, if True this prints examples to the logging stream for inspection, both before and after tokenization.
 debug_print_examples_rate – float, [0, 1.0], on average this fraction of dataset examples will be printed out in each phase i.e. pre and post tokenization.
Returns: a tf.data.Dataset with all the preprocessing and tokenization performed.

trax.data.tf_inputs.
get_t5_preprocessor_by_name
(name=None, fn_kwargs=None)¶ Returns a closure of any T5 preprocessor function with its arguments.
The main usecase is to use this (with gin scopes) to make any preprocessor function available in a gin file to configure and use.
See: TFInputs.test_gin_configurable_preprocessors
Parameters:  name – str, name of the preprocessor function to configure.
 fn_kwargs – optional dictionary, the arguments to configure, these will be partially applied to the function given by name.
Returns: a closure of the preprocessor function along with its arguments, this function takes two arguments only, dataset and boolean training and ignores the training and calls the t5 processor with the dataset (and closed over arguments only).

trax.data.tf_inputs.
download_and_prepare
(dataset_name, data_dir)¶ Downloads and prepares T2T or TFDS dataset.
Parameters:  dataset_name – tfds dataset or t2t problem name prefixed by ‘t2t_’.
 data_dir – location of existing dataset or None.
Returns: path string of downloaded data.
Return type: data_dir

trax.data.tf_inputs.
BertSingleSentenceInputs
(batch, labeled=True, cls_id=101, sep_id=102)¶ Prepares inputs for BERT: add [SEP], [CLS] and create embeddings.

trax.data.tf_inputs.
BertDoubleSentenceInputs
(batch, labeled=True, cls_id=101, sep_id=102)¶ Prepares inputs for BERT models by adding [SEP] and [CLS] tokens and creating segment embeddings.

trax.data.tf_inputs.
CreateBertInputs
(double_sentence=True, labeled=True, cls_id=101, sep_id=102)¶

trax.data.tf_inputs.
mask_random_tokens
(batch, explicit_vocab_size=30522, masking_prob=0.15, cls_id=101, sep_id=102, mask_id=103, vocab_start_id=999)¶ Prepares input for the masking task.
Preparation consist in masking masking_prob percentage of nonspecial tokens at each input row; round(masking_prob * num_nonspecial_tokens) random tokens are selected out of which each token is either  replaced with [MASK] token with 80% probability,  replaced with random token with 10% probability,  or unchanged with 10%. The implentation is based on https://github.com/googleresearch/bert/blob/master/create_pretraining_data.py#L342
Examples:  batch is a stream with each row having tuple (token_ids,). Function yields rows of form (modified_token_ids, original_tokens, token_weights), where modified_token_ids have [MASK] tokens or random tokens according to the procedure described above.  batch is a stream with each row having tuple (token_ids, segment_embeddings, nsp_label, nsp_weight).Function yields rows of form (modified_token_ids, segment_embeddings, nsp_label, nsp_weight, original_tokens, token_weights).
Parameters:  batch – stream of inputs. Each row in the stream is a tuple which first element is an array of tokens
 explicit_vocab_size – the total size of the vocabulary.
 masking_prob – Determines percent of nonspecial tokens to be selected for masking.
 cls_id – id of the special CLS token.
 sep_id – id of the special SEP token.
 mask_id – id of the special MASK token.
 vocab_start_id – id of first nonspecial token in the vocabulary.
Yields: a stream with tokens masked for MLM training and 2 appended arrays –  original tokens: a copy of original tokens used as a label for mlm training  token_weights: weights distributed uniformly over selected tokens (sum is 1). Other tokens have 0 weight.

trax.data.tf_inputs.
BertNextSentencePredictionInputs
(dataset_name, data_dir=None, text_key='text', train=True, shuffle_size=50000)¶ Defines a stream for the next sentence prediction task.

trax.data.tf_inputs.
CorpusToRandomChunks
(dataset_name, num_tokens=512, train=True)¶

trax.data.tf_inputs.
BertGlueTrainStream
(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)¶ Returns a Bertpreprocessed training stream for
benchmark
.Parameters: benchmark – Simple lowercase name of a GLUE benchmark, e.g., 'cola'
,'mnli'
,'rte'
.

trax.data.tf_inputs.
BertGlueEvalStream
(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)¶ Returns a Bertpreprocessed eval data stream for
benchmark
.Parameters: benchmark – Simple lowercase name of a GLUE benchmark, e.g., 'cola'
,'mnli'
,'rte'
. If the benchmark includes an alternate eval (e.g., MNLI’s “mismatched” eval/validation split), you can specify it with an'_e2'
suffix, e.g.,'mnli_e2'
.

trax.data.tf_inputs.
T5GlueTrainStream
(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)¶ Returns a T5preprocessed training data stream for
benchmark
.Parameters: benchmark – Simple lowercase name of a GLUE benchmark, e.g., 'cola'
,'mnli'
,'rte'
.

trax.data.tf_inputs.
T5GlueTrainStreamsParallel
(benchmark_list=<sphinx.ext.autodoc.importer._MockObject object>, counters=None, reweight_by_minimum=False, gradually_reweight=False)¶ Returns a parallel set of training streams, based on
benchmark_list
.Parameters:  benchmark_list – List of simple lowercase names of GLUE benchmarks, e.g.,
'cola'
,'mnli'
,'rte'
.  counters – a list of counters to be passed to data.Parallel, e.g.,
 392702, 2490] would be a reasonable counterpart to ([8551,) –
 = ["cola", "mnli", "rte"], see (benchmark_list) –
 https – //github.com/googleresearch/texttotexttransfertransformer/blob/master/t5/data/glue_utils.py#L42
 more details on counters. (for) –
 reweight_by_minimum – divide by the minimal counter.
 gradually_reweight – a more refined reweighting policy, see inputs.py for more details.
 benchmark_list – List of simple lowercase names of GLUE benchmarks, e.g.,

trax.data.tf_inputs.
T5GlueEvalStream
(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)¶ Returns a T5preprocessed eval data stream for
benchmark
.Parameters: benchmark – Simple lowercase name of a GLUE benchmark, e.g., 'cola'
,'mnli'
,'rte'
. If the benchmark includes an alternate eval (e.g., MNLI’s “mismatched” eval/validation split), you can specify it with an'_e2'
suffix, e.g.,'mnli_e2'
.

trax.data.tf_inputs.
T5GlueEvalStreamsParallel
(benchmark_list=<sphinx.ext.autodoc.importer._MockObject object>)¶ Returns a parallel set of T5 eval streams, based on
benchmark_list
.Parameters: benchmark_list – List of strings, each of which is a simple lowercase name of a GLUE benchmark, e.g., 'cola'
,'mnli'
,'rte'
. If a benchmark includes an alternate eval (e.g., MNLI’s “mismatched” eval/validation split), you can specify it with an'_e2'
suffix, e.g.,'mnli_e2'
.

trax.data.tf_inputs.
T5GlueEvalTasks
(benchmark_list=<sphinx.ext.autodoc.importer._MockObject object>)¶ Returns a list of T5 GLUE eval tasks, based on
benchmark_list
.Parameters: benchmark_list – List of strings, each of which indicates the name and data split of a GLUE benchmark. Data splits are indicated as underscore suffixes, e.g., 'cola_t'
(Cola benchmark, training split),'rte_e'
(RTE benchmark, eval/validation split), and'mnli_e2'
(MNLI alternate “mismatched” eval/validation split).

trax.data.tf_inputs.
compute_single_result
(op_name, num_args)¶ An implementation of the most popular ops from the MathQA dataset.

trax.data.tf_inputs.
compute_result
(list_op, list_num)¶ Python execution of MathQA ops.

trax.data.tf_inputs.
single_op_to_python_command
(op_name, num_args)¶ An implementation of the most popular ops from the MathQA dataset.

trax.data.tf_inputs.
compute_program
(list_op)¶ Python execution of MathQA ops.

trax.data.tf_inputs.
compute_nums
(question)¶ Finds numbers in a string and convert them to floats.

trax.data.tf_inputs.
compute_ops
(linear_formula)¶

trax.data.tf_inputs.
process_single_mathqa_example
(example)¶ Execute a single example and verify coherence of a MathQA problem.
Parameters: example – a dictionary with the following fields: Problem  a natural language formulation of the problem Rationale  a natural language solution of the problem options  five possible answers ( a) b) c) d) and e) ) correct  the letter representing the correct answer annotated_formula  formula representing the full solution linear_formula  a string of operations separated by the  character, e.g. multiply(n2,const_100)multiply(n0,n1)divide(#0,#1) multiply(#2,const_100)divide(#3,#1) category  a natural language description of the category to which a given problem belongs. Returns: numerical answer contained in the example python_result: numerical answers computed in Python, including intermediate results. The answer_num should be close python_result[1]list_op: list of arithmetic operations list_num: list of identified numbers in the text
Return type: answer_num

trax.data.tf_inputs.
convert_float_to_mathqa
(number)¶

trax.data.tf_inputs.
convert_to_subtract
(const_string)¶

trax.data.tf_inputs.
execute_mathqa_dsl_program
(problem, dsl_code)¶ Executes the DSL code for a given problem.
Parameters:  problem – problem formulation (needed to get parameters).
 dsl_code – DSL code.
Returns: the result of executing of the DSL code.

trax.data.tf_inputs.
is_number
(s)¶

trax.data.tf_inputs.
execute_mathqa_program
(problem, program)¶ Executes the DSL code for a given problem.
Parameters:  problem – problem formulation (not needed, but we want the same API as in the DSL case).
 program – Python code.
Returns: the result of executing of the Python code.

trax.data.tf_inputs.
CreateMathQAInputs
(dataset_path=None, train=True, test=False, challenge=False, tolerance=0.01, cumulative=True, python_code=False, full_dict=False, partial_results=True, nlp_rationale=False, correct_answer=False, answer_in_mathqa_format=True, correct_answer_given_reasoning=False, category=False, order_prediction=False, reduced_operation_name=True, qed=False)¶ Prepares MathQA inputs.
The generation procedure leaves a lot parameters to be set by the user. Currently we support only correct examples in the following sense: python execution agrees with the declared answer up to 1%.
According to this criterion wrong examples such as problem: calculate 85184 ÷ ? = 352 operations [‘multiply(n0,n1)’] are ignored (this should be divide(n0,n1) in this case).
Parameters:  dataset_path – a path with the MathQA dataset.
 train – if True, then generate training examples; if train, test and challenge are set to False generate validation examples.
 test – if train is set to False and test is set to True, then generate test examples.
 challenge – if train and test are set to False and challenge is set to True, then generate challenge examples.
 tolerance – if for a given example relative difference between Python result and the result declared in the dataset exceeds the level, then the example is dropped; tolerances ranging from 0.1 to 0.001 yield from 18K to 21K examples.
 cumulative – if set to True, then generate examples in the format input  problem + numbers + op1 + op2 + op3 target  op4 If set to False, then examples are in the format input  problem + numbers target  all operations.
 python_code – if set to True, then generates python code instead of MathQA commands.
 full_dict – if set to True, then Python examples are returned together with the DSL code and the NLP rationale.
 partial_results – if set to True, then partial results will be reported as part of the input, e.g. input  problem + numbers + op1 + #1 + op2 + #2 + op3 + #3, target  op4, where #k is the partial results from operation opk. Activated only in cumulative set to True.
 nlp_rationale – if set to True, then input is the problem and the target is the nlp rationale.
 correct_answer – if set to True, then input is the problem plus all possible answers and the target is the correct answer.
 answer_in_mathqa_format – if set to True, then convert numerical answer to the MathQA format and wrap it in the subtract operation. E.g. “3.13” is converted to “subtract(const_3_13,const_0)”.
 correct_answer_given_reasoning – if set to True, then input is the problem plus linear formula plus all possible answers and the target is the correct answer.
 category – if set to True, then input is the problem and the target is its category.
 order_prediction – if set to True, then input is the problem and a list of all operations; with probability 0.5 two operations are swapped; the task consists in detecting whether the operations were swapped. See the order prediction task in CreateAquaInputs in this file.
 reduced_operation_name – If set to True, then in order prediction consider only the operation token without parameterers.
 qed – if set to True, then the reasoning is finished with an additional operation qed.
Returns: a generator of MathQA examples; the generator yields nontokenized examples  they can be further processed using for example the tokenize function from this module
Return type: mathqa_yield_examples

trax.data.tf_inputs.
CreateAquaInputs
(dataset_path=None, train=True, cumulative=False, rationale=False, correct_answer=False, correct_answer_given_reasoning=False, partial_reasoning=True, order_prediction=False)¶ Prepares Aqua inputs.
Parameters:  dataset_path – a path with the Aqua dataset.
 train – if True, then generate training examples, otherwhise generate validation examples (the dataset has also a test set).
 cumulative – if set to True, then generate examples in the format input  problem + step1 + step3 + step3 target  step4 If set to False, then examples are in the format input  problem, target  all operations.
 rationale – if set to True, then input is the problem and the target is the rationale.
 correct_answer – if set to True, then input is the problem plus all possible answers and the target is the correct answer.
 correct_answer_given_reasoning – if set to True, then input is the problem plus reasoning (aka rationale) plus all possible answers and the target is the correct answer.
 partial_reasoning – an additional option related to correct_answer_given_reasoning; if set to True, then we take a random prefix of the reasoning.
 order_prediction –
if set to True, then input is the problem and a list of all operations; with probability 0.5 two operations are swapped; the task consists in detecting whether the operations were swapped. A similar additional task was considered in https://arxiv.org/pdf/1909.11942.pdf and
in a recent work of Piotr Piękos, henrykm@ and mateuszm@.
Returns: a generator of Aqua examples; the generator yields nontokenized examples  they can be further processed using for example the tokenize function from this module
Return type: aqua_yield_examples

trax.data.tf_inputs.
CreateDropInputs
(train=True, mathqa_format=False)¶ Prepares Drop inputs.
Parameters:  train – if True, then generate training examples, otherwhise generate validation examples (the dataset has also a test set).
 mathqa_format – if True, then floats in targets are converted to the the MathQA convention and wrapped in the subtract operation. E.g. “3.13” is converted to “subtract(const_3_13,const_0)”.
Returns: a generator of Drop examples; the generator yields nontokenized examples  they can be further processed using for example the tokenize function from this module
Return type: drop_yield_examples

trax.data.tf_inputs.
CreateAnnotatedDropInputs
(dataset_path=None, train=True, single_file=True, unique=False, total_number_of_samples=None, percentile=1.0)¶ Prepares annotated Drop inputs.
Example of an annotated input which can be used with this interface:
 {
 ‘passage’: ‘The Armenian Prelature of Cyprus was established in 973 by Catholicos Khatchig I. Historically, the Prelature has been under the jurisdiction of the Catholicosate of the Great House of Cilicia, while today it is the oldest theme that falls under its jurisdiction. Since 2014 the Prelate, a Catholicosal Vicar General, has been Archbishop Nareg Alemezian. The parish priest in Nicosia is Fr. Momik Habeshian, while the parish priest in Larnaca and Limassol is Fr. Mashdots Ashkarian. For centuries, the Prelature building was located within the Armenian compound in Victoria street in walled Nicosia; when that area was taken over by TurkishCypriot extremists in 19631964, the Prelature was temporarily housed in Aram Ouzounian street and, later on, in Kyriakos Matsis street in Ayios Dhometios. Thanks to the efforts of Bishop Zareh Aznavorian and with financial aid from the Evangelical Church of Westphalia, the new Prelature building was erected in 1983, next to the Virgin Mary church and the Nareg school in Nicosia, by architects Athos Dikaios & Alkis Dikaios; it was officially inaugurated on 4 March 1984, during the pastoral visit of Catholicos Karekin II. By initiative of Archbishop Varoujan Hergelian, in 1998 the basement of the building was renovated and the “Vahram Utidjian” Hall was formed; previously a store room, it became a reality from the proceeds of the auction in 1994 of the art collection that Vahram Utidjian had donated to the Prelature in 1954. It was inaugurated on 3 February 1999 by Catholicos Aram I; numerous charity, communal and cultural events take place there. The Prelature’s consistory houses a collection of ecclesiastical relics, some of which were previously in the old Virgin Mary church or the Magaravank.’, ‘question’: ‘How many years after the Vahram Utidjian was donated to the Prelature was it sold at an auction?’, ‘answer’: 40, ‘calculation’: ‘subtract(n8,n9)’
}
In this example the calculation is formulated using the notation from the MathQA dataset, but this is not required. subtract(n8,n9) means that the answer 40 can be obtained through the substraction of the 9th and and the 10th number in the input. The input consists of the passage concatened with the question. The annotations can be generated using, for example, a method from the paper https://arxiv.org/abs/1909.00109.
Parameters:  dataset_path – a path with the Aqua dataset.
 train – if True, then generate training examples, otherwhise generate validation examples (the dataset has also a test set).
 single_file – if True, then look just for one file. If False, read all json files in a given directory and assume that each file contains one example. Applied only to training data.
 unique – if set to True, then the generator will provide at most one question per passage.
 total_number_of_samples – if set to a positive integer, then the total number of unique samples will be bounded total_number_of_samples.
 percentile – the percentile of the train dataset used for training; default set to 1., though setting to a lower value can be interesting when combined train is combined with another source of data.
Returns: a generator of annotated Drop examples; the generator yields nontokenized examples  they can be further processed using for example the tokenize function from this module.
Return type: drop_annotated_yield_examples
trax.optimizers¶
adafactor¶
Adafactor optimizer class.

class
trax.optimizers.adafactor.
Adafactor
(learning_rate=0.05, factored=True, multiply_by_parameter_scale=True, do_clipping=True, do_momentum=False, momentum_in_bfloat16=False, beta1=0.0, decay_rate=0.8, clipping_threshold=1.0, weight_decay_rate=1e05, weight_decay_n_steps=0, epsilon1=1e16, epsilon2=0.001)¶ Bases:
trax.optimizers.base.Optimizer
Adafactor optimizer, as described in https://arxiv.org/abs/1804.04235.

__init__
(learning_rate=0.05, factored=True, multiply_by_parameter_scale=True, do_clipping=True, do_momentum=False, momentum_in_bfloat16=False, beta1=0.0, decay_rate=0.8, clipping_threshold=1.0, weight_decay_rate=1e05, weight_decay_n_steps=0, epsilon1=1e16, epsilon2=0.001)¶ Create the Adafactor optimizer.
Adafactor is described in https://arxiv.org/abs/1804.04235.
Parameters:  learning_rate – float: traxprovided learning rate.
 factored – boolean: whether to use factored secondmoment estimator for 2d variables.
 multiply_by_parameter_scale – boolean: if True, then scale provided learning_rate by parameter norm. if False, provided learning_rate is absolute step size.
 do_clipping – whether to clip gradients; if True, set clipping_theshold.
 do_momentum – whether to use momentum; if True, set beta1.
 momentum_in_bfloat16 – if True, store momentum in bfloat16 to save memory.
 beta1 – a float value between 0 and 1, enables momentum and uses extra memory if nonzero! Off by default.
 decay_rate – float: controls secondmoment exponential decay schedule.
 clipping_threshold – an optional float >= 1, if None no update clipping.
 weight_decay_rate – rate at which to decay weights.
 weight_decay_n_steps – for how many steps to decay weights (always if None)
 epsilon1 – Regularization constant for squared gradient.
 epsilon2 – Regularization constant for parameter scale.

init
(weights)¶ Creates optimizer slots that fit the given weights.
Parameters: weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.

update
(step, grads, weights, slots, opt_params)¶ Computes updated layer weights and optimizer slots for one training step.
Parameters:  step – Training step number.
 grads – Gradient values for this node (from backpropagation during a training step).
 weights – Current weight values for this node (i.e., layer weights).
 slots – Current slot values for this node.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns: Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

adam¶
Adam optimizer class.

class
trax.optimizers.adam.
Adam
(learning_rate=0.0001, weight_decay_rate=1e05, b1=0.9, b2=0.999, eps=1e05, clip_grad_norm=None)¶ Bases:
trax.optimizers.base.Optimizer
Adam optimizer; described in https://arxiv.org/abs/1412.6980.
The update rule for time step \(t\), given gradients \(g_t\) and “Stepsize” \(\alpha\), is:
\[\begin{split}\hat{m}_t &\leftarrow \big(\beta_1 \cdot m_{t1} + (1  \beta_1) \cdot g_t\big)\ /\ (1  \beta_1^t) \\ \hat{v}_t &\leftarrow \big(\beta_2 \cdot m_{t1} + (1  \beta_2) \cdot g_t^2\big)\ /\ (1  \beta_2^t) \\ \theta_t &\leftarrow \theta_{t1} \ \alpha \cdot \hat{m}_t / \big(\sqrt{\hat{v}_t} + \epsilon\big)\end{split}\]
__init__
(learning_rate=0.0001, weight_decay_rate=1e05, b1=0.9, b2=0.999, eps=1e05, clip_grad_norm=None)¶ Creates an Adam optimizer.
Parameters:  learning_rate – Initial (unadapted) learning rate \(\alpha\); original paper calls this Stepsize and suggests .001 as a generally good value.
 weight_decay_rate – Fraction of prior weight values to subtract on each step; equivalent to multiplying each weight element by 1  weight_decay_rate. (This is not part of the core Adam algorithm.)
 b1 – Exponential decay rate \(\beta_1\) for first moment estimates.
 b2 – Exponential decay rate \(\beta_2\) for second moment estimates.
 eps – Small positive constant \(\epsilon\) for numerical stability.
 clip_grad_norm – Threshold value above which gradient clipping occurs. (This is not part of the core Adam algorithm.)

init
(weights)¶ Creates optimizer slots that fit the given weights.
Parameters: weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.

update
(step, grads, weights, slots, opt_params)¶ Computes updated layer weights and optimizer slots for one training step.
Parameters:  step – Training step number.
 grads – Gradient values for this node (from backpropagation during a training step).
 weights – Current weight values for this node (i.e., layer weights).
 slots – Current slot values for this node.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns: Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

base¶
Trax base optimizer class.

class
trax.optimizers.base.
Optimizer
(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)¶ Bases:
object
Base class for optimizers that work hand in hand with Trax layers.
To define an optimizer subclass, specify its behavior with respect to a single node in the network (e.g., a single dense layer):
 init: how to create/initialize optimizerinternal parameters (“slots”),
 as a function of the node’s weights.
 update: how to use gradient information to update node weights and
 optimizer slots.
The Trax runtime combines these nodelocal computations into layer weight updates and optimizer slot updates for the whole tree of layers in the model.

__init__
(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)¶ Sets initial hyperparameter values for this optimizer.
Takes optimizer hyperparameters as keyword arguments. These values can change over time (training steps), e.g., for learning rate schedules.
To expose subclass hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.
Parameters:  learning_rate – Learning rate for the optimizer. This can change during training by means of a training rate schedule.
 clip_grad_norm – If specified, this scalar value is used to limit gradient size – all gradient elements in a training step are treated as if they belonged to a single vector and then scaled back if needed so that such a vector’s L2 norm does not exceed clip_grad_norm. If None, no clipping happens.
 **init_opt_params – Initial values of any additional optimizer parameters.

init
(weights)¶ Creates optimizer slots that fit the given weights.
Parameters: weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.

update
(step, grads, weights, slots, opt_params)¶ Computes updated layer weights and optimizer slots for one training step.
Parameters:  step – Training step number.
 grads – Gradient values for this node (from backpropagation during a training step).
 weights – Current weight values for this node (i.e., layer weights).
 slots – Current slot values for this node.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns: Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

slots
¶

opt_params
¶

tree_init
(weight_tree)¶ Assembles nodelocal initializations into fulltree initialization.
Parameters: weight_tree – Weights for an entire model, in a tree that matches the model’s layer structure. Returns: Tuple (slots, opt_params), where slots are the initialized optimizer slot values and opt_params are optimizer hyperparameters (e.g., learning rate, momentum).

tree_update
(step, grad_tree, weight_tree, slots, opt_params, store_slots=True)¶ Assembles nodelocal weight and slot updates for the full layer tree.
Parameters:  step – Current step number in the training process.
 grad_tree – Gradients for the entire model, in a tree that matches the model’s layer structure.
 weight_tree – Current weights for the entire model, in a tree that matches the model’s layer structure.
 slots – Optimizer slots.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum).
 store_slots – Boolean; if True, stores resulting slots in this object; when set to False, this becomes a pure function.
Returns: Tuple (weights, slots), where weights are the optimizerupdated weights for the whole model (in a tree matching the model’s layer structure) and slots are the updated optimizer slot values.

class
trax.optimizers.base.
SGD
(learning_rate=0.01, clip_grad_norm=None, **init_opt_params)¶ Bases:
trax.optimizers.base.Optimizer
Stochastic gradient descent (SGD) optimizer.

init
(weights)¶ Creates optimizer slots that fit the given weights.
Parameters: weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.

update
(step, grads, weights, slots, opt_params)¶ Computes updated layer weights and optimizer slots for one training step.
Parameters:  step – Training step number.
 grads – Gradient values for this node (from backpropagation during a training step).
 weights – Current weight values for this node (i.e., layer weights).
 slots – Current slot values for this node.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns: Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.


trax.optimizers.base.
l2_norm
(tree)¶ Returns an L2 norm computed over all elements of all tensors in tree.
Parameters: tree – Treestructured collection of tensors, e.g., model weights matching the model’s layer structure. Returns: A scalar value computed as if all the tensors in tree were combined and flattened into a single vector, and then the L2 norm of that vector was calculated.

trax.optimizers.base.
clip_grads
(grad_tree, max_norm)¶ Proportionally reduces each gradient value to respect an aggregate limit.
Parameters:  grad_tree – Gradient values structured as a tree of tensors matching the model’s layer structure.
 max_norm – The aggregate limit on gradient values. All gradient elements in grad_tree are treated as if they belonged to a single vector and that vector is shortened if needed so that its L2 norm does not exceed clip_grad_norm.
Returns: A new tree of tensors matching the structure of grad_tree, but with element values proportionally rescaled as needed to respect the max_norm limit.
momentum¶
Nesterov momentum optimizer (also known as Nesterov Accelerated Gradient).

class
trax.optimizers.momentum.
Momentum
(learning_rate=0.01, mass=0.9, weight_decay_rate=1e05, nesterov=True)¶ Bases:
trax.optimizers.base.Optimizer
A momentum optimizer.
This class implements two variants of momentum stochastic gradient descent (SGD): with and without the Nesterov correction. The implementation of the Nesterov update is based on the concepts in Sutskever et al. (2013) [http://jmlr.org/proceedings/papers/v28/sutskever13.pdf], reformulated in Bengio et al. (2012) [https://arxiv.org/abs/1212.0901], to work well with backpropagation (equations 6 and 7):
\[\begin{split}v_t &= \mu_{t1}v_{t1}  \epsilon_{t1}\nabla f(\Theta_{t1}) \\ \Theta_t &= \Theta_{t1}  \mu_{t1} v_{t1} + \mu_t v_t + v_t\end{split}\]where \(\mu_{t1}\) is the momentum (decay) coefficient at time step \(t1\) and \(\epsilon_{t1}\) is the learning rate at \(t1\).
Note that the implementation below also includes a weight decay rate (\(\alpha\)) on the parameters, independent of the Nesterov momentum.

__init__
(learning_rate=0.01, mass=0.9, weight_decay_rate=1e05, nesterov=True)¶ Sets initial hyperparameter values for this optimizer.
Takes optimizer hyperparameters as keyword arguments. These values can change over time (training steps), e.g., for learning rate schedules.
To expose subclass hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.
Parameters:  learning_rate – Learning rate for the optimizer. This can change during training by means of a training rate schedule.
 clip_grad_norm – If specified, this scalar value is used to limit gradient size – all gradient elements in a training step are treated as if they belonged to a single vector and then scaled back if needed so that such a vector’s L2 norm does not exceed clip_grad_norm. If None, no clipping happens.
 **init_opt_params – Initial values of any additional optimizer parameters.

init
(weights)¶ Creates optimizer slots that fit the given weights.
Parameters: weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.

update
(step, grads, weights, velocity, opt_params)¶ Computes updated layer weights and optimizer slots for one training step.
Parameters:  step – Training step number.
 grads – Gradient values for this node (from backpropagation during a training step).
 weights – Current weight values for this node (i.e., layer weights).
 slots – Current slot values for this node.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns: Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

rms_prop¶
RMSProp optimizer class.

class
trax.optimizers.rms_prop.
RMSProp
(learning_rate=0.001, gamma=0.9, eps=1e08, clip_grad_norm=None)¶ Bases:
trax.optimizers.base.Optimizer
RMSProp optimizer.
Uses optimizer weights (“slots”) to maintain a rootmeansquare exponentially decaying average of gradients from prior training batches.

__init__
(learning_rate=0.001, gamma=0.9, eps=1e08, clip_grad_norm=None)¶ Sets initial hyperparameter values for this optimizer.
Takes optimizer hyperparameters as keyword arguments. These values can change over time (training steps), e.g., for learning rate schedules.
To expose subclass hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See momentum.Momentum.__init__ for one such example.
Parameters:  learning_rate – Learning rate for the optimizer. This can change during training by means of a training rate schedule.
 clip_grad_norm – If specified, this scalar value is used to limit gradient size – all gradient elements in a training step are treated as if they belonged to a single vector and then scaled back if needed so that such a vector’s L2 norm does not exceed clip_grad_norm. If None, no clipping happens.
 **init_opt_params – Initial values of any additional optimizer parameters.

init
(weights)¶ Creates optimizer slots that fit the given weights.
Parameters: weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.

update
(step, grads, weights, avg_sq_grad, opt_params)¶ Computes updated layer weights and optimizer slots for one training step.
Parameters:  step – Training step number.
 grads – Gradient values for this node (from backpropagation during a training step).
 weights – Current weight values for this node (i.e., layer weights).
 slots – Current slot values for this node.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns: Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

sm3¶
SM3 optimizer class.

class
trax.optimizers.sm3.
MomentumType
¶ Bases:
enum.IntEnum
An enumeration.

EMA
= 1¶

HEAVY_BALL
= 2¶

NESTEROV
= 3¶


class
trax.optimizers.sm3.
SM3
(learning_rate=0.01, momentum=0.9, second_moment_averaging=1.0, weight_decay=0.0, momentum_type=<MomentumType.EMA: 1>)¶ Bases:
trax.optimizers.base.Optimizer
SM3 optimizer, as described in https://arxiv.org/abs/1901.11150.

__init__
(learning_rate=0.01, momentum=0.9, second_moment_averaging=1.0, weight_decay=0.0, momentum_type=<MomentumType.EMA: 1>)¶ Create the SM3 optimizer.
MemoryEfficient Adaptive Optimization. https://arxiv.org/abs/1901.11150
Parameters:  learning_rate – a postitive scalar value for the initial learning rate.
 momentum – optional, a positive scalar value for momentum
 second_moment_averaging – averaging of second moments (if 1.0, adds from begining of time like AdaGrad).
 weight_decay – Weight decay for regularizing the model.
 momentum_type – Nestrov, HeavyBall or EMA (Default).

init
(w)¶ Creates optimizer slots that fit the given weights.
Parameters: weights – Trainable weights for one layer. Optimizer slots typically match the data shape and type of the given layer weights.

update
(step, g, w, slots, opt_params)¶ Computes updated layer weights and optimizer slots for one training step.
Parameters:  step – Training step number.
 grads – Gradient values for this node (from backpropagation during a training step).
 weights – Current weight values for this node (i.e., layer weights).
 slots – Current slot values for this node.
 opt_params – Optimizer hyperparameters (e.g. learning rate, momentum), same across all nodes in the model.
Returns: Tuple of (new_weights, new_slots), which the Trax runtime will use to update the model and optimizer within each training step.

trax.supervised¶
decoding¶
Decoding with Trax models.

trax.supervised.decoding.
autoregressive_sample_stream
(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, accelerate=True, eval_mode=False, eval_min_length=1)¶ Yields samples from model, in autoregressive language model fashion.
This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and further calls to autoregressive_sample_stream repeat the process for successive positions indefinitely.
Inputs and outputs always come in batches, even if size 1. If inputs is present, it must have shape (batch_size, inputs_sequence_length), and each output in the stream has shape (batch_size, 1).
Parameters:  model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as an autoregressive onesampleatatime predictor (e.g., trax.models.TransformerLM), except if eval_mode is set – any model can be sampled then, but the sampling process may be much slower.
 inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model generates the first output based on just the start symbol.
 batch_size – Number of sequences to generate in parallel as a batch.
 temperature – Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability).
 start_id – Integer representing the start symbol for the autoregressive process, or array of shape (batch_size, 1) of such integers.
 accelerate – If True, create an accelerated version of model and use it for generating outputs.
 eval_mode – If True, assume the model is created in eval mode and sample by collecting all previous outputs and passing the whole tensor.
 eval_min_length – If set, the minimum length to pad to in eval mode.
Yields: Tensor of integers with shape (batch_size, 1), representing the batch of outputs for the next position in the stream.

trax.supervised.decoding.
autoregressive_sample
(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, eos_id=1, max_length=100, accelerate=True, eval_mode=False, eval_min_length=1)¶ Returns a batch of sequences created by autoregressive sampling.
This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and this loop repeats until either the model outputs the eos_id value or the output sequence reaches max_length items.
Parameters:  model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as autoregressive onesampleatatime predictor (e.g., trax.models.TransformerLM), except if eval_mode is set – any model can be sampled then, but the sampling process may be much slower.
 inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model must generate the first output with no input to guide it.
 batch_size – Number of sequences to generate in parallel as a batch.
 temperature – Parameter that controls the sharpness of the softmax that feeds the sampling process. Values range from 0.0 (all probability mass goes to one candidate; like an argmax) to positive infinity (all candidates have equal probability).
 start_id – The start symbol (ID/integer) for the autoregressive process, or array of shape (batch_size, 1) of such integers.
 eos_id – The endofsequence symbol (ID/integer) for the autoregressive process.
 max_length – Maximum length for generated sequences.
 accelerate – If True, create an accelerated version of model and use it for generating outputs.
 eval_mode – If True, assume the model is created in eval mode and sample by collecting all previous outputs and passing the whole tensor.
 eval_min_length – If set, the minimum length to pad to in eval mode.
Returns: Tensor of integers with shape (batch_size, output_length) representing a batch of output sequences. output_length is the maximum length of the output sequences, where each sequence can be no longer than max_length.

trax.supervised.decoding.
beam_search
(model, inputs=None, batch_size=1, n_beams=2, start_id=0, eos_id=1, max_length=100, length_penalty=1.0, accelerate=True)¶ Returns a batch of n_beamssequences created by beam search.
This function uses model to generate outputs one position at a time, with access to inputs for the current position and all preceding positions. The new output becomes the next position’s input, and this loop repeats until either the model outputs the eos_id value or the output sequence reaches max_length items – but keeping n_beams top beams.
Parameters:  model – A layer object (subclass of trax.layers.Layer) created in ‘predict’ mode and initialized from trained weights. The model must have a structure that allows it to run as autoregressive onesampleatatime predictor (e.g., trax.models.TransformerLM).
 inputs – Sequence of symbols the model sees as input the first time it generates an output. If None, the model must generate the first output with no input to guide it.
 batch_size – Number of sequences to generate in parallel as a batch.
 n_beams – How many beams to consider at the same time.
 start_id – The start symbol (ID/integer) for the autoregressive process, or array of shape (batch_size, 1) of such integers.
 eos_id – The endofsequence symbol (ID/integer) for the autoregressive process.
 max_length – Maximum length for generated sequences.
 length_penalty – Factor alpha in calculating the length penalty for beams.
 accelerate – If True, create an accelerated version of model and use it for generating outputs.
Returns: Tensor of integers with shape (batch_size, n_beams, output_length) with a batch of output sequences. output_length is the maximum length of the output sequences, where each sequence can be no longer than max_length.
lr_schedules¶
Learning rate (LR) schedules.
In Trax a learning rate schedule is a function: \(\text{step} \mapsto \text{learning_rate}\). This module provides helpers for constructing such functions. For example:
constant(0.001)
returns a function that always returns 0.001.

trax.supervised.lr_schedules.
constant
(value)¶ Returns an LR schedule that is constant from time (step) 1 to infinity.

trax.supervised.lr_schedules.
warmup
(n_warmup_steps, max_value)¶ Returns an LR schedule with linear warmup followed by constant value.
Parameters:  n_warmup_steps – Number of steps during which the learning rate rises on a line connecting (0, 0) and (n_warmup_steps, max_value).
 max_value – Value for learning rate after warmup has finished.

trax.supervised.lr_schedules.
warmup_and_rsqrt_decay
(n_warmup_steps, max_value)¶ Returns an LR schedule with warmup + reciprocal square root decay.

trax.supervised.lr_schedules.
multifactor
(factors='constant * linear_warmup * rsqrt_decay', constant=0.1, warmup_steps=400, decay_factor=0.5, steps_per_decay=20000, steps_per_cycle=100000, second_constant=0.01, second_constant_step=10000, minimum=0)¶ Factorbased learning rate schedule.
Interprets factors in the factors string which can consist of: * constant: interpreted as the constant value, * linear_warmup: interpreted as linear warmup until warmup_steps, * rsqrt_decay: divide by square root of max(step, warmup_steps) * decay_every: Every k steps decay the learning rate by decay_factor. * cosine_deay: Cyclic cosine decay, uses steps_per_cycle parameter. * two_constants: constant until second_constant_step, then switch to
second_constant.Parameters:  factors – a string with factors separated by ‘*’ that defines the schedule.
 constant – float, the starting constant for the learning rate schedule.
 warmup_steps – how many steps to warm up for in the warmup schedule.
 decay_factor – The amount to decay the learning rate by.
 steps_per_decay – How often to decay the learning rate.
 steps_per_cycle – Steps per cycle when using cosine decay.
 second_constant – float, the second constant for the learning rate schedule.
 second_constant_step – the step when the second_constant is triggered.
 minimum – if the computed rate is below the minimum, then return the minimum.
Returns: float > {‘learning_rate’: float}, the stepdependent lr.
Return type: a function learning_rate(step)
training¶
Simplified API (under development) for supervised learning/training in Trax.
This module will eventually replace trainer_lib.Trainer
.
Key classes:
Loop
: Core training loop for an nstep training session, starting from random initialization.TrainTask
: Labeled data + feedback mechanism (loss function w/ optimizer) for modifying a model’s weights.Optimizer
: How to compute model weight updates using lossderived gradients. May contain state (“slots”, 11 with model weights) that accumulates across training steps. (This class is defined in thetrax.optimizers
.)EvalTask
: How and when to measure model performance as a function of training step number.

class
trax.supervised.training.
Loop
(model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, checkpoint_low_metric=None, checkpoint_high_metric=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, callbacks=None)¶ Bases:
object
Loop that can run for a given number of steps to train a supervised model.
Can train the model on multiple tasks by interleaving updates according to the
which_task
argument.The typical supervised training process randomly initializes a model and updates its weights via feedback (lossderived gradients) from a training task, by looping through batches of labeled data. A training loop can also be configured to run periodic evals and save intermediate checkpoints.
For speed, the implementation takes advantage of JAX’s composable function transformations (specifically,
jit
andgrad
). It creates JITcompiled pure functions derived from variants of the core model; schematically: training variant: jit(grad(pure_function(model+loss)))
 evals variant: jit(pure_function(model+evals))
In training or during evals, these variants are called with explicit arguments for all relevant input data, model weights/state, optimizer slots, and random number seeds:
 batch: labeled data
 model weights/state: trainable weights and inputrelated state (e.g., as used by batch norm)
 optimizer slots: weights in the optimizer that evolve during the training process
 random number seeds: JAX PRNG keys that enable highquality, distributed, repeatable generation of pseudorandom numbers

__init__
(model, tasks, eval_model=None, eval_tasks=None, output_dir=None, checkpoint_at=None, checkpoint_low_metric=None, checkpoint_high_metric=None, permanent_checkpoint_at=None, eval_at=None, which_task=None, n_devices=None, random_seed=None, loss_chunk_size=0, use_memory_efficient_trainer=False, adasum=False, callbacks=None)¶ Configures a training
Loop
, including a random initialization.Parameters:  model – Trax layer, representing the core model to be trained. Loss functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 tasks – List of
TrainTask
instances, which define the training data, loss function, and optimizer to be used in respective tasks in this training loop. It can also be a singleTrainTask
instance which is treated in the same way as a singleton list.  eval_model – Optional Trax layer, representing model used for evaluation,
e.g., with dropout turned off. If
None
, the training model (model) will be used.  eval_tasks – List of
EvalTask
instances which define how to evaluate the model: which validation data to use and which metrics to report. Evaluation on each of the tasks and will run and be reported separately which allows to score a model on different subtasks. This argument can also beNone
, in which case no evals will be run, or a singleEvalTask
, which wil be treated in the same way as a singleton list.  output_dir – Path telling where to save outputs (evals and checkpoints).
Can be
None
if botheval_task
andcheckpoint_at
areNone
.  checkpoint_at – Function (integer –> boolean) telling, for step n, whether
that step should have its checkpoint saved. If
None
, the default is periodic checkpointing attask.n_steps_per_checkpoint
.  checkpoint_low_metric – Name of metric, or None. The metric name must
be one of the metric names from the evals in
eval_tasks
. At checkpoint times determined bycheckpoint_at
, a separate specially named checkpoint will be saved (overwriting any previous version) if the designated metric reaches a value less than or equal to any previous recorded low value. No such checkpoint is saved if arg value is None.  checkpoint_high_metric – Name of metric, or None. The metric name must
be one of the metric names from the evals in
eval_tasks
. At checkpoint times determined bycheckpoint_at
, a separate specially named checkpoint will be saved (overwriting any previous version) if the designated metric reaches a value greater than or equal to any previous recorded high value. No such checkpoint is saved if arg value is None.  permanent_checkpoint_at – Function (integer –> boolean) telling,
for step n, whether that step should have its checkpoint saved
permanently. If
None
, the default is periodic checkpointing attask.n_steps_per_permanent_checkpoint
.  eval_at – Function (integer –> boolean) that says, for training step n,
whether that step should run evals. If
None
, run evals on the first step and on every N’th step, as determined by the first training task.  which_task – Function (integer –> integer) indicating which task should be
used at which training step. Can be set to
None
in singletask training.  n_devices – integer or
None
, the number of devices for this computation.  random_seed – the random seed to use; time/os dependent if
None
(default).  loss_chunk_size – int, if > 0 use chunks of this size to make loss computation more more memoryefficient.
 use_memory_efficient_trainer – whether to use a special memoryefficient trainer; if set to 2, the memory efficiency if very aggressive
 adasum – if True, use adaptive summation for multidevice gradients
 callbacks – List of subclasses of StepCallback to call on training steps.

run
(n_steps=1)¶ Runs this training loop for n steps.
Optionally runs evals and saves checkpoints at specified points.
Parameters: n_steps – Stop training after completing n steps.

step
¶ Returns current step number in this training session.

history
¶ Returns history in this training session.

n_devices
¶ Returns the number of devices to be used in this computation.

is_chief
¶ Returns true if this Loop is the chief.

model
¶ Returns the model that is training.

tasks
¶ Returns the training tasks.

eval_model
¶ Returns the model used for evaluation.

eval_tasks
¶ Returns the evaluation tasks.

output_dir
¶ Returns the output directory.

new_rng
()¶ Returns a new singleuse random number generator (JAX PRNG key).

update_weights_and_state
(weights=None, state=None)¶ Updates the weights and state of the trained model.
Sends this data both to the singleton model accessible via Loop.model and to the replicated model on the accelerator.
Useful when the weights or state are modified outside of training, e.g. during data collection in RL agents.
Parameters:  weights – Model weights or
None
. IfNone
, don’t set.  state – Model state or
None
. IfNone
, don’t set.
 weights – Model weights or

run_evals
(summary_writers=None)¶ Runs and records evals for this training session.
Parameters: summary_writers – List of pertask Jaxboard summary writers to log metrics.

log_summary
(values, summary_writer, value_prefix, log_prefix, stdout=True)¶ Logs and saves provided metrics.
Parameters:  values – Dict from metric name to metric value.
 summary_writer – Jaxboard summary writer.
 value_prefix – String appended in front of summary_writer entries.
 log_prefix – String appended in front of logs.
 stdout – Boolean saying if logs should be logged to stdout as well.

save_checkpoint
(basename)¶ Saves checkpoint (multiple files) to disk for the current training step.
Saving a checkpoint will overwrite any previous checkpoint saved with the same
basename
. Use differingbasename
values to save multiple checkpoints or multiple copies of the same checkpoint.Parameters: basename – Basename for saving a checkpoint. Full file paths for the saved checkpoint will combine the output dir, basename, and relevant file extensions (e.g., .weights.npy.gz).

load_checkpoint
(directory=None, filename=None)¶ Loads model weights and step from a checkpoint on disk.
Parameters:  directory – Directory with the checkpoint (self._output_dir by default).
 filename – Checkpoint file name (model.pkl.gz by default).

trax.supervised.training.
pickle_to_file
(obj, file_path, gzip=False)¶ Pickle obj to file_path with gzipping and failure protection.

trax.supervised.training.
unpickle_from_file
(file_path, gzip=False)¶ Unpickle obj from file_path with gzipping.

trax.supervised.training.
init_host_and_devices
(n_devices=None, random_seed=None)¶ Initializes host and device attributes for this trainer.
Parameters:  n_devices – Number of devices this trainer will use. If
None
, get the number from the backend.  random_seed – Random seed as the starting point for all random numbers used
by the trainer. If
None
, calculate one from system time and host id.
Returns: True if this trainer has special chief responsibilities. host_count: Number of hosts in this computation. n_devices: The passed in value of n_devices or a computed default (for this
host).
random_seed: The passed in value of random_seed or a computed default.
Return type: is_chief
 n_devices – Number of devices this trainer will use. If
trax.rl package¶
actor_critic¶
Classes for RL training in Trax.

class
trax.rl.actor_critic.
ActorCriticAgent
(task, value_model=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate='logsumexp', q_value_temperature=1.0, q_value_n_samples=1, q_value_normalization=False, offline=False, **kwargs)¶ Bases:
trax.rl.training.PolicyAgent
Trains policy and value models using actorcritic methods.
 Attrs:
 on_policy (bool): Whether the algorithm is onpolicy. Used in the data
 generators. Should be set in derived classes.

on_policy
= None¶

__init__
(task, value_model=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate='logsumexp', q_value_temperature=1.0, q_value_n_samples=1, q_value_normalization=False, offline=False, **kwargs)¶ Configures the actorcritic trainer.
Parameters:  task – RLTask instance to use.
 value_model – Model to use for the value function.
 value_optimizer – Optimizer to train the value model.
 value_lr_schedule – lr schedule for value model training.
 value_batch_size – Batch size for value model training.
 value_train_steps_per_epoch – Number of steps are we using to train the value model in each epoch.
 value_evals_per_epoch – Number of value trainer evaluations per RL epoch. Every evaluation, we also synchronize the weights of the target network.
 value_eval_steps – Number of value trainer steps per evaluation; only affects metric reporting.
 n_shared_layers – Number of layers to share between value and policy models.
 added_policy_slice_length – How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by max_slice_length in **kwargs.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma).
 q_value – If True, use Qvalues as baselines.
 q_value_aggregate – How to aggregate Qvalues. Options: ‘mean’, ‘max’, ‘softmax’, ‘logsumexp’.
 q_value_temperature – Temperature parameter for the ‘softmax’ and ‘logsumexp’ aggregation methods.
 q_value_n_samples – Number of samples to average over when calculating baselines based on Qvalues.
 q_value_normalization – How to normalize Qvalues before aggregation. Allowed values: ‘std’, ‘abs’, None. If None, don’t normalize.
 offline – Whether to train in offline mode. This matters for some algorithms, e.g. QWR.
 **kwargs – Arguments for PolicyAgent superclass.

value_mean
¶ The mean value of the value function.

value_batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

policy_inputs
(trajectory, values)¶ Create inputs to policy model from a TimeStepBatch and values.
Parameters:  trajectory – a TimeStepBatch, the trajectory to create inputs from
 values – a numpy array: value function computed on trajectory
Returns: a tuple of numpy arrays of the form (inputs, x1, x2, …) that will be passed to the policy model; policy model will compute outputs from inputs and (outputs, x1, x2, …) will be passed to self.policy_loss which should be overridden accordingly.

policy_batches_stream
()¶ Use the RLTask self._task to create inputs to the policy model.

train_epoch
()¶ Trains RL for one epoch.

close
()¶

class
trax.rl.actor_critic.
AdvantageBasedActorCriticAgent
(task, advantage_estimator=<function td_lambda>, advantage_normalization=True, advantage_normalization_epsilon=1e05, advantage_normalization_factor=1.0, added_policy_slice_length=0, **kwargs)¶ Bases:
trax.rl.actor_critic.ActorCriticAgent
Base class for advantagebased actorcritic algorithms.

__init__
(task, advantage_estimator=<function td_lambda>, advantage_normalization=True, advantage_normalization_epsilon=1e05, advantage_normalization_factor=1.0, added_policy_slice_length=0, **kwargs)¶ Configures the actorcritic trainer.
Parameters:  task – RLTask instance to use.
 value_model – Model to use for the value function.
 value_optimizer – Optimizer to train the value model.
 value_lr_schedule – lr schedule for value model training.
 value_batch_size – Batch size for value model training.
 value_train_steps_per_epoch – Number of steps are we using to train the value model in each epoch.
 value_evals_per_epoch – Number of value trainer evaluations per RL epoch. Every evaluation, we also synchronize the weights of the target network.
 value_eval_steps – Number of value trainer steps per evaluation; only affects metric reporting.
 n_shared_layers – Number of layers to share between value and policy models.
 added_policy_slice_length – How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by max_slice_length in **kwargs.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma).
 q_value – If True, use Qvalues as baselines.
 q_value_aggregate – How to aggregate Qvalues. Options: ‘mean’, ‘max’, ‘softmax’, ‘logsumexp’.
 q_value_temperature – Temperature parameter for the ‘softmax’ and ‘logsumexp’ aggregation methods.
 q_value_n_samples – Number of samples to average over when calculating baselines based on Qvalues.
 q_value_normalization – How to normalize Qvalues before aggregation. Allowed values: ‘std’, ‘abs’, None. If None, don’t normalize.
 offline – Whether to train in offline mode. This matters for some algorithms, e.g. QWR.
 **kwargs – Arguments for PolicyAgent superclass.

policy_inputs
(trajectory, values)¶ Create inputs to policy model from a TimeStepBatch and values.

policy_loss_given_log_probs
¶ Policy loss given action logprobabilities.

policy_loss
¶ Policy loss.

policy_metrics
¶

advantage_mean
¶

advantage_std
¶


trax.rl.actor_critic.
every
(n_steps)¶ Returns True every n_steps, for use as *_at functions in various places.

class
trax.rl.actor_critic.
LoopActorCriticAgent
(task, model_fn, optimizer=<class 'trax.optimizers.adam.Adam'>, policy_lr_schedule=<function multifactor>, policy_n_steps_per_epoch=1000, policy_weight_fn=<function LoopActorCriticAgent.<lambda>>, value_lr_schedule=<function multifactor>, value_n_steps_per_epoch=1000, value_sync_at=<function LoopActorCriticAgent.<lambda>>, advantage_estimator=<function monte_carlo>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, margin=0, n_replay_epochs=1, **kwargs)¶ Bases:
trax.rl.training.Agent
Base class for actorcritic algorithms based on Loop.

on_policy
= None¶

__init__
(task, model_fn, optimizer=<class 'trax.optimizers.adam.Adam'>, policy_lr_schedule=<function multifactor>, policy_n_steps_per_epoch=1000, policy_weight_fn=<function LoopActorCriticAgent.<lambda>>, value_lr_schedule=<function multifactor>, value_n_steps_per_epoch=1000, value_sync_at=<function LoopActorCriticAgent.<lambda>>, advantage_estimator=<function monte_carlo>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, margin=0, n_replay_epochs=1, **kwargs)¶ Initializes LoopActorCriticAgent.
Parameters:  task – RLTask instance to use.
 model_fn – Function mode > Trax model, building a joint policy and value network.
 optimizer – Optimizer for the policy and value networks.
 policy_lr_schedule – Learning rate schedule for the policy network.
 policy_n_steps_per_epoch – Number of steps to train the policy network for in each epoch.
 policy_weight_fn – Function advantages > weights for calculating the log probability weights in policy training.
 value_lr_schedule – Learning rate schedule for the value network.
 value_n_steps_per_epoch – Number of steps to train the value network for in each epoch.
 value_sync_at – Function step > bool indicating when to synchronize the target network with the trained network in value training.
 advantage_estimator – Advantage estimator to use in policy and value training.
 batch_size – Batch size for training the networks.
 network_eval_at – Function step > bool indicating in when to evaluate the networks.
 n_eval_batches – Number of batches to compute the network evaluation metrics on.
 max_slice_length – Maximum length of a trajectory slice to train on.
 margin – Number of timesteps to add at the end of each trajectory slice for better advantage estimation.
 n_replay_epochs – Number of epochs of trajectories to store in the replay buffer.
 **kwargs – Keyword arguments forwarded to Agent.

loop
¶ Loop exposed for testing.

policy
(trajectory, temperature=1.0)¶ Policy function that allows to play using this agent.

train_epoch
()¶ Trains RL for one epoch.


class
trax.rl.actor_critic.
A2C
(task, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
Trains policy and value models using the A2C algorithm.

on_policy
= True¶

__init__
(task, entropy_coeff=0.01, **kwargs)¶ Configures the A2C Trainer.

policy_loss_given_log_probs
¶ Definition of the Advantage Actor Critic (A2C) loss.


class
trax.rl.actor_critic.
PPO
(task, epsilon=0.2, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
The Proximal Policy Optimization Algorithm aka PPO.
Trains policy and value models using the PPO algorithm.

on_policy
= True¶

__init__
(task, epsilon=0.2, entropy_coeff=0.01, **kwargs)¶ Configures the PPO Trainer.

policy_loss_given_log_probs
¶ Definition of the Proximal Policy Optimization loss.


trax.rl.actor_critic.
awr_weights
(advantages, beta, thresholds)¶

trax.rl.actor_critic.
awr_metrics
(beta, thresholds, preprocess_layer=None)¶

trax.rl.actor_critic.
awr_weight_stat
(stat_name, stat_fn, beta, thresholds, preprocess_layer)¶

trax.rl.actor_critic.
AWRLoss
(beta, w_max, thresholds)¶ Definition of the Advantage Weighted Regression (AWR) loss.

class
trax.rl.actor_critic.
AWR
(task, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
Trains policy and value models using AWR.

on_policy
= False¶

__init__
(task, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Configures the AWR Trainer.

policy_loss_given_log_probs
¶ Policy loss.


class
trax.rl.actor_critic.
LoopAWR
(task, model_fn, beta=1.0, w_max=20, **kwargs)¶ Bases:
trax.rl.actor_critic.LoopActorCriticAgent
Advantage Weighted Regression.

on_policy
= False¶

__init__
(task, model_fn, beta=1.0, w_max=20, **kwargs)¶ Initializes LoopActorCriticAgent.
Parameters:  task – RLTask instance to use.
 model_fn – Function mode > Trax model, building a joint policy and value network.
 optimizer – Optimizer for the policy and value networks.
 policy_lr_schedule – Learning rate schedule for the policy network.
 policy_n_steps_per_epoch – Number of steps to train the policy network for in each epoch.
 policy_weight_fn – Function advantages > weights for calculating the log probability weights in policy training.
 value_lr_schedule – Learning rate schedule for the value network.
 value_n_steps_per_epoch – Number of steps to train the value network for in each epoch.
 value_sync_at – Function step > bool indicating when to synchronize the target network with the trained network in value training.
 advantage_estimator – Advantage estimator to use in policy and value training.
 batch_size – Batch size for training the networks.
 network_eval_at – Function step > bool indicating in when to evaluate the networks.
 n_eval_batches – Number of batches to compute the network evaluation metrics on.
 max_slice_length – Maximum length of a trajectory slice to train on.
 margin – Number of timesteps to add at the end of each trajectory slice for better advantage estimation.
 n_replay_epochs – Number of epochs of trajectories to store in the replay buffer.
 **kwargs – Keyword arguments forwarded to Agent.


trax.rl.actor_critic.
SamplingAWRLoss
(beta, w_max, thresholds, reweight=False, sampled_all_discrete=False)¶ Definition of the Advantage Weighted Regression (AWR) loss.

class
trax.rl.actor_critic.
SamplingAWR
(task, beta=1.0, w_max=20.0, thresholds=None, reweight=False, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
Trains policy and value models using Sampling AWR.

on_policy
= False¶

__init__
(task, beta=1.0, w_max=20.0, thresholds=None, reweight=False, **kwargs)¶ Configures the AWR Trainer.

policy_metrics
¶

policy_loss
¶ Policy loss.

policy_batches_stream
()¶ Use the RLTask self._task to create inputs to the policy model.

actor_critic_joint¶
Classes for RL training in Trax.

class
trax.rl.actor_critic_joint.
ActorCriticJointAgent
(task, joint_model=None, optimizer=None, lr_schedule=<function multifactor>, batch_size=64, train_steps_per_epoch=500, supervised_evals_per_epoch=1, supervised_eval_steps=1, n_trajectories_per_epoch=50, max_slice_length=1, normalize_advantages=True, output_dir=None, n_replay_epochs=1)¶ Bases:
trax.rl.training.Agent
Trains a joint policyandvalue model using actorcritic methods.

__init__
(task, joint_model=None, optimizer=None, lr_schedule=<function multifactor>, batch_size=64, train_steps_per_epoch=500, supervised_evals_per_epoch=1, supervised_eval_steps=1, n_trajectories_per_epoch=50, max_slice_length=1, normalize_advantages=True, output_dir=None, n_replay_epochs=1)¶ Configures the joint trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 joint_model – Trax layer, representing the joint policy and value model.
 optimizer – the optimizer to use to train the joint model.
 lr_schedule – learning rate schedule to use to train the joint model/.
 batch_size – batch size used to train the joint model.
 train_steps_per_epoch – how long to train the joint model in each RL epoch.
 supervised_evals_per_epoch – number of value trainer evaluations per RL epoch  only affects metric reporting.
 supervised_eval_steps – number of value trainer steps per evaluation  only affects metric reporting.
 n_trajectories_per_epoch – how many trajectories to collect per epoch.
 max_slice_length – the maximum length of trajectory slices to use.
 normalize_advantages – if True, then normalize advantages  currently implemented only in PPO.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 n_replay_epochs – how many last epochs to take into the replay buffer; > 1 only makes sense for offpolicy algorithms.

close
()¶

batches_stream
()¶ Use self.task to create inputs to the policy model.

joint_loss
¶ Joint policy and value loss layer.

advantage_mean
¶ Mean of advantages.

advantage_norm
¶ Norm of advantages.

value_loss
¶ Value loss  so far generic for all A2C.

explained_variance
¶ Explained variance metric.

log_probs_mean
¶ Mean of log_probs aka dist_inputs.

preferred_move
¶ Preferred move  the mean of selected moves.

policy
(trajectory, temperature=1.0)¶ Chooses an action to play after a trajectory.

train_epoch
()¶ Trains RL for one epoch.


class
trax.rl.actor_critic_joint.
PPOJoint
(task, epsilon=0.2, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic_joint.ActorCriticJointAgent
The Proximal Policy Optimization Algorithm aka PPO.
Trains policy and value models using the PPO algortithm.

on_policy
= True¶

__init__
(task, epsilon=0.2, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Configures the PPO Trainer.

batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

joint_loss
¶ Joint policy and value loss.

probs_ratio_mean
¶ Joint policy and value loss layer.

clip_fraction
¶ Joint policy and value loss layer.

entropy_loss
¶ Entropy layer.

approximate_kl_divergence
¶ Approximate KL divergence.

unclipped_objective_mean
¶

clipped_objective_mean
¶

ppo_objective
¶ PPO objective with local parameters.

ppo_objective_mean
¶ PPO objective mean.


class
trax.rl.actor_critic_joint.
A2CJoint
(task, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic_joint.ActorCriticJointAgent
The A2C algorithm.
Trains policy and value models using the A2C algortithm.

on_policy
= True¶

__init__
(task, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Configures the A2C Trainer.

batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

joint_loss
¶ Joint policy and value loss.

entropy_loss
¶ Entropy layer.

approximate_kl_divergence
¶ Approximate KL divergence.

a2c_objective
¶ A2C objective with local parameters.

a2c_objective_mean
¶ A2C objective mean.


class
trax.rl.actor_critic_joint.
AWRJoint
(task, value_loss_coeff=0.1, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Bases:
trax.rl.actor_critic_joint.ActorCriticJointAgent
Trains a joint policyandvalue model using AWR.

__init__
(task, value_loss_coeff=0.1, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Configures the joint AWR Trainer.

batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

joint_loss
¶ Joint policy and value loss.

advantages¶
RL advantage estimators.

trax.rl.advantages.
mask_discount
(discount, discount_mask)¶ Computes a discount to apply at a given timestep, based on the mask.

trax.rl.advantages.
discounted_returns
(rewards, gammas)¶ Computes discounted returns for a trajectory or a batch of them.

trax.rl.advantages.
monte_carlo
(gamma, margin)¶ Calculate Monte Carlo advantage.
We assume the values are a tensor of shape [batch_size, length] and this is the same shape as rewards and returns.
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].

trax.rl.advantages.
td_k
(gamma, margin)¶ Calculate TDk advantage.
The k parameter is assumed to be the same as margin.
We calculate advantage(s_i) as:
gamma^n_steps * value(s_{i + n_steps})  value(s_i) + discounted_rewardswhere discounted_rewards is the sum of rewards in these steps with discounting by powers of gamma.
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].

trax.rl.advantages.
td_lambda
(gamma, margin, lambda_=0.95)¶ Calculate TDlambda advantage.
The estimated return is an exponentiallyweighted average of different TDk returns.
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
 lambda – float, the lambda parameter of TDlambda
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].

trax.rl.advantages.
gae
(gamma, margin, lambda_=0.95)¶ Calculate Generalized Advantage Estimation.
Calculate state values bootstrapping off the following state values  Generalized Advantage Estimation https://arxiv.org/abs/1506.02438
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
 lambda – float, the lambda parameter of GAE
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].
distributions¶
Probability distributions for RL training in Trax.

class
trax.rl.distributions.
Distribution
¶ Bases:
object
Abstract class for parametrized probability distributions.

n_inputs
¶ Returns the number of inputs to the distribution (i.e. parameters).

sample
(inputs, temperature=1.0)¶ Samples a point from the distribution.
Parameters:  inputs (jnp.ndarray) – Distribution inputs. Shape is subclassspecific. Broadcasts along the first dimensions. For example, in the categorical distribution parameter shape is (C,), where C is the number of categories. If (B, C) is passed, the object will represent a batch of B categorical distributions with different parameters.
 temperature – sampling temperature; 1.0 is default, at 0.0 chooses the most probable (preferred) action.
Returns: Sampled point of shape dependent on the subclass and on the shape of inputs.

log_prob
(inputs, point)¶ Retrieves log probability (or log probability density) of a point.
Parameters:  inputs (jnp.ndarray) – Distribution parameters.
 point (jnp.ndarray) – Point from the distribution. Shape should be consistent with inputs.
Returns: Array of log probabilities of points in the distribution.

LogProb
()¶ Builds a log probability layer for this distribution.


trax.rl.distributions.
create_distribution
(space)¶ Creates a Distribution for the given Gym space.

trax.rl.distributions.
LogLoss
(distribution, **unused_kwargs)¶ Builds a log loss layer for a Distribution.
normalization¶
Normalization helpers.

trax.rl.normalization.
running_mean_init
(shape, fill_value=0)¶

trax.rl.normalization.
running_mean_update
(x, state)¶

trax.rl.normalization.
running_mean_get_mean
(state)¶

trax.rl.normalization.
running_mean_get_count
(state)¶

trax.rl.normalization.
running_mean_and_variance_init
(shape)¶

trax.rl.normalization.
running_mean_and_variance_update
(x, state)¶

trax.rl.normalization.
running_mean_and_variance_get_mean
(state)¶

trax.rl.normalization.
running_mean_and_variance_get_count
(state)¶

trax.rl.normalization.
running_mean_and_variance_get_variance
(state)¶

trax.rl.normalization.
LayerNormSquash
(mode, width=128)¶ DenseLayerNormTanh normalizer inspired by ACME.
rl_layers¶
A number of RL functions intended to be later wrapped as Trax layers.
Wrapping happens with help of the function tl.Fn.

trax.rl.rl_layers.
ValueLoss
(values, returns, value_loss_coeff)¶ Definition of the loss of the value function.

trax.rl.rl_layers.
ExplainedVariance
(values, returns)¶ Definition of explained variance  an approach from OpenAI baselines.

trax.rl.rl_layers.
PreferredMove
(dist_inputs, sample)¶ Definition of the preferred move.

trax.rl.rl_layers.
NewLogProbs
(dist_inputs, actions, log_prob_fun)¶ Given distribution and actions calculate log probs.

trax.rl.rl_layers.
EntropyLoss
(dist_inputs, distribution, coeff)¶ Definition of the Entropy Layer.

trax.rl.rl_layers.
ProbsRatio
(dist_inputs, actions, old_log_probs, log_prob_fun)¶ Probability Ratio from the PPO algorithm.

trax.rl.rl_layers.
ApproximateKLDivergence
(dist_inputs, actions, old_log_probs, log_prob_fun)¶ Probability Ratio from the PPO algorithm.

trax.rl.rl_layers.
UnclippedObjective
(probs_ratio, advantages)¶ Unclipped Objective from the PPO algorithm.

trax.rl.rl_layers.
ClippedObjective
(probs_ratio, advantages, epsilon)¶ Clipped Objective from the PPO algorithm.

trax.rl.rl_layers.
PPOObjective
(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages)¶ PPO Objective.

trax.rl.rl_layers.
A2CObjective
(dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun, normalize_advantages)¶ Definition of the Advantage Actor Critic (A2C) loss.
serialization_utils¶
Utilities for serializing trajectories into discrete sequences.

trax.rl.serialization_utils.
Serialize
(serializer)¶ Layer that serializes a given array.

trax.rl.serialization_utils.
Interleave
()¶ Layer that interleaves and flattens two serialized sequences.
The first sequence can be longer by 1 than the second one. This is so we can interleave sequences of observations and actions, when there’s 1 extra observation at the end.
For serialized sequences [[x_1_1, …, x_1_R1], …, [x_L1_1, …, x_L1_R1]] and [[y_1_1, …, y_1_R2], …, [y_L2_1, …, y_L2_R2]], where L1 = L2 + 1, the result is [x_1_1, …, x_1_R1, y_1_1, …, y_1_R2, …, x_L2_1, …, x_L2_R1, y_L2_1, …, y_L2_R2, x_L1_1, …, x_L1_R1] (batch dimension omitted for clarity).
The layer inputs are a sequence pair of shapes (B, L1, R1) and (B, L2, R2), where B is batch size, L* is the length of the sequence and R* is the representation length of each element in the sequence.
Returns: Layer that interleaves sequence of shape (B, L1 * R1 + L2 * R2).

trax.rl.serialization_utils.
Deinterleave
(x_size, y_size)¶ Layer that does the inverse of Interleave.

trax.rl.serialization_utils.
RepresentationMask
(serializer)¶ Upsamples a mask to cover the serialized representation.

trax.rl.serialization_utils.
SignificanceWeights
(serializer, decay)¶ Multiplies a binary mask with a symbol significance mask.

class
trax.rl.serialization_utils.
SerializedModel
(seq_model, observation_serializer, action_serializer, significance_decay, mode='train')¶ Bases:
trax.layers.combinators.Serial
Wraps a world model in serialization machinery for training.
The resulting model takes as input the observation and action sequences, serializes them and interleaves into one sequence, which is fed into a given autoregressive model. The resulting logit sequence is deinterleaved into observations and actions, and the observation logits are returned together with computed symbol significance weights.
The model has a signature (obs, act, obs, mask) > (obs_logits, obs_repr, weights), where obs are observations (the second occurrence is the target), act are actions, mask is the observation mask, obs_logits are logits of the output observation representation, obs_repr is the target observation representation and weights are the target weights.

__init__
(seq_model, observation_serializer, action_serializer, significance_decay, mode='train')¶ Initializes SerializedModel.
Parameters:  seq_model – Trax autoregressive model taking as input a sequence of symbols and outputting a sequence of symbol logits.
 observation_serializer – Serializer to use for observations.
 action_serializer – Serializer to use for actions.
 significance_decay – Float from (0, 1) for exponential weighting of symbols in the representation.
 mode – ‘train’ or ‘eval’.

observation_serializer
¶

action_serializer
¶

make_predict_model
()¶ Returns a predictmode model of the same architecture.

seq_model_weights
¶ Extracts the weights of the underlying sequence model.

seq_model_state
¶ Extracts the state of the underlying sequence model.


trax.rl.serialization_utils.
TimeSeriesModel
(seq_model, low=0.0, high=1.0, precision=2, vocab_size=64, significance_decay=0.7, mode='train')¶ Simplified constructor for SerializedModel, for time series prediction.

trax.rl.serialization_utils.
RawPolicy
(seq_model, n_controls, n_actions)¶ Wraps a sequence model in a policy interface.
The resulting model takes as input observation anc action sequences, but only uses the observations. Adds output heads for action logits and value predictions.
Parameters:  seq_model – Trax sequence model taking as input and outputting a sequence of continuous vectors.
 n_controls – Number of controls.
 n_actions – Number of action categories in each control.
Returns: obs: (batch_size, length + 1, obs_depth) act: (batch_size, length, n_controls) act_logits: (batch_size, length, n_controls, n_actions) values: (batch_size, length)
Return type: A model of signature (obs, act) > (act_logits, values), with shapes

trax.rl.serialization_utils.
substitute_inner_policy_raw
(raw_policy, inner_policy)¶ Substitutes the weights/state of the inner model in a RawPolicy.

trax.rl.serialization_utils.
SerializedPolicy
(seq_model, n_controls, n_actions, observation_serializer, action_serializer)¶ Wraps a policy in serialization machinery for training.
The resulting model takes as input observation and action sequences, and serializes them into one sequence similar to SerializedModel, before passing to the given sequence model. Adds output heads for action logits and value predictions.
Parameters:  seq_model – Trax sequence model taking as input a sequence of symbols and outputting a sequence of continuous vectors.
 n_controls – Number of controls.
 n_actions – Number of action categories in each control.
 observation_serializer – Serializer to use for observations.
 action_serializer – Serializer to use for actions.
Returns: A model of signature (obs, act) > (act_logits, values), same as in RawPolicy.

trax.rl.serialization_utils.
substitute_inner_policy_serialized
(serialized_policy, inner_policy)¶ Substitutes the weights/state of the inner model in a SerializedPolicy.

trax.rl.serialization_utils.
analyze_action_space
(action_space)¶ Returns the number of controls and actions for an action space.

trax.rl.serialization_utils.
wrap_policy
(seq_model, observation_space, action_space, vocab_size)¶ Wraps a sequence model in either RawPolicy or SerializedPolicy.
Parameters:  seq_model – Trax sequence model.
 observation_space – Gym observation space.
 action_space – Gym action space.
 vocab_size – Either the number of symbols for a serialized policy, or None.
Returns: RawPolicy if vocab_size is None, else SerializedPolicy.

trax.rl.serialization_utils.
substitute_inner_policy
(wrapped_policy, inner_policy, vocab_size)¶ Substitutes the inner weights/state in a {Raw,Serialized}Policy.
Parameters:  wrapped_policy (pytree) – Weights or state of a wrapped policy.
 inner_policy (pytree) – Weights or state of an inner policy.
 vocab_size (int or None) – Vocabulary size of a serialized policy, or None in case of a raw policy.
Returns:  New weights or state of wrapped_policy, with the inner weights/state
copied from inner_policy.
space_serializer¶
Serialization of elements of Gym spaces into discrete sequences.

class
trax.rl.space_serializer.
SpaceSerializer
(space, vocab_size)¶ Bases:
object
Base class for Gym space serializers.
 Attrs:
 space_type: (type) Gym space class that this SpaceSerializer corresponds
 to. Should be defined in subclasses.
 representation_length: (int) Number of symbols in the representation of
 every element of the space.
 significance_map: (np.ndarray) Integer array of the same size as the
 discrete representation, where elements describe the significance of symbols, e.g. in fixedprecision encoding. 0 is the most significant symbol, 1 the second most significant etc.

space_type
= None¶

representation_length
= None¶

significance_map
= None¶

__init__
(space, vocab_size)¶ Creates a SpaceSerializer.
Subclasses should retain the signature.
Parameters:  space – (gym.Space) Gym space of type self.space_type.
 vocab_size – (int) Number of symbols in the vocabulary.

vocab_size
¶

serialize
(data)¶ Serializes a batch of space elements into discrete sequences.
Should be defined in subclasses.
Parameters: data – A batch of batch_size elements of the Gym space to be serialized. Returns: int32 array of shape (batch_size, self.representation_length).

deserialize
(representation)¶ Deserializes a batch of discrete sequences into space elements.
Should be defined in subclasses.
Parameters: representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized. Returns: A batch of batch_size deserialized elements of the Gym space.

trax.rl.space_serializer.
create
(space, vocab_size)¶ Creates a SpaceSerializer for the given Gym space.

class
trax.rl.space_serializer.
DiscreteSpaceSerializer
(space, vocab_size)¶ Bases:
trax.rl.space_serializer.SpaceSerializer
Serializer for gym.spaces.Discrete.
Assumes that the size of the space fits in the number of symbols.

space_type
¶ Used by autodoc_mock_imports.

representation_length
= 1¶

__init__
(space, vocab_size)¶ Creates a SpaceSerializer.
Subclasses should retain the signature.
Parameters:  space – (gym.Space) Gym space of type self.space_type.
 vocab_size – (int) Number of symbols in the vocabulary.

serialize
(data)¶ Serializes a batch of space elements into discrete sequences.
Should be defined in subclasses.
Parameters: data – A batch of batch_size elements of the Gym space to be serialized. Returns: int32 array of shape (batch_size, self.representation_length).

deserialize
(representation)¶ Deserializes a batch of discrete sequences into space elements.
Should be defined in subclasses.
Parameters: representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized. Returns: A batch of batch_size deserialized elements of the Gym space.

significance_map
¶


class
trax.rl.space_serializer.
MultiDiscreteSpaceSerializer
(space, vocab_size)¶ Bases:
trax.rl.space_serializer.SpaceSerializer
Serializer for gym.spaces.MultiDiscrete.
Assumes that the number of categories in each dimension fits in the number of symbols.

space_type
¶ Used by autodoc_mock_imports.

__init__
(space, vocab_size)¶ Creates a SpaceSerializer.
Subclasses should retain the signature.
Parameters:  space – (gym.Space) Gym space of type self.space_type.
 vocab_size – (int) Number of symbols in the vocabulary.

serialize
(data)¶ Serializes a batch of space elements into discrete sequences.
Should be defined in subclasses.
Parameters: data – A batch of batch_size elements of the Gym space to be serialized. Returns: int32 array of shape (batch_size, self.representation_length).

deserialize
(representation)¶ Deserializes a batch of discrete sequences into space elements.
Should be defined in subclasses.
Parameters: representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized. Returns: A batch of batch_size deserialized elements of the Gym space.

representation_length
¶

significance_map
¶

task¶
Classes for defining RL tasks in Trax.

class
trax.rl.task.
TimeStepBatch
(observation, action, reward, done, mask, dist_inputs, env_info, return_)¶ Bases:
tuple

action
¶ Alias for field number 1

dist_inputs
¶ Alias for field number 5

done
¶ Alias for field number 3

env_info
¶ Alias for field number 6

mask
¶ Alias for field number 4

observation
¶ Alias for field number 0

return_
¶ Alias for field number 7

reward
¶ Alias for field number 2


class
trax.rl.task.
EnvInfo
(control_mask, discount_mask)¶ Bases:
tuple

control_mask
¶ Alias for field number 0

discount_mask
¶ Alias for field number 1


class
trax.rl.task.
Trajectory
(observation)¶ Bases:
object
A trajectory of interactions with a RL environment.
Trajectories are created when interacting with an RL environment. They can be prolonged and sliced and when completed, allow to recalculate returns.

__init__
(observation)¶ Initialize self. See help(type(self)) for accurate signature.

suffix
(length)¶ Returns a Trajectory with the last length observations.

timesteps
¶

total_return
¶ Sum of all rewards in this trajectory.

last_observation
¶ Return the last observation in this trajectory.

done
¶ Returns whether the trajectory is finished.

extend
(new_observation, mask=1, **kwargs)¶ Take action in the last state, getting reward and going to new state.

calculate_returns
(gamma)¶ Calculate discounted returns.

to_np
(margin=1, timestep_to_np=None)¶ Create a tuple of numpy arrays from a given trajectory.
Parameters:  margin (int) – Number of dummy timesteps past the trajectory end to include. By default we include 1, which contains the last observation.
 timestep_to_np (callable or None) – Optional function TimeStepBatch[Any] > TimeStepBatch[np.array], converting the timestep data into numpy arrays.
Returns: TimeStepBatch, where all fields have shape (len(self) + margin  1, …).


trax.rl.task.
play
(env, policy, dm_suite=False, max_steps=None, last_observation=None)¶ Play an episode in env taking actions according to the given policy.
Environment is first reset and an from then on, a game proceeds. At each step, the policy is asked to choose an action and the environment moves forward. A Trajectory is created in that way and returns when the episode finished, which is either when env returns done or max_steps is reached.
Parameters:  env – the environment to play in, conforming to gym.Env or DeepMind suite interfaces.
 policy – a function taking a Trajectory and returning a pair consisting of an action (int or float) and the confidence in that action (float, defined as the log of the probability of taking that action).
 dm_suite – whether we are using the DeepMind suite or the gym interface
 max_steps – for how many steps to play.
 last_observation – last observation from a previous trajectory slice, used to begin a new one. Controls whether we reset the environment at the beginning  if None, resets the env and starts the slice from the observation got from reset().
Returns: a completed trajectory slice that was just played.
training¶
Classes for RL training in Trax.

class
trax.rl.training.
Agent
(task: <sphinx.ext.autodoc.importer._MockObject object at 0x7f0e15d73ad0>, n_trajectories_per_epoch=None, n_interactions_per_epoch=None, n_eval_episodes=0, eval_steps=None, eval_temperatures=(0.0, ), only_eval=False, output_dir=None, timestep_to_np=None)¶ Bases:
object
Abstract class for RL agents, presenting the required API.

__init__
(task: <sphinx.ext.autodoc.importer._MockObject object at 0x7f0e15d73ad0>, n_trajectories_per_epoch=None, n_interactions_per_epoch=None, n_eval_episodes=0, eval_steps=None, eval_temperatures=(0.0, ), only_eval=False, output_dir=None, timestep_to_np=None)¶ Configures the Agent.
Note that subclasses can have many more arguments, which will be configured using defaults and gin. But task and output_dir are passed explicitly.
Parameters:  task – RLTask instance, which defines the environment to train on.
 n_trajectories_per_epoch – How many new trajectories to collect in each epoch.
 n_interactions_per_epoch – How many interactions to collect in each epoch.
 n_eval_episodes – Number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only.
 eval_steps – an optional list of max_steps to use for evaluation (defaults to task.max_steps).
 eval_temperatures – we always train with temperature 1 and evaluate with temperature specified in the eval_temperatures list (defaults to [0.0, 0.5])
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 output_dir – Path telling where to save outputs such as checkpoints.
 timestep_to_np – Timesteptonumpy function to override in the task.

current_epoch
¶ Returns current step number in this training session.

task
¶ Returns the task.

avg_returns
¶

save_gin
(summary_writer=None)¶

save_to_file
(file_name='rl.pkl', task_file_name='trajectories.pkl')¶ Save current epoch number and average returns to file.

init_from_file
(file_name='rl.pkl', task_file_name='trajectories.pkl')¶ Initialize epoch number and average returns from file.

policy
(trajectory, temperature=1.0)¶ Policy function that allows to play using this trainer.
Parameters:  trajectory – an instance of trax.rl.task.Trajectory
 temperature – temperature used to sample from the policy (default=1.0)
Returns: a pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

train_epoch
()¶ Trains this Agent for one epoch – main RL logic goes here.

run
(n_epochs=1, n_epochs_is_total_epochs=False)¶ Runs this loop for n epochs.
Parameters:  n_epochs – Stop training after completing n steps.
 n_epochs_is_total_epochs – if True, consider n_epochs as the total number of epochs to train, including previously trained ones

close
()¶


class
trax.rl.training.
PolicyAgent
(task, policy_model=None, policy_optimizer=None, policy_lr_schedule=<function multifactor>, policy_batch_size=64, policy_train_steps_per_epoch=500, policy_evals_per_epoch=1, policy_eval_steps=1, n_eval_episodes=0, only_eval=False, max_slice_length=1, output_dir=None, **kwargs)¶ Bases:
trax.rl.training.Agent
Agent that uses a deep learning model for policy.
Many deep RL methods, such as policy gradient (REINFORCE) or actorcritic fall into this category, so a lot of classes will be subclasses of this one. But some methods only have a value or Q function, these are different.

__init__
(task, policy_model=None, policy_optimizer=None, policy_lr_schedule=<function multifactor>, policy_batch_size=64, policy_train_steps_per_epoch=500, policy_evals_per_epoch=1, policy_eval_steps=1, n_eval_episodes=0, only_eval=False, max_slice_length=1, output_dir=None, **kwargs)¶ Configures the policy trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 policy_model – Trax layer, representing the policy model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 policy_optimizer – the optimizer to use to train the policy model.
 policy_lr_schedule – learning rate schedule to use to train the policy.
 policy_batch_size – batch size used to train the policy model.
 policy_train_steps_per_epoch – how long to train policy in each RL epoch.
 policy_evals_per_epoch – number of policy trainer evaluations per RL epoch  only affects metric reporting.
 policy_eval_steps – number of policy trainer steps per evaluation  only affects metric reporting.
 n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 max_slice_length – the maximum length of trajectory slices to use.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 **kwargs – arguments for the superclass Agent.

policy_loss
¶ Policy loss.

policy_metrics
¶

policy_batches_stream
()¶ Use self.task to create inputs to the policy model.

policy
(trajectory, temperature=1.0)¶ Chooses an action to play after a trajectory.

train_epoch
()¶ Trains RL for one epoch.

close
()¶


trax.rl.training.
remaining_evals
(cur_step, epoch, train_steps_per_epoch, evals_per_epoch)¶ Helper function to calculate remaining evaluations for a trainer.
Parameters:  cur_step – current step of the supervised trainer
 epoch – current epoch of the RL trainer
 train_steps_per_epoch – supervised trainer steps per RL epoch
 evals_per_epoch – supervised trainer evals per RL epoch
Returns: number of remaining evals to do this epoch
Raises: ValueError if the provided numbers indicate a step mismatch

class
trax.rl.training.
LoopPolicyAgent
(task, model_fn, value_fn, weight_fn, n_replay_epochs, n_train_steps_per_epoch, advantage_normalization, optimizer=<class 'trax.optimizers.adam.Adam'>, lr_schedule=<function multifactor>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, trajectory_stream_preprocessing_fn=None, **kwargs)¶ Bases:
trax.rl.training.Agent
Base class for policyonly Agents based on Loop.

__init__
(task, model_fn, value_fn, weight_fn, n_replay_epochs, n_train_steps_per_epoch, advantage_normalization, optimizer=<class 'trax.optimizers.adam.Adam'>, lr_schedule=<function multifactor>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, trajectory_stream_preprocessing_fn=None, **kwargs)¶ Initializes LoopPolicyAgent.
Parameters:  task – Instance of trax.rl.task.RLTask.
 model_fn – Function (policy_distribution, mode) > policy_model.
 value_fn – Function TimeStepBatch > array (batch_size, seq_len) calculating the baseline for advantage calculation.
 weight_fn – Function float > float to apply to advantages when calculating policy loss.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 n_train_steps_per_epoch – Number of steps to train the policy network for in each epoch.
 advantage_normalization – Whether to normalize the advantages before passing them to weight_fn.
 optimizer – Optimizer for network training.
 lr_schedule – Learning rate schedule for network training.
 batch_size – Batch size for network training.
 network_eval_at – Function step > bool indicating the training steps, when network evaluation should be performed.
 n_eval_batches – Number of batches to run during network evaluation.
 max_slice_length – The length of trajectory slices to run the network on.
 trajectory_stream_preprocessing_fn – Function to apply to the trajectory stream before batching. Can be used e.g. to filter trajectories.
 **kwargs – Keyword arguments passed to the superclass.

loop
¶ Loop exposed for testing.

train_epoch
()¶ Trains RL for one epoch.


class
trax.rl.training.
PolicyGradient
(task, model_fn, **kwargs)¶ Bases:
trax.rl.training.LoopPolicyAgent
Trains a policy model using policy gradient on the given RLTask.

__init__
(task, model_fn, **kwargs)¶ Initializes PolicyGradient.
Parameters:  task – Instance of trax.rl.task.RLTask.
 model_fn – Function (policy_distribution, mode) > policy_model.
 **kwargs – Keyword arguments passed to the superclass.

policy
(trajectory, temperature=1.0)¶ Policy function that samples from the trained network.


trax.rl.training.
sharpened_network_policy
(temperature, temperature_multiplier=1.0, **kwargs)¶ Expert function that runs a policy network with lower temperature.
Parameters:  temperature – Temperature passed from the Agent.
 temperature_multiplier – Multiplier to apply to the temperature to “sharpen” the policy distribution. Should be <= 1, but this is not a requirement.
 **kwargs – Keyword arguments passed to network_policy.
Returns: Pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

class
trax.rl.training.
ExpertIteration
(task, model_fn, expert_policy_fn=<function sharpened_network_policy>, quantile=0.9, n_replay_epochs=10, n_train_steps_per_epoch=1000, filter_buffer_size=256, **kwargs)¶ Bases:
trax.rl.training.LoopPolicyAgent
Trains a policy model using expert iteration with a given expert.

__init__
(task, model_fn, expert_policy_fn=<function sharpened_network_policy>, quantile=0.9, n_replay_epochs=10, n_train_steps_per_epoch=1000, filter_buffer_size=256, **kwargs)¶ Initializes ExpertIteration.
Parameters:  task – Instance of trax.rl.task.RLTask.
 model_fn – Function (policy_distribution, mode) > policy_model.
 expert_policy_fn – Function of the same signature as network_policy, to be used as an expert. The policy will be trained to mimic the expert on the “solved” trajectories.
 quantile – Quantile of best trajectories to be marked as “solved”. They will be used to train the policy.
 n_replay_epochs – Number of last epochs to include in the replay buffer.
 n_train_steps_per_epoch – Number of policy training steps to run in each epoch.
 filter_buffer_size – Number of trajectories in the trajectory filter buffer, used to select the best trajectories based on the quantile.
 **kwargs – Keyword arguments passed to the superclass.

policy
(trajectory, temperature=1.0)¶ Policy function that runs the expert.


trax.rl.training.
network_policy
(collect_model, policy_distribution, loop, trajectory_np, head_index=0, temperature=1.0)¶ Policy function powered by a neural network.
Used to implement Agent.policy() in policybased agents.
Parameters:  collect_model – the model used for collecting trajectories
 policy_distribution – an instance of trax.rl.distributions.Distribution
 loop – trax.supervised.training.Loop used to train the policy network
 trajectory_np – an instance of trax.rl.task.TimeStepBatch
 head_index – index of the policy head a multihead model.
 temperature – temperature used to sample from the policy (default=1.0)
Returns: a pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

class
trax.rl.training.
ValueAgent
(task, value_body=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial(<function multifactor>, factors='constant * decay_every', constant=1.0, decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs)¶ Bases:
trax.rl.training.Agent
Trainer that uses a deep learning model for value function.
Compute the loss using variants of the Bellman equation.

__init__
(task, value_body=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial(<function multifactor>, factors='constant * decay_every', constant=1.0, decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs)¶ Configures the value trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 value_body – Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 value_optimizer – the optimizer to use to train the policy model.
 value_lr_schedule – learning rate schedule to use to train the policy.
 value_batch_size – batch size used to train the policy model.
 value_train_steps_per_epoch – how long to train policy in each RL epoch.
 value_evals_per_epoch – number of policy trainer evaluations per RL epoch  only affects metric reporting.
 value_eval_steps – number of policy trainer steps per evaluation  only affects metric reporting.
 exploration_rate – exploration rate schedule  used in the policy method.
 n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 max_slice_length – the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future.
 sync_freq – frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using nstep returns.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma). We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 **kwargs – arguments for the superclass RLTrainer.

value_batches_stream
()¶ Use self.task to create inputs to the policy model.

policy
(trajectory, temperature=1)¶ Chooses an action to play after a trajectory.

train_epoch
()¶ Trains RL for one epoch.

close
()¶

value_mean
¶ The mean value of actions selected by the behavioral policy.

returns_mean
¶ The mean value of actions selected by the behavioral policy.


class
trax.rl.training.
DQN
(task, advantage_estimator=<function monte_carlo>, max_slice_length=1, smoothl1loss=True, double_dqn=False, **kwargs)¶ Bases:
trax.rl.training.ValueAgent
Trains a value model using DQN on the given RLTask.
Notice that the algorithm and the parameters signficantly diverge from the original DQN paper. In particular we have separated learning and data collection.
The Bellman loss is computed in the value_loss method. The formula takes the stateaction values tensors Q and nstep returns R:
\[L(s,a) = Q(s,a)  R(s,a)\]where R is computed in value_batches_stream. In the simplest case of the 1step returns we are getting
\[L(s,a) = Q(s,a)  r(s,a)  gamma * \max_{a'} Q'(s',a')\]where s’ is the state reached after taking action a in state s, Q’ is the target network, gamma is the discount factor and the maximum is taken with respect to all actions avaliable in the state s’. The tensor Q’ is updated using the sync_freq parameter.
In code the maximum is visible in the policy method where we take sample = jnp.argmax(values). The epsilongreedy policy is taking a random move with probability epsilon and oterhwise in state s it takes the action argmax_a Q(s,a).

__init__
(task, advantage_estimator=<function monte_carlo>, max_slice_length=1, smoothl1loss=True, double_dqn=False, **kwargs)¶ Configures the value trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 value_body – Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 value_optimizer – the optimizer to use to train the policy model.
 value_lr_schedule – learning rate schedule to use to train the policy.
 value_batch_size – batch size used to train the policy model.
 value_train_steps_per_epoch – how long to train policy in each RL epoch.
 value_evals_per_epoch – number of policy trainer evaluations per RL epoch  only affects metric reporting.
 value_eval_steps – number of policy trainer steps per evaluation  only affects metric reporting.
 exploration_rate – exploration rate schedule  used in the policy method.
 n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 max_slice_length – the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future.
 sync_freq – frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using nstep returns.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma). We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 **kwargs – arguments for the superclass RLTrainer.

value_loss
¶ Value loss computed using smooth L1 loss or L2 loss.

value_batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

policy
(trajectory, temperature=1)¶ Chooses an action to play after a trajectory.

value_mean
¶ The mean value of actions selected by the behavioral policy.

shapes¶
Core class and functions for handling data abstractly as shapes/dtypes.

class
trax.shapes.
ShapeDtype
(shape, dtype=<sphinx.ext.autodoc.importer._MockObject object>)¶ Bases:
object
A NumPy ndarraylike object abstracted as shape and dtype.
Main use is for representing input and output signatures.

__init__
(shape, dtype=<sphinx.ext.autodoc.importer._MockObject object>)¶ Creates a ShapeDtype instance, with canonicalized shape and dtype.
Parameters:  shape – A tuple or list, each element of which is an int or, less often, None.
 dtype – A dtype object, either from NumPy or TensorFlow.
Returns: A ShapeDtype instance whose shape is a tuple and dtype is a NumPy dtype object.

shape
¶

dtype
¶

as_tuple
()¶

replace
(**kwargs)¶ Creates a copy of the object with some parameters replaced.


trax.shapes.
signature
(obj)¶ Returns a ShapeDtype signature for the given obj.
A signature is either a ShapeDtype instance or a tuple of ShapeDtype instances. Note that this function is permissive with respect to its inputs (accepts lists or tuples or dicts, and underlying objects can be any type as long as they have shape and dtype attributes) and returns the corresponding nested structure of ShapeDtype.
Parameters: obj – An object that has shape and dtype attributes, or a list/tuple/dict of such objects. Returns: A corresponding nested structure of ShapeDtype instances.

trax.shapes.
splice_signatures
(*sigs)¶ Creates a new signature by splicing together any number of signatures.
The splicing effectively flattens the top level input signatures. For instance, it would perform the following mapping:
 *sigs: sd1, (sd2, sd3, sd4), (), sd5
 return: (sd1, sd2, sd3, sd4, sd5)
Parameters: *sigs – Any number of signatures. A signature is either a ShapeDtype instance or a tuple of ShapeDtype instances. Returns: A single ShapeDtype instance if the spliced signature has one element, else a tuple of ShapeDtype instances.

trax.shapes.
assert_shape_equals
(array, shape)¶ Asserts that an array has the given shape.

trax.shapes.
assert_same_shape
(array1, array2)¶ Asserts that two arrays have the same shapes.
trainer¶
Trax trainer.

trax.trainer.
tf_init_tpu
(worker='', protocol=None)¶ Initializes TPU for TensorFlow.
Parameters:  worker – The BNS address of the remote TPU worker. If it’s empty (the default value), TF will assume the TPU devices are connected to the local host.
 protocol – The network protocol used to connect to the TPU worker.
Returns: The device name of the TPU worker’s CPU.

trax.trainer.
main
(_)¶
rl_trainer¶
Trainer for RL environments.
For now we only support PPO as RL algorithm.
Sample invocation:
TRAIN_BATCH_SIZE=32
python trax/rl_trainer.py \
config_file=trax/rl/configs/ppo_acrobot.gin \
train_batch_size=${TRAIN_BATCH_SIZE} \
output_dir=${HOME}/ppo_acrobot \
alsologtostderr

trax.rl_trainer.
train_rl
(output_dir, n_epochs=10000, light_rl=True, light_rl_trainer=<class 'trax.rl.training.PolicyGradient'>)¶ Train the RL agent.
Parameters:  output_dir – Output directory.
 n_epochs – Number epochs to run the training for.
 light_rl – deprecated, always True, left out for old gin configs.
 light_rl_trainer – which light RL trainer to use (experimental).

trax.rl_trainer.
main
(argv)¶
trax2keras¶
TraxtoKeras converter.

trax.trax2keras.
tensor_shapes_to_shape_dtypes
(shapes, dtype)¶

trax.trax2keras.
read_values
(variables)¶

trax.trax2keras.
to_tensors
(args)¶

trax.trax2keras.
to_arrays
(args)¶

class
trax.trax2keras.
AsKeras
(trax_layer, batch_size=None, initializer_rng=None, rng=None, rng_updater=None, dtype=None)¶ Bases:
sphinx.ext.autodoc.importer._MockObject
A Keras layer built from a Trax layer.
This subclass of tf.keras.layers.Layer takes in a Trax layer as a constructor argument and wraps it to be a Keras layer. It uses tf.Variable to store weights and state (initialized according to the Trax layer), and uses the Trax layer’s forward function as its forward function.
Consider this code snippet:
keras_layer = AsKeras(trax_layer, initializer_rng=initializer_rng, rng=rng, rng_updater=rng_updater) keras_layer.build(...) # optional outputs = keras_layer(inputs)
(Note that in Keras calling Layer.build is optional. If omitted, it will be called automatically by Layer.__call__.)
If trax_layer already has weights at build time, the snippet is roughly equivalent to:
weights = trax_layer.weights state = trax_layer.state keras_layer = tf.keras.layers.Layer() keras_layer._weights = tf.Variable(weights) keras_layer._state = tf.Variable(state) keras_layer._rng = tf.Variable(rng) outputs, new_state = trax_layer(inputs, keras_layer._weights, keras_layer._state, keras_layer._rng) keras_layer._state.assign(new_state) keras_layer._rng.assign(rng_updater(rng))
If trax_layer doesn’t have weights at build time, the snippet is roughly equivalent to:
weights, state = trax_layer.init(..., rng=initializer_rng) keras_layer = ... ...
AsKeras uses tf.Variable to store weights, not shared with the original Trax layer (which uses tensors to store weights), so using AsKeras may double the memory footprint. This problem can be solved by making sure that the Trax layer’s weights/state are cleared whenever tf.Variable.assign (and tf.Variable.assign_add etc.) is called, because tf.Variable is copyonwrite by default.
Mutations in those tf.Variable`s won’t affect the Trax layer’s weights, but `AsKeras’s forward function calls the Trax layer’s forward function, which caches the weights in the Trax