Trax Quick Intro

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

General Setup

Execute the following few cells (once) before running any of the code samples.

[1]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import numpy as np



[2]:
#@title
# Import Trax

!pip install -q -U trax
import trax
/bin/sh: pip: command not found

1. Run a pre-trained Transformer

Here is how you create an Engligh-German translator in a few lines of code:

[3]:

# Create a Transformer model. # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin model = trax.models.Transformer( input_vocab_size=33300, d_model=512, d_ff=2048, n_heads=8, n_encoder_layers=6, n_decoder_layers=6, max_len=2048, mode='predict') # Initialize using pre-trained weights. model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz', weights_only=True) # Tokenize a sentence. sentence = 'It is nice to learn new things today!' tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams. vocab_dir='gs://trax-ml/vocabs/', vocab_file='ende_32k.subword'))[0] # Decode from the Transformer. tokenized = tokenized[None, :] # Add batch dimension. tokenized_translation = trax.supervised.decoding.autoregressive_sample( model, tokenized, temperature=0.0) # Higher temperature: more diverse results. # De-tokenize, tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS. translation = trax.data.detokenize(tokenized_translation, vocab_dir='gs://trax-ml/vocabs/', vocab_file='ende_32k.subword') print(translation)
Es ist schön, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations – numpy. You should take a look at the numpy guide if you don’t know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends – JAX and TensorFlow numpy.

[4]:
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix =\n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix =
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

[5]:
def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
print(f'grad(2x^2) at -2 = {grad_f(-2.0)}')
grad(2x^2) at 1 = 4.0
grad(2x^2) at -2 = -8.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

class Embedding(base.Layer):
  """Trainable layer that maps discrete tokens/IDs to vectors."""

  def __init__(self,
               vocab_size,
               d_feature,
               kernel_initializer=init.RandomNormalInitializer(1.0)):
    """Returns an embedding layer with given vocabulary size and vector size.

    Args:
      vocab_size: Size of the input vocabulary. The layer will assign a unique
          vector to each id in `range(vocab_size)`.
      d_feature: Dimensionality/depth of the output vectors.
      kernel_initializer: Function that creates (random) initial vectors for
          the embedding.
    """
    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
    self._d_feature = d_feature  # feature dimensionality
    self._vocab_size = vocab_size
    self._kernel_initializer = kernel_initializer

  def forward(self, x):
    """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0, mode='clip')

  def init_weights_and_state(self, input_signature):
    """Randomly initializes this layer's weights."""
    del input_signature
    shape_w = (self._vocab_size, self._d_feature)
    w = self._kernel_initializer(shape_w, self.rng)
    self.weights = w

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

[6]:
from trax import layers as tl

# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')

# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y = (15, 32)

Models

Models in Trax are built from layers most often using the Serial and Branch combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.

[7]:
model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
)

# You can print model structure.
print(model)
Serial[
  Embedding_8192_256
  Mean
  Dense_2
]

Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream) and get a tuple, e.g., (inputs, targets). Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt').

[8]:
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream))  # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)

Using the trax.data module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial and they are functions that you apply to streams to create processed streams.

[9]:
data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.
shapes = [(8, 2048), (8,), (8,)]

Supervised training

When you have the model and the data, use trax.supervised.training to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

[10]:
from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)

Step      1: Total number of trainable weights: 2097666
Step      1: Ran 1 train steps in 1.15 secs
Step      1: train WeightedCategoryCrossEntropy |  0.69192106
Step      1: eval  WeightedCategoryCrossEntropy |  0.69349981
Step      1: eval      WeightedCategoryAccuracy |  0.50312500

Step    500: Ran 499 train steps in 10.62 secs
Step    500: train WeightedCategoryCrossEntropy |  0.50712883
Step    500: eval  WeightedCategoryCrossEntropy |  0.42969493
Step    500: eval      WeightedCategoryAccuracy |  0.81406250

Step   1000: Ran 500 train steps in 8.89 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.35916388
Step   1000: eval  WeightedCategoryCrossEntropy |  0.41775789
Step   1000: eval      WeightedCategoryAccuracy |  0.79531250

Step   1500: Ran 500 train steps in 9.13 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.35241464
Step   1500: eval  WeightedCategoryCrossEntropy |  0.35194683
Step   1500: eval      WeightedCategoryAccuracy |  0.85117188

