trax.layers

activation_fns

Layers that compute activation functions.

An activation layer computes element-wise a nonlinear function of the preceding layer’s output. Historically, an activation function was considered part of each node in each layer of the neural network. Trax follows the common current practice of separating the activation function as its own layer, which enables easier experimentation across different activation functions.

trax.layers.activation_fns.Relu()

Returns a layer that computes the Rectified Linear Unit (ReLU) function.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]
trax.layers.activation_fns.ParametricRelu(a=1.0)

Returns a layer that computes a ReLU function with the given slope.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ ax & \text{otherwise}. \end{array} \right.\end{split}\]
Parameters:a – Slope of line for positive inputs.
trax.layers.activation_fns.LeakyRelu(a=0.01)

Returns a ReLU-like layer with linear nonzero outputs for negative inputs.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} ax & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]
Parameters:a – Slope of line for negative inputs.
trax.layers.activation_fns.Elu(a=1.0)

Returns a ReLU-like layer with exponential outputs for negative inputs.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} a \cdot (e^x - 1) & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]

(Asymptotically, \(f(x)\rightarrow -a\) as \(x\rightarrow - \infty\).)

Parameters:a – Coefficient multiplying the exponential, for negative inputs.
trax.layers.activation_fns.Selu(alpha=1.6732632423543772, lmbda=1.0507009873554805)

Returns an Elu-like layer with an additional scaling/slope parameter.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} \lambda \cdot \alpha \cdot (e^x - 1) & \text{if}\ x \leq 0, \\ \lambda \cdot x & \text{otherwise}. \end{array} \right.\end{split}\]
Parameters:
  • alpha – Coefficient multiplying the exponential, for negative inputs.
  • lmbda – Coefficient scaling the whole function.
trax.layers.activation_fns.Gelu()

Returns a layer that computes the Gaussian Error Linear Unit function.

\[f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}}))\]
trax.layers.activation_fns.FastGelu()

Returns a layer that computes a fast approximation to Gelu.

\[f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3))\]

where \(a = 0.7978845608\) and \(b = 0.044715\).

trax.layers.activation_fns.Sigmoid()

Returns a layer that computes the sigmoid function.

\[f(x) = \frac{1}{1 + e^{-x}}\]
trax.layers.activation_fns.Tanh()

Returns a layer that computes the hyperbolic tangent function.

\[f(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]
trax.layers.activation_fns.HardSigmoid()

Returns a layer that computes a linear approximation to Sigmoid.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{if}\ 0 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]
trax.layers.activation_fns.HardTanh()

Returns a layer that computes a linear approximation to Tanh.

