Using Trax with TensorFlow NumPy and Keras

This notebook (run it in colab) shows how you can run Trax directly with TensorFlow NumPy. You will also see how to use Trax layers and models inside Keras so you can use Trax in production, e.g., with TensorFlow.js or TensorFlow Serving.

  1. Trax with TensorFlow NumPy: use Trax with TensorFlow NumPy without any code changes
  2. Convert Trax to Keras: how to get a Keras layer for your Trax model and use it
  3. Exporting Trax Models for Deployment: how to export Trax models to TensorFlow SavedModel

1. Trax with TensorFlow NumPy

In Trax, all computations rely on accelerated math operations happening in the fastmath module. This module can use different backends for acceleration. One of them is TensorFlow NumPy which uses TensorFlow 2 to accelerate the computations.

The backend can be set using a call to trax.fastmath.set_backend as you’ll see below. Currently available backends are jax (default), tensorflow-numpy and numpy (for debugging). The tensorflow-numpy backend uses TensorFlow Numpy for executing fastmath functions on TensorFlow, while the jax backend calls JAX which lowers to TensorFlow XLA.

You may see that tensorflow-numpy and jax backends show different speed and memory characteristics. You may also see different error messages when debugging since it might expose you to the internals of the backends. However for the most part, users can choose a backend and not worry about the internal details of these backends.

Let’s train the sentiment analysis model from the Trax intro using TensorFlow NumPy to see how it works.

General Setup

Execute the following few cells (once) before running any of the code samples.

[1]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



[2]:
# Install and import Trax
!pip install -q -U git+https://github.com/google/trax@master

import os
import numpy as np
import trax

Here is how you can set the fastmath backend to tensorflow-numpy and verify that it’s been set.

[3]:
# Use the tensorflow-numpy backend.
trax.fastmath.set_backend('tensorflow-numpy')
print(trax.fastmath.backend_name())
tensorflow-numpy
[4]:
# Create data streams.
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[512, 128,  32,    8, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)

# Print example shapes.
example_batch = next(train_batches_stream)
print(f'batch shapes = {[x.shape for x in example_batch]}')
batch shapes = [(8, 2048), (8,), (8,)]
[5]:
# Create the model.
from trax import layers as tl

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
)

# You can print model structure.
print(model)
Serial[
  Embedding_8192_256
  Mean
  Dense_2
]
[6]:
# Train the model.
from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)

Step      1: Total number of trainable weights: 2097666
Step      1: Ran 1 train steps in 1.01 secs
Step      1: train WeightedCategoryCrossEntropy |  0.69292086
Step      1: eval  WeightedCategoryCrossEntropy |  0.68457415
Step      1: eval      WeightedCategoryAccuracy |  0.56406250

Step    500: Ran 499 train steps in 19.92 secs
Step    500: train WeightedCategoryCrossEntropy |  0.50587755
Step    500: eval  WeightedCategoryCrossEntropy |  0.46716719
Step    500: eval      WeightedCategoryAccuracy |  0.80625000

Step   1000: Ran 500 train steps in 17.50 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.36375266
Step   1000: eval  WeightedCategoryCrossEntropy |  0.44373559
Step   1000: eval      WeightedCategoryAccuracy |  0.80000000

Step   1500: Ran 500 train steps in 18.40 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.34449804
Step   1500: eval  WeightedCategoryCrossEntropy |  0.34941847
Step   1500: eval      WeightedCategoryAccuracy |  0.84687500

Step   2000: Ran 500 train steps in 17.18 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.28685242
Step   2000: eval  WeightedCategoryCrossEntropy |  0.50030373
Step   2000: eval      WeightedCategoryAccuracy |  0.77539062
[7]:
# Run on an example.
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_activations = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment activations: {np.asarray(sentiment_activations)}')
example input_str: The movie features another exceptional collaboration between director William Wyler and cinematographer Gregg Toland, the first after Toland worked on Citizen Kane. But the talent of both these men was focused on achieving a perfectly crafted movie, understood in the good old American sense as a great story. The technical aspects of the movie are covered so as the viewer gets absorbed into the action that takes place on the screen without submitting to the power of the image. Technique is seen as a vehicle of representation unlike in Citizen Kane where Welles' baroque style almost drew the attention from the story to the way the story was told. One of my favorite moves with deep focus in this film is the drama conveyed by the returning home welcoming of Homer and Al. If Homer's girl, Wilma comes towards him perfectly in focus, Al goes over to his wife also perfectly in focus. This is a brilliant move because it shows only through the use of the image the nature of these relationships as we will see them throughout the movie: Wilma loves Homer and she accepts him as he is, Al's wife loves him also but she feels unprepared to fully welcome him home. Also later in the film we find out that their marriage has not always been a bed of roses.<br /><br />Wyler is a director whose force lies in being true to his work without feeling the need to boast. He wanted to show his audience how hard it was for the American soldiers returning from the war to fit into a society that either didn't understand them or treated them with contempt. With a perfect cast and great dialogue Goldwin and Wyler produced a movie that will forever be the template for any other returning home movie. The three hours which coincide with the "rough cut" because the test audience back then never felt for a moment that the action was slow and indeed every scene from the film seems perfectly justified. The whole thing is constructed beautifully, every character gets a fair amount of exposure, nothing is left to chance and it is quite pitiful that Hollywood nowadays never manages to bring so much character conflict to the screen. TBYOOL explores the depth of the American way of life, of the American family and society to an extent that makes other movies look like "the children's hour".<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment activations: [[-1.6396211  1.6328843]]

2. Convert Trax to Keras

Thanks to TensorFlow NumPy you can convert the model you just trained into a Keras layer using trax.AsKeras. This allows you to:

  • use Trax layers inside Keras models
  • run Trax models with existing Keras input pipelines
  • export Trax models to TensorFlow SavedModel

When creating a Keras layer from a Trax one, the Keras layer weights will get initialized to the ones the Trax layer had at the moment of creation. In this way, you can create Keras layers from pre-trained Trax models and save them as SavedModel as shown below.

[8]:
# Convert the model into a Keras layer, use the weights from model.
keras_layer = trax.AsKeras(model)
print(keras_layer)

# Run the Keras layer to verify it returns the same result.
sentiment_activations = keras_layer(example_input[None, :])
print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')
<trax.trax2keras.AsKeras object at 0x7efff5a47a90>
Keras returned sentiment activations: [[-1.6396211  1.6328843]]
[9]:
import tensorflow as tf

# Create a full Keras  model using the layer from Trax.
inputs = tf.keras.Input(shape=(None,), dtype='int32')
hidden = keras_layer(inputs)
# You can add other Keras layers here operating on hidden.
outputs = hidden
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
print(keras_model)

# Run the Keras model to verify it returns the same result.
sentiment_activations = keras_model(example_input[None, :])
print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')
Keras returned sentiment activations: [[-1.6396211  1.6328843]]

3. Exporting Trax Models for Deployment

You can export the Keras model to disk as TensorFlow SavedModel. It’s as simple as calling keras_model.save and allows you to use models with TF tools TensorFlow.js, TensorFlow Serving and TensorFlow Lite.

[10]:
# Save the Keras model to output_dir.
model_file = os.path.join(output_dir, "model_checkpoint")
keras_model.save(model_file)

# Load the model from SavedModel.
loaded_model = tf.keras.models.load_model(model_file)

# Run the loaded model to verify it returns the same result.
sentiment_activations = loaded_model(example_input[None, :])
print(f'Keras returned sentiment activations: {np.asarray(sentiment_activations)}')
Keras returned sentiment activations: [[-1.6396211  1.6328843]]