Step   2000: Ran 500 train steps in 8.54 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.29129386
Step   2000: eval  WeightedCategoryCrossEntropy |  0.37591279
Step   2000: eval      WeightedCategoryAccuracy |  0.84062500

After training the model, run it like any layer to get results.

[11]:
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: There are a few aspects to Park's movies, and in particular Wallace & Gromit, that I would say make them so great. The first is subtlety and observation, the flagship of which is the character of Gromit. He doesn't speak, he doesn't make any noise, all he has are his eyes, brow, and body posture, and with these he commands the film. Park manages to give us everything we need from this silent character through his expression. The comedy and the emotion is conveyed through the subtlest of movements and it works superbly well.<br /><br />Watching the movie you have to be aware of the entire screen. Normally you'll be guided to things in the movies, the screen won't be cluttered too much, there won't be many things to take your eyes away from the main clue or action. Park seems to need to look the other way with his movies. He throws extra content at his audience, there's action in the background, to the side of the screen, even off screen, and there's just about always something in the foreground to catch your eye. His movies are about multiple viewing and discovery, they're layered with jokes and ancillary action.<br /><br />Throughout this film there are layers of things happening on screen, jokes in the foreground maybe on a jar label and background shadows that give away action. You can imagine that for Park the movies has always been an event, and the movies he loves are ones which he wants to watch again and again. This is what shows in his movies, and in through his most beloved characters.<br /><br />Then there are the bizarre and wacky inventions which Wallace make, something which is reflected in the storyline and the twists and turns of the plot, everything is bizarre and off the wall, yet it seems so perfectly normal in this world. You can imagine that inside Park is the mind of Wallace.<br /><br />There's also one more thing that make these movies so unique, and that's the modelling and precise hand animation. I must admit I was concerned when I knew Dreamworks was involved in the making of this movie, and I thought that they would bring their computer animation experience to the forefront. What I was scared of was Wallace & Gromit becoming CGI entities, or at the smallest, CGI being used to clean up the feel that the modelling brought to the movie.<br /><br />Not so. You can still see thumbprints and toolmarks on the characters, and far from distracting from the movie, this just adds so much real feeling to it and a feeling of physical depth to the characters and the scene on screen.<br /><br />So what of the movie? Well I must say that the plot twist was something I had thought about well before the film was in the cinema and it came as no surprise, but that did not affect my enjoyment one little bit. Actually watching the twist unfold and the comic timing of the discovery and reactions was everything, and it had me just as sucked in as if it was a thriller, yet all the time I was laughing.<br /><br />Watching the movie was fascinating in various ways. To see the animation completed, how wild the inventions are, how Wallace is going to get into trouble and Gromit get him out, where all the cross references are in the movie, and where all the jokes are! I must admit afterwards talking with my friends I couldn't believe how much I had missed.<br /><br />There's something different in this movie than with the others, there's a new level of adult humour in here, and I don't mean rude jokes (although there are a couple that are just so British you can't help laughing), I mean jokes that simply fly over kids heads but slap adults in the face. The kind you are used to seeing come out of somewhere like Pixar. This just adds even more appeal to the movie.<br /><br />Okay though, let me try and be a bit negative here. I didn't notice the voices in this movie, you know how you usually listen to the actors and see if you can recognise them? Well I was just too wrapped up in the movie to care or to notice who they were...okay, that's not negative. Let me try again. The main plot wasn't as strong and gripping as I'd expected, and I found myself being caught up in the side stories and the characters themselves...again...that's not a bad thing, the film was just so much rich entertainment.<br /><br />I honestly can't think of a bad thing to say about this movie, probably the worst thing I could say is that the title sequence at the end is quite repetitive...until the final title! Really, that's the worst I can say.<br /><br />The story is a lot of fun, well set-up, well written, well executed. There's lot's of fantastic characters in here, not just Wallace & Gromit. There's so much happening on screen, so many references and jokes (check out the dresses of Lady Tottingham), cheese jokes everywhere, jokes for all the family. The characters are superbly absorbing and you'll find that you've taken to them before you realise. There's just so much in this movie for everyone.<br /><br />There's so much I could say and write about, but I know it will quickly turn into a backslapping exercise for Park and Aardman, it would also just turn into a series of "this bit was really funny" and "there's a bit when...", and what I would rather do is tell you that this is a superb movie, to go see it, and to experience the whole thing for yourselves. I will say though that the bunnies are excellent!<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[0.36765265 2.7904649 ]]