\[\begin{split}f(x) = \left\{ \begin{array}{cl} -1 & \text{if}\ x \leq 0, \\ x & \text{if}\ -1 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]
trax.layers.activation_fns.Softplus()

Returns a layer that computes the softplus function.

\[f(x) = \ln(e^x + 1)\]
class trax.layers.activation_fns.ThresholdedLinearUnit(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Bases: trax.layers.base.Layer

Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf .

init_weights_and_state(input_signature)

Initializes this layer’s single weight to zero.

forward(inputs)

Executes this layer as part of a forward pass through the model.

Parameters:inputs – Tensor.
Returns:Tensor of same shape and dtype as the input.

attention

Attention-related layers.

Attention is a powerful extension of basic neural network ideas. In a classic neural network:

  • node activations are floating point values (one float per node), and
  • inter-node connections are trainable weights (one float per connection).

Attention assembles networks of vectors and uses vector calculations to derive connection strength; in other words:

  • node activations are floating point vectors, and
  • inter-node connections come from trainable vector computations.

Attention thus involves extra concepts/mechanisms – queries, keys, values, masks, attention heads – that factor heavily into this module’s API. See specific classes and functions for details.

NOTE: Attention layers in this module include mode-dependent behavior. The possible modes are:

  • ‘train’: in training – dropouts and position shifts active
  • ‘eval’: in evals – dropouts inactive, position shifts active
  • ‘predict’: in prediction – dropouts and position shifts inactive
trax.layers.attention.Attention(d_feature, n_heads=1, dropout=0.0, mode='train')

Returns a layer that maps (activations, mask) to (new_activations, mask).

This layer type represents one pass of multi-head self-attention, best known for its central role in Transformer models. Internally, it:

  • maps incoming sequence of activations to sequence of (query, key, value) triples,
  • splits queries, keys, and values into multiple ‘heads’,
  • computes per-head attention weights from per-head (queries, keys),
  • applies mask to screen out positions that come from padding tokens,
  • [in ‘train’ mode] applies dropout to attention weights,
  • uses attention weights to combine per-head values vectors, and
  • fuses per-head results into outgoing activations matching original input activation shapes.
Parameters:
  • d_feature – Depth/dimensionality of feature embedding.
  • n_heads – Number of attention heads.
  • dropout – Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
trax.layers.attention.AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train')

Returns a layer that maps (q, k, v, mask) to (activations, mask).

See Attention above for further context/details.

Parameters:
  • d_feature – Depth/dimensionality of feature embedding.
  • n_heads – Number of attention heads.
  • dropout – Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
class trax.layers.attention.PureAttention(n_heads=1, dropout=0.0, mode='train')

Bases: trax.layers.base.Layer

Returns a layer that maps (q, k, v, mask) to (activations, mask).

This layer type performs the inner workings of one pass of multi-head self-attention. It:

  • splits queries, keys, and values into multiple ‘heads’,
  • computes per-head attention weights from per-head (queries, keys),
  • applies mask to screen out positions that come from padding tokens,
  • [in ‘train’ mode] applies dropout to attention weights,
  • uses attention weights to combine per-head values vectors, and
  • merges per-head results into outgoing activations matching original input activation vector shapes.
__init__(n_heads=1, dropout=0.0, mode='train')

Returns a new PureAttention instance.

Parameters:
  • n_heads – Number of attention heads.
  • dropout – Probababilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
forward(inputs)

Returns attention-computed activations and unmodified mask.

Parameters:inputs – A (queries, keys, values, mask) tuple.
trax.layers.attention.DotProductAttention(queries, keys, values, mask, dropout, mode, rng)

Computes new activations via masked attention-weighted sum of values.

This function is the core of the attention mechanism. It:
  • computes per-head attention weights from per-head queries and keys,
  • applies mask to screen out positions that come from padding tokens,
  • optionally applies dropout to attention weights, and
  • uses attention weights to combine per-head values vectors.
Parameters:
  • queries – Per-head activations representing attention queries.
  • keys – Per-head activations representing attention keys.
  • values – Per-head activations to be combined by computed attention weights.
  • mask – Mask that distinguishes positions with real content vs. padding.
  • dropout – Probababilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
  • rng – Single-use random number generator (JAX PRNG key).
Returns:

Per-head activations resulting from masked per-head attention-weighted sum of per-head values.

trax.layers.attention.CausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train')

Returns a layer that maps activations to activations, with causal masking.

Like Attention, this layer type represents one pass of multi-head self-attention, but with causal masking rather than padding-based masking.

Parameters:
  • d_feature – Depth/dimensionality of feature embedding.
  • n_heads – Number of attention heads.
  • dropout – Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values.
  • max_inference_length – maximum length for inference.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
class trax.layers.attention.DotProductCausalAttention(dropout=0.0, max_inference_length=2048, mode='train')

Bases: trax.layers.base.Layer

Layer that computes attention strengths by masking out the “future”.

Causal attention uses masking to prevent a given sequence position from attending to positions greater than / following it. This is used, for example, when training autoregressive sequence models, or when decoding a sequence symbol by symbol.

This layer performs the core per-head attention calculation. The layer assumes that any splitting into attention heads precedes it, and that any merging of attention heads will follow it.

__init__(dropout=0.0, max_inference_length=2048, mode='train')

Creates a DotProductCausalAttention instance.

Parameters:
  • dropout – Probababilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values.
  • max_inference_length – maximum length of sequences during inference.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
forward(inputs)

Returns attention-computed activations.

Parameters:inputs – A (queries, keys, values) tuple.
init_weights_and_state(input_signature)

Initializes this layer for fast inference, if in ‘predict’ mode.

trax.layers.attention.ShiftRight(n_positions=1, mode='train')

Returns a layer that can insert padding to shift the input sequence.

Parameters:
  • n_positions – Number of positions to shift the input sequence rightward; initial positions freed by the shift get padded with zeros.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
trax.layers.attention.PaddingMask(pad=0)

Returns a layer that maps integer sequences to padding masks.

The layer expects as input a batch of integer sequences. The layer output is a tensor that marks for each sequence position whether the integer (e.g., a token ID) in that position represents padding – value pad – versus text/content – all other values. The padding mask shape is (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast to cover any number of attention heads and axis 2 will broadcast to cover decoder sequence positions.

Parameters:pad – Integer that represents padding rather than a token/content ID.
trax.layers.attention.EncoderDecoderMask()

Returns a layer that creates a mask for encoder-decoder cross attention.

The layer expects two inputs:

  • decoder_input: batch of integer (e.g., token ID) sequences
  • mask: padding mask from the encoder

The layer output is a mask that marks for each sequence position (for both encoder and decoder) whether that position can be attended to or not. The encoder-decoder mask shape is (batch_size, 1, decoder_sequence_length, encoder_sequence_length), such that axis 1 will automatically broadcast to cover any number of attention heads.

class trax.layers.attention.PositionalEncoding(max_len=2048, dropout=0.0, dropout_broadcast_dims=(-2, ), mode='train')

Bases: trax.layers.base.Layer

Implements bare positional encoding.

Positional encoding includes a kind of dropout, if the layer is created in ‘train’ mode with a nonzero dropout value. For such a layer, on each forward pass a subset of sequence positions selected at random will not receive positional marking.

__init__(max_len=2048, dropout=0.0, dropout_broadcast_dims=(-2, ), mode='train')

Creates a PositionalEncoding instance.

Parameters:
  • max_len – Maximum input sequence length.
  • dropout – Probability of not adding positional encoding to a sequence position.
  • dropout_broadcast_dims – Axes along which dropout mask values are broadcast rather than individually set at random.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
forward(inputs)

Returns the input activations, with added positional information.

init_weights_and_state(input_signature)

Randomly initializes the positional encoding vectors.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on.

base

Base layer class.

class trax.layers.base.Layer(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Bases: object

Base class for composable layers in a deep learning network.

Layers are the basic building blocks for deep learning models. A Trax layer computes a function from zero or more inputs to zero or more outputs, optionally using trainable weights (common) and non-parameter state (not common). Authors of new layer subclasses typically override at most two methods of the base Layer class:

forward(inputs):
Computes this layer’s output as part of a forward pass through the model.
init_weights_and_state(self, input_signature):
Initializes weights and state for inputs with the given signature.

A small subset of layer types are combinators – they organize the computation of their sublayers, e.g., applying their sublayers in series or in parallel.

All layers have the following properties, with default values implemented in the base Layer class:

  • n_in: int (default 1)
  • n_out: int (default 1)
  • weights: tuple (default empty – the layer has no weights)
  • state: tuple (default empty – the layer has no non-parameter state)
  • sublayers: tuple (default empty – the layer has no sublayers)

The inputs to a layer are tensors, packaged according to how many there are:

  • n_in = 0: an empty tuple
  • n_in = 1: one tensor (NOT wrapped in a tuple)
  • n_in > 1: a tuple of tensors

(The special treatment of the single-input case is meant to simplify the work of layer writers; this design choice may be revisited in the future.)

The outputs from a layer are also tensors, packaged the same as layer inputs:

  • n_out = 0: an empty tuple
  • n_out = 1: the tensor (NOT wrapped in a tuple)
  • n_out > 1: a tuple of tensors

The Trax runtime maintains a data stack with which layer calls are composed. For more complex data network architectures, possibly involving multiple data flows, one can view each layer as a function from stack state to stack state, where the function’s inputs are a slice from the stack, and the function’s outputs are spliced back into the stack.

__init__(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
__call__(x, weights=None, state=None, rng=None)

Makes layers callable; for use in tests or interactive settings.

This convenience method helps library users play with, test, or otherwise probe the behavior of layers outside of a full training environment. It presents the layer as callable function from inputs to outputs, with the option of manually specifying weights and non-parameter state per individual call. For convenience, weights and non-parameter state are cached per layer instance, starting from default values of EMPTY_WEIGHTS and EMPTY_STATE, and acquiring non-empty values either by initialization or from values explicitly provided via the weights and state keyword arguments.

Parameters:
  • x – Zero or more input tensors, packaged as described in the Layer class docstring.
  • weights – Weights or None; if None, use self’s cached weights value.
  • state – State or None; if None, use self’s cached state value.
  • rng – Single-use random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.
Returns:

Zero or more output tensors, packaged as described in the Layer class docstring.

forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
has_backward

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng)

Custom backward pass to propagate gradients in a custom way.

Parameters:
  • inputs – Input tensors; can be a (possibly nested) tuple.
  • output – The result of running this layer on inputs.
  • grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
  • weights – This layer’s weights.
  • state – This layer’s state prior to the current forward pass.
  • new_state – This layer’s state after the current forward pass.
  • rng – Single-use random number generator (JAX PRNG key).
Returns:

The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.

init(input_signature, rng=None, use_cache=False)

Initializes weights/state of this layer and its sublayers recursively.

Initialization creates layer weights and state, for layers that use them. It derives the necessary array shapes and data types from the layer’s input signature, which is itself just shape and data type information.

For layers without weights or state, this method safely does nothing.

This method is designed to create weights/state only once for each layer instance, even if the same layer instance occurs in multiple places in the network. This enables weight sharing to be implemented as layer sharing.

Parameters:
  • input_signatureShapeDtype instance (if this layer takes one input) or list/tuple of ShapeDtype instances.
  • rng – Single-use random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.
  • use_cache – If True, and if this layer instance has already been initialized elsewhere in the network, then return special marker values – tuple (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE). Else return this layer’s newly initialized weights and state.
Returns:

A (weights, state) tuple.

init_from_file(file_name, weights_only=False, input_signature=None)

Initializes this layer and its sublayers from a pickled checkpoint.

In the common case (weights_only=False), the file must be a gziped pickled dictionary containing items with keys ‘flat_weights’, `’flat_state’ and ‘input_signature’, which are used to initialize this layer. If input_signature is specified, it’s used instead of the one in the file. If weights_only is True, the dictionary does not need to have the ‘flat_state’ item and the state it not restored either.

Parameters:
  • file_name – Name/path of the pickeled weights/state file.
  • weights_only – If True, initialize only the layer’s weights. Else initialize both weights and state.
  • input_signature – Input signature to be used instead of the one from file.
name

Returns the name of this layer.

n_in

Returns how many tensors this layer expects as input.

n_out

Returns how many tensors this layer promises as output.

sublayers

Returns a tuple containing this layer’s sublayers; may be empty.

weights

Returns this layer’s weights.

Depending on the layer, the weights can be in the form of:

  • an empty tuple
  • a tensor (ndarray)
  • a nested structure of tuples and tensors

If the layer has sublayers, the weights by convention will be a tuple of length len(sublayers) containing the weights of sublayers. Note that in this case self._weights only marks which ones are shared.

state

Returns a tuple containing this layer’s state; may be empty.

If the layer has sublayers, the state by convention will be a tuple of length len(sublayers) containing sublayer states. Note that in this case self._state only marks which ones are shared.

weights_and_state_signature(input_signature)

Return a pair containing the signatures of weights and state.

rng

Returns a single-use random number generator without advancing it.

pure_fn(x, weights, state, rng, use_cache=False)

Applies this layer as a pure function with no optional args.

This method exposes the layer’s computation as a pure function. This is especially useful for JIT compilation. Do not override, use forward instead.

Parameters:
  • x – Zero or more input tensors, packaged as described in the Layer class docstring.
  • weights – A tuple or list of trainable weights, with one element for this layer if this layer has no sublayers, or one for each sublayer if this layer has sublayers. If a layer (or sublayer) has no trainable weights, the corresponding weights element is an empty tuple.
  • state – Layer-specific non-parameter state that can update between batches.
  • rng – Single-use random number generator (JAX PRNG key).
  • use_cache – if True, cache weights and state in the layer object; used to implement layer sharing in combinators.
Returns:

A tuple of (tensors, state). The tensors match the number (n_out) promised by this layer, and are packaged as described in the Layer class docstring.

output_signature(input_signature)

Returns output signature this layer would give for input_signature.

trax.layers.base.layer(n_in=1, n_out=1, name=None)

Decorator for creating simple layers. DEPRECATED; use base.Fn instead.

class trax.layers.base.PureLayer(forward_fn, n_in=1, n_out=1, name='PureLayer')

Bases: trax.layers.base.Layer

Pure function from inputs to outputs, packaged as neural network layer.

The PureLayer class represents the simplest kinds of layers: layers with no trainable weights and no randomness, hence pure functions from inputs to outputs.

__init__(forward_fn, n_in=1, n_out=1, name='PureLayer')

Creates an unconnected PureLayer instance.

Parameters:
  • forward_fn – Pure function from input tensors to output tensors, where inputs and outputs are packaged as specified for forward.
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use only in debugging.
forward(inputs)

Overrides Layer.forward.

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
Raises:ValueError – If weights is other than an empty tuple/list.
trax.layers.base.Fn(name, f, n_out=1)

Returns a layer with no weights that applies the function f.

f can take and return any number of arguments, and takes only positional arguments – no default or keyword arguments. It often uses JAX-numpy (jnp). The following, for example, would create a layer that takes two inputs and returns two outputs – element-wise sums and maxima:

Fn(‘SumAndMax’, lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)

The layer’s number of inputs (n_in) is automatically set to number of positional arguments in f, but you must explicitly set the number of outputs (n_out) whenever it’s not the default value 1.

Parameters:
  • name – Class-like name for the resulting layer; for use in debugging.
  • f – Pure function from input tensors to output tensors, where each input tensor is a separate positional arg, e.g., f(x0, x1) –> x0 + x1. Output tensors must be packaged as specified in the Layer class docstring.
  • n_out – Number of outputs promised by the layer; default value 1.
Returns:

Layer executing the function f.

exception trax.layers.base.LayerError(layer_name, function_name, caller, input_signature, traceback_string)

Bases: Exception

Exception raised in the layer stack.

__init__(layer_name, function_name, caller, input_signature, traceback_string)

Initialize self. See help(type(self)) for accurate signature.

message

Assembles current layer context into an error message.

trax.layers.base.flatten_weights_and_state(weights, state)

Flatten weights and state into lists, excluding empty and cached ones.

trax.layers.base.unflatten_weights_and_state(flat_weights, flat_state, weights_and_state_signature, weights_only=False)

Un-flatten weights and state given their signatures.

trax.layers.base.to_list(outputs)

Converts layer outputs to a nested list, for easier equality testing.

Parameters:outputs – A tensor or tuple/list of tensors coming from the forward application of a layer. Each tensor is NumPy ndarray-like, which complicates simple equality testing (e.g., via assertEquals): such tensors require equality testing to use either all (all elements match) or any (at least one element matches), which is not directly supported in absltest.
Returns:A nested list structure containing all the output values, but now directly testable using assertEquals.

combinators

Combinators for composing layers.

class trax.layers.combinators.Serial(*sublayers, name=None, sublayers_to_print=None)

Bases: trax.layers.base.Layer

Combinator that applies layers serially (by function composition).

This combinator is commonly used to construct deep networks, e.g., like this:

mlp = tl.Serial(
  tl.Dense(128),
  tl.Relu(),
  tl.Dense(10),
  tl.LogSoftmax()
)

A Serial combinator uses stack semantics to manage data for its sublayers. Each sublayer sees only the inputs it needs and returns only the outputs it has generated. The sublayers interact via the data stack. For instance, a sublayer k, following sublayer j, gets called with the data stack in the state left after layer j has applied. The Serial combinator then:

  • takes n_in items off the top of the stack (n_in = k.n_in) and calls layer k, passing those items as arguments; and
  • takes layer k’s n_out return values (n_out = k.n_out) and pushes them onto the data stack.

A Serial instance with no sublayers acts as a special-case (but useful) 1-input 1-output no-op.

__init__(*sublayers, name=None, sublayers_to_print=None)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(xs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
class trax.layers.combinators.Parallel(*sublayers, name=None)

Bases: trax.layers.base.Layer

Combinator that applies a list of layers in parallel to its inputs.

Layers in the list apply to successive spans of inputs, where the spans are determined how many inputs each layer takes. The resulting output is the (flattened) concatenation of the respective layer outputs.

For example, suppose one has three layers:

  • F: 1 input, 1 output
  • G: 3 inputs, 1 output
  • H: 2 inputs, 2 outputs (h1, h2)

Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:

  • inputs: a, b, c, d, e, f
  • outputs: F(a), G(b, c, d), h1, h2

As an important special case, a None argument to Parallel acts as if it takes one argument, which it leaves unchanged. (It acts as a one-arg no-op.) For .. rubric:: Example

Parallel(None, F)

creates a layer that passes its first input unchanged and applies F to the following input(s).

__init__(*sublayers, name=None)

The constructor.

Parameters:
  • *sublayers – A list of sublayers.
  • name – Descriptive name for this layer.
Returns:

A new layer in which each of the given sublayers applies to its corresponding span of elements in the dataflow stack.

forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
class trax.layers.combinators.Concatenate(n_items=2, axis=-1)

Bases: trax.layers.base.Layer

Concatenates n tensors into a single tensor.

__init__(n_items=2, axis=-1)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(xs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
class trax.layers.combinators.Split(n_items=2, axis=-1)

Bases: trax.layers.base.Layer

Splits the input into n items along an axis.

__init__(n_items=2, axis=-1)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
class trax.layers.combinators.Scan(layer, axis=0, n_carry=1, remat=False)

Bases: trax.layers.base.Layer

Applies a layer progressively/cumulatively to an axis-derived sequence.

Conceptually, this is a function from a list to a same-length list of partial (cumulative) results. For instance, a list of values ([1, 2, 3, 4, 5]) can transform to a list of cumulative sums ([1, 3, 6, 10, 15]). Functions for the same concept are called scan in Scala, scanl in Haskell, and accumulate* in Factor.

In more detail, we assume the layer takes a tuple of inputs of the following form:

(input1, …, inputN, carry1, …, carryM)

and returns:

(output1, …, outputK, new_carry1, …, new_carryM)

The scanned version applies the layer iteratively to a tensor treating values at the given axis as if they were a list. For example, to calculate all sums of prefixes of a tensor, we can do this:

def add(x, carry):
  def f(input, carry):
    res = input + carry
    return res, res  # output and carry are the same
  return tl.Fn('add', f, n_out=2)

Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6
__init__(layer, axis=0, n_carry=1, remat=False)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
sublayer

Returns the unique sublayer managed by this layer.

forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
class trax.layers.combinators.Cond(cond, true, false=None, name=None)

Bases: trax.layers.base.Layer

Applies layers conditionally.

For parameters cond, true, and false runs the equivalent of true(y) if cond(x) else false(y), where x is cond.n_in elements from front of the stack and y is the rest of the stack. Exactly one of true and false functions is executed, so it can be used to conditionally run long computations. The state of non-executed function is not updated. Note that different branches may be executed on different devices if cond returns different values on them. By default ‘false’ function is an identity.

cond must return exactly one element: a Boolean value. true and false must have the same n_in, and the same n_out.

__init__(cond, true, false=None, name=None)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
forward(xs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
trax.layers.combinators.Chunk(layer, chunk_size)

Executes layer using batch chunks of size chunk_size to save memory.

trax.layers.combinators.Branch(*layers, name='Branch')

Combinator that applies a list of layers in parallel to copies of inputs.

Each layer in the input list is applied to as many inputs from the stack as it needs, and their outputs are successively combined on stack.

For example, suppose one has three layers:

  • F: 1 input, 1 output
  • G: 3 inputs, 1 output
  • H: 2 inputs, 2 outputs (h1, h2)

Then Branch(F, G, H) will take 3 inputs and give 4 outputs:

  • inputs: a, b, c
  • outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)

As an important special case, a None argument to Branch acts as if it takes one argument, which it leaves unchanged. (It acts as a one-arg no-op.)

Parameters:
  • *layers – List of layers.
  • name – Descriptive name for this layer.
Returns:

A branch layer built from the given sublayers.

trax.layers.combinators.Residual(*layers, shortcut=None)

Wraps a series of layers with a residual connection.

Parameters:
  • *layers – One or more layers, to be applied in series.
  • shortcut – If None (the usual case), the Residual layer computes the element-wise sum of the stack-top input with the output of the layer series. If specified, the shortcut layer applies to a copy of the inputs and (elementwise) adds its output to the output from the main layer series.
Returns:

A layer representing a residual connection paired with a layer series.

trax.layers.combinators.Select(indices, n_in=None, name=None)

Copies, reorders, or deletes stack elements according to indices.

Parameters:
  • indices – A list or tuple of 0-based indices to select elements relative to the top of the stack.
  • n_in – Number of input elements to pop from the stack, and replace with those specified by indices. If not specified, its value will be calculated as max(indices) + 1.
  • name – Descriptive name for this layer.
Returns:

Tensors, matching the number selected (n_out = len(indices)). Specifically:

  • n_out = 0: an empty tuple
  • n_out = 1: one tensor (NOT wrapped in a tuple)
  • n_out > 1: a tuple of tensors, with n_out items

trax.layers.combinators.Drop()

Drops the top stack element.

trax.layers.combinators.Dup()

Duplicates (copies) the top element on the data stack.

trax.layers.combinators.Swap()

Swaps the top two stack elements.

trax.layers.combinators.SerialWithSideOutputs(layers, n_side_outputs=1)

Serial layer with side outputs.

This layer makes it easier to manage the stack when layers have side outputs.

In the simplest case of layers with n_in=1, n_out=2 and with n_side_outputs=1, this layer runs the following computation on x:

side_outputs = []
for i in range(len(layers)):
  x, side_output = layers[i](x)
  side_outputs.append(side_output)
return [x] + side_outputs

In the general case of layers with variable n_in and n_out and n_side_outputs being a list of N integers, it does the following:

side_outputs = []
for i in range(N):
  res = layer[i](cur_stack)  # remove n_in from stack
  cur_stack.append(res[:n_side_outputs[i]])  # put back some on stack
  side_outputs.extend(res[n_side_outputs:])
return cur_stack + side_outputs
Parameters:
  • layers – a list of layers to execute
  • n_side_outputs – an int or a list of ints, how many outputs of each layer to put aside
Returns:

A layer that performs the above computation.

trax.layers.combinators.FlattenList()

Flatten lists.

trax.layers.combinators.Add()

Adds two tensors.

trax.layers.combinators.SubtractTop()

Subtracts the first tensor from the second.

trax.layers.combinators.Multiply()

Multiplies two tensors.

trax.layers.combinators.Gate()

Returns a gating layer on a (memory, gate, candidate) tuple.

Final update is memory * gate + (1 - gate) * candidate

This gating equation may also be referred to as Highway Network. Highway Networks: https://arxiv.org/abs/1505.00387

class trax.layers.combinators.Cache(layer)

Bases: trax.layers.base.Layer

Applies a layer on the first run and returns the outputs on next calls.

__init__(layer)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
sublayer

Returns the unique sublayer managed by this layer.

state

Returns a tuple containing this layer’s state; may be empty.

init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
class trax.layers.combinators.BatchLeadingAxes(layer, n_last_axes_to_keep=1)

Bases: trax.layers.base.Layer

Applies a layer after flattening all but n_last_axes_to_keep to batch.

This can be used to make layers accept an arbitrary number of leading axes (dimensions) as batch. For example, a Convolution layer may normally only operate on tensors of shape [B, W, H, C]. In this case, the layer

BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3)

will operate on any tensor […, W, H, C] and treat the leading axes as batch.

__init__(layer, n_last_axes_to_keep=1)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
sublayer

Returns the unique sublayer managed by this layer.

forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
trax.layers.combinators.inputs_from_stack(stack, n)

Returns n inputs from stack.

trax.layers.combinators.outputs_onto_stack(outputs, stack, n)

“Returns the new stack after removing n items and pushing outputs there.

convolution

Trax convolution layers.

class trax.layers.convolution.Conv(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Layer constructor function for a general convolution layer.

__init__(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(x)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
class trax.layers.convolution.CausalConv(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.convolution.Conv

Causal (masked) convolution for [batch x time x depth] sequences.

Maintains causality along time axis. Used in language modeling tasks.

__init__(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(x)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
trax.layers.convolution.Conv1d(filters, kernel_size, stride=1, padding='VALID', kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

core

Core layer types, such as Dense, Embedding, and Dropout.

class trax.layers.core.Dense(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)

Bases: trax.layers.base.Layer

A dense (a.k.a. fully-connected, affine) layer.

Dense layers are the prototypical example of a trainable layer, i.e., a layer with trainable weights. Each node in a dense layer computes a weighted sum of all node values from the preceding layer and adds to that sum a node-specific bias term. The full layer computation is expressed compactly in linear algebra as an affine map y = Wx + b, where W is a matrix and y, x, and b are vectors. The layer is trained, or “learns”, by updating the values in W and b.

Less commonly, a dense layer can omit the bias term and be a pure linear map: y = Wx.

__init__(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)

Returns a dense (fully connected) layer of width n_units.

A dense layer maps collections of R^m vectors to R^n, where n (= n_units) is fixed at layer creation time, and m is set at layer initialization time.

Parameters:
  • n_units – Number of nodes in the layer, also known as the width of the layer.
  • kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
  • bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
  • use_bias – If True, compute an affine map y = Wx + b; else compute a linear map y = Wx.
forward(x)

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor of same shape and dtype as the input, except the final dimension is the layer’s n_units value.
init_weights_and_state(input_signature)

Randomly initializes this layer’s weights.

Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on.
class trax.layers.core.Embedding(vocab_size, d_feature, kernel_initializer=<function ScaledInitializer.<locals>.Init>)

Bases: trax.layers.base.Layer

Trainable layer that maps discrete tokens/ids to vectors.

__init__(vocab_size, d_feature, kernel_initializer=<function ScaledInitializer.<locals>.Init>)

Returns an embedding layer with given vocabulary size and vector size.

The layer clips input values (token ids) to the range [0, vocab_size). That is, negative token ids all clip to 0 before being mapped to a vector, and token ids with value vocab_size or greater all clip to vocab_size - 1 before being mapped to a vector.

Parameters:
  • vocab_size – Size of the input vocabulary. The layer will assign a unique vector to each id in range(vocab_size).
  • d_feature – Dimensionality/depth of the output vectors.
  • kernel_initializer – Function that creates (random) initial vectors for the embedding.
forward(x)

Returns embedding vectors corresponding to input token id’s.

Parameters:x – Tensor of token id’s.
Returns:Tensor of embedding vectors.
init_weights_and_state(input_signature)

Randomly initializes this layer’s weights.

class trax.layers.core.Dropout(rate=0.0, shared_axes=None, mode='train')

Bases: trax.layers.base.Layer

A layer that stochastically ignores a subset of inputs each training step.

In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1 - rate).

The parameter shared_axes allows to specify a list of axes on which the mask will be shared: we will use size 1 on those axes for dropout mask and broadcast it. Sharing reduces randomness, but can save memory.

This layer is active only during training (mode=’train’). In other circumstances it is a no-op.

__init__(rate=0.0, shared_axes=None, mode='train')

Creates a dropout layer with the given target drop rate.

Parameters:
  • rate – Stochastic rate (probability) for dropping an activation value from the preceding layer (setting it to zero).
  • shared_axes – List of axes on which the mask is shared.
  • mode – If ‘train’, this layer will perform dropout; else, it will pass all values through unaltered.
init_weights_and_state(input_signature)

Sets layer-specific internal state.

forward(x)

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of activations.
Returns:Tensor of same shape and dtype as the input.
class trax.layers.core.Weights(initializer, shape=())

Bases: trax.layers.base.Layer

Learnable weights as a layer.

It takes no input and returns a single tensor: weights.

__init__(initializer, shape=())

Returns a learnable tensor of shape shape.

Parameters:
  • initializer – Function taking shape and rng as arguments.
  • shape – Shape of the learnable weights.
forward(x)

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor with previously specified shape and dtype.
init_weights_and_state(input_signature)

Returns newly initialized weights for this layer.

Weights is a single w tensor with previously specified shape.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on. Unused.
trax.layers.core.PrintShape(n_in=1, msg='')

Prints the shapes of n_in inputs and returns then unchanged.

class trax.layers.core.SummaryScalar(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)

Bases: trax.layers.base.Layer

A layer receiving a tensor, and adding it to TensorBoard as a scalar.

It takes an input and returns it unchanged. It stores this input as a state to be used as a metric in TensorBoard. It converts a tensor to a scalar by running a given aggregation function (mean by default). On TensorBoard, results for each device will be reported separately.

__init__(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)

Takes a tensor and returns it.

Parameters:
  • name – Name of the metric to be reported.
  • aggregation_fun – Aggregation function to be used.
forward(x)

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor with previously specified shape and dtype.
init_weights_and_state(input_signature)

Returns newly initialized weights for this layer.

Weights is a single w tensor with previously specified shape.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on. Unused.
class trax.layers.core.RandomUniform(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)

Bases: trax.layers.base.Layer

Layer returning a tensor with random values distributed uniformly.

__init__(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)

Layer returning a tensor with random values distributed uniformly.

Parameters:
  • min_val – Lower end of uniform distribution.
  • max_val – Upper end of uniform distribution.
  • shape – Shape of the tensor to return. Values are sampled independently.
  • dtype – Type of value to return.
  • sync – Whether to synchronise rng across devices.
forward(xs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
class trax.layers.core.LocallyConnected1d(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')

Bases: trax.layers.base.Layer

Locally-connected layer for 1D inputs.

The LocallyConnected1d layer applies a different set of filters to each patch of the input. This is similar to applying a convolution layer, except that locally-connected layer uses a different set of weights for each patch.

The size of patch is determined by the kernel size. The stride is currently not modifiable and set to one. This means for the input of shape (…, L, D) the output shape for paddings ‘SAME’ and ‘WRAP’ will be (…, L, filters) and for padding ‘VALID’ (…, L-kernel_size+1, filters); where L is the number of “pixels” or “steps” in the input, D is the size of the embedding.

Note that, since the weights for different patches are not shared, the number of “pixels” or “steps” cannot change after calling init_weights_and_state. This is because each “pixel” is assigned its own set of weights.

__init__(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')

Returns a locally-connected conv-like layer.

Parameters:
  • filters – Number of output filters in the convolution.
  • kernel_size – A length of the convolution window. Must be an odd number.
  • kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
  • bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
  • use_bias – If True, the layer uses a bias vector.
  • padding – The type of padding to use; must be ‘VALID’, ‘SAME’, or ‘WRAP’.
forward(x)

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor of same shape and dtype as the input, except the final dimension is the layer’s filters value, and the second to last dimension is shrinked if ‘VALID’ padding is used with kernel_size bigger than one.
init_weights_and_state(input_signature)

Randomly initializes this layer’s weights.

Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.

Parameters:input_signatureShapeDtype instance characterizing the input this layer should compute on.
trax.layers.core.Flatten(n_axes_to_keep=1)

Returns a layer that combines one or more trailing axes of a tensor.

Flattening keeps all the values of the input tensor, but reshapes it by collapsing one or more trailing axes into a single axis. For example, a Flatten(n_axes_to_keep=2) layer would map a tensor with shape (2, 3, 5, 7, 11) to the same values with shape (2, 3, 385).

Parameters:n_axes_to_keep – Number of leading axes to leave unchanged when reshaping; collapse only the axes after these.
trax.layers.core.Exp()

Returns a layer that computes the element-wise exponential of a tensor.

trax.layers.core.LogSoftmax(axis=-1)

Returns a layer that applies log softmax along one tensor axis.

LogSoftmax acts on a group of values and normalizes them to look like a set of log probability values. (Probability values must be non-negative, and as a set must sum to 1. A group of log probability values can be seen as the natural logarithm function applied to a set of probability values.)

Parameters:axis – Axis along which values are grouped for computing log softmax.
trax.layers.core.Softmax(axis=-1)

Returns a layer that applies softmax along one tensor axis.

Softmax acts on a group of values and normalizes them to look like a set of probability values. (Probability values must be non-negative, and as a set must sum to 1.)

Parameters:axis – Axis along which values are grouped for computing softmax.
trax.layers.core.ToFloat()

Returns a layer that changes the dtype of a tensor to float32.

trax.layers.core.Mean(axis=-1, keepdims=False)

Returns a layer that computes mean values using one tensor axis.

Mean uses one tensor axis to form groups of values and replaces each group with the mean value of that group. The resulting values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.

Parameters:
  • axis – Axis along which values are grouped for computing a mean.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.
trax.layers.core.Min(axis=-1, keepdims=False)

Returns a layer that applies min along one tensor axis.

Parameters:
  • axis – Axis along which values are grouped for computing minimum.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.
trax.layers.core.Max(axis=-1, keepdims=False)

Returns a layer that applies max along one tensor axis.

Parameters:
  • axis – Axis along which values are grouped for computing maximum.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.
trax.layers.core.Sum(axis=-1, keepdims=False)

Returns a layer that computes sums using one tensor axis.

Sum uses one tensor axis to form groups of values and replaces each group with the sum of that group. The resulting sum values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.

Parameters:
  • axis – Axis along which values are grouped for computing a sum.
  • keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.
trax.layers.core.Negate()

Returns a layer that computes the element-wise negation of a tensor.

trax.layers.core.StopGradient()

Returns an identity layer with a stop gradient.

trax.layers.core.log_gaussian_pdf(x, mu, sigma)

Returns log N(x | mu, sigma).

Parameters:
  • x – <tbd>
  • mu – <tbd>
  • sigma – <tbd>
trax.layers.core.log_gaussian_diag_pdf(x, mu, diag_sigma)

Returns log N(x | mu, eye(diag_sigma)).

Parameters:
  • x – <tbd>
  • mu – <tbd>
  • diag_sigma – <tbd>
trax.layers.core.multigaussian_loss(preds, targets, ngauss=1)

Returns a mixture of gaussians loss.

Parameters:
  • preds – <tbd>
  • targets – <tbd>
  • ngauss – <tbd>
trax.layers.core.logsoftmax_sample(log_probs, temperature=1.0)

Returns a sample from a log-softmax output, with temperature.

Parameters:
  • log_probs – Logarithms of probabilities (often coming from LogSofmax)
  • temperature – For scaling before sampling (1.0 = default, 0.0 = pick argmax)

initializers

Trax initializers.

trax.layers.initializers.InitializerFromFile(path)

Loads parameters from .npy file.

trax.layers.initializers.RandomNormalInitializer(stddev=0.01)

Returns an initializer for random normal coefficients.

trax.layers.initializers.RandomUniformInitializer(lim=1.0)

Returns an initializer for random uniform coefficients.

trax.layers.initializers.ScaledInitializer(out_dim, in_dim, scale, mode, distribution)

Returns an initializer that adjusts its scale based on weight shapes.

trax.layers.initializers.GlorotNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random Glorot-scaled coefficients.

trax.layers.initializers.GlorotUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random uniform Glorot-scaled coefficients.

trax.layers.initializers.LeCunNormalInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random LeCun-scaled coefficients.

trax.layers.initializers.LeCunUniformInitializer(out_dim=-1, in_dim=-2, scale=1.0)

Returns an initializer for random uniform LeCun-scaled coefficients.

trax.layers.initializers.KaimingNormalInitializer(out_dim=-1, in_dim=-2, param=0.0)

Returns an initializer for random Kaiming-scaled coefficients.

trax.layers.initializers.KaimingUniformInitializer(out_dim=-1, in_dim=-2, param=0.0)

Returns an initializer for random uniform Kaiming-scaled coefficients.

trax.layers.initializers.OrthogonalInitializer(stddev=1.0)

Returns an orthogonal initializer.

trax.layers.initializers.AtariConvInit(kernel_shape, rng, dtype=<sphinx.ext.autodoc.importer._MockObject object>)

The standard init for Conv laters and Atari.

metrics

Trax layers for computing metrics (loss functions and evaluation metrics).

A metric layer takes three inputs:

  • model output: Batch of predicted values (typically vectors).
  • targets: Batch of target values (e.g., categories or vectors).
  • weights: Tensor that can assign different weights to different positions in the model output. One common use of weights is for masking – assigning weight 0 to positions that correspond to padding in the input so that they don’t affect metrics.

and returns a single scalar.

The L2Loss layer treats a batch as an unanalyzed tensor and computes an elementwise-weighted loss.

Other metric layers take into account the items that make up a batch. For each item in a batch, a raw metric value is computed by comparing (item-wise) the model output to the target value. These item-wise values are then combined into a single scalar for the batch by a weighted reduction function, typically weighted mean. For example:

  • Accuracy: Treat model output as giving different strength/votes to the possible categories; measure the category prediction as correct (value 1) if argmax(output) == target_category, else as incorrect (value 0). The accuracy for the batch is then the weighted mean of these 1’s and 0’s.
  • Cross Entropy: Treat model output and target values as two probability distributions; measure the cross entropy of the model output relative to the (assumed true) target distribution. The scalar value for the batch is then the weighted mean of the item-wise cross-entropy values.

In deriving a single scalar for the batch, there is flexibility to use reducing functions other than mean, for instance sum or a specialized sequence mean.

trax.layers.metrics.L2Loss()

Returns a layer that computes total L2 loss for one batch.

trax.layers.metrics.SmoothL1Loss()

Returns a layer that computes total smooth L1 loss for one batch.

trax.layers.metrics.BinaryClassifier(threshold=0.5)

Returns a layer that performs binary classification of the model output.

trax.layers.metrics.MulticlassClassifier(axis=-1)

Multiclass classification of the model output.

trax.layers.metrics.Accuracy(classifier=MulticlassClassifier)

Returns a layer that computes mean category prediction accuracy.

trax.layers.metrics.SequenceAccuracy(classifier=MulticlassClassifier)

Returns a layer that computes mean sequence prediction accuracy.

trax.layers.metrics.BinaryCrossEntropyLoss()

Mean prediction-target cross entropy for binary classification.

trax.layers.metrics.CrossEntropyLoss()

Mean prediction-target cross entropy for multiclass classification.

trax.layers.metrics.BinaryCrossEntropySum()

Sum of prediction-target cross entropies for binary classification.

trax.layers.metrics.CrossEntropySum()

Sum of prediction-target cross entropies for multiclass classification.

trax.layers.metrics.SumOfWeights()

Returns a layer that computes sum of weights.

trax.layers.metrics.WeightedSum()

Returns a layer that computes a weighted sum of the given values.

trax.layers.metrics.one_hot(x, n_categories, dtype=<sphinx.ext.autodoc.importer._MockObject object>)

Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).

normalization

Trax normalization layers.

class trax.layers.normalization.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, momentum=0.999, mode='train')

Bases: trax.layers.base.Layer

Layer that performs batch normalization.

In training, batch normalization keeps smoothed cumulative statistics across batches of input data and modifies each new batch so that its components are normally distributed. In eval or inference, a BatchNorm instance uses its stored mean and variance to approximately normalize each new batch of data.

See https://arxiv.org/abs/1502.03167 for original presentation and motivation of batch normalization).

__init__(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, momentum=0.999, mode='train')

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(x)

Computes batch normalization as part of a forward pass in the model.

init_weights_and_state(input_signature)

Helper to initialize batch norm weights and state.

class trax.layers.normalization.LayerNorm(epsilon=1e-06)

Bases: trax.layers.base.Layer

Layer normalization.

__init__(epsilon=1e-06)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(x)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
class trax.layers.normalization.FilterResponseNorm(mode=None, learn_epsilon=False, init_epsilon=1e-06, init_learnt_epsilon=0.0001)

Bases: trax.layers.base.Layer

Filter Response Normalization layer without Threshold Linear Unit.

c.f. https://arxiv.org/pdf/1911.09737.pdf

__init__(mode=None, learn_epsilon=False, init_epsilon=1e-06, init_learnt_epsilon=0.0001)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.

pooling

Trax pooling layers.

trax.layers.pooling.MaxPool(pool_size=(2, 2), strides=None, padding='VALID')

Reduces each multi-dimensional window to the max of the window’s values.

Windows, as specified by pool_size and strides, involve all axes of an n-dimensional array except the first and last: \((d_1, ..., d_{n-2})\) from shape \((d_0, d_1, ..., d_{n-2}, d_{n-1})\).

Parameters:
  • pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)-dimensional arrays, then pool_size must be a tuple of length \(n-2\).
  • strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
  • padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the selection of max values.
Returns:

N-dimensional array in which each valid (or padded-valid) window position in the input is reduced to / replaced by the max value from that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.SumPool(pool_size=(2, 2), strides=None, padding='VALID')

Reduces each multi-dimensional window to the sum of the window’s values.

Windows, as specified by pool_size and strides, involve all axes of an n-dimensional array except the first and last: \((d_1, ..., d_{n-2})\) from shape \((d_0, d_1, ..., d_{n-2}, d_{n-1})\).

Parameters:
  • pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)-dimensional arrays, then pool_size must be a tuple of length \(n-2\).
  • strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
  • padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the computation of sums.
Returns:

N-dimensional array in which each valid (or padded-valid) window position in the input is reduced to / replaced by the sum of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.AvgPool(pool_size=(2, 2), strides=None, padding='VALID')

Reduces each multi-dimensional window to the mean of the window’s values.

Windows, as specified by pool_size and strides, involve all axes of an n-dimensional array except the first and last: \((d_1, ..., d_{n-2})\) from shape \((d_0, d_1, ..., d_{n-2}, d_{n-1})\).

Parameters:
  • pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)-dimensional arrays, then pool_size must be a tuple of length \(n-2\).
  • strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
  • padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed but is not counted in the computation of averages.
Returns:

N-dimensional array in which each valid (or padded-valid) window position in the input is reduced to / replaced by the mean of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.

reversible

Implementations of reversible layers.

class trax.layers.reversible.ReversibleLayer(n_in=1, n_out=1, name=None, sublayers_to_print=None)

Bases: trax.layers.base.Layer

Reversible Layer.

reverse(output, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

reverse_and_grad(output, grad, weights=(), state=(), new_state=(), rng=None)

Backward pass: computes the inverse of a layer and propagates gradients.

While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.

Parameters:
  • output – Output activations; can be a (possibly nested) tuple.
  • grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
  • weights – layer weights
  • state – start state
  • new_state – updated state computed by the forward pass
  • rng – Single-use random number generator (JAX PRNG key).
Returns:

A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

has_backward

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng)

Custom backward pass to propagate gradients in a custom way.

Parameters:
  • inputs – Input tensors; can be a (possibly nested) tuple.
  • output – The result of running this layer on inputs.
  • grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
  • weights – This layer’s weights.
  • state – This layer’s state prior to the current forward pass.
  • new_state – This layer’s state after the current forward pass.
  • rng – Single-use random number generator (JAX PRNG key).
Returns:

The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.

class trax.layers.reversible.ReversibleSelect(indices, n_in=None, name=None)

Bases: trax.layers.reversible.ReversibleLayer

Reversible version of the Select combinator.

__init__(indices, n_in=None, name=None)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
reverse(outputs, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

trax.layers.reversible.ReversibleSwap()
class trax.layers.reversible.ReversibleSerial(*layers)

Bases: trax.layers.reversible.ReversibleLayer, trax.layers.combinators.Serial

A reversible version of tl.Serial (requires reversible sub-layers).

__init__(*layers)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
reverse(output, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

reverse_and_grad(output, grad, weights=(), state=(), new_state=(), rng=None)

Backward pass: computes the inverse of a layer and propagates gradients.

While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.

Parameters:
  • output – Output activations; can be a (possibly nested) tuple.
  • grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
  • weights – layer weights
  • state – start state
  • new_state – updated state computed by the forward pass
  • rng – Single-use random number generator (JAX PRNG key).
Returns:

A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

class trax.layers.reversible.ReversibleHalfResidual(*residual_layers, attention_layer=None)

Bases: trax.layers.reversible.ReversibleLayer

Half of a RevNet-style residual that optionally performs attention.

When attention_layer is None, this layer has the signature

[accumulator, *context] -> [accumulator + f(context), *context]

The attention_layer must be an instance of EfficientAttentionBase or one of its subclasses (see efficient_attention.py), or None.

Attention is special-cased for the following two reasons:

  • LSH attention needs to save bucket assignments from the forward pass to the backward pass, for training stability. This requires special-casing it.
  • We can call attention_layer.forward_and_or_backward to compute its output (needed for inverting a reversible residual layer) while simultaneously performing the backward pass. Sharing computation between these two operations improves training speed.
__init__(*residual_layers, attention_layer=None)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(xs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
reverse(output, weights=(), state=(), new_state=(), rng=None)

Reverse this layer: compute input given output.

reverse_and_grad(output, ct, weights=(), state=(), new_state=(), rng=None)

Backward pass: computes the inverse of a layer and propagates gradients.

While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.

Parameters:
  • output – Output activations; can be a (possibly nested) tuple.
  • grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
  • weights – layer weights
  • state – start state
  • new_state – updated state computed by the forward pass
  • rng – Single-use random number generator (JAX PRNG key).
Returns:

A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.

rnn

Implementations of common recurrent neural network cells (RNNs).

class trax.layers.rnn.LSTMCell(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

LSTM Cell.

For a nice overview of the motivation and (i, o, f) gates, see this tutorial: https://colah.github.io/posts/2015-08-Understanding-LSTMs/

See this paper for a description and detailed study of all gate types: https://arxiv.org/pdf/1503.04069.pdf

__init__(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
trax.layers.rnn.MakeZeroState(depth_multiplier=1)

Makes zeros of shape like x but removing the length (axis 1).

trax.layers.rnn.LSTM(n_units)

LSTM running on axis 1.

class trax.layers.rnn.GRUCell(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Builds a traditional GRU cell with dense internal transformations.

Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555

__init__(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
trax.layers.rnn.GRU(n_units)

GRU running on axis 1.

trax.layers.rnn.ConvGRUCell(n_units, kernel_size=(3, 3))

Builds a convolutional GRU.

Paper: https://arxiv.org/abs/1511.06432.

Parameters:
  • n_units – Number of hidden units
  • kernel_size – Kernel size for convolution
Returns:

A Stax model representing a GRU cell with convolution transforms.

trax.layers.rnn.GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=<function Sigmoid>, candidate_nonlinearity=<function Tanh>, dropout_rate_c=0.1, sigmoid_bias=0.5)

Parametrized Gated Recurrent Unit (GRU) cell construction.

GRU update equations for update gate, reset gate, candidate memory, and new state:

\[\begin{split}u_t &= \sigma(U' \times s_{t-1} + B') \\ r_t &= \sigma(U'' \times s_{t-1} + B'') \\ c_t &= \tanh(U \times (r_t \odot s_{t-1}) + B) \\ s_t &= u_t \odot s_{t-1} + (1 - u_t) \odot c_t\end{split}\]

See combinators.Gate for details on the gating function.

Parameters:
  • candidate_transform – Transform to apply inside the Candidate branch. Applied before nonlinearities.
  • memory_transform_fn – Optional transformation on the memory before gating.
  • gate_nonlinearity – Function to use as gate activation; allows trying alternatives to Sigmoid, such as HardSigmoid.
  • candidate_nonlinearity – Nonlinearity to apply after candidate branch; allows trying alternatives to traditional Tanh, such as HardTanh.
  • dropout_rate_c – Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch.
  • sigmoid_bias – Constant to add before sigmoid gates. Generally want to start off with a positive bias.
Returns:

A model representing a GRU cell with specified transforms.

trax.layers.rnn.InnerSRUCell()

The inner (non-parallel) computation of an SRU.

trax.layers.rnn.SRU(n_units, activation=None)

SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

As defined in the paper:

\[\begin{split}y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t\end{split}\]

We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It’s best to use at least 2, they say in the paper, except inside a Transformer.

Parameters:
  • n_units – output depth of the SRU layer.
  • activation – Optional activation function.
Returns:

The SRU layer.

research.efficient_attention

Attention Layers optimized for efficiency (second-pass implementation).

The approach taken in the first round of efficient attention implementations revealed several limitations, which this code attempts to address:

  1. Simultaneously instantiating queries, keys, and values for all heads can exceed the memory budget. Transformers are typically tuned such that n_heads * d_attention_key == d_model. Since attention involves queries, keys, AND values, the memory to store them can be ~3x the memory needed to store the input activations. Once the O(n^2) dot-product bottleneck is removed – as is the case in all of our efficient attention implementations – this becomes the next critical bottleneck for scaling up Transformer models.
  2. Attention masking is implemented by associating an integer (typically, the sequence position) with each query and key vector, and defining a function to compute attention masks from this information. The standard attention API (attention.py) is unscalable because it instantiates O(n^2)-size attention masks, and the previous efficient implementations (efficient_attention.py) only supported causal masking.
trax.layers.research.efficient_attention.tie_in(x, y)
trax.layers.research.efficient_attention.length_normalized(x, epsilon=1e-06)
trax.layers.research.efficient_attention.hash_vecs(vecs, n_buckets_in, n_hashes, rng)

Hash vectors into buckets.

Parameters:
  • vecs – vectors to hash, a tensor of shape [batch_size, depth]
  • n_buckets_in – an int or a list of ints, number of hash buckets; if it is a list, we do hierarchical hashing as specified by the list
  • n_hashes – number of hashes
  • rng – random generator to use for hashing
Returns:

A pair (buckets, n_buckets) where buckets is a tensor of shape [n_hashes, batch_size] of integers – the hash bucket ids, and n_buckets is an int, the total number of hash buckets, equal to the product of all items in n_buckets_in.

trax.layers.research.efficient_attention.look_adjacent(x, n_chunks_before, n_chunks_after)

Used to implement attention between consecutive chunks.

Parameters:
  • x – array of shape [n_chunks, chunk_len, …]
  • n_chunks_before – Number of previous chunks to attend to.
  • n_chunks_after – Number of subsequent chunks to attend to.
Returns:

array of shape [n_chunks, N * chunk_len, …], where N = (1 + n_chunks_before + n_chunks_after).

trax.layers.research.efficient_attention.mask_self_attention(dots, q_info, kv_info, causal=True, exclude_self=True, masked=False)

Performs masking for self-attention.

trax.layers.research.efficient_attention.attend(q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None)

Dot-product attention, with optional chunking and/or masking.

Parameters:
  • q – Query vectors, shape [q_len, d_qk]
  • k – Key vectors, shape [kv_len, d_qk]; or None
  • v – Value vectors, shape [kv_len, d_v]
  • q_chunk_len – Set to non-zero to enable chunking for query vectors
  • kv_chunk_len – Set to non-zero to enable chunking for key/value vectors
  • n_chunks_before – Number of adjacent previous chunks to attend to
  • n_chunks_after – Number of adjacent subsequent chunks to attend to
  • mask_fn – TODO(kitaev) doc
  • q_info – Query-associated metadata for masking
  • kv_info – Key-associated metadata for masking
  • dropout – Dropout rate
  • rng – RNG for dropout
Returns:

A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention).

trax.layers.research.efficient_attention.apply_broadcasted_dropout(vecs, dropout_rate, rng)

Apply dropout, broadcasted across all but the last dimension of vecs.

trax.layers.research.efficient_attention.permute_via_gather(val, permutation, inverse_permutation, axis=0)

Permutation helper for LSH attention.

trax.layers.research.efficient_attention.permute_via_sort(val, keys, inverse_keys, axis=0)

Permutation helper for LSH attention.

class trax.layers.research.efficient_attention.EfficientAttentionBase(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.base.Layer

Base class for efficient attention.

This is a base class that implements memory-efficient batching for both the forward and backward passes. Subclasses should override create_weights_unbatched, create_state_unbatched, forward_unbatched, and optionally incremental_forward_unbatched to define the actual attention mechanism.

__init__(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)

Constructs an EfficientAttentionBase instance.

Parameters:
  • n_heads – Number of attention heads.
  • n_in – Number of inputs to the layer (default 1).
  • n_parallel_heads

    Number of attention heads to compute in parallel.

    • If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
    • If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
    • If n_parallel_heads is a multiple of n_heads, attention is computed for sub-batches of (n_parallel_heads // n_heads) examples at a time.
    • If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
  • incremental – If True, enable fast inference for self-attention types. Note that this flag should not be set when doing encoder-decoder attention, but only when doing self-attention.
  • predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
  • predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
  • use_python_loop – Set to True to use a Python loop when iterating over sub-batches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
  • use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
create_weights_unbatched(input_signature, rng)
create_state_unbatched(input_signature, rng)
forward_unbatched(*inputs, weights, state)

Perform attention for a single batch element and head.

Subclasses should override this method.

Parameters:
  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
Returns:

A tuple (output, new_state) – output and new state for a single example and attention head.

incremental_forward_unbatched(*inputs, q_start, q_len, weights, state)

Perform fast inference for a single batch element and head.

Subclasses should override this method.

Parameters:
  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • q_start – Index along the sequence-length dimension that points to the first input element that should be used as a query (and not just a key).
  • q_len – Number of new query elements in this call to the attention mechanism. This is typically 1 for autoregressive decoding, but may be longer if initializing a language model with a prefix.
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
Returns:

A tuple (output, new_state) – output and new state for a single example and attention head.

forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Parameters:inputs – Layer inputs (subclasses may use different inputs)
Returns:A tuple (output, new_state).
use_predict_mem(inputs, state)

Update input cache for fast inference.

has_backward

Returns True if this layer provides its own custom backward pass code.

A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)

Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)

Performs batched forward and/or backward passes.

See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.

Parameters:
  • inputs – inputs to the attention layer
  • weights – weights for the attention layer
  • state – state of the attention layer
  • rng – PRNG key for the layer (shared across all examples and heads)
  • output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
  • compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
  • update_state – bool: whether to return an updated layer state.
Returns:

A tuple (output, new_state, inputs_grad, weights_grad).

  • output is not None iff compute_output is True
  • new_state is not None iff update_state is True
  • inputs_grad & weights_grad are not None iff output_grad is not None

class trax.layers.research.efficient_attention.SelfAttention(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.research.efficient_attention.EfficientAttentionBase

Memory-efficient self-attention (second attempt).

__init__(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Construct a self-attention layer.

Parameters:
  • n_heads – int: Number of attention heads
  • d_qk – int: Depth of query ond key vectors
  • d_v – int: Depth of value vectors
  • share_qk – bool: Set to True to share query and key projection weights
  • causal – bool: Set to True to mask out attention to future items
  • masked – bool: Set to True to accept an additional mask argument, that allows masking out attention to padding tokens.
  • chunk_len (optional) – Number of tokens per chunk. Setting this option will enable chunked attention.
  • n_chunks_before – Number of previous chunks to attend to, when using chunked attention.
  • n_chunks_after – Number of subsequent chunks to attend to, when using chunked attention. Don’t use this option for causal attention, because attention to future tokens will be masked out anyway. However, note that cross-chunk attention “wraps around” in both directions, so this option is never a strict no-op.
  • bias – bool: Set to True to add bias vectors when computing query/key/value
  • mode – ‘train’, ‘eval’, or ‘predict’
  • predict_mem_len – int: Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten. When chunking is enabled, the default is to store chunk_len * (1 + n_chunks_before) elements.
  • predict_drop_len – int: Number of input elements to drop once the fast inference input cache fills up. When chunking is enabled, the default is to drop exactly chunk_len elements.
  • attention_dropout – Dropout probability for attention mask.
  • output_dropout – Dropout probability for the layer output.
  • n_parallel_heads – see EfficientAttentionBase. This option controls the trade-off between parallelism and memory usage.
  • use_python_loop – For testing/debugging (see EfficientAttentionBase)
  • use_reference_code – For testing/debugging (see EfficientAttentionBase)
create_weights_unbatched(input_signature, rng)
forward_unbatched(x, mask=None, *, weights, state, rng, update_state)

Perform attention for a single batch element and head.

Subclasses should override this method.

Parameters:
  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
Returns:

A tuple (output, new_state) – output and new state for a single example and attention head.

incremental_forward_unbatched(x, mask=None, *, q_start, q_len, weights, state, rng, update_state)

Perform fast inference for a single batch element and head.

Subclasses should override this method.

Parameters:
  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • q_start – Index along the sequence-length dimension that points to the first input element that should be used as a query (and not just a key).
  • q_len – Number of new query elements in this call to the attention mechanism. This is typically 1 for autoregressive decoding, but may be longer if initializing a language model with a prefix.
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
Returns:

A tuple (output, new_state) – output and new state for a single example and attention head.

class trax.layers.research.efficient_attention.LSHSelfAttention(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=1, use_python_loop=False, use_reference_code=False, max_length_for_buckets=None)

Bases: trax.layers.research.efficient_attention.SelfAttention

LSH self-attention (second implementation).

__init__(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=1, use_python_loop=False, use_reference_code=False, max_length_for_buckets=None)

Construct an LSH self-attention layer.

create_state_unbatched(input_signature, rng)
hash_vectors(vecs, rng, mask=None)
forward_unbatched(x, mask=None, *, weights, state, rng, update_state)

Perform attention for a single batch element and head.

Subclasses should override this method.

Parameters:
  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
Returns:

A tuple (output, new_state) – output and new state for a single example and attention head.

incremental_forward_unbatched(x, *, q_start, q_len, weights, state, rng, update_state)

Perform fast inference for a single batch element and head.

Subclasses should override this method.

Parameters:
  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • q_start – Index along the sequence-length dimension that points to the first input element that should be used as a query (and not just a key).
  • q_len – Number of new query elements in this call to the attention mechanism. This is typically 1 for autoregressive decoding, but may be longer if initializing a language model with a prefix.
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
Returns:

A tuple (output, new_state) – output and new state for a single example and attention head.

class trax.layers.research.efficient_attention.EncDecAttention(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Bases: trax.layers.research.efficient_attention.EfficientAttentionBase

Memory-efficient encoder-decoder attention.

__init__(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)

Constructs an EfficientAttentionBase instance.

Parameters:
  • n_heads – Number of attention heads.
  • n_in – Number of inputs to the layer (default 1).
  • n_parallel_heads

    Number of attention heads to compute in parallel.

    • If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
    • If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
    • If n_parallel_heads is a multiple of n_heads, attention is computed for sub-batches of (n_parallel_heads // n_heads) examples at a time.
    • If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
  • incremental – If True, enable fast inference for self-attention types. Note that this flag should not be set when doing encoder-decoder attention, but only when doing self-attention.
  • predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
  • predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
  • use_python_loop – Set to True to use a Python loop when iterating over sub-batches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
  • use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.
create_weights_unbatched(input_signature, rng)
forward_unbatched(q_antecedent, kv_antecedent, mask=None, *, weights, state, rng, update_state)

Perform attention for a single batch element and head.

Subclasses should override this method.

Parameters:
  • *inputs – Inputs for a single example (subclasses may use different inputs)
  • weights – Weights for a single attention head
  • state – State for a single example & attention head pair.
Returns:

A tuple (output, new_state) – output and new state for a single example and attention head.

trax.layers.research.efficient_attention.CausalFavor(d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, precision=None, mode='train')

Returns a layer that maps activations to activations, with causal masking.

Like CausalAttention, this layer type represents one pass of multi-head causal attention, but using FAVOR fast attention as in the following paper: https://arxiv.org/abs/2006.03555

Parameters:
  • d_feature – Depth/dimensionality of feature embedding.
  • n_heads – Number of attention heads.
  • dropout – Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values.
  • numerical_stabilizer – float, small number used for numerical stability.
  • precision – passed to np.einsum to define arithmetic precision.
  • mode – One of ‘train’, ‘eval’, or ‘predict’.
class trax.layers.research.efficient_attention.LSHFF(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Feed-forward block with LSH.

The original (non-LSH) feed-forward block is a triple Dense(d_ff)-Relu-Dense that takes an input, makes it of size d_ff (usually larger than it was) and then brings it back to the original size after Relu. It is commonly used in Transformer models where it often accounts for most of the trainable weights.

The original block can be slow in decoding due to the need to fetch a lot of weights from memory. The LSH block aims to exploit this sparsity. So in the first Dense(d_ff) layer, instead of making a full matrix multiplication, this block only multiplies by the parts of the weights matrix that have the highest chance to give non-0 after Relu. This is determined by taking a number of locality-sensitive hashes and masking to only include weights that have one hash identical to the multiplied element.

__init__(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Returns a LSH feed-forward block.

forward(x)

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor of same shape and dtype as the input.
init_weights_and_state(input_signature)

Randomly initializes this layer’s weights.

class trax.layers.research.efficient_attention.SparseFF(d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.7, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Feed-forward block with sparsity.

The original (non-LSH) feed-forward block is a triple Dense(d_ff)-Relu-Dense that takes an input, makes it of size d_ff (usually larger than it was) and then brings it back to the original size after Relu. It is commonly used in Transformer models where it often accounts for most of the trainable weights.

The original block can be slow in decoding due to the need to fetch a lot of weights from memory. This sparse block only allows one non-zero element in a block of a specified size. This is trained with straight-through Gumbel softmax trick.

__init__(d_ff, n_elements_in_block=32, d_lowrank=64, temperature=0.7, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)

Returns a sparse feed-forward block.

forward(x)

Executes this layer as part of a forward pass through the model.

Parameters:x – Tensor of same shape and dtype as the input signature used to initialize this layer.
Returns:Tensor of same shape and dtype as the input.
init_weights_and_state(input_signature)

Randomly initializes this layer’s weights.

research.position_encodings

Experimenting with position encodings.

class trax.layers.research.position_encodings.AxialPositionalEncoding(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')

Bases: trax.layers.base.Layer

Axial positional encoding.

__init__(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
class trax.layers.research.position_encodings.FixedBasePositionalEncoding(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=100, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)

Bases: trax.layers.base.Layer

Implements fixed-base positional encoding.

__init__(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=100, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)

Creates a partially initialized, unconnected layer instance.

Parameters:
  • n_in – Number of inputs expected by this layer.
  • n_out – Number of outputs promised by this layer.
  • name – Class-like name for this layer; for use when printing this layer.
  • sublayers_to_print – Sublayers to display when printing out this layer; By default (when None) we display all sublayers.
forward(x)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
trax.layers.research.position_encodings.threefry_2x32_prf(key, x: <sphinx.ext.autodoc.importer._MockObject object at 0x7f7e6ac0a390>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f7e6ac0a4e0>

Apply the threefry PRF to an array of inputs.

This function is vectorized over x. For threefry_2x32: K = X = uint32[2]

Parameters:
  • key – uint32[2] the key of the PRF
  • x – uint32[…, 2] the inputs
Returns:

uint32[…, 2] the outputs

Return type:

y

trax.layers.research.position_encodings.threefry_2x32_prange(key, lo: int = 0, hi: int = 2)

Splits a key into a stream of random keys.

This uses the little-endian counter mode.

Parameters:
  • key – uint32[2] the key to split
  • lo – the range to start extracting from
  • hi – the range to stop extracting from
Returns:

uint32[hi - lo, 2] the split keys

Return type:

keys

class trax.layers.research.position_encodings.InfinitePositionalEncoding(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')

Bases: trax.layers.base.Layer

Infinite positional encoding.

__init__(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')

Initializes the encoding.

The encoding tries to roughly evenly traverse the latent space. The recurrence time is dependent on how many bits per dimension you use.

There are two parameters to control randomization: - randomizing the origin every 1/drift steps by letting it drift - randomizing the origin per call

Parameters:
  • drift – variance in position difference per unit of difference
  • affine – whether to randomize the origin every call
  • transform – learnable transform after encoding (any/diag/none)
  • time_bin_length – Add features AxialPositionalEncoding learns if TimeBinCausalAttention is the first layer. bin_length should match TBCA.bin_length If you set transform=’diag’, this flag increases your model capacity to close to transform=’any’, though it will still train slower.
  • mode – if ‘predict’, allow evaluating one token at a time
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.
class trax.layers.research.position_encodings.TimeBinPositionalEncoding(time_bin_length, mode='train')

Bases: trax.layers.base.Layer

Just the engineered features from InfinitePositionalEncoding.

num_features = 3
__init__(time_bin_length, mode='train')

Initializes the encoding.

Parameters:
  • time_bin_length – TimeBinCausalAttention.bin_length of the first layer.
  • mode – if ‘predict’, allow evaluating one token at a time
forward(inputs)

Computes this layer’s output as part of a forward pass through the model.

Authors of new layer subclasses should override this method to define the forward computation that their layer performs. Use self.weights to access trainable weights of this layer. If you need to use local non-trainable state or randomness, use self.rng for the random seed (no need to set it) and use self.state for non-trainable state (and set it to the new value).

Parameters:inputs – Zero or more input tensors, packaged as described in the Layer class docstring.
Returns:Zero or more output tensors, packaged as described in the Layer class docstring.
init_weights_and_state(input_signature)

Initializes weights and state for inputs with the given signature.

Authors of new layer subclasses should override this method if their layer uses trainable weights or non-trainable state. To initialize trainable weights, set self.weights and to initialize non-trainable state, set self.state to the intended value.

Parameters:input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances; signatures of inputs.