

Modifications to data and computation to use accelerators (better).

class trax.layers.acceleration.Accelerate(layer, n_devices=None)

Bases: trax.layers.base.Layer

Accelerates a layer, running in data-parallel way on multiple devices.

By default it uses all available accelerators, splits the input on the first (batch) axis, and runs each part on the corresponding accelerator. If only one accelerator is available, this layer JIT-compiles the underlying layer and in this way makes it run faster.

The output is guaranteed to be the same as the output of the original layer if the batch dimension is divisible by the number of devices. If it is not, then 0-padding is added to make it divisible and the output may be affected if it relies on layers like batch normalization.

This layer does not require calling init if the underlying layer has already been initialized, so it can be used as follows:

layer = tl.Serial(...)
fast_layer = tl.Accelerate(layer)
y = fast_layer(x)  # Split x on batch and run data-parallel

In case the weights of this layer need to be set using the weights of the sublayer, use the replicate_weights function:

# Instead of layer.weights = new_weights:
__init__(layer, n_devices=None)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Returns the unique sublayer managed by this layer.

pure_fn(x, weights, state, rng, use_cache=False)

Calls self.sublayer.pure_fn in an accelerated way.


Calls self.sublayer.init and replicates its values onto devices.


Sets the weights of the sublayer and replicates them for this layer.


Sets the state of the sublayer and replicates it for this layer.


Returns this layer’s weights.

Depending on the layer, the weights can be in the form of:

  • an empty tuple
  • a tensor (ndarray)
  • a nested structure of tuples and tensors

If the layer has sublayers, the weights by convention will be a tuple of length len(sublayers) containing the weights of sublayers. Note that in this case self._weights only marks which ones are shared.


Returns a tuple containing this layer’s state; may be empty.

If the layer has sublayers, the state by convention will be a tuple of length len(sublayers) containing sublayer states. Note that in this case self._state only marks which ones are shared.

trax.layers.acceleration.mean_or_pmean(n_devices, x, axis=None)

Computes the mean of a distributed value x.

  • n_devices – Number of devices.
  • x – Distributed array.
  • axis – Axis along which to compute means; can only be 0 or None.

A local array.

trax.layers.acceleration.jit_forward(forward, n_devices, do_mean=True)

Returns a JIT-compiled forward function running on n_devices.

trax.layers.acceleration.reshape_by_device(x, n_devices, pure_np=False)

Reshapes possibly nested x into a shape (n_devices, ...).

trax.layers.acceleration.for_n_devices(x, n_devices)

Replicates/broadcasts x for n_devices.


Puts x in CPU memory in JAX.


Puts x in (single) accelerator memory in JAX.


Layers that compute activation functions.

An activation layer computes element-wise a nonlinear function of the preceding layer’s output. Historically, an activation function was considered part of each node in each layer of the neural network. Trax follows the common current practice of separating the activation function as its own layer, which enables easier experimentation across different activation functions.


Returns a layer that computes the Rectified Linear Unit (ReLU) function.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]

Returns a layer that computes a ReLU function with the given slope.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ ax & \text{otherwise}. \end{array} \right.\end{split}\]
Parameters:a – Slope of line for positive inputs.

Returns a ReLU-like layer with linear nonzero outputs for negative inputs.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} ax & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]
Parameters:a – Slope of line for negative inputs.

Returns a ReLU-like layer with exponential outputs for negative inputs.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} a \cdot (e^x - 1) & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]

(Asymptotically, \(f(x)\rightarrow -a\) as \(x\rightarrow - \infty\).)

Parameters:a – Coefficient multiplying the exponential, for negative inputs.
trax.layers.activation_fns.Selu(alpha=1.6732632423543772, lmbda=1.0507009873554805)

Returns an Elu-like layer with an additional scaling/slope parameter.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} \lambda \cdot \alpha \cdot (e^x - 1) & \text{if}\ x \leq 0, \\ \lambda \cdot x & \text{otherwise}. \end{array} \right.\end{split}\]
  • alpha – Coefficient multiplying the exponential, for negative inputs.
  • lmbda – Coefficient scaling the whole function.

Returns a layer that computes the Gaussian Error Linear Unit function.

\[f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}}))\]

Returns a layer that computes a fast approximation to Gelu.

\[f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3))\]

where \(a = 0.7978845608\) and \(b = 0.044715\).


Returns a layer that computes the sigmoid function.

\[f(x) = \frac{1}{1 + e^{-x}}\]

Returns a layer that computes the hyperbolic tangent function.

\[f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]

Returns a layer that computes a linear approximation to Sigmoid.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{if}\ 0 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]

Returns a layer that computes a linear approximation to Tanh.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} -1 & \text{if}\ x \leq -1, \\ x & \text{if}\ -1 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]

Returns a layer that computes the softplus function.

\[f(x) = \ln(e^x + 1)\]

Returns a layer that computes the element-wise exponential of a tensor.


Returns a layer that computes the element-wise logarithm of a tensor.


Returns a layer that computes the Swish function.

\[f(x) = x \cdot \text{sigmoid}(x)\]

Returns a layer that computes the Gated Linear Unit function.

\[f(x) = a \cdot \text{sigmoid}(b)\]

where a and b are formed by splitting input in half along axis

class trax.layers.activation_fns.ThresholdedLinearUnit(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Bases: trax.layers.base.Layer

Thresholded Linear Unit, c.f. .


Initializes this layer’s single weight to zero.


Executes this layer as part of a forward pass through the model.

Parameters:inputs – Tensor.
Returns:Tensor of same shape and dtype as the input.


Attention-related layers, as used in Transformer(-like) models.

Attention is a trainable mechanism for mapping between collections of vectors:

\[\text{Attention}: \mathbf{X}^{n} \rightarrow \mathbf{X}^{n}\!, \ \text{for} \ \mathbf{X} \in \mathbb{R}^d\]

Whereas classic neural networks assemble nodes of numbers with weighted connections:

  • node activations: floating point values (one float per node)
  • inter-node connections: trainable weights (one float per connection),

attention lets one assemble nodes of vectors and use further vectors to calculate connection strengths:

  • node activations: floating point vectors, and
  • inter-node connections: computed using trainable vectors.

Computing connection strengths involves several concepts – queries, keys, values, masks, attention heads – that factor heavily into the API below.

NOTE: Attention, positional encoding, and shift layers in this module include mode-dependent behavior. The possible modes are:

  • 'train': in training – dropouts and position shifts active
  • 'eval': in evals – dropouts inactive, position shifts active
  • 'predict': in prediction – dropouts and position shifts inactive
trax.layers.attention.Attention(d_feature, n_heads=1, dropout=0.0, mode='train')

Returns a layer that maps (vectors, mask) to (new_vectors, mask).

This layer type represents one pass of multi-head self-attention, from vector set to vector set, using masks to represent out-of-bound (e.g., padding) positions. It:

  • makes three copies of incoming activations and maps these to multi-head query (Q) vectors, key (K) vectors, and value (V) vectors, respectively;
  • for each head, computes the scaled dot product of each Q-K pair;
  • applies mask to screen out positions that come from padding tokens (indicated by 0 value);
  • [in 'train' mode] applies dropout to Q-K dot products;
  • for each head, computes Q-K attention strengths using a per-query softmax of the Q-K dot products;
  • for each head, for each query position, combines V vectors according to the Q-K attention strengths; and
  • concatenates and fuses resulting per-head vectors into outgoing activations matching original input activation shapes.
  • d_feature – Last/innermost dimension of activations in the input to and output from this layer.
  • n_heads – Number of attention heads. Attention heads effectively split activation vectors into n_heads subvectors, of size d_feature / n_heads.
  • dropout – Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don’t contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in 'train' mode.
  • mode – One of 'train', 'eval', or 'predict'.
trax.layers.attention.AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None)

Returns a layer that maps (AQ, AK, AV, mask) to (new-A, mask).

Unlike Attention above, AttentionQKV allows the incoming activations (AQ, AK, and AV) to come from different sources. This is used, for instance, in encoder-decoder attention (Q-related activations AQ from the decoder, K- and V-related activations – AK and AV – from the encoder). Otherwise, see the Attention description for further context/details.

  • d_feature – Last/innermost dimension of activations in the input to and output from this layer.
  • n_heads – Number of attention heads. Attention heads effectively split activation vectors into n_heads subvectors, of size d_feature / n_heads.
  • dropout – Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don’t contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in 'train' mode.
  • mode – One of 'train', 'eval', or 'predict'.
  • cache_KV_in_predict – Whether to cache K/V arrays in 'predict' mode.
  • q_sparsity – Sparsity with which to process queries. If None, Dense is used; if 'noop', no processing is used.
  • result_sparsity – Sparsity with which to process result of the attention. If None, Dense is used; if 'noop', no processing is used.
class trax.layers.attention.PureAttention(n_heads=1, dropout=0.0, mode='train')

Bases: trax.layers.base.Layer

Returns a layer that maps (Q, K, V, mask) to (activations, mask).

This layer type performs the inner workings of one pass of multi-head self-attention. It:

  • subdivides incoming Q/K/V activations into multi-head versions;
  • for each head, computes the scaled dot product of each Q-K pair;
  • applies mask to screen out positions that come from padding tokens (indicated by 0 value);
  • [in 'train' mode] applies dropout to Q-K dot products;
  • for each head, computes Q-K attention strengths using a per-query softmax of the Q-K dot products;
  • for each head, for each query position, combines V vectors according to the Q-K attention strengths; and
  • concatenates and fuses resulting per-head vectors into outgoing activations matching original input activation shapes.
__init__(n_heads=1, dropout=0.0, mode='train')

Returns a new PureAttention instance.

  • n_heads – Number of attention heads.
  • dropout – Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don’t contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in 'train' mode.
  • mode – One of 'train', 'eval', or 'predict'.

Returns attention-computed activations and unmodified mask.

Parameters:inputs – A (Q, K, V, mask) tuple, whose query, key, and value activations have not yet been subdivided into heads.
class trax.layers.attention.DotProductAttention(dropout=0.0, mode='train')

Bases: trax.layers.base.Layer

Returns a layer that computes per-head attention (via scaled dot-product).

This layer computes the core of the attention mechanism. Given per-head queries (Q), keys (K), values (V), and mask, it:

  • computes the scaled dot product of each Q-K pair;
  • applies mask to screen out positions that come from padding tokens (indicated by 0 value);
  • [if created in 'train' mode] applies dropout to Q-K dot products;
  • computes Q-K attention strengths using a per-query softmax of the Q-K dot products; and
  • for each query position, combines V vectors according to the Q-K attention strengths.
__init__(dropout=0.0, mode='train')

Creates a DotProductAttention instance in a specific mode.

  • dropout – Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don’t contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in 'train' mode.
  • mode – One of 'train', 'eval', 'predict' or 'viz'.

