Core class and functions for handling data abstractly as shapes/dtypes.

class trax.shapes.ShapeDtype(shape, dtype=<sphinx.ext.autodoc.importer._MockObject object>)

Bases: object

A NumPy ndarray-like object abstracted as shape and dtype.

Main use is for representing input and output signatures.

__init__(shape, dtype=<sphinx.ext.autodoc.importer._MockObject object>)

Creates a ShapeDtype instance, with canonicalized shape and dtype.

  • shape – A tuple or list, each element of which is an int or, less often, None.
  • dtype – A dtype object, either from NumPy or TensorFlow.

A ShapeDtype instance whose shape is a tuple and dtype is a NumPy dtype object.


Creates a copy of the object with some parameters replaced.


Returns a ShapeDtype signature for the given obj.

A signature is either a ShapeDtype instance or a tuple of ShapeDtype instances. Note that this function is permissive with respect to its inputs (accepts lists or tuples or dicts, and underlying objects can be any type as long as they have shape and dtype attributes) and returns the corresponding nested structure of ShapeDtype.

Parameters:obj – An object that has shape and dtype attributes, or a list/tuple/dict of such objects.
Returns:A corresponding nested structure of ShapeDtype instances.

Creates a new signature by splicing together any number of signatures.

The splicing effectively flattens the top level input signatures. For instance, it would perform the following mapping:

  • *sigs: sd1, (sd2, sd3, sd4), (), sd5
  • return: (sd1, sd2, sd3, sd4, sd5)
Parameters:*sigs – Any number of signatures. A signature is either a ShapeDtype instance or a tuple of ShapeDtype instances.
Returns:A single ShapeDtype instance if the spliced signature has one element, else a tuple of ShapeDtype instances.
trax.shapes.assert_shape_equals(array, shape)

Asserts that an array has the given shape.

trax.shapes.assert_same_shape(array1, array2)

Asserts that two arrays have the same shapes.


Trax trainer.

trax.trainer.tf_init_tpu(worker='', protocol=None)

Initializes TPU for TensorFlow.

  • worker – The BNS address of the remote TPU worker. If it’s empty (the default value), TF will assume the TPU devices are connected to the local host.
  • protocol – The network protocol used to connect to the TPU worker.

The device name of the TPU worker’s CPU.



Trainer for RL environments.

For now we only support PPO as RL algorithm.

Sample invocation:

python trax/ \
  --config_file=trax/rl/configs/ppo_acrobot.gin \
  --train_batch_size=${TRAIN_BATCH_SIZE} \
  --output_dir=${HOME}/ppo_acrobot \
trax.rl_trainer.train_rl(output_dir, n_epochs=10000, light_rl=True, light_rl_trainer=<class ''>)

Train the RL agent.

  • output_dir – Output directory.
  • n_epochs – Number epochs to run the training for.
  • light_rl – deprecated, always True, left out for old gin configs.
  • light_rl_trainer – which light RL trainer to use (experimental).


Trax-to-Keras converter.

trax.trax2keras.tensor_shapes_to_shape_dtypes(shapes, dtype)
class trax.trax2keras.AsKeras(trax_layer, batch_size=None, initializer_rng=None, rng=None, rng_updater=None, dtype=None)

Bases: sphinx.ext.autodoc.importer._MockObject

A Keras layer built from a Trax layer.

This subclass of tf.keras.layers.Layer takes in a Trax layer as a constructor argument and wraps it to be a Keras layer. It uses tf.Variable to store weights and state (initialized according to the Trax layer), and uses the Trax layer’s forward function as its forward function.

Consider this code snippet:

keras_layer = AsKeras(trax_layer, initializer_rng=initializer_rng,
                             rng=rng, rng_updater=rng_updater)  # optional
outputs = keras_layer(inputs)

(Note that in Keras calling is optional. If omitted, it will be called automatically by Layer.__call__.)

If trax_layer already has weights at build time, the snippet is roughly equivalent to:

weights = trax_layer.weights
state = trax_layer.state
keras_layer = tf.keras.layers.Layer()
keras_layer._weights = tf.Variable(weights)
keras_layer._state = tf.Variable(state)
keras_layer._rng = tf.Variable(rng)
outputs, new_state = trax_layer(inputs, keras_layer._weights,
                                keras_layer._state, keras_layer._rng)

If trax_layer doesn’t have weights at build time, the snippet is roughly equivalent to:

weights, state = trax_layer.init(..., rng=initializer_rng)
keras_layer = ...

AsKeras uses tf.Variable to store weights, not shared with the original Trax layer (which uses tensors to store weights), so using AsKeras may double the memory footprint. This problem can be solved by making sure that the Trax layer’s weights/state are cleared whenever tf.Variable.assign (and tf.Variable.assign_add etc.) is called, because tf.Variable is copy-on-write by default.

Mutations in those tf.Variable`s won’t affect the Trax layer’s weights, but `AsKeras’s forward function calls the Trax layer’s forward function, which caches the weights in the Trax layer object, so a forward pass may change the weights cached in the original Trax layer.

Note that this class is not thread-safe. If the same AsKeras object is used in multiple threads, the tf.Variable updates may happen in a non-deterministic order.

__init__(trax_layer, batch_size=None, initializer_rng=None, rng=None, rng_updater=None, dtype=None)

Creates a Keras layer wrapping around a Trax layer.

  • trax_layer – an object of class trax.layers.Layer, the trax layer to wrap.
  • batch_size – (optional) an integer, the batch size that this Keras layer will be used on. Keras sometimes needs to generate a TF graph for a layer (e.g. for acceleration or checkpointing). The inputs used to trace the graph will have None as the length of their batch dimensions, so as to generate a graph that can handle any batch size. Some Trax layers can’t handle tensors whose shapes contain None. If batch_size is set to an integer, the graph will be traced with batch_size as the batch size instead of None. Note that in this case the graph (and the Keras layer) can only be used on a specific batch size. If you want to use a different batch size, you need to create another AsKeras object with a different batch_size.
  • initializer_rng – (optional) an RNG key used to create the weights and state if trax_layer doesn’t have them. If None, trax.fastmath.random.get_prng(0) will be used.
  • rng – (optional) an RNG key for the forward function (aka the “forward key”). If None, trax.fastmath.random.get_prng(0) will be used.
  • rng_updater – (optional) a function of type rng_key -> rng_key, used to update the forward key after each forward pass. If None, the function lambda x: trax.fastmath.random.split(x, 1)[0] will be used, which advances the RNG key.
  • dtype – (optional) the dtype of the inputs. See the dtype argument of tf.keras.layers.Layer.__init__ for details.