Returns attention-computed per-head activations and unchanged mask.

Parameters:inputs – A (Q, K, V, mask) tuple, whose query, key, and value activations have been subdivided into heads.
trax.layers.attention.SplitIntoHeads(n_heads, merged_batch_and_head=True)

Returns a layer that reshapes an array for multi-head computation.

trax.layers.attention.MergeHeads(n_heads, merged_batch_and_head=True)

Returns a layer that rejoins heads, after multi-head computation.

trax.layers.attention.ConfigurableAttention(q_layer, k_layer, v_layer, final_layer, qkv_attention_layer, n_heads=1)

Returns a configured multi-head self-attention layer.

A ConfigurableAttention layer acts similarly to Attention layers, but with configurable components. It

  • makes three copies of incoming activations and uses q_layer, k_layer, and v_layer to map activations to multi-head query (Q) vectors, key (K) vectors, and value (V) vectors, respectively;
  • uses qkv_attention_layer to compute per-head attention, similar to DotProductAttention or DotProductCausalAttention;
  • concatenates and fuses resulting per-head vectors into activations matching original input activation shapes; and
  • applies a final layer, final_layer, mapping activations to activations (with shape matching the original input activations).
  • q_layer – Layer that maps input activations to per-head query activations.
  • k_layer – Layer that maps input activations to per-head key activations.
  • v_layer – Layer that maps input activations to per-head value activations.
  • final_layer – After main multi-head computation and rejoining of heads, layer that maps activations to activations (with shape matching the original input activations).
  • qkv_attention_layer – Layer the does the core multi-head self-attention computation.
  • n_heads – Number of attention heads. Attention heads effectively split activation vectors into n_heads subvectors, of size d_feature / n_heads.
trax.layers.attention.CausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train')

Returns a layer that maps activations to activations, with causal masking.

Like Attention, this layer type represents one pass of multi-head self-attention, but with causal masking rather than padding-based masking.

  • d_feature – Last/innermost dimension of activations in the input to and output from this layer.
  • n_heads – Number of attention heads. Attention heads effectively split activation vectors into n_heads subvectors, of size d_feature / n_heads.
  • dropout – Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don’t contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in 'train' mode.
  • max_inference_length – Maximum sequence length allowed in non-training modes.
  • mode – One of 'train', 'eval', or 'predict'.
class trax.layers.attention.DotProductCausalAttention(dropout=0.0, max_inference_length=2048, mode='train')

Bases: trax.layers.base.Layer

Layer that computes attention strengths by masking out the “future”.

Causal attention uses masking to prevent a given sequence position from attending to positions greater than / following it. This is used, for example, when training autoregressive sequence models, or when decoding a sequence symbol by symbol.

This layer performs the core per-head attention calculation. The layer assumes that any splitting into attention heads precedes it, and that any merging of attention heads will follow it.

__init__(dropout=0.0, max_inference_length=2048, mode='train')

Creates a DotProductCausalAttention instance.

  • dropout – Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don’t contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only if layer is created in 'train' mode.
  • max_inference_length – Maximum sequence length allowed in non-training modes.
  • mode – One of 'train', 'eval', or 'predict'.

Returns attention-computed activations.

Parameters:inputs – A (queries, keys, values) tuple.

Initializes this layer for fast inference, if in 'predict' mode.

trax.layers.attention.ShiftRight(n_positions=1, mode='train')

Returns a layer that can insert padding to shift the input sequence.

  • n_positions – Number of positions to shift the input sequence rightward; initial positions freed by the shift get padded with zeros. Applies only if layer is created in a non-'eval' mode.
  • mode – One of 'train', 'eval', or 'predict'.

Returns a layer that maps integer sequences to padding masks.

The layer expects as input a batch of integer sequences. The layer output is an N-D array that marks for each sequence position whether the integer (e.g., a token ID) in that position represents padding – value pad – versus text/content – all other values. The padding mask shape is (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast to cover any number of attention heads and axis 2 will broadcast to cover decoder sequence positions.

Parameters:pad – Integer that represents padding rather than a token/content ID.

Returns a layer that creates a mask for encoder-decoder cross attention.

The layer expects two inputs:

  • decoder_input: batch of integer (e.g., token ID) sequences
  • mask: padding mask from the encoder

The layer output is a mask that marks for each sequence position (for both encoder and decoder) whether that position can be attended to or not. The encoder-decoder mask shape is (batch_size, 1, decoder_sequence_length, encoder_sequence_length), such that axis 1 will automatically broadcast to cover any number of attention heads.

class trax.layers.attention.PositionalEncoding(max_len=2048, dropout=0.0, dropout_broadcast_dims=(-2, ), use_bfloat16=False, start_from_zero_prob=1.0, max_offset_to_add=0, d_feature=None, mode='train')

Bases: trax.layers.base.Layer

Implements bare positional encoding.

Positional encoding includes a kind of dropout, if the layer is created in 'train' mode with a nonzero dropout value. For such a layer, on each forward pass a subset of sequence positions selected at random will not receive positional marking.

__init__(max_len=2048, dropout=0.0, dropout_broadcast_dims=(-2, ), use_bfloat16=False, start_from_zero_prob=1.0, max_offset_to_add=0, d_feature=None, mode='train')

Creates a PositionalEncoding instance in a given mode.

  • max_len – Maximum input sequence length.
  • dropout – Probability of not adding positional encoding to a sequence position. Applies only if layer is created in 'train' mode.
  • dropout_broadcast_dims – Axes along which dropout mask values are broadcast rather than individually set at random.
  • use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.
  • start_from_zero_prob – how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize).
  • max_offset_to_add – maximum offset to add to the positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples.
  • d_feature – int or None; have this dimension for embeddings + shared FF if not None.
  • mode – One of 'train', 'eval', or 'predict'.

Returns the input activations, with added positional information.


Randomly initializes the positional encoding vectors.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on.


The key layer abstraction (Layer class) and supporting machinery.

class trax.layers.base.Layer(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Bases: object

Base class for composable layers in a deep learning network.

Layers are the basic building blocks for deep learning models. A layer computes a function from zero or more inputs to zero or more outputs, optionally using trainable weights (common) and non-parameter state (not common).

Layer subclasses typically override at most two methods of the base Layer class:

Computes the layer’s output as part of a forward pass through the model.
init_weights_and_state(self, input_signature):
Initializes the layer’s weights and state to handle input with the given signature (number, shapes and dtypes of input arguments).

A small number of layer types are combinators – they organize the computation of their sublayers, e.g., applying their sublayers in series or in parallel.

All layers have the following properties, with default values implemented in the base Layer class:

  • n_in: int (default 1)
  • n_out: int (default 1)
  • weights: tuple (default empty – the layer has no weights)
  • state: tuple (default empty – the layer has no non-parameter state)
  • sublayers: tuple (default empty – the layer has no sublayers)

The inputs to a layer are tensors, packaged according to how many there are:

  • n_in = 0: an empty tuple
  • n_in = 1: one tensor (NOT wrapped in a tuple)
  • n_in > 1: a tuple of tensors

(The special treatment of the single-input case is meant to simplify the work of layer writers; this design choice may be revisited in the future.)

The outputs from a layer are also tensors, packaged the same as layer inputs:

  • n_out = 0: an empty tuple
  • n_out = 1: the tensor (NOT wrapped in a tuple)
  • n_out > 1: a tuple of tensors

The Trax runtime maintains a data stack with which layer calls are composed. For more complex data network architectures, possibly involving multiple data flows, one can view each layer as a function from stack state to stack state, where the function’s inputs are a slice from the stack, and the function’s outputs are spliced back into the stack.

__init__(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.
__call__(x, weights=None, state=None, rng=None)

Makes layers callable; for use in tests or interactive settings.

This convenience method helps library users play with, test, or otherwise probe the behavior of layers outside of a full training environment. It presents the layer as callable function from inputs to outputs, with the option of manually specifying weights and non-parameter state per individual call. For convenience, weights and non-parameter state are cached per layer instance, starting from default values of EMPTY_WEIGHTS and EMPTY_STATE, and acquiring non-empty values either by initialization or from values explicitly provided via the weights and state keyword arguments, in which case the old weights will be preserved, and the state will be updated.

  • x – Zero or more input tensors, packaged as described in the Layer class docstring.
  • weights – Weights or None; if None, use self’s cached weights value.
  • state – State or None; if None, use self’s cached state value.
  • rng – Single-use random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.

Zero or more output tensors, packaged as described in the Layer class docstring.


Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng)

Custom backward pass to propagate gradients in a custom way.

  • inputs – Input tensors; can be a (possibly nested) tuple.
  • output – The result of running this layer on inputs.
  • grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
  • weights – This layer’s weights.
  • state – This layer’s state prior to the current forward pass.
  • new_state – This layer’s state after the current forward pass.
  • rng – Single-use random number generator (JAX PRNG key).

The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.

init(input_signature, rng=None, use_cache=False)

Initializes weights/state of this layer and its sublayers recursively.

Initialization creates layer weights and state, for layers that use them. It derives the necessary array shapes and data types from the layer’s input signature, which is itself just shape and data type information.

For layers without weights or state, this method safely does nothing.

This method is designed to create weights/state only once for each layer instance, even if the same layer instance occurs in multiple places in the network. This enables weight sharing to be implemented as layer sharing.

  • input_signatureShapeDtype instance (if this layer takes one input) or list/tuple of ShapeDtype instances.
  • rng – Single-use random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.
  • use_cache – If True, and if this layer instance has already been initialized elsewhere in the network, then return special marker values – tuple (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE). Else return this layer’s newly initialized weights and state.

A (weights, state) tuple.

init_from_file(file_name, weights_only=False, input_signature=None)

Initializes this layer and its sublayers from a pickled checkpoint.

In the common case (weights_only=False), the file must be a gziped pickled dictionary containing items with keys ‘flat_weights’, `’flat_state’ and ‘input_signature’, which are used to initialize this layer. If input_signature is specified, it’s used instead of the one in the file. If weights_only is True, the dictionary does not need to have the ‘flat_state’ item and the state it not restored either.

  • file_name – Name/path of the pickled weights/state file.
  • weights_only – If True, initialize only the layer’s weights. Else initialize both weights and state.
  • input_signature – Input signature to be used instead of the one from file.

A (weights, state) tuple.

save_to_file(file_name, weights_only=False, input_signature=None)

Saves this layer and its sublayers to a pickled checkpoint.

  • file_name – Name/path of the pickled weights/state file.
  • weights_only – If True, save only the layer’s weights. Else save both weights and state.
  • input_signature – Input signature to be used.

Returns the name of this layer.


Returns how many tensors this layer expects as input.


Returns how many tensors this layer promises as output.


Returns a tuple containing this layer’s sublayers; may be empty.


Returns this layer’s weights.

Depending on the layer, the weights can be in the form of:

  • an empty tuple
  • a tensor (ndarray)
  • a nested structure of tuples and tensors

If the layer has sublayers, the weights by convention will be a tuple of length len(sublayers) containing the weights of sublayers. Note that in this case self._weights only marks which ones are shared.


Returns a tuple containing this layer’s state; may be empty.

If the layer has sublayers, the state by convention will be a tuple of length len(sublayers) containing sublayer states. Note that in this case self._state only marks which ones are shared.

weights_and_state_signature(input_signature, unsafe=False)

Return a pair containing the signatures of weights and state.


Returns this layer’s current single-use random number generator.

Code that wants to base random samples on this generator must explicitly split off new generators from it. (See, for example, the rng setter code below.)

pure_fn(x, weights, state, rng, use_cache=False)

Applies this layer as a pure function with no optional args.

This method exposes the layer’s computation as a pure function. This is especially useful for JIT compilation. Do not override, use forward instead.

  • x – Zero or more input tensors, packaged as described in the Layer class docstring.
  • weights – A tuple or list of trainable weights, with one element for this layer if this layer has no sublayers, or one for each sublayer if this layer has sublayers. If a layer (or sublayer) has no trainable weights, the corresponding weights element is an empty tuple.
  • state – Layer-specific non-parameter state that can update between batches.
  • rng – Single-use random number generator (JAX PRNG key).
  • use_cache – if True, cache weights and state in the layer object; used to implement layer sharing in combinators.

A tuple of (tensors, state). The tensors match the number (n_out) promised by this layer, and are packaged as described in the Layer class docstring.


Returns output signature this layer would give for input_signature.

class trax.layers.base.PureLayer(forward_fn, n_in=1, n_out=1, name='PureLayer')

Bases: trax.layers.base.Layer

Pure function from inputs to outputs, packaged as neural network layer.

The PureLayer class represents the simplest kinds of layers: layers with no trainable weights and no randomness, hence pure functions from inputs to outputs.

__init__(forward_fn, n_in=1, n_out=1, name='PureLayer')

Creates an unconnected PureLayer instance.

  • forward_fn – Pure function from input tensors to output tensors, where inputs and outputs are packaged as specified for forward.
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use only in debugging.

Overrides Layer.forward.

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
trax.layers.base.Fn(name, f, n_out=1)

Returns a layer with no weights that applies the function f.

f can take and return any number of arguments, and takes only positional arguments – no default or keyword arguments. It often uses JAX-numpy (jnp). The following, for example, would create a layer that takes two inputs and returns two outputs – element-wise sums and maxima:

Fn(‘SumAndMax’, lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)

The layer’s number of inputs (n_in) is automatically set to number of positional arguments in f, but you must explicitly set the number of outputs (n_out) whenever it’s not the default value 1.

  • name – Class-like name for the resulting layer; for use in debugging.
  • f – Pure function from input tensors to output tensors, where each input tensor is a separate positional arg, e.g., f(x0, x1) –> x0 + x1. Output tensors must be packaged as specified in the Layer class docstring.
  • n_out – Number of outputs promised by the layer; default value 1.

Layer executing the function f.

exception trax.layers.base.LayerError(layer_name, function_name, caller, input_signature, traceback_string)

Bases: Exception

Exception raised in the layer stack.

__init__(layer_name, function_name, caller, input_signature, traceback_string)

Initialize self. See help(type(self)) for accurate signature.


Assembles current layer context into an error message.

trax.layers.base.flatten_weights_and_state(weights, state)

Flatten weights and state into lists, excluding empty and cached ones.

trax.layers.base.unflatten_weights_and_state(flat_weights, flat_state, weights_and_state_signature, weights_only=False)

Unflatten weights and state given their signatures.

trax.layers.base.np_to_file(list_of_nparrays, file_path, compresslevel)

Save numpy arrays to file_path with gzipping and failure protection.

trax.layers.base.np_from_file(file_path, compresslevel)

Load numpy arrays from file_path with gzipping.


Converts layer outputs to a nested list, for easier equality testing.

Parameters:outputs – A tensor or tuple/list of tensors coming from the forward application of a layer. Each tensor is NumPy ndarray-like, which complicates simple equality testing (e.g., via assertEquals): such tensors require equality testing to use either all (all elements match) or any (at least one element matches), which is not directly supported in absltest.
Returns:A nested list structure containing all the output values, but now directly testable using assertEquals.
trax.layers.base.shard(tensors, n_shards=None)

Shard tensors across n_shards.

trax.layers.base.unshard_in_pmap(tensors, n_shards)

Unshard tensors that were sharded into n_shards (call inside pmap).

trax.layers.base.unshard(tensors, n_shards=None)

Unshard tensors that were sharded into n_shards (outside of pmap).


Combinators for composing layers.

class trax.layers.combinators.Serial(*sublayers, name=None, sublayers_to_print=None)

Bases: trax.layers.base.Layer

Combinator that applies layers serially (by function composition).

This combinator is commonly used to construct deep networks, e.g., like this:

mlp = tl.Serial(

A Serial combinator uses stack semantics to manage data for its sublayers. Each sublayer sees only the inputs it needs and returns only the outputs it has generated. The sublayers interact via the data stack. For instance, a sublayer k, following sublayer j, gets called with the data stack in the state left after layer j has applied. The Serial combinator then:

  • takes n_in items off the top of the stack (n_in = k.n_in) and calls layer k, passing those items as arguments; and
  • takes layer k’s n_out return values (n_out = k.n_out) and pushes them onto the data stack.

A Serial instance with no sublayers acts as a special-case (but useful) 1-input 1-output no-op.

__init__(*sublayers, name=None, sublayers_to_print=None)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Executes this layer as part of a forward pass through the model.


Initializes weights and state for inputs with the given signature.

class trax.layers.combinators.Parallel(*sublayers, name=None)

Bases: trax.layers.base.Layer

Combinator that applies a list of layers in parallel to its inputs.

Layers in the list apply to successive spans of inputs, where the spans are determined how many inputs each layer takes. The resulting output is the (flattened) concatenation of the respective layer outputs.

For example, suppose one has three layers:

  • F: 1 input, 1 output
  • G: 3 inputs, 1 output
  • H: 2 inputs, 2 outputs (h1, h2)

Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:

  • inputs: a, b, c, d, e, f
  • outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f)

As an important special case, a None argument to Parallel acts as if it takes one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For .. rubric:: Example

Parallel(None, F)

creates a layer that passes its first input unchanged and applies F to the following input(s).

__init__(*sublayers, name=None)

The constructor.

  • *sublayers – A list of sublayers.
  • name – Descriptive name for this layer.

A new layer in which each of the given sublayers applies to its corresponding span of elements in the dataflow stack.


Executes this layer as part of a forward pass through the model.


Initializes weights and state for inputs with the given signature.

class trax.layers.combinators.Concatenate(n_items=2, axis=-1)

Bases: trax.layers.base.Layer

Concatenates a number of tensors into a single tensor.

For example:

x = np.array([1, 2])
y = np.array([3, 4])
z = np.array([5, 6])
concat3 = tl.Concatenate(n_items=3)
z = concat3((x, y, z))  # z = [1, 2, 3, 4, 5, 6]

Use the axis argument to specify on which axis to concatenate the tensors. By default it’s the last axis, axis=-1, and n_items=2.

__init__(n_items=2, axis=-1)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Executes this layer as part of a forward pass through the model.

class trax.layers.combinators.Split(n_items=2, axis=-1)

Bases: trax.layers.base.Layer

Splits the input into n items along an axis.

__init__(n_items=2, axis=-1)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Executes this layer as part of a forward pass through the model.

class trax.layers.combinators.Scan(layer, axis=0, n_carry=1, remat=False, mode='train')

Bases: trax.layers.base.Layer

Applies a layer progressively/cumulatively to an axis-derived sequence.

Conceptually, this is a function from a list to a same-length list of partial (cumulative) results. For instance, a list of values ([1, 2, 3, 4, 5]) can transform to a list of cumulative sums ([1, 3, 6, 10, 15]). Functions for the same concept are called scan in Scala, scanl in Haskell, and accumulate* in Factor.

In more detail, we assume the layer takes a tuple of inputs of the following form:

(input1, …, inputN, carry1, …, carryM)

and returns:

(output1, …, outputK, new_carry1, …, new_carryM)

The scanned version applies the layer iteratively to a tensor treating values at the given axis as if they were a list. For example, to calculate all sums of prefixes of a tensor, we can do this:

def add(x, carry):
  def f(input, carry):
    res = input + carry
    return res, res  # output and carry are the same
  return tl.Fn('add', f, n_out=2)

Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6
__init__(layer, axis=0, n_carry=1, remat=False, mode='train')

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Returns the unique sublayer managed by this layer.


Returns a tuple containing this layer’s state.


Executes this layer as part of a forward pass through the model.


Initializes weights and state for inputs with the given signature.

class trax.layers.combinators.Cond(cond, true, false=None, name=None)

Bases: trax.layers.base.Layer

Applies layers conditionally.

For parameters cond, true, and false runs the equivalent of true(y) if cond(x) else false(y), where x is cond.n_in elements from front of the stack and y is the rest of the stack. Exactly one of true and false functions is executed, so it can be used to conditionally run long computations. The state of non-executed function is not updated. Note that different branches may be executed on different devices if cond returns different values on them. By default ‘false’ function is an identity.

cond must return exactly one element: a Boolean value. true and false must have the same n_in, and the same n_out.

__init__(cond, true, false=None, name=None)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Initializes weights and state for inputs with the given signature.


Executes this layer as part of a forward pass through the model.

Parameters:xs – Tensors of as required by the branches of this conditional.
Returns:Tensors resulting from running the chosen branch.
trax.layers.combinators.Chunk(layer, chunk_size, pass_unchunkable=True)

Executes layer using batch chunks of size chunk_size to save memory.

trax.layers.combinators.Branch(*layers, name='Branch')

Combinator that applies a list of layers in parallel to copies of inputs.

Each layer in the input list is applied to as many inputs from the stack as it needs, and their outputs are successively combined on stack.

For example, suppose one has three layers:

  • F: 1 input, 1 output
  • G: 3 inputs, 1 output
  • H: 2 inputs, 2 outputs (h1, h2)

Then Branch(F, G, H) will take 3 inputs and give 4 outputs:

  • inputs: a, b, c
  • outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)

As an important special case, a None argument to Branch acts as if it takes one argument, which it leaves unchanged. (It acts as a one-arg no-op.)

  • *layers – List of layers.
  • name – Descriptive name for this layer.

A branch layer built from the given sublayers.

trax.layers.combinators.Residual(*layers, shortcut=None)

Wraps a series of layers with a residual connection.

  • *layers – One or more layers, to be applied in series.
  • shortcut – If None (the usual case), the Residual layer computes the element-wise sum of the stack-top input with the output of the layer series. If specified, the shortcut layer applies to a copy of the inputs and (elementwise) adds its output to the output from the main layer series.

A layer representing a residual connection paired with a layer series.

trax.layers.combinators.Select(indices, n_in=None, name=None)

Copies, reorders, or deletes stack elements according to indices.

  • indices – A list or tuple of 0-based indices to select elements relative to the top of the stack.
  • n_in – Number of input elements to pop from the stack, and replace with those specified by indices. If not specified, its value will be calculated as max(indices) + 1.
  • name – Descriptive name for this layer.

Tensors, matching the number selected (n_out = len(indices)). Specifically:

  • n_out = 0: an empty tuple
  • n_out = 1: one tensor (NOT wrapped in a tuple)
  • n_out > 1: a tuple of tensors, with n_out items


Drops the top stack element.


Duplicates (copies) the top element on the data stack.


Swaps the top two stack elements.

trax.layers.combinators.SerialWithSideOutputs(layers, n_side_outputs=1)

Serial layer with side outputs.

This layer makes it easier to manage the stack when layers have side outputs.

In the simplest case of layers with n_in=1, n_out=2 and with n_side_outputs=1, this layer runs the following computation on x:

side_outputs = []
for i in range(len(layers)):
  x, side_output = layers[i](x)
return [x] + side_outputs

In the general case of layers with variable n_in and n_out and n_side_outputs being a list of N integers, it does the following:

side_outputs = []
for i in range(N):
  res = layer[i](cur_stack)  # remove n_in from stack
  cur_stack.append(res[:n_side_outputs[i]])  # put back some on stack
return cur_stack + side_outputs
  • layers – a list of layers to execute
  • n_side_outputs – an int or a list of ints, how many outputs of each layer to put aside

A layer that performs the above computation.


Flatten lists.


Adds two tensors.


Subtracts the first tensor from the second.


Multiplies two tensors.


Returns a gating layer on a (memory, gate, candidate) tuple.

Final update is memory * gate + (1 - gate) * candidate

This gating equation may also be referred to as Highway Network. Highway Networks:

class trax.layers.combinators.Cache(layer)

Bases: trax.layers.base.Layer

Applies a layer on the first run and returns the outputs on next calls.


Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Returns the unique sublayer managed by this layer.


Returns a tuple containing this layer’s state; may be empty.


Initializes weights and state for inputs with the given signature.


Executes this layer as part of a forward pass through the model.

Parameters:inputs – Tensors required by the sublayer.
Returns:Tensors resulting from running the sublayer the first time.
class trax.layers.combinators.BatchLeadingAxes(layer, n_last_axes_to_keep=1)

Bases: trax.layers.base.Layer

Applies a layer after flattening all but n_last_axes_to_keep to batch.

This can be used to make layers accept an arbitrary number of leading axes (dimensions) as batch. For example, a Convolution layer may normally only operate on tensors of shape [B, W, H, C]. In this case, the layer

BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3)

will operate on any tensor […, W, H, C] and treat the leading axes as batch.

__init__(layer, n_last_axes_to_keep=1)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Returns the unique sublayer managed by this layer.


Executes this layer as part of a forward pass through the model.


Initializes weights and state for inputs with the given signature.

trax.layers.combinators.inputs_from_stack(stack, n)

Returns n inputs from stack.

trax.layers.combinators.outputs_onto_stack(outputs, stack, n)

“Returns the new stack after removing n items and pushing outputs there.


Trax convolution layers.

class trax.layers.convolution.Conv(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)

Bases: trax.layers.base.Layer

Layer constructor function for a general convolution layer.

__init__(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
class trax.layers.convolution.CausalConv(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)

Bases: trax.layers.convolution.Conv

Causal (masked) convolution for [batch x time x depth] sequences.

Maintains causality along time axis. Used in language modeling tasks.

__init__(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
trax.layers.convolution.Conv1d(filters, kernel_size, stride=1, padding='VALID', kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)


Core layer types and key functions used by various layers.

class trax.layers.core.Dense(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, use_bfloat16=False)

Bases: trax.layers.base.Layer

A dense (a.k.a. fully-connected, affine) layer.

Dense layers are the prototypical example of a trainable layer, i.e., a layer with trainable weights. Each node in a dense layer computes a weighted sum of all node values from the preceding layer and adds to that sum a node-specific bias term. The full layer computation is expressed compactly in linear algebra as an affine map y = Wx + b, where W is a matrix and y, x, and b are vectors. The layer is trained, or “learns”, by updating the values in W and b.

Less commonly, a dense layer can omit the bias term and be a pure linear map: y = Wx.

__init__(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, use_bfloat16=False)

Returns a dense (fully connected) layer of width n_units.

A dense layer maps collections of R^m vectors to R^n, where n (= n_units) is fixed at layer creation time, and m is set at layer initialization time.

  • n_units – Number of nodes in the layer, also known as the width of the layer.
  • kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
  • bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
  • use_bias – If True, compute an affine map y = Wx + b; else compute a linear map y = Wx.
  • use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor of same shape and dtype as the input, except the final dimension is the layer’s n_units value.

Randomly initializes this layer’s weights.

Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on.
class trax.layers.core.Embedding(vocab_size, d_feature, use_bfloat16=False, kernel_initializer=<function ScaledInitializer.<locals>.Init>)

Bases: trax.layers.base.Layer

Trainable layer that maps discrete tokens/IDs to vectors.

Embedding layers are commonly used to map discrete data, like words in NLP, into vectors. Here is a canonical example:

vocab_size = 5
word_ids = np.array([1, 2, 3, 4], dtype=np.int32)  # word_ids < vocab_size
embedding_layer = tl.Embedding(vocab_size, 32)
embedded = embedding_layer(word_ids)  # embedded.shape = (4, 32)
__init__(vocab_size, d_feature, use_bfloat16=False, kernel_initializer=<function ScaledInitializer.<locals>.Init>)

Returns an embedding layer with given vocabulary size and vector size.

The layer clips input values (token IDs) to the range [0, vocab_size). That is, negative token IDs all clip to 0 before being mapped to a vector, and token IDs with value vocab_size or greater all clip to vocab_size - 1 before being mapped to a vector.

  • vocab_size – Size of the input vocabulary. The layer will assign a unique vector to each id in range(vocab_size).
  • d_feature – Dimensionality/depth of the output vectors.
  • use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.
  • kernel_initializer – Function that creates (random) initial vectors for the embedding.

Returns embedding vectors corresponding to input token IDs.

Parameters:x – Tensor of token IDs.
Returns:Tensor of embedding vectors.

Randomly initializes this layer’s weights.

class trax.layers.core.Dropout(rate=0.0, shared_axes=None, mode='train')

Bases: trax.layers.base.Layer

A layer that stochastically ignores a subset of inputs each training step.

In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1 - rate).

The parameter shared_axes allows to specify a list of axes on which the mask will be shared: we will use size 1 on those axes for dropout mask and broadcast it. Sharing reduces randomness, but can save memory.

This layer is active only during training (mode=’train’). In other circumstances it is a no-op.

Originally introduced in the paper “Dropout: A Simple Way to Prevent Neural Networks from Overfitting” available under the following link:

__init__(rate=0.0, shared_axes=None, mode='train')

Creates a dropout layer with the given target drop rate.

  • rate – Stochastic rate (probability) for dropping an activation value from the preceding layer (setting it to zero).
  • shared_axes – List of axes on which the mask is shared.
  • mode – If ‘train’, this layer will perform dropout; else, it will pass all values through unaltered.

Sets layer-specific internal state.


Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of activations.
Returns:Tensor of same shape and dtype as the input.
class trax.layers.core.Weights(initializer, shape=(), use_bfloat16=False)

Bases: trax.layers.base.Layer

Learnable weights as a layer.

It takes no input and returns a single tensor: weights.

__init__(initializer, shape=(), use_bfloat16=False)

Returns a learnable tensor of shape shape.

  • initializer – Function taking shape and rng as arguments.
  • shape – Shape of the learnable weights.
  • use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor with previously specified shape and dtype.

Returns newly initialized weights for this layer.

Weights is a single w tensor with previously specified shape.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on. Unused.
trax.layers.core.PrintShape(n_in=1, msg='')

Prints the shapes of n_in inputs and returns then unchanged.

class trax.layers.core.SummaryImage(name, n_in, num_summaries=5, recover_fn=None)

Bases: trax.layers.base.Layer

A layer receiving a tensor, and adding it to TensorBoard as an image.

It takes an input and returns it unchanged. It stores this input as a state to be used as a metric in TensorBoard. It converts a tensor to a scalar by running a given aggregation function (mean by default). On TensorBoard, results for each device will be reported separately.

__init__(name, n_in, num_summaries=5, recover_fn=None)

Takes a tensor and returns it.

  • name – Name of the metric to be reported.
  • n_in – Number of inputs.
  • num_summaries – Number of images to show.
  • recover_fn – the function for converting a tensor to a dipslayable image.

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor with previously specified shape and dtype.

Returns newly initialized weights for this layer.

Weights is a single w tensor with previously specified shape.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on. Unused.
class trax.layers.core.SummaryScalar(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)

Bases: trax.layers.base.Layer

A layer receiving a tensor, and adding it to TensorBoard as a scalar.

It takes an input and returns it unchanged. It stores this input as a state to be used as a metric in TensorBoard. It converts a tensor to a scalar by running a given aggregation function (mean by default). On TensorBoard, results for each device will be reported separately.

__init__(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)

Takes a tensor and returns it.

  • name – Name of the metric to be reported.
  • aggregation_fun – Aggregation function to be used.

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor with previously specified shape and dtype.

Returns newly initialized weights for this layer.

Weights is a single w tensor with previously specified shape.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on. Unused.
class trax.layers.core.RandomUniform(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)

Bases: trax.layers.base.Layer

Layer returning a tensor with random values distributed uniformly.

__init__(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)

Layer returning a tensor with random values distributed uniformly.

  • min_val – Lower end of uniform distribution.
  • max_val – Upper end of uniform distribution.
  • shape – Shape of the tensor to return. Values are sampled independently.
  • dtype – Type of value to return.
  • sync – Whether to synchronise rng across devices.

Executes this layer as part of a forward pass through the model.

Parameters:xs – Unused tensors.
Returns:Random uniform tensor of the shape and type specified in constructor.
class trax.layers.core.LocallyConnected1d(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')

Bases: trax.layers.base.Layer

Locally-connected layer for 1D inputs.

The LocallyConnected1d layer applies a different set of filters to each patch of the input. This is similar to applying a convolution layer, except that locally-connected layer uses a different set of weights for each patch.

The size of patch is determined by the kernel size. The stride is currently not modifiable and set to one. This means for the input of shape (…, L, D) the output shape for paddings ‘SAME’ and ‘WRAP’ will be (…, L, filters) and for padding ‘VALID’ (…, L-kernel_size+1, filters); where L is the number of “pixels” or “steps” in the input, D is the size of the embedding.

Note that, since the weights for different patches are not shared, the number of “pixels” or “steps” cannot change after calling init_weights_and_state. This is because each “pixel” is assigned its own set of weights.

__init__(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')

Returns a locally-connected conv-like layer.

  • filters – Number of output filters in the convolution.
  • kernel_size – A length of the convolution window. Must be an odd number.
  • kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
  • bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
  • use_bias – If True, the layer uses a bias vector.
  • padding – The type of padding to use; must be ‘VALID’, ‘SAME’, or ‘WRAP’.

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor of same shape and dtype as the input, except the final dimension is the layer’s filters value, and the second to last dimension is shrinked if ‘VALID’ padding is used with kernel_size bigger than one.

Randomly initializes this layer’s weights.

Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on.

Returns a layer that combines one or more trailing axes of a tensor.

Flattening keeps all the values of the input tensor, but reshapes it by collapsing one or more trailing axes into a single axis. For example, a Flatten(n_axes_to_keep=2) layer would map a tensor with shape (2, 3, 5, 7, 11) to the same values with shape (2, 3, 385).

Parameters:n_axes_to_keep – Number of leading axes to leave unchanged when reshaping; collapse only the axes after these.

Returns a layer that applies log softmax along one tensor axis.

Note that the implementation actually computes x - LogSumExp(x), which is mathematically equal to LogSoftmax(x).

LogSoftmax acts on a group of values and normalizes them to look like a set of log probability values. (Probability values must be non-negative, and as a set must sum to 1. A group of log probability values can be seen as the natural logarithm function applied to a set of probability values.)

Parameters:axis – Axis along which values are grouped for computing log softmax.

Returns a layer that computes log(sum(exp(x))) along one tensor axis.

Parameters:axis – Axis along which values are grouped for computing log-sum-exp.

Returns a layer that applies softmax along one tensor axis.

Softmax acts on a group of values and normalizes them to look like a set of probability values. (Probability values must be non-negative, and as a set must sum to 1.)

Parameters:axis – Axis along which values are grouped for computing softmax.

Returns a layer that changes the dtype of a tensor to float32.

trax.layers.core.Mean(axis=-1, keepdims=False)

Returns a layer that computes mean values using one tensor axis.

Mean uses one tensor axis to form groups of values and replaces each group with the mean value of that group. The resulting values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.

  • axis – Axis along which values are grouped for computing a mean.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.
trax.layers.core.Min(axis=-1, keepdims=False)

Returns a layer that applies min along one tensor axis.

  • axis – Axis along which values are grouped for computing minimum.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.
trax.layers.core.Max(axis=-1, keepdims=False)

Returns a layer that applies max along one tensor axis.

  • axis – Axis along which values are grouped for computing maximum.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.
trax.layers.core.Sum(axis=None, keepdims=False)

Returns a layer that computes sums using one tensor axis.

Sum uses one tensor axis to form groups of values and replaces each group with the sum of that group. The resulting sum values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.

  • axis – Axis along which values are grouped for computing a sum; if None, compute sum over all elements in tensor.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

Returns a layer that thresholds inputs to yield outputs in {0, 1}.


Returns a layer that calculates argmax along the given axis.


Returns a layer that computes the element-wise negation of a tensor.


Returns an identity layer with a stop gradient.

trax.layers.core.one_hot(x, n_categories, dtype=<sphinx.ext.autodoc.importer._MockObject object>)

Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).

trax.layers.core.log_softmax(x, axis=-1)

Transforms activation vectors to log-probability vectors.

Log probability vectors are derived by, in effect, applying softmax to raw activation vectors and then applying log element-wise. The actual implementation uses a mathematically valid simplification of this.

  • x – An ndarray with activation vectors along the given axis.
  • axis – Axis along which values are grouped for computing log softmax.

An ndarray containing log-probability vectors derived from the raw activation vectors in x.

trax.layers.core.log_gaussian_pdf(x, mu, sigma)

Returns log N(x | mu, sigma).

  • x – <tbd>
  • mu – <tbd>
  • sigma – <tbd>
trax.layers.core.log_gaussian_diag_pdf(x, mu, diag_sigma)

Returns log N(x | mu, eye(diag_sigma)).

  • x – <tbd>
  • mu – <tbd>
  • diag_sigma – <tbd>
trax.layers.core.multigaussian_loss(preds, targets, ngauss=1)

Returns a mixture of gaussians loss.

  • preds – <tbd>
  • targets – <tbd>
  • ngauss – <tbd>
trax.layers.core.logsoftmax_sample(log_probs, temperature=1.0)

Returns a sample from a log-softmax output, with temperature.

  • log_probs – Logarithms of probabilities (often coming from LogSoftmax)
  • temperature – For scaling before sampling (1.0 = default, 0.0 = pick argmax)


Trax initializers.


Loads parameters from .npy file.


Returns an initializer for random normal coefficients.


Returns an initializer for random uniform coefficients.

trax.layers.initializers.ScaledInitializer(out_dim, in_dim, scale, mode, distribution)

Returns an initializer that adjusts its scale based on weight shapes.

trax.layers.initializers.GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random Glorot-scaled coefficients.

trax.layers.initializers.GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random uniform Glorot-scaled coefficients.

trax.layers.initializers.LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random LeCun-scaled coefficients.

trax.layers.initializers.LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random uniform LeCun-scaled coefficients.

trax.layers.initializers.KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.0)

Returns an initializer for random Kaiming-scaled coefficients.

trax.layers.initializers.KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.0)

Returns an initializer for random uniform Kaiming-scaled coefficients.


Returns an orthogonal initializer.

trax.layers.initializers.AtariConvInit(kernel_shape, rng, dtype=<sphinx.ext.autodoc.importer._MockObject object>)

The standard init for Conv laters and Atari.


Layers for computing loss functions and evaluation metrics.

A metric layer computes a scalar value from two or three ndarray inputs:

  • model outputs: Batch of predicted values (typically vectors).
  • targets: Batch of target values (e.g., categories or vectors).
  • weights: Float values that allow for uneven weighting of batch items, sequence positions, or vector components when computing an overall scalar value for the batch.

Most metric computations take into account the items that make up a batch. For each item in a batch, a raw metric value is computed by comparing (item-wise) the model output to the target value. These item-wise values are then combined into a single scalar for the batch by a function such as sum, average, or weighted-average. For example:

  • CategoryAccuracy: Treat model output as vectors whose components correspond to the possible categories; measure a vector as correct (value 1) if its largest component is the target category, else as incorrect (value 0). The accuracy for the batch is then the average across vectors of these 1’s and 0’s.
  • CategoryCrossEntropy: Treat model output and target values as the source of two probability distributions; measure the cross-entropy of the model’s predicted distribution relative to the (assumed true) target distribution. The scalar value for the batch is then the average of the item-wise cross-entropy values.

Returns a layer that computes category prediction accuracy.

The layer takes two inputs:

  • A batch of activation vectors. The components in a given vector should be mappable to a probability distribution in the following loose sense: within a vector, a higher component value corresponds to a higher probability, such that argmax within a vector (axis=-1) picks the index (category) having the highest probablity.
  • A batch of target categories; each target is an integer in \(\{0, ..., N-1\}\).

The predicted category from each vector is the index of the highest-valued vector component. The layer returns the accuracy of these predictions averaged over the batch.


Returns a layer that computes a weighted category prediction accuracy.

The layer takes three inputs:

  • A batch of activation vectors. The components in a given vector should be mappable to a probability distribution in the following loose sense: within a vector, a higher component value corresponds to a higher probability, such that argmax within a vector (axis=-1) picks the index (category) having the highest probablity.
  • A batch of target categories; each target is an integer in \(\{0, ..., N-1\}\), where \(N\) is the activation vector depth/dimensionality.
  • A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item’s target category).

The predicted category from each vector is the index of the highest-valued vector component. The layer returns a weighted average accuracy of these predictions.


Returns a layer that computes cross-entropy from activations and integers.

The layer takes two inputs:

  • A batch of activation vectors. The components in a given vector should be pre-softmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and cross-entropy computations are combined inside the layer.
  • A batch of target categories; each target is an integer in \(\{0, ..., N-1\}\), where \(N\) is the activation vector depth/dimensionality.

To compute cross-entropy per batch item, the layer derives probability distributions:

  • from model output (vectors): \(\ q = \text{softmax}(v)\)
  • from target categories (integers): \(\ p = \text{one_hot}(n)\) or \(p = (1-\varepsilon)\cdot\text{one_hot}(n) + \frac{\varepsilon}{N}\), where \(\varepsilon\) is the label smoothing factor.

(The conversion of integer category targets to one-hot vectors amounts to assigning all the probability mass to the target category.) Cross-entropy per batch item is computed between the resulting distributions:

\[\text{cross_entropy} = - \sum_{i=0}^{N-1} p_i \log q_i\]

The layer returns the average of these cross-entropy values over all items in the batch.

Parameters:label_smoothing – Creates soft targets if provided. Must be between 0 and 1.

Returns a layer like CategoryCrossEntropy, with weights as third input.

The layer takes three inputs:

  • A batch of activation vectors. The components in a given vector should be pre-softmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and cross-entropy computations are combined inside the layer.
  • A batch of target categories; each target is an integer in \(\{0, ..., N-1\}\), where \(N\) is the activation vector depth/dimensionality.
  • A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item’s target category).

The layer returns the weighted average of these cross-entropy values over all items in the batch.

Parameters:label_smoothing – Creates soft targets if provided. Must be between 0 and 1.

Returns a layer that computes cross-entropy for binary classification.

The layer takes two inputs:

  • A batch of activation values; each batch item \(x\) is a float in \((-\infty, \infty)\).
  • A batch of binary targets; each target \(t\) is an integer in \(\{0, 1\}\).

The layer maps each activation value into the range \((0, 1)\), interpreted as the model-predicted probability that item’s category is 1:

\[q = \frac 1 {1 + e^{-x}} \ \ \text{[model-predicted probability]}\]

and computes cross-entropy (per batch item) by treating the target category as having probability 1:

\[\begin{split}\text{cross_entropy} = \left\{ \begin{array}{cl} - \log q & \text{if}\ t = 1, \\ - \log (1 - q) & \text{if}\ t = 0. \end{array} \right.\end{split}\]

The layer returns the average of these cross-entropy values over all items in the batch.


Returns a layer that computes sequence prediction accuracy with masking.

This layer type is intended for variable length sequences, especially text, represented as a batch of fixed-length sequences via padding for unused positions.

The layer takes three inputs:

  • A batch of sequences of activation vectors. The components in a given vector should be mappable to a probability distribution in the following loose sense: within a vector, a higher component value corresponds to a higher probability, such that argmax within a vector (axis=-1) picks the index having the highest probablity. In text modeling, the index represents a token id from a predetermined token vocabulary (or padding).
  • A batch of target integer sequences, with values in \(\{0, ..., N-1\}\), where \(N\) is the activation vector depth/dimensionality. In text modeling, these sequences typically represent token ids from a predetermined token vocabulary (or padding).
  • A batch of weights/masks, which matches or can be broadcast to match the shape of the target ndarray. This arg is used to give weight 0 to padding positions, which masks those positions out of the calculation. Only the zero/non-zero distinction matters; all non-zero values are treated alike as signaling non-masked (i.e., valid/in-use) positions.

The predicted integer value for each sequence position is the index of the highest-valued component of the position’s vector. A predicted integer sequence is judged correct if it matches the target integer sequence in all non-zero-weighted positions. The layer returns the accuracy of predicted sequences averaged over the batch.


Returns a layer that computes mean category prediction accuracy.

DEPRECATED; use WeightedCategoryAccuracy instead.

Parameters:classifier – Layer that transforms activation vectors into category predictions.

Returns a layer that computes mean sequence prediction accuracy.

DEPRECATED; use MaskedSequenceAccuracy instead.

Parameters:classifier – Layer that transforms activation vectors into category predictions.

Returns a layer that outputs multiclass prediction-target cross-entropy.

DEPRECATED; refactor to use WeightedCategoryCrossEntropy or CategoryCrossEntropy instead.

(CrossEntropyLoss by itself does not compute cross-entropy. In older code, this layer had to be preceded by LogSoftmax, and the two layers together did the work of converting category information to probability distributions and computing the cross-entropy between those distributions. All this is now done by WeightedCategoryCrossEntropy.)


Mean prediction-target cross-entropy for multiclass classification.


Returns a layer that outputs binary prediction-target cross-entropy.

DEPRECATED; refactor to use BinaryCrossEntropy instead. (The newer BinaryCrossEntropy does not use weights, so refactor accordingly. Unless and until clear motivating use cases arise, the library will not include a binary cross-entropy function with weights.)


Returns a layer that computes an L2-like loss for one batch.

The layer takes three inputs:

  • Model output from one batch, an ndarray of float-valued elements.
  • A batch of element-wise target values, which matches the shape of the model output.
  • A batch of weights, which matches the shape of the model output.

The layer returns a weighted average of element-wise squared error terms \((y_i - t_i)^2\).


Returns a layer that computes a weighted, smoothed L1 loss for one batch.

The layer takes three inputs:

  • Model output from one batch, an ndarray of float-valued elements.
  • A batch of element-wise target values, which matches the shape of the model output.
  • A batch of weights, which matches the shape of the model output.

The layer computes a “smooth” L1 loss (a.k.a. Huber loss), for model output float \(y_i\) and target float \(t_i\):

\[\begin{split}\text{output} = \left\{ \begin{array}{cl} \frac 1 2 (y_i - t_i)^2, & \text{if}\ |y_i - t_i| < 1, \\ |y_i - t_i| - \frac 1 2, & \text{otherwise}. \end{array} \right.\end{split}\]

The layer returns a weighted average of these element-wise values.

trax.layers.metrics.MacroAveragedFScore(beta=1.0, initial_category_index=0)

Returns a layer that computes a macro-averaged F-score.

The macro-averaged F-score summarize how well the classifier’s k predictions align with the observed/gold instances of k. It additionally cares about all the classes equally regardless of their size.

  • beta – a parameter that determines the weight of recall in the F-score.
  • initial_category_index – an index of the initial category.

The layer takes two inputs:

  • Model output from one batch, an ndarray of float-valued elements.
  • A batch of element-wise target values, which matches the shape of the model output.

The layer returns an macro-averaged F-score across all the classes.

trax.layers.metrics.WeightedFScore(beta=1.0, initial_category_index=0)

Returns a layer that computes a weighted F-score.

The weighted F-score summarize how well the classifier’s k predictions align with the observed/gold instances of k. It additionally weights the summary by the number of observed/gold and predicted examples in each class.

  • beta – a parameter that determines the weight of recall in the F-score.
  • initial_category_index – an index of the initial category.

The layer takes two inputs:

  • Model output from one batch, an ndarray of float-valued elements.
  • A batch of element-wise target values, which matches the shape of the model output.

The layer returns a weighted F-score across all the classes.


Returns a layer that computes a weighted sum of the given values.


Sum of prediction-target cross entropies for multiclass classification.


Sum of prediction-target cross entropies for binary classification.


Trax normalization layers.

class trax.layers.normalization.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, momentum=0.999, mode='train')

Bases: trax.layers.base.Layer

Layer that performs batch normalization.

In training, batch normalization keeps smoothed cumulative statistics across batches of input data and modifies each new batch so that its components are normally distributed. In eval or inference, a BatchNorm instance uses its stored mean and variance to approximately normalize each new batch of data.

See for original presentation and motivation of batch normalization).

__init__(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, momentum=0.999, mode='train')

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes batch normalization as part of a forward pass in the model.


Helper to initialize batch norm weights and state.

class trax.layers.normalization.LayerNorm(center=True, epsilon=1e-06)

Bases: trax.layers.base.Layer

Layer normalization.

__init__(center=True, epsilon=1e-06)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
class trax.layers.normalization.FilterResponseNorm(mode=None, learn_epsilon=False, init_epsilon=1e-06, init_learnt_epsilon=0.0001)

Bases: trax.layers.base.Layer

Filter Response Normalization layer without Threshold Linear Unit.


__init__(mode=None, learn_epsilon=False, init_epsilon=1e-06, init_learnt_epsilon=0.0001)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


Trax pooling layers.

trax.layers.pooling.MaxPool(pool_size=(2, 2), strides=None, padding='VALID')

Reduces each multi-dimensional window to the max of the window’s values.

Windows, as specified by pool_size and strides, involve all axes of an n-dimensional array except the first and last: \((d_1, ..., d_{n-2})\) from shape \((d_0, d_1, ..., d_{n-2}, d_{n-1})\).

  • pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)-dimensional arrays, then pool_size must be a tuple of length \(n-2\).
  • strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
  • padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the selection of max values.

N-dimensional array in which each valid (or padded-valid) window position in the input is reduced to / replaced by the max value from that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.SumPool(pool_size=(2, 2), strides=None, padding='VALID')

Reduces each multi-dimensional window to the sum of the window’s values.

Windows, as specified by pool_size and strides, involve all axes of an n-dimensional array except the first and last: \((d_1, ..., d_{n-2})\) from shape \((d_0, d_1, ..., d_{n-2}, d_{n-1})\).

  • pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)-dimensional arrays, then pool_size must be a tuple of length \(n-2\).
  • strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
  • padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the computation of sums.

N-dimensional array in which each valid (or padded-valid) window position in the input is reduced to / replaced by the sum of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.AvgPool(pool_size=(2, 2), strides=None, padding='VALID')

Reduces each multi-dimensional window to the mean of the window’s values.

Windows, as specified by pool_size and strides, involve all axes of an n-dimensional array except the first and last: \((d_1, ..., d_{n-2})\) from shape \((d_0, d_1, ..., d_{n-2}, d_{n-1})\).

  • pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)-dimensional arrays, then pool_size must be a tuple of length \(n-2\).
  • strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
  • padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed but is not counted in the computation of averages.

N-dimensional array in which each valid (or padded-valid) window position in the input is reduced to / replaced by the mean of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.


Layers that can run in reverse to compute inputs from outputs.

Reversible layers reduce the memory required for backpropagation-based training, especially for deep networks. In a series of reversible layers, input activations from a forward pass don’t need to be stored: they can be reconstructed on the backward pass, layer by layer, from outputs to inputs.

See, e.g., [The Reversible Residual Network: Backpropagation Without Storing Activations]( and [Reformer: The Efficient Transformer](

class trax.layers.reversible.ReversibleLayer(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Bases: trax.layers.base.Layer

Reversible Layer.

reverse(output, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

reverse_and_grad(output, grad, weights=(), state=(), new_state=(), rng=None)

Backward pass: computes the inverse of a layer and propagates gradients.

While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.

  • output – Output activations; can be a (possibly nested) tuple.
  • grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
  • weights – layer weights
  • state – start state
  • new_state – updated state computed by the forward pass
  • rng – Single-use random number generator (JAX PRNG key).

A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.


Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng)

Custom backward pass to propagate gradients in a custom way.

  • inputs – Input tensors; can be a (possibly nested) tuple.
  • output – The result of running this layer on inputs.
  • grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
  • weights – This layer’s weights.
  • state – This layer’s state prior to the current forward pass.
  • new_state – This layer’s state after the current forward pass.
  • rng – Single-use random number generator (JAX PRNG key).

The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.

class trax.layers.reversible.ReversibleConcatenatePair

Bases: trax.layers.reversible.ReversibleLayer

Maps (x, y) -> ([x, y], [x, y]); [x, y] is concatenation on last axis.


Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
reverse(outputs, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

class trax.layers.reversible.ReversibleSelect(indices, n_in=None, name=None)

Bases: trax.layers.reversible.ReversibleLayer

Reversible version of the Select combinator.

__init__(indices, n_in=None, name=None)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
reverse(outputs, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

class trax.layers.reversible.ReversibleReshape(shape1, shape2, n_in=1)

Bases: trax.layers.reversible.ReversibleLayer

Reversible reshaping layer.

__init__(shape1, shape2, n_in=1)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
reverse(outputs, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

class trax.layers.reversible.ReversiblePrintShape(n_in=1, msg='')

Bases: trax.layers.reversible.ReversibleLayer

Reversible PrintShape for debugging reversible serial layers.

__init__(n_in=1, msg='')

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
reverse(outputs, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

class trax.layers.reversible.ReversibleSerial(*layers)

Bases: trax.layers.reversible.ReversibleLayer, trax.layers.combinators.Serial

A reversible version of tl.Serial (requires reversible sub-layers).


Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.
reverse(output, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

reverse_and_grad(output, grad, weights=(), state=(), new_state=(), rng=None)

Backward pass: computes the inverse of a layer and propagates gradients.

While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.

  • output – Output activations; can be a (possibly nested) tuple.
  • grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
  • weights – layer weights
  • state – start state
  • new_state – updated state computed by the forward pass
  • rng – Single-use random number generator (JAX PRNG key).

A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

class trax.layers.reversible.ReversibleHalfResidual(*residual_layers, attention_layer=None, name=None)

Bases: trax.layers.reversible.ReversibleLayer

Half of a RevNet-style residual that optionally performs attention.

When attention_layer is None, this layer has the signature

[accumulator, *context] -> [accumulator + f(context), *context]

The attention_layer must be an instance of EfficientAttentionBase or one of its subclasses (see, or None.

Attention is special-cased for the following two reasons:

  • LSH attention needs to save bucket assignments from the forward pass to the backward pass, for training stability. This requires special-casing it.
  • We can call attention_layer.forward_and_or_backward to compute its output (needed for inverting a reversible residual layer) while simultaneously performing the backward pass. Sharing computation between these two operations improves training speed.
__init__(*residual_layers, attention_layer=None, name=None)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
reverse(output, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

reverse_and_grad(output, ct, weights=(), state=(), new_state=(), rng=None)

Backward pass: computes the inverse of a layer and propagates gradients.

While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.

  • output – Output activations; can be a (possibly nested) tuple.
  • grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
  • weights – layer weights
  • state – start state
  • new_state – updated state computed by the forward pass
  • rng – Single-use random number generator (JAX PRNG key).

A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.


Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


Implementations of common recurrent neural network cells (RNNs).

class trax.layers.rnn.LSTMCell(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

LSTM Cell.

For a nice overview of the motivation and (i, o, f) gates, see this tutorial:

See this paper for a description and detailed study of all gate types:

__init__(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

Makes zeros of shape like x but removing the length (axis 1).

trax.layers.rnn.LSTM(n_units, mode='train')

LSTM running on axis 1.

class trax.layers.rnn.GRUCell(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Builds a traditional GRU cell with dense internal transformations.

Gated Recurrent Unit paper:

__init__(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
trax.layers.rnn.GRU(n_units, mode='train')

GRU running on axis 1.

trax.layers.rnn.ConvGRUCell(n_units, kernel_size=(3, 3))

Builds a convolutional GRU.


  • n_units – Number of hidden units
  • kernel_size – Kernel size for convolution

A Stax model representing a GRU cell with convolution transforms.

trax.layers.rnn.GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=<function Sigmoid>, candidate_nonlinearity=<function Tanh>, dropout_rate_c=0.1, sigmoid_bias=0.5)

Parametrized Gated Recurrent Unit (GRU) cell construction.

GRU update equations for update gate, reset gate, candidate memory, and new state:

\[\begin{split}u_t &= \sigma(U' \times s_{t-1} + B') \\ r_t &= \sigma(U'' \times s_{t-1} + B'') \\ c_t &= \tanh(U \times (r_t \odot s_{t-1}) + B) \\ s_t &= u_t \odot s_{t-1} + (1 - u_t) \odot c_t\end{split}\]

See combinators.Gate for details on the gating function.

  • candidate_transform – Transform to apply inside the Candidate branch. Applied before nonlinearities.
  • memory_transform_fn – Optional transformation on the memory before gating.
  • gate_nonlinearity – Function to use as gate activation; allows trying alternatives to Sigmoid, such as HardSigmoid.
  • candidate_nonlinearity – Nonlinearity to apply after candidate branch; allows trying alternatives to traditional Tanh, such as HardTanh.
  • dropout_rate_c – Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch.
  • sigmoid_bias – Constant to add before sigmoid gates. Generally want to start off with a positive bias.

A model representing a GRU cell with specified transforms.


The inner (non-parallel) computation of an SRU.

trax.layers.rnn.SRU(n_units, activation=None, mode='train')

SRU (Simple Recurrent Unit) layer as in

As defined in the paper:

\[\begin{split}y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t\end{split}\]

We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It’s best to use at least 2, they say in the paper, except inside a Transformer.

  • n_units – output depth of the SRU layer.
  • activation – Optional activation function.
  • mode – if ‘predict’ then we save the previous state for one-by-one inference

The SRU layer.


Attention Layers optimized for efficiency (second-pass implementation).

The approach taken in the first round of efficient attention implementations revealed several limitations, which this code attempts to address:

  1. Simultaneously instantiating queries, keys, and values for all heads can exceed the memory budget. Transformers are typically tuned such that n_heads * d_attention_key == d_model. Since attention involves queries, keys, AND values, the memory to store them can be ~3x the memory needed to store the input activations. Once the O(n^2) dot-product bottleneck is removed – as is the case in all of our efficient attention implementations – this becomes the next critical bottleneck for scaling up Transformer models.
  2. Attention masking is implemented by associating an integer (typically, the sequence position) with each query and key vector, and defining a function to compute attention masks from this information. The standard attention API ( is unscalable because it instantiates O(n^2)-size attention masks, and the previous efficient implementations ( only supported causal masking.
trax.layers.research.efficient_attention.length_normalized(x, epsilon=1e-06)
trax.layers.research.efficient_attention.hash_vecs(vecs, n_buckets_in, n_hashes, rng)

Hash vectors into buckets.

  • vecs – vectors to hash, a tensor of shape [batch_size, depth]
  • n_buckets_in – an int or a list of ints, number of hash buckets; if it is a list, we do hierarchical hashing as specified by the list
  • n_hashes – number of hashes
  • rng – random generator to use for hashing

A pair (buckets, n_buckets) where buckets is a tensor of shape [n_hashes, batch_size] of integers – the hash bucket IDs, and n_buckets is an int, the total number of hash buckets, equal to the product of all items in n_buckets_in.

trax.layers.research.efficient_attention.look_adjacent(x, n_chunks_before, n_chunks_after)

Used to implement attention between consecutive chunks.

  • x – array of shape [n_chunks, chunk_len, …]
  • n_chunks_before – Number of previous chunks to attend to.
  • n_chunks_after – Number of subsequent chunks to attend to.

array of shape [n_chunks, N * chunk_len, …], where N = (1 + n_chunks_before + n_chunks_after).

trax.layers.research.efficient_attention.mask_self_attention(dots, q_info, kv_info, causal=True, exclude_self=True, masked=False)

Performs masking for self-attention.

trax.layers.research.efficient_attention.attend(q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None)

Dot-product attention, with optional chunking and/or masking.

  • q – Query vectors, shape [q_len, d_qk]
  • k – Key vectors, shape [kv_len, d_qk]; or None
  • v – Value vectors, shape [kv_len, d_v]
  • q_chunk_len – Set to non-zero to enable chunking for query vectors
  • kv_chunk_len – Set to non-zero to enable chunking for key/value vectors
  • n_chunks_before – Number of adjacent previous chunks to attend to
  • n_chunks_after – Number of adjacent subsequent chunks to attend to
  • mask_fn – TODO(kitaev) doc
  • q_info – Query-associated metadata for masking
  • kv_info – Key-associated metadata for masking
  • dropout – Dropout rate
  • rng – RNG for dropout

A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention).

trax.layers.research.efficient_attention.apply_broadcasted_dropout(vecs, dropout_rate, rng)

Apply dropout, broadcasted across all but the last dimension of vecs.

trax.layers.research.efficient_attention.permute_via_gather(val, permutation, inverse_permutation, axis=0)

Permutation helper for LSH attention.

trax.layers.research.efficient_attention.permute_via_sort(val, keys, inverse_keys, axis=0)

Permutation helper for LSH attention.

class trax.layers.research.efficient_attention.EfficientAttentionBase(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.base.Layer

Base class for efficient attention.

This is a base class that implements memory-efficient batching for both the forward and backward passes. Subclasses should override create_weights_unbatched, create_state_unbatched, forward_unbatched, and optionally incremental_forward_unbatched to define the actual attention mechanism.

__init__(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)

Constructs an EfficientAttentionBase instance.

  • n_heads – Number of attention heads.
  • n_in – Number of inputs to the layer (default 1).
  • n_parallel_heads

    Number of attention heads to compute in parallel.

    • If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
    • If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
    • If n_parallel_heads is a multiple of n_heads, attention is computed for sub-batches of (n_parallel_heads // n_heads) examples at a time.
    • If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
  • incremental – If True, enable fast inference for self-attention types. Note that this flag should not be set when doing encoder-decoder attention, but only when doing self-attention.
  • predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
  • predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
  • use_python_loop – Set to True to use a Python loop when iterating over sub-batches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
  • use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
create_weights_unbatched(input_signature, rng)
create_state_unbatched(input_signature, rng)
forward_unbatched(*inputs, weights, state)

Perform attention for a single batch element and head.

Subclasses should override this method.

  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.

A tuple (output, new_state) – output and new state for a single example and attention head.


Computes this layer’s output as part of a forward pass through the model.

Parameters:inputs – Layer inputs (subclasses may use different inputs)
Returns:A tuple (output, new_state).

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)

Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)

Performs batched forward and/or backward passes.

See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.

  • inputs – inputs to the attention layer
  • weights – weights for the attention layer
  • state – state of the attention layer
  • rng – PRNG key for the layer (shared across all examples and heads)
  • output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
  • compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
  • update_state – bool: whether to return an updated layer state.

A tuple (output, new_state, inputs_grad, weights_grad).

  • output is not None iff compute_output is True
  • new_state is not None iff update_state is True
  • inputs_grad & weights_grad are not None iff output_grad is not None

class trax.layers.research.efficient_attention.SelfAttention(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.base.Layer

Memory-efficient self-attention (second attempt).

__init__(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Construct a self-attention layer.

  • n_heads – int: Number of attention heads
  • d_qk – int: Depth of query ond key vectors
  • d_v – int: Depth of value vectors
  • share_qk – bool: Set to True to share query and key projection weights
  • causal – bool: Set to True to mask out attention to future items
  • masked – bool: Set to True to accept an additional mask argument, that allows masking out attention to padding tokens.
  • chunk_len (optional) – Number of tokens per chunk. Setting this option will enable chunked attention.
  • n_chunks_before – Number of previous chunks to attend to, when using chunked attention.
  • n_chunks_after – Number of subsequent chunks to attend to, when using chunked attention. Don’t use this option for causal attention, because attention to future tokens will be masked out anyway. However, note that cross-chunk attention “wraps around” in both directions, so this option is never a strict no-op.
  • bias – bool: Set to True to add bias vectors when computing query/key/value
  • mode – ‘train’, ‘eval’, or ‘predict’
  • predict_mem_len – int: Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten. When chunking is enabled, the default is to store chunk_len * (1 + n_chunks_before) elements.
  • predict_drop_len – int: Number of input elements to drop once the fast inference input cache fills up. When chunking is enabled, the default is to drop exactly chunk_len elements.
  • attention_dropout – Dropout probability for attention mask.
  • output_dropout – Dropout probability for the layer output.
  • n_parallel_heads

    Number of attention heads to compute in parallel.

    • If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
    • If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
    • If n_parallel_heads is a multiple of n_heads, attention is computed for sub-batches of (n_parallel_heads // n_heads) examples at a time.
    • If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
  • use_python_loop – Set to True to use a Python loop when iterating over sub-batches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
  • use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
create_weights_unbatched(input_signature, rng)
create_state_unbatched(input_signature, rng)
forward_unbatched(x, mask=None, *, weights, state, rng, update_state)

Perform attention for a single batch element and head.

  • x – Inputs for a single example (subclasses may use different inputs)
  • mask – Mask for the inputs.
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
  • rng – PRNG key for the layer (shared across all examples and heads)
  • update_state – bool: whether to return an updated layer state.

A tuple (output, new_state) – output and new state for a single example and attention head.


Computes this layer’s output as part of a forward pass through the model.

Parameters:inputs – Layer inputs (subclasses may use different inputs)
Returns:A tuple (output, new_state).

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)

Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)

Performs batched forward and/or backward passes.

See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.

  • inputs – inputs to the attention layer
  • weights – weights for the attention layer
  • state – state of the attention layer
  • rng – PRNG key for the layer (shared across all examples and heads)
  • output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
  • compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
  • update_state – bool: whether to return an updated layer state.

A tuple (output, new_state, inputs_grad, weights_grad).

  • output is not None iff compute_output is True
  • new_state is not None iff update_state is True
  • inputs_grad & weights_grad are not None iff output_grad is not None

class trax.layers.research.efficient_attention.LSHSelfAttention(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.base.Layer

LSH self-attention (second implementation).

__init__(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)

Construct an LSH self-attention layer.


Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
create_weights_unbatched(input_signature, rng)
create_state_unbatched(input_signature, rng)
hash_vectors(vecs, rng, mask=None)
forward_unbatched(x, mask=None, *, weights, state, rng, update_state)

Computes this layer’s output as part of a forward pass through the model.

Parameters:inputs – Layer inputs (subclasses may use different inputs)
Returns:A tuple (output, new_state).

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)

Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)

Performs batched forward and/or backward passes.

See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.

  • inputs – inputs to the attention layer
  • weights – weights for the attention layer
  • state – state of the attention layer
  • rng – PRNG key for the layer (shared across all examples and heads)
  • output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
  • compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
  • update_state – bool: whether to return an updated layer state.

A tuple (output, new_state, inputs_grad, weights_grad).

  • output is not None iff compute_output is True
  • new_state is not None iff update_state is True
  • inputs_grad & weights_grad are not None iff output_grad is not None

class trax.layers.research.efficient_attention.PureLSHSelfAttention(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.base.Layer

LSH self-attention without weights.

__init__(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)

Construct an LSH self-attention layer.


Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
create_state_unbatched(input_signature, rng)
hash_vectors(vecs, rng, mask=None)
forward_unbatched(qk, v, mask=None, *, state, rng, update_state)

Computes this layer’s output as part of a forward pass through the model.

Parameters:inputs – Layer inputs (subclasses may use different inputs)
Returns:A tuple (output, new_state).

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)

Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward(inputs, state, rng, output_grad=None, compute_output=True, update_state=True)

Performs batched forward and/or backward passes.

See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.

  • inputs – inputs to the attention layer tuple (qk, v, mask)
  • state – state of the attention layer
  • rng – PRNG key for the layer (shared across all examples and heads)
  • output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
  • compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
  • update_state – bool: whether to return an updated layer state.

A tuple (output, new_state, inputs_grad, weights_grad).

  • output is not None iff compute_output is True
  • new_state is not None iff update_state is True
  • inputs_grad & weights_grad are not None iff output_grad is not None

class trax.layers.research.efficient_attention.MixedLSHSelfAttention(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, std_length=None, mode='train', output_dropout=0.0, attention_dropout=0.0, force_no_dropout=False, **pure_lsh_implementation_kwargs)

Bases: trax.layers.base.Layer

LSH attention mixed with standard attention used until std_length.

__init__(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, std_length=None, mode='train', output_dropout=0.0, attention_dropout=0.0, force_no_dropout=False, **pure_lsh_implementation_kwargs)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Initializes weights and state for inputs with the given signature.


Executes this layer as part of a forward pass through the model.

forward_and_or_backward(inputs, state, rng, output_grad=None, compute_output=True, update_state=True)

Performs batched forward and/or backward passes.

class trax.layers.research.efficient_attention.PureLSHSelfAttentionWrapper(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, output_dropout=0.0, attention_dropout=0.0, pure_lsh_implementation=None, bias=True, mode='train', num_weights=3, sparsity=16, weights_format='model', **pure_lsh_implementation_kwargs)

Bases: trax.layers.combinators.Serial

Pure LSH serial.

__init__(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, output_dropout=0.0, attention_dropout=0.0, pure_lsh_implementation=None, bias=True, mode='train', num_weights=3, sparsity=16, weights_format='model', **pure_lsh_implementation_kwargs)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.
forward_and_or_backward(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)

Performs batched forward and/or backward passes.

  • inputs – inputs to the attention layer
  • weights – weights for the attention layer
  • state – state of the attention layer
  • rng – PRNG key for the layer (shared across all examples and heads)
  • output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
  • compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
  • update_state – bool: whether to return an updated layer state.

A tuple (output, new_state, inputs_grad, weights_grad). - output is not None iff compute_output is True - new_state is not None iff update_state is True - inputs_grad & weights_grad are not None iff output_grad is not None

class trax.layers.research.efficient_attention.EncDecAttention(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.research.efficient_attention.EfficientAttentionBase

Memory-efficient encoder-decoder attention.

__init__(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Constructs an EfficientAttentionBase instance.

  • n_heads – Number of attention heads.
  • n_in – Number of inputs to the layer (default 1).
  • n_parallel_heads

    Number of attention heads to compute in parallel.

    • If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
    • If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
    • If n_parallel_heads is a multiple of n_heads, attention is computed for sub-batches of (n_parallel_heads // n_heads) examples at a time.
    • If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
  • incremental – If True, enable fast inference for self-attention types. Note that this flag should not be set when doing encoder-decoder attention, but only when doing self-attention.
  • predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
  • predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
  • use_python_loop – Set to True to use a Python loop when iterating over sub-batches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
  • use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.
create_weights_unbatched(input_signature, rng)
forward_unbatched(q_antecedent, kv_antecedent, mask=None, *, weights, state, rng, update_state)

Perform attention for a single batch element and head.

Subclasses should override this method.

  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.

A tuple (output, new_state) – output and new state for a single example and attention head.

class trax.layers.research.efficient_attention.LSHFF(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Feed-forward block with LSH.

The original (non-LSH) feed-forward block is a triple Dense(d_ff)-Relu-Dense that takes an input, makes it of size d_ff (usually larger than it was) and then brings it back to the original size after Relu. It is commonly used in Transformer models where it often accounts for most of the trainable weights.

The original block can be slow in decoding due to the need to fetch a lot of weights from memory. The LSH block aims to exploit this sparsity. So in the first Dense(d_ff) layer, instead of making a full matrix multiplication, this block only multiplies by the parts of the weights matrix that have the highest chance to give non-0 after Relu. This is determined by taking a number of locality-sensitive hashes and masking to only include weights that have one hash identical to the multiplied element.

__init__(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Returns a LSH feed-forward block.


Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor of same shape and dtype as the input.

Randomly initializes this layer’s weights.


Experimenting with position encodings.

class trax.layers.research.position_encodings.AxialPositionalEncoding(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')

Bases: trax.layers.base.Layer

Axial positional encoding.

__init__(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
class trax.layers.research.position_encodings.SinCosPositionalEncoding(add_offset=2048, dropout=0.0, dropout_broadcast_dims=(-2, ), start_from_zero_one_in=2, mode='train')

Bases: trax.layers.base.Layer

Implements the sin-cos positional encoding.

__init__(add_offset=2048, dropout=0.0, dropout_broadcast_dims=(-2, ), start_from_zero_one_in=2, mode='train')

Creates a SinCosPositionalEncoding instance.

  • add_offset – Maximumnumber to add to positions during training.
  • dropout – Probability of not adding positional encoding to a sequence position.
  • dropout_broadcast_dims – Axes along which dropout mask values are broadcast rather than individually set at random.
  • start_from_zero_one_in – how often to start from 0 during training, every one in that many times (e.g., if 4, then it’s 25% of the time).
  • mode – One of ‘train’, ‘eval’, or ‘predict’.

Returns the input activations, with added positional information.


Randomly initializes the positional encoding vectors.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on.
class trax.layers.research.position_encodings.FixedBasePositionalEncoding(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=2, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Implements fixed-base positional encoding.

__init__(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=2, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
trax.layers.research.position_encodings.threefry_2x32_prf(key, x: <sphinx.ext.autodoc.importer._MockObject object at 0x7febabe63e10>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7febabe63cd0>

Apply the threefry PRF to an array of inputs.

This function is vectorized over x. For threefry_2x32: K = X = uint32[2]

  • key – uint32[2] the key of the PRF
  • x – uint32[…, 2] the inputs

uint32[…, 2] the outputs

Return type:


trax.layers.research.position_encodings.threefry_2x32_prange(key, lo: int = 0, hi: int = 2)

Splits a key into a stream of random keys.

This uses the little-endian counter mode.

  • key – uint32[2] the key to split
  • lo – the range to start extracting from
  • hi – the range to stop extracting from

uint32[hi - lo, 2] the split keys

Return type:


class trax.layers.research.position_encodings.InfinitePositionalEncoding(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')

Bases: trax.layers.base.Layer

Infinite positional encoding.

__init__(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')

Initializes the encoding.

The encoding tries to roughly evenly traverse the latent space. The recurrence time is dependent on how many bits per dimension you use.

There are two parameters to control randomization: - randomizing the origin every 1/drift steps by letting it drift - randomizing the origin per call

  • drift – variance in position difference per unit of difference
  • affine – whether to randomize the origin every call
  • transform – learnable transform after encoding (any/diag/none)
  • time_bin_length – Add features AxialPositionalEncoding learns if TimeBinCausalAttention is the first layer. bin_length should match TBCA.bin_length If you set transform=’diag’, this flag increases your model capacity to close to transform=’any’, though it will still train slower.
  • mode – if ‘predict’, allow evaluating one token at a time

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
class trax.layers.research.position_encodings.TimeBinPositionalEncoding(time_bin_length, mode='train')

Bases: trax.layers.base.Layer

Just the engineered features from InfinitePositionalEncoding.

num_features = 3
__init__(time_bin_length, mode='train')

Initializes the encoding.

  • time_bin_length – TimeBinCausalAttention.bin_length of the first layer.
  • mode – if ‘predict’, allow evaluating one token at a time

Computes this layer’s output as part of a forward pass through the model.

A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.

Initializes weights and state, to handle input with the given signature.

A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.