trax.layers¶
acceleration¶
Modifications to data and computation to use accelerators (better).

class
trax.layers.acceleration.
Accelerate
(layer, n_devices=None)¶ Bases:
trax.layers.base.Layer
Accelerates a layer, running in dataparallel way on multiple devices.
By default it uses all available accelerators, splits the input on the first (batch) axis, and runs each part on the corresponding accelerator. If only one accelerator is available, this layer JITcompiles the underlying layer and in this way makes it run faster.
The output is guaranteed to be the same as the output of the original layer if the batch dimension is divisible by the number of devices. If it is not, then 0padding is added to make it divisible and the output may be affected if it relies on layers like batch normalization.
This layer does not require calling
init
if the underlying layer has already been initialized, so it can be used as follows:layer = tl.Serial(...) layer.init(...) fast_layer = tl.Accelerate(layer) y = fast_layer(x) # Split x on batch and run dataparallel
In case the weights of this layer need to be set using the weights of the sublayer, use the
replicate_weights
function:# Instead of layer.weights = new_weights: fast_layer.replicate_weights(new_weights)

__init__
(layer, n_devices=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

pure_fn
(x, weights, state, rng, use_cache=False)¶ Calls
self.sublayer.pure_fn
in an accelerated way.

init
(input_signature)¶ Calls
self.sublayer.init
and replicates its values onto devices.

replicate_weights
(weights)¶ Sets the weights of the sublayer and replicates them for this layer.

replicate_state
(state)¶ Sets the state of the sublayer and replicates it for this layer.

weights
¶ Returns this layer’s weights.
Depending on the layer, the weights can be in the form of:
 an empty tuple
 a tensor (ndarray)
 a nested structure of tuples and tensors
If the layer has sublayers, the weights by convention will be a tuple of length len(sublayers) containing the weights of sublayers. Note that in this case self._weights only marks which ones are shared.

state
¶ Returns a tuple containing this layer’s state; may be empty.
If the layer has sublayers, the state by convention will be a tuple of length len(sublayers) containing sublayer states. Note that in this case self._state only marks which ones are shared.


trax.layers.acceleration.
mean_or_pmean
(n_devices, x, axis=None)¶ Computes the mean of a distributed value
x
.Parameters:  n_devices – Number of devices.
 x – Distributed array.
 axis – Axis along which to compute means; can only be
0
orNone
.
Returns: A local array.

trax.layers.acceleration.
jit_forward
(forward, n_devices, do_mean=True)¶ Returns a JITcompiled forward function running on
n_devices
.

trax.layers.acceleration.
reshape_by_device
(x, n_devices, pure_np=False)¶ Reshapes possibly nested
x
into a shape(n_devices, ...)
.

trax.layers.acceleration.
for_n_devices
(x, n_devices)¶ Replicates/broadcasts
x
forn_devices
.

trax.layers.acceleration.
on_cpu
(x)¶ Puts
x
in CPU memory in JAX.

trax.layers.acceleration.
on_accelerator
(x)¶ Puts
x
in (single) accelerator memory in JAX.
activation_fns¶
Layers that compute activation functions.
An activation layer computes elementwise a nonlinear function of the preceding layer’s output. Historically, an activation function was considered part of each node in each layer of the neural network. Trax follows the common current practice of separating the activation function as its own layer, which enables easier experimentation across different activation functions.

trax.layers.activation_fns.
Relu
()¶ Returns a layer that computes the Rectified Linear Unit (ReLU) function.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]

trax.layers.activation_fns.
ParametricRelu
(a=1.0)¶ Returns a layer that computes a ReLU function with the given slope.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ ax & \text{otherwise}. \end{array} \right.\end{split}\]Parameters: a – Slope of line for positive inputs.

trax.layers.activation_fns.
LeakyRelu
(a=0.01)¶ Returns a ReLUlike layer with linear nonzero outputs for negative inputs.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} ax & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\]Parameters: a – Slope of line for negative inputs.

trax.layers.activation_fns.
Elu
(a=1.0)¶ Returns a ReLUlike layer with exponential outputs for negative inputs.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} a \cdot (e^x  1) & \text{if}\ x \leq 0, \\ x & \text{otherwise}. \end{array} \right.\end{split}\](Asymptotically, \(f(x)\rightarrow a\) as \(x\rightarrow  \infty\).)
Parameters: a – Coefficient multiplying the exponential, for negative inputs.

trax.layers.activation_fns.
Selu
(alpha=1.6732632423543772, lmbda=1.0507009873554805)¶ Returns an Elulike layer with an additional scaling/slope parameter.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} \lambda \cdot \alpha \cdot (e^x  1) & \text{if}\ x \leq 0, \\ \lambda \cdot x & \text{otherwise}. \end{array} \right.\end{split}\]Parameters:  alpha – Coefficient multiplying the exponential, for negative inputs.
 lmbda – Coefficient scaling the whole function.

trax.layers.activation_fns.
Gelu
()¶ Returns a layer that computes the Gaussian Error Linear Unit function.
\[f(x) = \frac{x}{2} \cdot (1 + \hbox{erf}(\frac{x}{\sqrt{2}}))\]

trax.layers.activation_fns.
FastGelu
()¶ Returns a layer that computes a fast approximation to Gelu.
\[f(x) = \frac{x}{2} \cdot (1 + \tanh(ax + abx^3))\]where \(a = 0.7978845608\) and \(b = 0.044715\).

trax.layers.activation_fns.
Sigmoid
()¶ Returns a layer that computes the sigmoid function.
\[f(x) = \frac{1}{1 + e^{x}}\]

trax.layers.activation_fns.
Tanh
()¶ Returns a layer that computes the hyperbolic tangent function.
\[f(x) = \frac{e^x  e^{x}}{e^x + e^{x}}\]

trax.layers.activation_fns.
HardSigmoid
()¶ Returns a layer that computes a linear approximation to Sigmoid.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 0 & \text{if}\ x \leq 0, \\ x & \text{if}\ 0 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]

trax.layers.activation_fns.
HardTanh
()¶ Returns a layer that computes a linear approximation to Tanh.
\[\begin{split}f(x) = \left\{ \begin{array}{cl} 1 & \text{if}\ x \leq 1, \\ x & \text{if}\ 1 < x < 1, \\ 1 & \text{otherwise}. \end{array} \right.\end{split}\]

trax.layers.activation_fns.
Softplus
()¶ Returns a layer that computes the softplus function.
\[f(x) = \ln(e^x + 1)\]

trax.layers.activation_fns.
Exp
()¶ Returns a layer that computes the elementwise exponential of a tensor.

trax.layers.activation_fns.
Log
()¶ Returns a layer that computes the elementwise logarithm of a tensor.

trax.layers.activation_fns.
Swish
()¶ Returns a layer that computes the Swish function.
\[f(x) = x \cdot \text{sigmoid}(x)\]

trax.layers.activation_fns.
Glu
()¶ Returns a layer that computes the Gated Linear Unit function.
\[f(x) = a \cdot \text{sigmoid}(b)\]where a and b are formed by splitting input in half along axis

class
trax.layers.activation_fns.
ThresholdedLinearUnit
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Bases:
trax.layers.base.Layer
Thresholded Linear Unit, c.f. https://arxiv.org/pdf/1911.09737.pdf .

init_weights_and_state
(input_signature)¶ Initializes this layer’s single weight to zero.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.
Parameters: inputs – Tensor. Returns: Tensor of same shape and dtype as the input.

attention¶
Attentionrelated layers, as used in Transformer(like) models.
Attention is a trainable mechanism for mapping between collections of vectors:
Whereas classic neural networks assemble nodes of numbers with weighted connections:
 node activations: floating point values (one float per node)
 internode connections: trainable weights (one float per connection),
attention lets one assemble nodes of vectors and use further vectors to calculate connection strengths:
 node activations: floating point vectors, and
 internode connections: computed using trainable vectors.
Computing connection strengths involves several concepts – queries, keys, values, masks, attention heads – that factor heavily into the API below.
NOTE: Attention, positional encoding, and shift layers in this module include
mode
dependent behavior. The possible modes are:
'train'
: in training – dropouts and position shifts active'eval'
: in evals – dropouts inactive, position shifts active'predict'
: in prediction – dropouts and position shifts inactive

trax.layers.attention.
Attention
(d_feature, n_heads=1, dropout=0.0, mode='train')¶ Returns a layer that maps (vectors, mask) to (new_vectors, mask).
This layer type represents one pass of multihead selfattention, from vector set to vector set, using masks to represent outofbound (e.g., padding) positions. It:
 makes three copies of incoming activations and maps these to multihead query (Q) vectors, key (K) vectors, and value (V) vectors, respectively;
 for each head, computes the scaled dot product of each QK pair;
 applies mask to screen out positions that come from padding tokens (indicated by 0 value);
 [in
'train'
mode] applies dropout to QK dot products;  for each head, computes QK attention strengths using a perquery softmax of the QK dot products;
 for each head, for each query position, combines V vectors according to the QK attention strengths; and
 concatenates and fuses resulting perhead vectors into outgoing activations matching original input activation shapes.
Parameters:  d_feature – Last/innermost dimension of activations in the input to and output from this layer.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.

trax.layers.attention.
AttentionQKV
(d_feature, n_heads=1, dropout=0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None)¶ Returns a layer that maps (AQ, AK, AV, mask) to (newA, mask).
Unlike
Attention
above,AttentionQKV
allows the incoming activations (AQ, AK, and AV) to come from different sources. This is used, for instance, in encoderdecoder attention (Qrelated activations AQ from the decoder, K and Vrelated activations – AK and AV – from the encoder). Otherwise, see theAttention
description for further context/details.Parameters:  d_feature – Last/innermost dimension of activations in the input to and output from this layer.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.  cache_KV_in_predict – Whether to cache K/V arrays in
'predict'
mode.  q_sparsity – Sparsity with which to process queries. If
None
,Dense
is used; if'noop'
, no processing is used.  result_sparsity – Sparsity with which to process result of the attention.
If
None
,Dense
is used; if'noop'
, no processing is used.

class
trax.layers.attention.
PureAttention
(n_heads=1, dropout=0.0, mode='train')¶ Bases:
trax.layers.base.Layer
Returns a layer that maps (Q, K, V, mask) to (activations, mask).
This layer type performs the inner workings of one pass of multihead selfattention. It:
 subdivides incoming Q/K/V activations into multihead versions;
 for each head, computes the scaled dot product of each QK pair;
 applies mask to screen out positions that come from padding tokens (indicated by 0 value);
 [in
'train'
mode] applies dropout to QK dot products;  for each head, computes QK attention strengths using a perquery softmax of the QK dot products;
 for each head, for each query position, combines V vectors according to the QK attention strengths; and
 concatenates and fuses resulting perhead vectors into outgoing activations matching original input activation shapes.

__init__
(n_heads=1, dropout=0.0, mode='train')¶ Returns a new
PureAttention
instance.Parameters:  n_heads – Number of attention heads.
 dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.

forward
(inputs)¶ Returns attentioncomputed activations and unmodified mask.
Parameters: inputs – A (Q, K, V, mask) tuple, whose query, key, and value activations have not yet been subdivided into heads.

class
trax.layers.attention.
DotProductAttention
(dropout=0.0, mode='train')¶ Bases:
trax.layers.base.Layer
Returns a layer that computes perhead attention (via scaled dotproduct).
This layer computes the core of the attention mechanism. Given perhead queries (Q), keys (K), values (V), and mask, it:
 computes the scaled dot product of each QK pair;
 applies mask to screen out positions that come from padding tokens (indicated by 0 value);
 [if created in
'train'
mode] applies dropout to QK dot products;  computes QK attention strengths using a perquery softmax of the QK dot products; and
 for each query position, combines V vectors according to the QK attention strengths.

__init__
(dropout=0.0, mode='train')¶ Creates a
DotProductAttention
instance in a specific mode.Parameters:  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  mode – One of
'train'
,'eval'
,'predict'
or'viz'
.
 dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in

forward
(inputs)¶ Returns attentioncomputed perhead activations and unchanged mask.
Parameters: inputs – A (Q, K, V, mask) tuple, whose query, key, and value activations have been subdivided into heads.

trax.layers.attention.
SplitIntoHeads
(n_heads, merged_batch_and_head=True)¶ Returns a layer that reshapes an array for multihead computation.

trax.layers.attention.
MergeHeads
(n_heads, merged_batch_and_head=True)¶ Returns a layer that rejoins heads, after multihead computation.

trax.layers.attention.
ConfigurableAttention
(q_layer, k_layer, v_layer, final_layer, qkv_attention_layer, n_heads=1)¶ Returns a configured multihead selfattention layer.
A
ConfigurableAttention
layer acts similarly toAttention
layers, but with configurable components. It makes three copies of incoming activations and uses
q_layer
,k_layer
, andv_layer
to map activations to multihead query (Q) vectors, key (K) vectors, and value (V) vectors, respectively;  uses
qkv_attention_layer
to compute perhead attention, similar toDotProductAttention
orDotProductCausalAttention
;  concatenates and fuses resulting perhead vectors into activations matching original input activation shapes; and
 applies a final layer,
final_layer
, mapping activations to activations (with shape matching the original input activations).
Parameters:  q_layer – Layer that maps input activations to perhead query activations.
 k_layer – Layer that maps input activations to perhead key activations.
 v_layer – Layer that maps input activations to perhead value activations.
 final_layer – After main multihead computation and rejoining of heads, layer that maps activations to activations (with shape matching the original input activations).
 qkv_attention_layer – Layer the does the core multihead selfattention computation.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.
 makes three copies of incoming activations and uses

trax.layers.attention.
CausalAttention
(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, mode='train')¶ Returns a layer that maps activations to activations, with causal masking.
Like
Attention
, this layer type represents one pass of multihead selfattention, but with causal masking rather than paddingbased masking.Parameters:  d_feature – Last/innermost dimension of activations in the input to and output from this layer.
 n_heads – Number of attention heads. Attention heads effectively split
activation vectors into
n_heads
subvectors, of sized_feature / n_heads
.  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  max_inference_length – Maximum sequence length allowed in nontraining modes.
 mode – One of
'train'
,'eval'
, or'predict'
.

class
trax.layers.attention.
DotProductCausalAttention
(dropout=0.0, max_inference_length=2048, mode='train')¶ Bases:
trax.layers.base.Layer
Layer that computes attention strengths by masking out the “future”.
Causal attention uses masking to prevent a given sequence position from attending to positions greater than / following it. This is used, for example, when training autoregressive sequence models, or when decoding a sequence symbol by symbol.
This layer performs the core perhead attention calculation. The layer assumes that any splitting into attention heads precedes it, and that any merging of attention heads will follow it.

__init__
(dropout=0.0, max_inference_length=2048, mode='train')¶ Creates a
DotProductCausalAttention
instance.Parameters:  dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in
'train'
mode.  max_inference_length – Maximum sequence length allowed in nontraining modes.
 mode – One of
'train'
,'eval'
, or'predict'
.
 dropout – Probababilistic rate for attention dropout, which overrides
(sets to zero) some attention strengths derived from querykey
matching. As a result, on a given forward pass, some value vectors
don’t contribute to the output, analogous to how regular dropout can
cause some node activations to be ignored. Applies only if layer is
created in

monkey_patched_mask
()¶

forward
(inputs)¶ Returns attentioncomputed activations.
Parameters: inputs – A (queries, keys, values) tuple.

init_weights_and_state
(input_signature)¶ Initializes this layer for fast inference, if in
'predict'
mode.


trax.layers.attention.
ShiftRight
(n_positions=1, mode='train')¶ Returns a layer that can insert padding to shift the input sequence.
Parameters:  n_positions – Number of positions to shift the input sequence rightward;
initial positions freed by the shift get padded with zeros. Applies
only if layer is created in a non
'eval'
mode.  mode – One of
'train'
,'eval'
, or'predict'
.
 n_positions – Number of positions to shift the input sequence rightward;
initial positions freed by the shift get padded with zeros. Applies
only if layer is created in a non

trax.layers.attention.
PaddingMask
(pad=0)¶ Returns a layer that maps integer sequences to padding masks.
The layer expects as input a batch of integer sequences. The layer output is an ND array that marks for each sequence position whether the integer (e.g., a token ID) in that position represents padding – value
pad
– versus text/content – all other values. The padding mask shape is (batch_size, 1, 1, encoder_sequence_length), such that axis 1 will broadcast to cover any number of attention heads and axis 2 will broadcast to cover decoder sequence positions.Parameters: pad – Integer that represents padding rather than a token/content ID.

trax.layers.attention.
EncoderDecoderMask
()¶ Returns a layer that creates a mask for encoderdecoder cross attention.
The layer expects two inputs:
 decoder_input: batch of integer (e.g., token ID) sequences
 mask: padding mask from the encoder
The layer output is a mask that marks for each sequence position (for both encoder and decoder) whether that position can be attended to or not. The encoderdecoder mask shape is (batch_size, 1, decoder_sequence_length, encoder_sequence_length), such that axis 1 will automatically broadcast to cover any number of attention heads.

class
trax.layers.attention.
PositionalEncoding
(max_len=2048, dropout=0.0, dropout_broadcast_dims=(2, ), use_bfloat16=False, start_from_zero_prob=1.0, max_offset_to_add=0, d_feature=None, mode='train')¶ Bases:
trax.layers.base.Layer
Implements bare positional encoding.
Positional encoding includes a kind of dropout, if the layer is created in
'train'
mode with a nonzerodropout
value. For such a layer, on each forward pass a subset of sequence positions selected at random will not receive positional marking.
__init__
(max_len=2048, dropout=0.0, dropout_broadcast_dims=(2, ), use_bfloat16=False, start_from_zero_prob=1.0, max_offset_to_add=0, d_feature=None, mode='train')¶ Creates a
PositionalEncoding
instance in a given mode.Parameters:  max_len – Maximum input sequence length.
 dropout – Probability of not adding positional encoding to a sequence
position. Applies only if layer is created in
'train'
mode.  dropout_broadcast_dims – Axes along which dropout mask values are broadcast rather than individually set at random.
 use_bfloat16 – If
True
, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.  start_from_zero_prob – how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize).
 max_offset_to_add – maximum offset to add to the positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples.
 d_feature – int or None; have this dimension for embeddings + shared FF if not None.
 mode – One of
'train'
,'eval'
, or'predict'
.

forward
(inputs)¶ Returns the input activations, with added positional information.

init_weights_and_state
(input_signature)¶ Randomly initializes the positional encoding vectors.
Parameters: input_signature – ShapeDtype
instance characterizing the input this layer should compute on.

base¶
The key layer abstraction (Layer class) and supporting machinery.

class
trax.layers.base.
Layer
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Bases:
object
Base class for composable layers in a deep learning network.
Layers are the basic building blocks for deep learning models. A layer computes a function from zero or more inputs to zero or more outputs, optionally using trainable weights (common) and nonparameter state (not common).
Layer subclasses typically override at most two methods of the base Layer class:
 forward(inputs):
 Computes the layer’s output as part of a forward pass through the model.
 init_weights_and_state(self, input_signature):
 Initializes the layer’s weights and state to handle input with the given signature (number, shapes and dtypes of input arguments).
A small number of layer types are combinators – they organize the computation of their sublayers, e.g., applying their sublayers in series or in parallel.
All layers have the following properties, with default values implemented in the base Layer class:
 n_in: int (default 1)
 n_out: int (default 1)
 weights: tuple (default empty – the layer has no weights)
 state: tuple (default empty – the layer has no nonparameter state)
 sublayers: tuple (default empty – the layer has no sublayers)
The inputs to a layer are tensors, packaged according to how many there are:
 n_in = 0: an empty tuple
 n_in = 1: one tensor (NOT wrapped in a tuple)
 n_in > 1: a tuple of tensors
(The special treatment of the singleinput case is meant to simplify the work of layer writers; this design choice may be revisited in the future.)
The outputs from a layer are also tensors, packaged the same as layer inputs:
 n_out = 0: an empty tuple
 n_out = 1: the tensor (NOT wrapped in a tuple)
 n_out > 1: a tuple of tensors
The Trax runtime maintains a data stack with which layer calls are composed. For more complex data network architectures, possibly involving multiple data flows, one can view each layer as a function from stack state to stack state, where the function’s inputs are a slice from the stack, and the function’s outputs are spliced back into the stack.

__init__
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

__call__
(x, weights=None, state=None, rng=None)¶ Makes layers callable; for use in tests or interactive settings.
This convenience method helps library users play with, test, or otherwise probe the behavior of layers outside of a full training environment. It presents the layer as callable function from inputs to outputs, with the option of manually specifying weights and nonparameter state per individual call. For convenience, weights and nonparameter state are cached per layer instance, starting from default values of EMPTY_WEIGHTS and EMPTY_STATE, and acquiring nonempty values either by initialization or from values explicitly provided via the weights and state keyword arguments, in which case the old weights will be preserved, and the state will be updated.
Parameters:  x – Zero or more input tensors, packaged as described in the Layer class docstring.
 weights – Weights or None; if None, use self’s cached weights value.
 state – State or None; if None, use self’s cached state value.
 rng – Singleuse random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.
Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng)¶ Custom backward pass to propagate gradients in a custom way.
Parameters:  inputs – Input tensors; can be a (possibly nested) tuple.
 output – The result of running this layer on inputs.
 grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
 weights – This layer’s weights.
 state – This layer’s state prior to the current forward pass.
 new_state – This layer’s state after the current forward pass.
 rng – Singleuse random number generator (JAX PRNG key).
Returns: The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.

init
(input_signature, rng=None, use_cache=False)¶ Initializes weights/state of this layer and its sublayers recursively.
Initialization creates layer weights and state, for layers that use them. It derives the necessary array shapes and data types from the layer’s input signature, which is itself just shape and data type information.
For layers without weights or state, this method safely does nothing.
This method is designed to create weights/state only once for each layer instance, even if the same layer instance occurs in multiple places in the network. This enables weight sharing to be implemented as layer sharing.
Parameters:  input_signature – ShapeDtype instance (if this layer takes one input) or list/tuple of ShapeDtype instances.
 rng – Singleuse random number generator (JAX PRNG key), or None; if None, use a default computed from an integer 0 seed.
 use_cache – If True, and if this layer instance has already been initialized elsewhere in the network, then return special marker values – tuple (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE). Else return this layer’s newly initialized weights and state.
Returns: A (weights, state) tuple.

init_from_file
(file_name, weights_only=False, input_signature=None)¶ Initializes this layer and its sublayers from a pickled checkpoint.
In the common case (weights_only=False), the file must be a gziped pickled dictionary containing items with keys ‘flat_weights’, `’flat_state’ and ‘input_signature’, which are used to initialize this layer. If input_signature is specified, it’s used instead of the one in the file. If weights_only is True, the dictionary does not need to have the ‘flat_state’ item and the state it not restored either.
Parameters:  file_name – Name/path of the pickled weights/state file.
 weights_only – If True, initialize only the layer’s weights. Else initialize both weights and state.
 input_signature – Input signature to be used instead of the one from file.
Returns: A (weights, state) tuple.

save_to_file
(file_name, weights_only=False, input_signature=None)¶ Saves this layer and its sublayers to a pickled checkpoint.
Parameters:  file_name – Name/path of the pickled weights/state file.
 weights_only – If True, save only the layer’s weights. Else save both weights and state.
 input_signature – Input signature to be used.

name
¶ Returns the name of this layer.

n_in
¶ Returns how many tensors this layer expects as input.

n_out
¶ Returns how many tensors this layer promises as output.

sublayers
¶ Returns a tuple containing this layer’s sublayers; may be empty.

weights
¶ Returns this layer’s weights.
Depending on the layer, the weights can be in the form of:
 an empty tuple
 a tensor (ndarray)
 a nested structure of tuples and tensors
If the layer has sublayers, the weights by convention will be a tuple of length len(sublayers) containing the weights of sublayers. Note that in this case self._weights only marks which ones are shared.

state
¶ Returns a tuple containing this layer’s state; may be empty.
If the layer has sublayers, the state by convention will be a tuple of length len(sublayers) containing sublayer states. Note that in this case self._state only marks which ones are shared.

weights_and_state_signature
(input_signature, unsafe=False)¶ Return a pair containing the signatures of weights and state.

rng
¶ Returns this layer’s current singleuse random number generator.
Code that wants to base random samples on this generator must explicitly split off new generators from it. (See, for example, the rng setter code below.)

pure_fn
(x, weights, state, rng, use_cache=False)¶ Applies this layer as a pure function with no optional args.
This method exposes the layer’s computation as a pure function. This is especially useful for JIT compilation. Do not override, use forward instead.
Parameters:  x – Zero or more input tensors, packaged as described in the Layer class docstring.
 weights – A tuple or list of trainable weights, with one element for this layer if this layer has no sublayers, or one for each sublayer if this layer has sublayers. If a layer (or sublayer) has no trainable weights, the corresponding weights element is an empty tuple.
 state – Layerspecific nonparameter state that can update between batches.
 rng – Singleuse random number generator (JAX PRNG key).
 use_cache – if True, cache weights and state in the layer object; used to implement layer sharing in combinators.
Returns: A tuple of (tensors, state). The tensors match the number (n_out) promised by this layer, and are packaged as described in the Layer class docstring.

output_signature
(input_signature)¶ Returns output signature this layer would give for input_signature.

class
trax.layers.base.
PureLayer
(forward_fn, n_in=1, n_out=1, name='PureLayer')¶ Bases:
trax.layers.base.Layer
Pure function from inputs to outputs, packaged as neural network layer.
The PureLayer class represents the simplest kinds of layers: layers with no trainable weights and no randomness, hence pure functions from inputs to outputs.

__init__
(forward_fn, n_in=1, n_out=1, name='PureLayer')¶ Creates an unconnected PureLayer instance.
Parameters:  forward_fn – Pure function from input tensors to output tensors, where inputs and outputs are packaged as specified for forward.
 n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use only in debugging.

forward
(inputs)¶ Overrides Layer.forward.
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.


trax.layers.base.
Fn
(name, f, n_out=1)¶ Returns a layer with no weights that applies the function f.
f can take and return any number of arguments, and takes only positional arguments – no default or keyword arguments. It often uses JAXnumpy (jnp). The following, for example, would create a layer that takes two inputs and returns two outputs – elementwise sums and maxima:
Fn(‘SumAndMax’, lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)The layer’s number of inputs (n_in) is automatically set to number of positional arguments in f, but you must explicitly set the number of outputs (n_out) whenever it’s not the default value 1.
Parameters:  name – Classlike name for the resulting layer; for use in debugging.
 f – Pure function from input tensors to output tensors, where each input tensor is a separate positional arg, e.g., f(x0, x1) –> x0 + x1. Output tensors must be packaged as specified in the Layer class docstring.
 n_out – Number of outputs promised by the layer; default value 1.
Returns: Layer executing the function f.

exception
trax.layers.base.
LayerError
(layer_name, function_name, caller, input_signature, traceback_string)¶ Bases:
Exception
Exception raised in the layer stack.

__init__
(layer_name, function_name, caller, input_signature, traceback_string)¶ Initialize self. See help(type(self)) for accurate signature.

message
¶ Assembles current layer context into an error message.


trax.layers.base.
flatten_weights_and_state
(weights, state)¶ Flatten weights and state into lists, excluding empty and cached ones.

trax.layers.base.
unflatten_weights_and_state
(flat_weights, flat_state, weights_and_state_signature, weights_only=False)¶ Unflatten weights and state given their signatures.

trax.layers.base.
np_to_file
(list_of_nparrays, file_path, compresslevel)¶ Save numpy arrays to file_path with gzipping and failure protection.

trax.layers.base.
np_from_file
(file_path, compresslevel)¶ Load numpy arrays from file_path with gzipping.

trax.layers.base.
to_list
(outputs)¶ Converts layer outputs to a nested list, for easier equality testing.
Parameters: outputs – A tensor or tuple/list of tensors coming from the forward application of a layer. Each tensor is NumPy ndarraylike, which complicates simple equality testing (e.g., via assertEquals): such tensors require equality testing to use either all (all elements match) or any (at least one element matches), which is not directly supported in absltest. Returns: A nested list structure containing all the output values, but now directly testable using assertEquals.

trax.layers.base.
shard
(tensors, n_shards=None)¶ Shard tensors across n_shards.

trax.layers.base.
unshard_in_pmap
(tensors, n_shards)¶ Unshard tensors that were sharded into n_shards (call inside pmap).

trax.layers.base.
unshard
(tensors, n_shards=None)¶ Unshard tensors that were sharded into n_shards (outside of pmap).
combinators¶
Combinators for composing layers.

class
trax.layers.combinators.
Serial
(*sublayers, name=None, sublayers_to_print=None)¶ Bases:
trax.layers.base.Layer
Combinator that applies layers serially (by function composition).
This combinator is commonly used to construct deep networks, e.g., like this:
mlp = tl.Serial( tl.Dense(128), tl.Relu(), tl.Dense(10), )
A Serial combinator uses stack semantics to manage data for its sublayers. Each sublayer sees only the inputs it needs and returns only the outputs it has generated. The sublayers interact via the data stack. For instance, a sublayer k, following sublayer j, gets called with the data stack in the state left after layer j has applied. The Serial combinator then:
 takes n_in items off the top of the stack (n_in = k.n_in) and calls layer k, passing those items as arguments; and
 takes layer k’s n_out return values (n_out = k.n_out) and pushes them onto the data stack.
A Serial instance with no sublayers acts as a specialcase (but useful) 1input 1output noop.

__init__
(*sublayers, name=None, sublayers_to_print=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

class
trax.layers.combinators.
Parallel
(*sublayers, name=None)¶ Bases:
trax.layers.base.Layer
Combinator that applies a list of layers in parallel to its inputs.
Layers in the list apply to successive spans of inputs, where the spans are determined how many inputs each layer takes. The resulting output is the (flattened) concatenation of the respective layer outputs.
For example, suppose one has three layers:
 F: 1 input, 1 output
 G: 3 inputs, 1 output
 H: 2 inputs, 2 outputs (h1, h2)
Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:
 inputs: a, b, c, d, e, f
 outputs: F(a), G(b, c, d), h1, h2 where h1, h2 = H(e, f)
As an important special case, a None argument to Parallel acts as if it takes one argument, which it leaves unchanged. (It acts as a onearg noop.) For .. rubric:: Example
Parallel(None, F)
creates a layer that passes its first input unchanged and applies F to the following input(s).

__init__
(*sublayers, name=None)¶ The constructor.
Parameters:  *sublayers – A list of sublayers.
 name – Descriptive name for this layer.
Returns: A new layer in which each of the given sublayers applies to its corresponding span of elements in the dataflow stack.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

class
trax.layers.combinators.
Concatenate
(n_items=2, axis=1)¶ Bases:
trax.layers.base.Layer
Concatenates a number of tensors into a single tensor.
For example:
x = np.array([1, 2]) y = np.array([3, 4]) z = np.array([5, 6]) concat3 = tl.Concatenate(n_items=3) z = concat3((x, y, z)) # z = [1, 2, 3, 4, 5, 6]
Use the axis argument to specify on which axis to concatenate the tensors. By default it’s the last axis, axis=1, and n_items=2.

__init__
(n_items=2, axis=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.


class
trax.layers.combinators.
Split
(n_items=2, axis=1)¶ Bases:
trax.layers.base.Layer
Splits the input into n items along an axis.

__init__
(n_items=2, axis=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.


class
trax.layers.combinators.
Scan
(layer, axis=0, n_carry=1, remat=False, mode='train')¶ Bases:
trax.layers.base.Layer
Applies a layer progressively/cumulatively to an axisderived sequence.
Conceptually, this is a function from a list to a samelength list of partial (cumulative) results. For instance, a list of values ([1, 2, 3, 4, 5]) can transform to a list of cumulative sums ([1, 3, 6, 10, 15]). Functions for the same concept are called scan in Scala, scanl in Haskell, and accumulate* in Factor.
In more detail, we assume the layer takes a tuple of inputs of the following form:
(input1, …, inputN, carry1, …, carryM)and returns:
(output1, …, outputK, new_carry1, …, new_carryM)The scanned version applies the layer iteratively to a tensor treating values at the given axis as if they were a list. For example, to calculate all sums of prefixes of a tensor, we can do this:
def add(x, carry): def f(input, carry): res = input + carry return res, res # output and carry are the same return tl.Fn('add', f, n_out=2) Scan(add)([1, 2, 3], 0) = [1, 3, 6], 6

__init__
(layer, axis=0, n_carry=1, remat=False, mode='train')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

state
¶ Returns a tuple containing this layer’s state.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.


class
trax.layers.combinators.
Cond
(cond, true, false=None, name=None)¶ Bases:
trax.layers.base.Layer
Applies layers conditionally.
For parameters cond, true, and false runs the equivalent of true(y) if cond(x) else false(y), where x is cond.n_in elements from front of the stack and y is the rest of the stack. Exactly one of true and false functions is executed, so it can be used to conditionally run long computations. The state of nonexecuted function is not updated. Note that different branches may be executed on different devices if cond returns different values on them. By default ‘false’ function is an identity.
cond must return exactly one element: a Boolean value. true and false must have the same n_in, and the same n_out.

__init__
(cond, true, false=None, name=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.
Parameters: xs – Tensors of as required by the branches of this conditional. Returns: Tensors resulting from running the chosen branch.


trax.layers.combinators.
Chunk
(layer, chunk_size, pass_unchunkable=True)¶ Executes layer using batch chunks of size chunk_size to save memory.

trax.layers.combinators.
Branch
(*layers, name='Branch')¶ Combinator that applies a list of layers in parallel to copies of inputs.
Each layer in the input list is applied to as many inputs from the stack as it needs, and their outputs are successively combined on stack.
For example, suppose one has three layers:
 F: 1 input, 1 output
 G: 3 inputs, 1 output
 H: 2 inputs, 2 outputs (h1, h2)
Then Branch(F, G, H) will take 3 inputs and give 4 outputs:
 inputs: a, b, c
 outputs: F(a), G(a, b, c), h1, h2 where h1, h2 = H(a, b)
As an important special case, a None argument to Branch acts as if it takes one argument, which it leaves unchanged. (It acts as a onearg noop.)
Parameters:  *layers – List of layers.
 name – Descriptive name for this layer.
Returns: A branch layer built from the given sublayers.

trax.layers.combinators.
Residual
(*layers, shortcut=None)¶ Wraps a series of layers with a residual connection.
Parameters:  *layers – One or more layers, to be applied in series.
 shortcut – If None (the usual case), the Residual layer computes the elementwise sum of the stacktop input with the output of the layer series. If specified, the shortcut layer applies to a copy of the inputs and (elementwise) adds its output to the output from the main layer series.
Returns: A layer representing a residual connection paired with a layer series.

trax.layers.combinators.
Select
(indices, n_in=None, name=None)¶ Copies, reorders, or deletes stack elements according to indices.
Parameters:  indices – A list or tuple of 0based indices to select elements relative to the top of the stack.
 n_in – Number of input elements to pop from the stack, and replace with those specified by indices. If not specified, its value will be calculated as max(indices) + 1.
 name – Descriptive name for this layer.
Returns: Tensors, matching the number selected (n_out = len(indices)). Specifically:
 n_out = 0: an empty tuple
 n_out = 1: one tensor (NOT wrapped in a tuple)
 n_out > 1: a tuple of tensors, with n_out items

trax.layers.combinators.
Drop
()¶ Drops the top stack element.

trax.layers.combinators.
Dup
()¶ Duplicates (copies) the top element on the data stack.

trax.layers.combinators.
Swap
()¶ Swaps the top two stack elements.

trax.layers.combinators.
SerialWithSideOutputs
(layers, n_side_outputs=1)¶ Serial layer with side outputs.
This layer makes it easier to manage the stack when layers have side outputs.
In the simplest case of layers with n_in=1, n_out=2 and with n_side_outputs=1, this layer runs the following computation on x:
side_outputs = [] for i in range(len(layers)): x, side_output = layers[i](x) side_outputs.append(side_output) return [x] + side_outputs
In the general case of layers with variable n_in and n_out and n_side_outputs being a list of N integers, it does the following:
side_outputs = [] for i in range(N): res = layer[i](cur_stack) # remove n_in from stack cur_stack.append(res[:n_side_outputs[i]]) # put back some on stack side_outputs.extend(res[n_side_outputs:]) return cur_stack + side_outputs
Parameters:  layers – a list of layers to execute
 n_side_outputs – an int or a list of ints, how many outputs of each layer to put aside
Returns: A layer that performs the above computation.

trax.layers.combinators.
FlattenList
()¶ Flatten lists.

trax.layers.combinators.
Add
()¶ Adds two tensors.

trax.layers.combinators.
SubtractTop
()¶ Subtracts the first tensor from the second.

trax.layers.combinators.
Multiply
()¶ Multiplies two tensors.

trax.layers.combinators.
Gate
()¶ Returns a gating layer on a (memory, gate, candidate) tuple.
Final update is memory * gate + (1  gate) * candidate
This gating equation may also be referred to as Highway Network. Highway Networks: https://arxiv.org/abs/1505.00387

class
trax.layers.combinators.
Cache
(layer)¶ Bases:
trax.layers.base.Layer
Applies a layer on the first run and returns the outputs on next calls.

__init__
(layer)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

state
¶ Returns a tuple containing this layer’s state; may be empty.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.
Parameters: inputs – Tensors required by the sublayer. Returns: Tensors resulting from running the sublayer the first time.


class
trax.layers.combinators.
BatchLeadingAxes
(layer, n_last_axes_to_keep=1)¶ Bases:
trax.layers.base.Layer
Applies a layer after flattening all but n_last_axes_to_keep to batch.
This can be used to make layers accept an arbitrary number of leading axes (dimensions) as batch. For example, a Convolution layer may normally only operate on tensors of shape [B, W, H, C]. In this case, the layer
BatchLeadingAxes(Convolution(), n_last_axes_to_keep=3)will operate on any tensor […, W, H, C] and treat the leading axes as batch.

__init__
(layer, n_last_axes_to_keep=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

sublayer
¶ Returns the unique sublayer managed by this layer.

forward
(inputs)¶ Executes this layer as part of a forward pass through the model.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.


trax.layers.combinators.
Bidirectional
(forward_layer, axis=1, merge_layer=Concatenate_in2)¶ Bidirectional combinator for RNNs.
Parameters:  forward_layer – A layer, such as trax.layers.LSTM or trax.layers.GRU.
 axis – a time axis of the inputs. Default value is 1.
 merge_layer – A combinator used to combine outputs of the forward and backward RNNs. Default value is ‘trax.layers.Concatenate’.
Example
Bidirectional(RNN(n_units=8))
Returns: The Bidirectional combinator for RNNs.

trax.layers.combinators.
inputs_from_stack
(stack, n)¶ Returns n inputs from stack.

trax.layers.combinators.
outputs_onto_stack
(outputs, stack, n)¶ “Returns the new stack after removing n items and pushing outputs there.
convolution¶
Trax convolution layers.

class
trax.layers.convolution.
Conv
(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Bases:
trax.layers.base.Layer
Layer constructor function for a general convolution layer.

__init__
(filters, kernel_size, strides=None, padding='VALID', dimension_numbers=('NHWC', 'HWIO', 'NHWC'), kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.convolution.
CausalConv
(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Bases:
trax.layers.convolution.Conv
Causal (masked) convolution for [batch x time x depth] sequences.
Maintains causality along time axis. Used in language modeling tasks.

__init__
(filters, kernel_width=3, kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.


trax.layers.convolution.
Conv1d
(filters, kernel_size, stride=1, padding='VALID', kernel_initializer=None, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True)¶
core¶
Core layer types and key functions used by various layers.

class
trax.layers.core.
Dense
(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, use_bfloat16=False)¶ Bases:
trax.layers.base.Layer
A dense (a.k.a. fullyconnected, affine) layer.
Dense layers are the prototypical example of a trainable layer, i.e., a layer with trainable weights. Each node in a dense layer computes a weighted sum of all node values from the preceding layer and adds to that sum a nodespecific bias term. The full layer computation is expressed compactly in linear algebra as an affine map y = Wx + b, where W is a matrix and y, x, and b are vectors. The layer is trained, or “learns”, by updating the values in W and b.
Less commonly, a dense layer can omit the bias term and be a pure linear map: y = Wx.

__init__
(n_units, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, use_bfloat16=False)¶ Returns a dense (fully connected) layer of width n_units.
A dense layer maps collections of R^m vectors to R^n, where n (= n_units) is fixed at layer creation time, and m is set at layer initialization time.
Parameters:  n_units – Number of nodes in the layer, also known as the width of the layer.
 kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
 bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
 use_bias – If True, compute an affine map y = Wx + b; else compute a linear map y = Wx.
 use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer’s n_units value.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.
Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on.


class
trax.layers.core.
Embedding
(vocab_size, d_feature, use_bfloat16=False, kernel_initializer=<function ScaledInitializer.<locals>.Init>)¶ Bases:
trax.layers.base.Layer
Trainable layer that maps discrete tokens/IDs to vectors.
Embedding layers are commonly used to map discrete data, like words in NLP, into vectors. Here is a canonical example:
vocab_size = 5 word_ids = np.array([1, 2, 3, 4], dtype=np.int32) # word_ids < vocab_size embedding_layer = tl.Embedding(vocab_size, 32) embedding_layer.init(trax.shapes.signature(word_ids)) embedded = embedding_layer(word_ids) # embedded.shape = (4, 32)

__init__
(vocab_size, d_feature, use_bfloat16=False, kernel_initializer=<function ScaledInitializer.<locals>.Init>)¶ Returns an embedding layer with given vocabulary size and vector size.
The layer clips input values (token IDs) to the range [0, vocab_size). That is, negative token IDs all clip to 0 before being mapped to a vector, and token IDs with value vocab_size or greater all clip to vocab_size  1 before being mapped to a vector.
Parameters:  vocab_size – Size of the input vocabulary. The layer will assign a unique vector to each id in range(vocab_size).
 d_feature – Dimensionality/depth of the output vectors.
 use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.
 kernel_initializer – Function that creates (random) initial vectors for the embedding.

forward
(x)¶ Returns embedding vectors corresponding to input token IDs.
Parameters: x – Tensor of token IDs. Returns: Tensor of embedding vectors.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.


class
trax.layers.core.
Dropout
(rate=0.0, shared_axes=None, mode='train')¶ Bases:
trax.layers.base.Layer
A layer that stochastically ignores a subset of inputs each training step.
In training, to compensate for the fraction of input values dropped (rate), all surviving values are multiplied by 1 / (1  rate).
The parameter shared_axes allows to specify a list of axes on which the mask will be shared: we will use size 1 on those axes for dropout mask and broadcast it. Sharing reduces randomness, but can save memory.
This layer is active only during training (mode=’train’). In other circumstances it is a noop.
Originally introduced in the paper “Dropout: A Simple Way to Prevent Neural Networks from Overfitting” available under the following link: https://www.cs.toronto.edu/~hinton/absps/JMLRdropout.pdf

__init__
(rate=0.0, shared_axes=None, mode='train')¶ Creates a dropout layer with the given target drop rate.
Parameters:  rate – Stochastic rate (probability) for dropping an activation value from the preceding layer (setting it to zero).
 shared_axes – List of axes on which the mask is shared.
 mode – If ‘train’, this layer will perform dropout; else, it will pass all values through unaltered.

init_weights_and_state
(input_signature)¶ Sets layerspecific internal state.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of activations. Returns: Tensor of same shape and dtype as the input.


class
trax.layers.core.
Weights
(initializer, shape=(), use_bfloat16=False)¶ Bases:
trax.layers.base.Layer
Learnable weights as a layer.
It takes no input and returns a single tensor: weights.

__init__
(initializer, shape=(), use_bfloat16=False)¶ Returns a learnable tensor of shape shape.
Parameters:  initializer – Function taking shape and rng as arguments.
 shape – Shape of the learnable weights.
 use_bfloat16 – If True, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor with previously specified shape and dtype.

init_weights_and_state
(input_signature)¶ Returns newly initialized weights for this layer.
Weights is a single w tensor with previously specified shape.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on. Unused.


trax.layers.core.
PrintShape
(n_in=1, msg='')¶ Prints the shapes of n_in inputs and returns then unchanged.

class
trax.layers.core.
SummaryImage
(name, n_in, num_summaries=5, recover_fn=None)¶ Bases:
trax.layers.base.Layer
A layer receiving a tensor, and adding it to TensorBoard as an image.
It takes an input and returns it unchanged. It stores this input as a state to be used as a metric in TensorBoard. It converts a tensor to a scalar by running a given aggregation function (mean by default). On TensorBoard, results for each device will be reported separately.

__init__
(name, n_in, num_summaries=5, recover_fn=None)¶ Takes a tensor and returns it.
Parameters:  name – Name of the metric to be reported.
 n_in – Number of inputs.
 num_summaries – Number of images to show.
 recover_fn – the function for converting a tensor to a dipslayable image.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor with previously specified shape and dtype.

init_weights_and_state
(input_signature)¶ Returns newly initialized weights for this layer.
Weights is a single w tensor with previously specified shape.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on. Unused.


class
trax.layers.core.
SummaryScalar
(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)¶ Bases:
trax.layers.base.Layer
A layer receiving a tensor, and adding it to TensorBoard as a scalar.
It takes an input and returns it unchanged. It stores this input as a state to be used as a metric in TensorBoard. It converts a tensor to a scalar by running a given aggregation function (mean by default). On TensorBoard, results for each device will be reported separately.

__init__
(name, aggregation_fun=<sphinx.ext.autodoc.importer._MockObject object>)¶ Takes a tensor and returns it.
Parameters:  name – Name of the metric to be reported.
 aggregation_fun – Aggregation function to be used.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor with previously specified shape and dtype.

init_weights_and_state
(input_signature)¶ Returns newly initialized weights for this layer.
Weights is a single w tensor with previously specified shape.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on. Unused.


class
trax.layers.core.
RandomUniform
(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)¶ Bases:
trax.layers.base.Layer
Layer returning a tensor with random values distributed uniformly.

__init__
(min_val=0.0, max_val=1.0, shape=(), dtype=<sphinx.ext.autodoc.importer._MockObject object>, sync=False)¶ Layer returning a tensor with random values distributed uniformly.
Parameters:  min_val – Lower end of uniform distribution.
 max_val – Upper end of uniform distribution.
 shape – Shape of the tensor to return. Values are sampled independently.
 dtype – Type of value to return.
 sync – Whether to synchronise rng across devices.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.
Parameters: xs – Unused tensors. Returns: Random uniform tensor of the shape and type specified in constructor.


class
trax.layers.core.
LocallyConnected1d
(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')¶ Bases:
trax.layers.base.Layer
Locallyconnected layer for 1D inputs.
The LocallyConnected1d layer applies a different set of filters to each patch of the input. This is similar to applying a convolution layer, except that locallyconnected layer uses a different set of weights for each patch.
The size of patch is determined by the kernel size. The stride is currently not modifiable and set to one. This means for the input of shape (…, L, D) the output shape for paddings ‘SAME’ and ‘WRAP’ will be (…, L, filters) and for padding ‘VALID’ (…, Lkernel_size+1, filters); where L is the number of “pixels” or “steps” in the input, D is the size of the embedding.
Note that, since the weights for different patches are not shared, the number of “pixels” or “steps” cannot change after calling init_weights_and_state. This is because each “pixel” is assigned its own set of weights.

__init__
(filters, kernel_size, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, use_bias=True, padding='VALID')¶ Returns a locallyconnected convlike layer.
Parameters:  filters – Number of output filters in the convolution.
 kernel_size – A length of the convolution window. Must be an odd number.
 kernel_initializer – Function that creates a matrix of (random) initial connection weights W for the layer.
 bias_initializer – Function that creates a vector of (random) initial bias weights b for the layer.
 use_bias – If True, the layer uses a bias vector.
 padding – The type of padding to use; must be ‘VALID’, ‘SAME’, or ‘WRAP’.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer’s filters value, and the second to last dimension is shrinked if ‘VALID’ padding is used with kernel_size bigger than one.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.
Weights are a (w, b) tuple for layers created with use_bias=True (the default case), or a w tensor for layers created with use_bias=False.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on.


trax.layers.core.
Flatten
(n_axes_to_keep=1)¶ Returns a layer that combines one or more trailing axes of a tensor.
Flattening keeps all the values of the input tensor, but reshapes it by collapsing one or more trailing axes into a single axis. For example, a Flatten(n_axes_to_keep=2) layer would map a tensor with shape (2, 3, 5, 7, 11) to the same values with shape (2, 3, 385).
Parameters: n_axes_to_keep – Number of leading axes to leave unchanged when reshaping; collapse only the axes after these.

trax.layers.core.
LogSoftmax
(axis=1)¶ Returns a layer that applies log softmax along one tensor axis.
Note that the implementation actually computes x  LogSumExp(x), which is mathematically equal to LogSoftmax(x).
LogSoftmax acts on a group of values and normalizes them to look like a set of log probability values. (Probability values must be nonnegative, and as a set must sum to 1. A group of log probability values can be seen as the natural logarithm function applied to a set of probability values.)
Parameters: axis – Axis along which values are grouped for computing log softmax.

trax.layers.core.
LogSumExp
(axis=1)¶ Returns a layer that computes log(sum(exp(x))) along one tensor axis.
Parameters: axis – Axis along which values are grouped for computing logsumexp.

trax.layers.core.
Softmax
(axis=1)¶ Returns a layer that applies softmax along one tensor axis.
Softmax acts on a group of values and normalizes them to look like a set of probability values. (Probability values must be nonnegative, and as a set must sum to 1.)
Parameters: axis – Axis along which values are grouped for computing softmax.

trax.layers.core.
ToFloat
()¶ Returns a layer that changes the dtype of a tensor to float32.

trax.layers.core.
Mean
(axis=1, keepdims=False)¶ Returns a layer that computes mean values using one tensor axis.
Mean uses one tensor axis to form groups of values and replaces each group with the mean value of that group. The resulting values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.
Parameters:  axis – Axis along which values are grouped for computing a mean.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
Min
(axis=1, keepdims=False)¶ Returns a layer that applies min along one tensor axis.
Parameters:  axis – Axis along which values are grouped for computing minimum.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
Max
(axis=1, keepdims=False)¶ Returns a layer that applies max along one tensor axis.
Parameters:  axis – Axis along which values are grouped for computing maximum.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
Sum
(axis=None, keepdims=False)¶ Returns a layer that computes sums using one tensor axis.
Sum uses one tensor axis to form groups of values and replaces each group with the sum of that group. The resulting sum values can either remain in their own size 1 axis (keepdims=True), or that axis can be removed from the overall tensor (default keepdims=False), lowering the rank of the tensor by one.
Parameters:  axis – Axis along which values are grouped for computing a sum; if None, compute sum over all elements in tensor.
 keepdims – If True, keep the resulting size 1 axis as a separate tensor axis; else, remove that axis.

trax.layers.core.
ThresholdToBinary
(threshold=0.5)¶ Returns a layer that thresholds inputs to yield outputs in {0, 1}.

trax.layers.core.
ArgMax
(axis=1)¶ Returns a layer that calculates argmax along the given axis.

trax.layers.core.
Negate
()¶ Returns a layer that computes the elementwise negation of a tensor.

trax.layers.core.
StopGradient
()¶ Returns an identity layer with a stop gradient.

trax.layers.core.
one_hot
(x, n_categories, dtype=<sphinx.ext.autodoc.importer._MockObject object>)¶ Makes a onehot array (n+1 dims) from an intcategorical array (n dims).

trax.layers.core.
log_softmax
(x, axis=1)¶ Transforms activation vectors to logprobability vectors.
Log probability vectors are derived by, in effect, applying softmax to raw activation vectors and then applying log elementwise. The actual implementation uses a mathematically valid simplification of this.
Parameters:  x – An ndarray with activation vectors along the given axis.
 axis – Axis along which values are grouped for computing log softmax.
Returns: An ndarray containing logprobability vectors derived from the raw activation vectors in x.

trax.layers.core.
log_gaussian_pdf
(x, mu, sigma)¶ Returns log N(x  mu, sigma).
Parameters:  x – <tbd>
 mu – <tbd>
 sigma – <tbd>

trax.layers.core.
log_gaussian_diag_pdf
(x, mu, diag_sigma)¶ Returns log N(x  mu, eye(diag_sigma)).
Parameters:  x – <tbd>
 mu – <tbd>
 diag_sigma – <tbd>

trax.layers.core.
multigaussian_loss
(preds, targets, ngauss=1)¶ Returns a mixture of gaussians loss.
Parameters:  preds – <tbd>
 targets – <tbd>
 ngauss – <tbd>

trax.layers.core.
logsoftmax_sample
(log_probs, temperature=1.0)¶ Returns a sample from a logsoftmax output, with temperature.
Parameters:  log_probs – Logarithms of probabilities (often coming from LogSoftmax)
 temperature – For scaling before sampling (1.0 = default, 0.0 = pick argmax)
initializers¶
Trax initializers.

trax.layers.initializers.
InitializerFromFile
(path)¶ Loads parameters from .npy file.

trax.layers.initializers.
RandomNormalInitializer
(stddev=0.01)¶ Returns an initializer for random normal coefficients.

trax.layers.initializers.
RandomUniformInitializer
(lim=1.0)¶ Returns an initializer for random uniform coefficients.

trax.layers.initializers.
ScaledInitializer
(out_dim, in_dim, scale, mode, distribution)¶ Returns an initializer that adjusts its scale based on weight shapes.

trax.layers.initializers.
GlorotNormalInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random Glorotscaled coefficients.

trax.layers.initializers.
GlorotUniformInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random uniform Glorotscaled coefficients.

trax.layers.initializers.
LeCunNormalInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random LeCunscaled coefficients.

trax.layers.initializers.
LeCunUniformInitializer
(out_dim=1, in_dim=2, scale=1.0)¶ Returns an initializer for random uniform LeCunscaled coefficients.

trax.layers.initializers.
KaimingNormalInitializer
(out_dim=1, in_dim=2, param=0.0)¶ Returns an initializer for random Kaimingscaled coefficients.

trax.layers.initializers.
KaimingUniformInitializer
(out_dim=1, in_dim=2, param=0.0)¶ Returns an initializer for random uniform Kaimingscaled coefficients.

trax.layers.initializers.
OrthogonalInitializer
(stddev=1.0)¶ Returns an orthogonal initializer.

trax.layers.initializers.
AtariConvInit
(kernel_shape, rng, dtype=<sphinx.ext.autodoc.importer._MockObject object>)¶ The standard init for Conv laters and Atari.
metrics¶
Layers for computing loss functions and evaluation metrics.
A metric layer computes a scalar value from two or three ndarray inputs:
 model outputs: Batch of predicted values (typically vectors).
 targets: Batch of target values (e.g., categories or vectors).
 weights: Float values that allow for uneven weighting of batch items, sequence positions, or vector components when computing an overall scalar value for the batch.
Most metric computations take into account the items that make up a batch. For each item in a batch, a raw metric value is computed by comparing (itemwise) the model output to the target value. These itemwise values are then combined into a single scalar for the batch by a function such as sum, average, or weightedaverage. For example:
 CategoryAccuracy: Treat model output as vectors whose components correspond to the possible categories; measure a vector as correct (value 1) if its largest component is the target category, else as incorrect (value 0). The accuracy for the batch is then the average across vectors of these 1’s and 0’s.
 CategoryCrossEntropy: Treat model output and target values as the source of two probability distributions; measure the crossentropy of the model’s predicted distribution relative to the (assumed true) target distribution. The scalar value for the batch is then the average of the itemwise crossentropy values.

trax.layers.metrics.
CategoryAccuracy
()¶ Returns a layer that computes category prediction accuracy.
The layer takes two inputs:
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (
axis=1
) picks the index (category) having the highest probablity.  A batch of target categories; each target is an integer in \(\{0, ..., N1\}\).
The predicted category from each vector is the index of the highestvalued vector component. The layer returns the accuracy of these predictions averaged over the batch.
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (

trax.layers.metrics.
WeightedCategoryAccuracy
()¶ Returns a layer that computes a weighted category prediction accuracy.
The layer takes three inputs:
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (
axis=1
) picks the index (category) having the highest probablity.  A batch of target categories; each target is an integer in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality.
 A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item’s target category).
The predicted category from each vector is the index of the highestvalued vector component. The layer returns a weighted average accuracy of these predictions.
 A batch of activation vectors. The components in a given vector should
be mappable to a probability distribution in the following loose sense:
within a vector, a higher component value corresponds to a higher
probability, such that argmax within a vector (

trax.layers.metrics.
CategoryCrossEntropy
(label_smoothing=None)¶ Returns a layer that computes crossentropy from activations and integers.
The layer takes two inputs:
 A batch of activation vectors. The components in a given vector should be presoftmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and crossentropy computations are combined inside the layer.
 A batch of target categories; each target is an integer in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality.
To compute crossentropy per batch item, the layer derives probability distributions:
 from model output (vectors): \(\ q = \text{softmax}(v)\)
 from target categories (integers): \(\ p = \text{one_hot}(n)\) or \(p = (1\varepsilon)\cdot\text{one_hot}(n) + \frac{\varepsilon}{N}\), where \(\varepsilon\) is the label smoothing factor.
(The conversion of integer category targets to onehot vectors amounts to assigning all the probability mass to the target category.) Crossentropy per batch item is computed between the resulting distributions:
\[\text{cross_entropy} =  \sum_{i=0}^{N1} p_i \log q_i\]The layer returns the average of these crossentropy values over all items in the batch.
Parameters: label_smoothing – Creates soft targets if provided. Must be between 0 and 1.

trax.layers.metrics.
WeightedCategoryCrossEntropy
(label_smoothing=None, cutoff=0.0)¶ Returns a layer like
CategoryCrossEntropy
, with weights as third input.The layer takes three inputs:
 A batch of activation vectors. The components in a given vector should be presoftmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and crossentropy computations are combined inside the layer.
 A batch of target categories; each target is an integer in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality.
 A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item’s target category).
The layer returns the weighted average of these crossentropy values over all items in the batch.
Parameters:  label_smoothing – Creates soft targets if provided. Must be between 0 and 1.
 cutoff – Prevent loss lower than this cutoff (0.0 meaning none by default).

trax.layers.metrics.
BinaryCrossEntropy
()¶ Returns a layer that computes crossentropy for binary classification.
The layer takes two inputs:
 A batch of activation values; each batch item \(x\) is a float in \((\infty, \infty)\).
 A batch of binary targets; each target \(t\) is an integer in \(\{0, 1\}\).
The layer maps each activation value into the range \((0, 1)\), interpreted as the modelpredicted probability that item’s category is 1:
\[q = \frac 1 {1 + e^{x}} \ \ \text{[modelpredicted probability]}\]and computes crossentropy (per batch item) by treating the target category as having probability 1:
\[\begin{split}\text{cross_entropy} = \left\{ \begin{array}{cl}  \log q & \text{if}\ t = 1, \\  \log (1  q) & \text{if}\ t = 0. \end{array} \right.\end{split}\]The layer returns the average of these crossentropy values over all items in the batch.

trax.layers.metrics.
MaskedSequenceAccuracy
()¶ Returns a layer that computes sequence prediction accuracy with masking.
This layer type is intended for variable length sequences, especially text, represented as a batch of fixedlength sequences via padding for unused positions.
The layer takes three inputs:
 A batch of sequences of activation vectors. The components in a given
vector should be mappable to a probability distribution in the following
loose sense: within a vector, a higher component value corresponds to a
higher probability, such that argmax within a vector (
axis=1
) picks the index having the highest probablity. In text modeling, the index represents a token id from a predetermined token vocabulary (or padding).  A batch of target integer sequences, with values in \(\{0, ..., N1\}\), where \(N\) is the activation vector depth/dimensionality. In text modeling, these sequences typically represent token ids from a predetermined token vocabulary (or padding).
 A batch of weights/masks, which matches or can be broadcast to match the shape of the target ndarray. This arg is used to give weight 0 to padding positions, which masks those positions out of the calculation. Only the zero/nonzero distinction matters; all nonzero values are treated alike as signaling nonmasked (i.e., valid/inuse) positions.
The predicted integer value for each sequence position is the index of the highestvalued component of the position’s vector. A predicted integer sequence is judged correct if it matches the target integer sequence in all nonzeroweighted positions. The layer returns the accuracy of predicted sequences averaged over the batch.
 A batch of sequences of activation vectors. The components in a given
vector should be mappable to a probability distribution in the following
loose sense: within a vector, a higher component value corresponds to a
higher probability, such that argmax within a vector (

trax.layers.metrics.
Accuracy
(classifier=ArgMax)¶ Returns a layer that computes mean category prediction accuracy.
DEPRECATED; use
WeightedCategoryAccuracy
instead.Parameters: classifier – Layer that transforms activation vectors into category predictions.

trax.layers.metrics.
SequenceAccuracy
(classifier=ArgMax)¶ Returns a layer that computes mean sequence prediction accuracy.
DEPRECATED; use
MaskedSequenceAccuracy
instead.Parameters: classifier – Layer that transforms activation vectors into category predictions.

trax.layers.metrics.
CrossEntropyLoss
()¶ Returns a layer that outputs multiclass predictiontarget crossentropy.
DEPRECATED; refactor to use
WeightedCategoryCrossEntropy
orCategoryCrossEntropy
instead.(
CrossEntropyLoss
by itself does not compute crossentropy. In older code, this layer had to be preceded byLogSoftmax
, and the two layers together did the work of converting category information to probability distributions and computing the crossentropy between those distributions. All this is now done byWeightedCategoryCrossEntropy
.)

trax.layers.metrics.
CrossEntropyLossWithLogSoftmax
()¶ Mean predictiontarget crossentropy for multiclass classification.

trax.layers.metrics.
BinaryCrossEntropyLoss
()¶ Returns a layer that outputs binary predictiontarget crossentropy.
DEPRECATED; refactor to use
BinaryCrossEntropy
instead. (The newerBinaryCrossEntropy
does not use weights, so refactor accordingly. Unless and until clear motivating use cases arise, the library will not include a binary crossentropy function with weights.)

trax.layers.metrics.
L2Loss
()¶ Returns a layer that computes an L2like loss for one batch.
The layer takes three inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
 A batch of weights, which matches the shape of the model output.
The layer returns a weighted average of elementwise squared error terms \((y_i  t_i)^2\).

trax.layers.metrics.
SmoothL1Loss
()¶ Returns a layer that computes a weighted, smoothed L1 loss for one batch.
The layer takes three inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
 A batch of weights, which matches the shape of the model output.
The layer computes a “smooth” L1 loss (a.k.a. Huber loss), for model output float \(y_i\) and target float \(t_i\):
\[\begin{split}\text{output} = \left\{ \begin{array}{cl} \frac 1 2 (y_i  t_i)^2, & \text{if}\ y_i  t_i < 1, \\ y_i  t_i  \frac 1 2, & \text{otherwise}. \end{array} \right.\end{split}\]The layer returns a weighted average of these elementwise values.

trax.layers.metrics.
MacroAveragedFScore
(beta=1.0, initial_category_index=0)¶ Returns a layer that computes a macroaveraged Fscore.
The macroaveraged Fscore summarize how well the classifier’s k predictions align with the observed/gold instances of k. It additionally cares about all the classes equally regardless of their size.
Parameters:  beta – a parameter that determines the weight of recall in the Fscore.
 initial_category_index – an index of the initial category.
The layer takes two inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
The layer returns an macroaveraged Fscore across all the classes.

trax.layers.metrics.
WeightedFScore
(beta=1.0, initial_category_index=0)¶ Returns a layer that computes a weighted Fscore.
The weighted Fscore summarize how well the classifier’s k predictions align with the observed/gold instances of k. It additionally weights the summary by the number of observed/gold and predicted examples in each class.
Parameters:  beta – a parameter that determines the weight of recall in the Fscore.
 initial_category_index – an index of the initial category.
The layer takes two inputs:
 Model output from one batch, an ndarray of floatvalued elements.
 A batch of elementwise target values, which matches the shape of the model output.
The layer returns a weighted Fscore across all the classes.

trax.layers.metrics.
WeightedSum
()¶ Returns a layer that computes a weighted sum of the given values.

trax.layers.metrics.
CrossEntropySum
()¶ Sum of predictiontarget cross entropies for multiclass classification.

trax.layers.metrics.
BinaryCrossEntropySum
()¶ Sum of predictiontarget cross entropies for binary classification.
normalization¶
Trax normalization layers.

class
trax.layers.normalization.
BatchNorm
(axis=(0, 1, 2), epsilon=1e05, center=True, scale=True, momentum=0.999, mode='train')¶ Bases:
trax.layers.base.Layer
Layer that performs batch normalization.
In training, batch normalization keeps smoothed cumulative statistics across batches of input data and modifies each new batch so that its components are normally distributed. In eval or inference, a BatchNorm instance uses its stored mean and variance to approximately normalize each new batch of data.
See https://arxiv.org/abs/1502.03167 for original presentation and motivation of batch normalization).

__init__
(axis=(0, 1, 2), epsilon=1e05, center=True, scale=True, momentum=0.999, mode='train')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes batch normalization as part of a forward pass in the model.

init_weights_and_state
(input_signature)¶ Helper to initialize batch norm weights and state.


class
trax.layers.normalization.
LayerNorm
(center=True, epsilon=1e06)¶ Bases:
trax.layers.base.Layer
Layer normalization.

__init__
(center=True, epsilon=1e06)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.normalization.
FilterResponseNorm
(mode=None, learn_epsilon=False, init_epsilon=1e06, init_learnt_epsilon=0.0001)¶ Bases:
trax.layers.base.Layer
Filter Response Normalization layer without Threshold Linear Unit.
c.f. https://arxiv.org/pdf/1911.09737.pdf

__init__
(mode=None, learn_epsilon=False, init_epsilon=1e06, init_learnt_epsilon=0.0001)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

pooling¶
Trax pooling layers.

trax.layers.pooling.
MaxPool
(pool_size=(2, 2), strides=None, padding='VALID')¶ Reduces each multidimensional window to the max of the window’s values.
Windows, as specified by pool_size and strides, involve all axes of an ndimensional array except the first and last: \((d_1, ..., d_{n2})\) from shape \((d_0, d_1, ..., d_{n2}, d_{n1})\).
Parameters:  pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)dimensional arrays, then pool_size must be a tuple of length \(n2\).
 strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
 padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the selection of max values.
Returns: Ndimensional array in which each valid (or paddedvalid) window position in the input is reduced to / replaced by the max value from that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.
SumPool
(pool_size=(2, 2), strides=None, padding='VALID')¶ Reduces each multidimensional window to the sum of the window’s values.
Windows, as specified by pool_size and strides, involve all axes of an ndimensional array except the first and last: \((d_1, ..., d_{n2})\) from shape \((d_0, d_1, ..., d_{n2}, d_{n1})\).
Parameters:  pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)dimensional arrays, then pool_size must be a tuple of length \(n2\).
 strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
 padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed to avoid partial windows but does not otherwise affect the computation of sums.
Returns: Ndimensional array in which each valid (or paddedvalid) window position in the input is reduced to / replaced by the sum of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.

trax.layers.pooling.
AvgPool
(pool_size=(2, 2), strides=None, padding='VALID')¶ Reduces each multidimensional window to the mean of the window’s values.
Windows, as specified by pool_size and strides, involve all axes of an ndimensional array except the first and last: \((d_1, ..., d_{n2})\) from shape \((d_0, d_1, ..., d_{n2}, d_{n1})\).
Parameters:  pool_size – Shape of window that gets reduced to a single vector value. If the layer inputs are \(n\)dimensional arrays, then pool_size must be a tuple of length \(n2\).
 strides – Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as pool_size. If None, then offsets of 1 along each window axis, \((1, ..., 1)\), will be used.
 padding – ‘VALID’ or ‘SAME’. If ‘VALID’, no padding is done, and only full windows get reduced; partial windows are discarded. If ‘SAME’, padding is added at array edges as needed but is not counted in the computation of averages.
Returns: Ndimensional array in which each valid (or paddedvalid) window position in the input is reduced to / replaced by the mean of values in that window. An output array has the same number of dimensions as its input, but has fewer elements.
reversible¶
Layers that can run in reverse to compute inputs from outputs.
Reversible layers reduce the memory required for backpropagationbased training, especially for deep networks. In a series of reversible layers, input activations from a forward pass don’t need to be stored: they can be reconstructed on the backward pass, layer by layer, from outputs to inputs.
See, e.g., [The Reversible Residual Network: Backpropagation Without Storing Activations](https://arxiv.org/abs/1707.04585) and [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451).

class
trax.layers.reversible.
ReversibleLayer
(n_in=1, n_out=1, name=None, sublayers_to_print=None)¶ Bases:
trax.layers.base.Layer
Reversible Layer.

reverse
(output, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.

reverse_and_grad
(output, grad, weights=(), state=(), new_state=(), rng=None)¶ Backward pass: computes the inverse of a layer and propagates gradients.
While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.
Parameters:  output – Output activations; can be a (possibly nested) tuple.
 grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
 weights – layer weights
 state – start state
 new_state – updated state computed by the forward pass
 rng – Singleuse random number generator (JAX PRNG key).
Returns: A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng)¶ Custom backward pass to propagate gradients in a custom way.
Parameters:  inputs – Input tensors; can be a (possibly nested) tuple.
 output – The result of running this layer on inputs.
 grad – Gradient signal computed based on subsequent layers; its structure and shape must match output.
 weights – This layer’s weights.
 state – This layer’s state prior to the current forward pass.
 new_state – This layer’s state after the current forward pass.
 rng – Singleuse random number generator (JAX PRNG key).
Returns: The custom gradient signal for the input. Note that we need to return a gradient for each argument of forward, so it will usually be a tuple of signals: the gradient for inputs and weights.


class
trax.layers.reversible.
ReversibleConcatenatePair
¶ Bases:
trax.layers.reversible.ReversibleLayer
Maps (x, y) > ([x, y], [x, y]); [x, y] is concatenation on last axis.

__init__
()¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


class
trax.layers.reversible.
ReversibleSelect
(indices, n_in=None, name=None)¶ Bases:
trax.layers.reversible.ReversibleLayer
Reversible version of the Select combinator.

__init__
(indices, n_in=None, name=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


trax.layers.reversible.
ReversibleSwap
()¶

class
trax.layers.reversible.
ReversibleReshape
(shape1, shape2, n_in=1)¶ Bases:
trax.layers.reversible.ReversibleLayer
Reversible reshaping layer.

__init__
(shape1, shape2, n_in=1)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


class
trax.layers.reversible.
ReversiblePrintShape
(n_in=1, msg='')¶ Bases:
trax.layers.reversible.ReversibleLayer
Reversible PrintShape for debugging reversible serial layers.

__init__
(n_in=1, msg='')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(outputs, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.


class
trax.layers.reversible.
ReversibleSerial
(*layers)¶ Bases:
trax.layers.reversible.ReversibleLayer
,trax.layers.combinators.Serial
A reversible version of tl.Serial (requires reversible sublayers).

__init__
(*layers)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

reverse
(output, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.

reverse_and_grad
(output, grad, weights=(), state=(), new_state=(), rng=None)¶ Backward pass: computes the inverse of a layer and propagates gradients.
While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.
Parameters:  output – Output activations; can be a (possibly nested) tuple.
 grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
 weights – layer weights
 state – start state
 new_state – updated state computed by the forward pass
 rng – Singleuse random number generator (JAX PRNG key).
Returns: A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.


class
trax.layers.reversible.
ReversibleHalfResidual
(*residual_layers, attention_layer=None, name=None)¶ Bases:
trax.layers.reversible.ReversibleLayer
Half of a RevNetstyle residual that optionally performs attention.
When attention_layer is None, this layer has the signature
[accumulator, *context] > [accumulator + f(context), *context]
The attention_layer must be an instance of EfficientAttentionBase or one of its subclasses (see efficient_attention.py), or None.
Attention is specialcased for the following two reasons:
 LSH attention needs to save bucket assignments from the forward pass to the backward pass, for training stability. This requires specialcasing it.
 We can call attention_layer.forward_and_or_backward to compute its output (needed for inverting a reversible residual layer) while simultaneously performing the backward pass. Sharing computation between these two operations improves training speed.

__init__
(*residual_layers, attention_layer=None, name=None)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(xs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

reverse
(output, weights=(), state=(), new_state=(), rng=None)¶ Reverse this layer: compute input given output.

reverse_and_grad
(output, ct, weights=(), state=(), new_state=(), rng=None)¶ Backward pass: computes the inverse of a layer and propagates gradients.
While you may choose to only implement reverse, some layers implement this function directly as computation may be shared between reversing and computing gradients.
Parameters:  output – Output activations; can be a (possibly nested) tuple.
 grad – gradient signal (cotangent) computed based on subsequent layers. The structure and shape must match the output.
 weights – layer weights
 state – start state
 new_state – updated state computed by the forward pass
 rng – Singleuse random number generator (JAX PRNG key).
Returns: A tuple (x, (x_grad, weights_grad)), where x is the reconstructed input, x_grad is the gradient signal for the input, and weights_grad is the gradient signal for the weights.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
rnn¶
Implementations of common recurrent neural network cells (RNNs).

class
trax.layers.rnn.
LSTMCell
(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
LSTM Cell.
For a nice overview of the motivation and (i, o, f) gates, see this tutorial: https://colah.github.io/posts/201508UnderstandingLSTMs/
See this paper for a description and detailed study of all gate types: https://arxiv.org/pdf/1503.04069.pdf

__init__
(n_units, forget_bias=1.0, kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


trax.layers.rnn.
MakeZeroState
(depth_multiplier=1)¶ Makes zeros of shape like x but removing the length (axis 1).

trax.layers.rnn.
LSTM
(n_units, mode='train')¶ LSTM running on axis 1.

class
trax.layers.rnn.
GRUCell
(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
Builds a traditional GRU cell with dense internal transformations.
Gated Recurrent Unit paper: https://arxiv.org/abs/1412.3555

__init__
(n_units, forget_bias=0.0, kernel_initializer=<function RandomUniformInitializer.<locals>.<lambda>>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


trax.layers.rnn.
GRU
(n_units, mode='train')¶ GRU running on axis 1.

trax.layers.rnn.
ConvGRUCell
(n_units, kernel_size=(3, 3))¶ Builds a convolutional GRU.
Paper: https://arxiv.org/abs/1511.06432.
Parameters:  n_units – Number of hidden units
 kernel_size – Kernel size for convolution
Returns: A Stax model representing a GRU cell with convolution transforms.

trax.layers.rnn.
GeneralGRUCell
(candidate_transform, memory_transform_fn=None, gate_nonlinearity=<function Sigmoid>, candidate_nonlinearity=<function Tanh>, dropout_rate_c=0.1, sigmoid_bias=0.5)¶ Parametrized Gated Recurrent Unit (GRU) cell construction.
GRU update equations for update gate, reset gate, candidate memory, and new state:
\[\begin{split}u_t &= \sigma(U' \times s_{t1} + B') \\ r_t &= \sigma(U'' \times s_{t1} + B'') \\ c_t &= \tanh(U \times (r_t \odot s_{t1}) + B) \\ s_t &= u_t \odot s_{t1} + (1  u_t) \odot c_t\end{split}\]See combinators.Gate for details on the gating function.
Parameters:  candidate_transform – Transform to apply inside the Candidate branch. Applied before nonlinearities.
 memory_transform_fn – Optional transformation on the memory before gating.
 gate_nonlinearity – Function to use as gate activation; allows trying alternatives to Sigmoid, such as HardSigmoid.
 candidate_nonlinearity – Nonlinearity to apply after candidate branch; allows trying alternatives to traditional Tanh, such as HardTanh.
 dropout_rate_c – Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch.
 sigmoid_bias – Constant to add before sigmoid gates. Generally want to start off with a positive bias.
Returns: A model representing a GRU cell with specified transforms.

trax.layers.rnn.
InnerSRUCell
()¶ The inner (nonparallel) computation of an SRU.

trax.layers.rnn.
ScanSRUCell
(mode, monkey_patched_mask=None)¶ The inner (nonparallel) computation of an SRU.

trax.layers.rnn.
SRU
(n_units, activation=None, mode='train')¶ SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.
As defined in the paper:
\[\begin{split}y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t1} + (1  f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1  r_t) \times x_t\end{split}\]We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It’s best to use at least 2, they say in the paper, except inside a Transformer.
Parameters:  n_units – output depth of the SRU layer.
 activation – Optional activation function.
 mode – if ‘predict’ then we save the previous state for onebyone inference
Returns: The SRU layer.
research.efficient_attention¶
Attention Layers optimized for efficiency (secondpass implementation).
The approach taken in the first round of efficient attention implementations revealed several limitations, which this code attempts to address:
 Simultaneously instantiating queries, keys, and values for all heads can exceed the memory budget. Transformers are typically tuned such that n_heads * d_attention_key == d_model. Since attention involves queries, keys, AND values, the memory to store them can be ~3x the memory needed to store the input activations. Once the O(n^2) dotproduct bottleneck is removed – as is the case in all of our efficient attention implementations – this becomes the next critical bottleneck for scaling up Transformer models.
 Attention masking is implemented by associating an integer (typically, the sequence position) with each query and key vector, and defining a function to compute attention masks from this information. The standard attention API (attention.py) is unscalable because it instantiates O(n^2)size attention masks, and the previous efficient implementations (efficient_attention.py) only supported causal masking.

trax.layers.research.efficient_attention.
length_normalized
(x, epsilon=1e06)¶

trax.layers.research.efficient_attention.
hash_vecs
(vecs, n_buckets_in, n_hashes, rng)¶ Hash vectors into buckets.
Parameters:  vecs – vectors to hash, a tensor of shape [batch_size, depth]
 n_buckets_in – an int or a list of ints, number of hash buckets; if it is a list, we do hierarchical hashing as specified by the list
 n_hashes – number of hashes
 rng – random generator to use for hashing
Returns: A pair (buckets, n_buckets) where buckets is a tensor of shape [n_hashes, batch_size] of integers – the hash bucket IDs, and n_buckets is an int, the total number of hash buckets, equal to the product of all items in n_buckets_in.

trax.layers.research.efficient_attention.
look_adjacent
(x, n_chunks_before, n_chunks_after)¶ Used to implement attention between consecutive chunks.
Parameters:  x – array of shape [n_chunks, chunk_len, …]
 n_chunks_before – Number of previous chunks to attend to.
 n_chunks_after – Number of subsequent chunks to attend to.
Returns: array of shape [n_chunks, N * chunk_len, …], where N = (1 + n_chunks_before + n_chunks_after).

trax.layers.research.efficient_attention.
mask_self_attention
(dots, q_info, kv_info, causal=True, exclude_self=True, masked=False)¶ Performs masking for selfattention.

trax.layers.research.efficient_attention.
attend
(q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None)¶ Dotproduct attention, with optional chunking and/or masking.
Parameters:  q – Query vectors, shape [q_len, d_qk]
 k – Key vectors, shape [kv_len, d_qk]; or None
 v – Value vectors, shape [kv_len, d_v]
 q_chunk_len – Set to nonzero to enable chunking for query vectors
 kv_chunk_len – Set to nonzero to enable chunking for key/value vectors
 n_chunks_before – Number of adjacent previous chunks to attend to
 n_chunks_after – Number of adjacent subsequent chunks to attend to
 mask_fn – TODO(kitaev) doc
 q_info – Queryassociated metadata for masking
 kv_info – Keyassociated metadata for masking
 dropout – Dropout rate
 rng – RNG for dropout
Returns: A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention).

trax.layers.research.efficient_attention.
apply_broadcasted_dropout
(vecs, dropout_rate, rng)¶ Apply dropout, broadcasted across all but the last dimension of vecs.

trax.layers.research.efficient_attention.
permute_via_gather
(val, permutation, inverse_permutation, axis=0)¶ Permutation helper for LSH attention.

trax.layers.research.efficient_attention.
permute_via_sort
(val, keys, inverse_keys, axis=0)¶ Permutation helper for LSH attention.

class
trax.layers.research.efficient_attention.
EfficientAttentionBase
(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
Base class for efficient attention.
This is a base class that implements memoryefficient batching for both the forward and backward passes. Subclasses should override create_weights_unbatched, create_state_unbatched, forward_unbatched, and optionally incremental_forward_unbatched to define the actual attention mechanism.

__init__
(n_heads, n_in=1, n_parallel_heads=None, incremental=False, predict_mem_len=None, predict_drop_len=None, use_python_loop=False, use_reference_code=False)¶ Constructs an EfficientAttentionBase instance.
Parameters:  n_heads – Number of attention heads.
 n_in – Number of inputs to the layer (default 1).
 n_parallel_heads –
Number of attention heads to compute in parallel.
 If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
 If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
 If n_parallel_heads is a multiple of n_heads, attention is computed for subbatches of (n_parallel_heads // n_heads) examples at a time.
 If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
 incremental – If True, enable fast inference for selfattention types. Note that this flag should not be set when doing encoderdecoder attention, but only when doing selfattention.
 predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
 predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
 use_python_loop – Set to True to use a Python loop when iterating over subbatches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
 use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_weights_unbatched
(input_signature, rng)¶

create_state_unbatched
(input_signature, rng)¶

forward_unbatched
(*inputs, weights, state)¶ Perform attention for a single batch element and head.
Subclasses should override this method.
Parameters:  *inputs – Inputs for a single example (subclasses may use different inputs)
 weights – Weights for a single attention head
 state – State for a single example & attention head pair.
Returns: A tuple (output, new_state) – output and new state for a single example and attention head.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
SelfAttention
(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
Memoryefficient selfattention (second attempt).

__init__
(n_heads=2, d_qk=64, d_v=64, share_qk=False, causal=False, masked=False, chunk_len=None, n_chunks_before=0, n_chunks_after=0, bias=False, mode='train', predict_mem_len=None, predict_drop_len=None, attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Construct a selfattention layer.
Parameters:  n_heads – int: Number of attention heads
 d_qk – int: Depth of query ond key vectors
 d_v – int: Depth of value vectors
 share_qk – bool: Set to True to share query and key projection weights
 causal – bool: Set to True to mask out attention to future items
 masked – bool: Set to True to accept an additional mask argument, that allows masking out attention to padding tokens.
 chunk_len (optional) – Number of tokens per chunk. Setting this option will enable chunked attention.
 n_chunks_before – Number of previous chunks to attend to, when using chunked attention.
 n_chunks_after – Number of subsequent chunks to attend to, when using chunked attention. Don’t use this option for causal attention, because attention to future tokens will be masked out anyway. However, note that crosschunk attention “wraps around” in both directions, so this option is never a strict noop.
 bias – bool: Set to True to add bias vectors when computing query/key/value
 mode – ‘train’, ‘eval’, or ‘predict’
 predict_mem_len – int: Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten. When chunking is enabled, the default is to store chunk_len * (1 + n_chunks_before) elements.
 predict_drop_len – int: Number of input elements to drop once the fast inference input cache fills up. When chunking is enabled, the default is to drop exactly chunk_len elements.
 attention_dropout – Dropout probability for attention mask.
 output_dropout – Dropout probability for the layer output.
 n_parallel_heads –
Number of attention heads to compute in parallel.
 If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
 If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
 If n_parallel_heads is a multiple of n_heads, attention is computed for subbatches of (n_parallel_heads // n_heads) examples at a time.
 If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
 use_python_loop – Set to True to use a Python loop when iterating over subbatches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
 use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_weights_unbatched
(input_signature, rng)¶

create_state_unbatched
(input_signature, rng)¶

forward_unbatched
(x, mask=None, *, weights, state, rng, update_state)¶ Perform attention for a single batch element and head.
Parameters:  x – Inputs for a single example (subclasses may use different inputs)
 mask – Mask for the inputs.
 weights – Weights for a single attention head
 state – State for a single example & attention head pair.
 rng – PRNG key for the layer (shared across all examples and heads)
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state) – output and new state for a single example and attention head.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
LSHSelfAttention
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
LSH selfattention (second implementation).

__init__
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Construct an LSH selfattention layer.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_weights_unbatched
(input_signature, rng)¶

create_state_unbatched
(input_signature, rng)¶

hash_vectors
(vecs, rng, mask=None)¶

forward_unbatched
(x, mask=None, *, weights, state, rng, update_state)¶

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
PureLSHSelfAttention
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.base.Layer
LSH selfattention without weights.

__init__
(n_heads=2, d_qk=64, d_v=64, share_qk='unused', causal=False, masked=False, chunk_len=128, n_chunks_before=1, n_chunks_after=0, n_hashes=1, n_buckets=None, mode='train', predict_mem_len=2048, predict_drop_len=256, attention_dropout=0.0, output_dropout=0.0, max_length_for_buckets=None, bias=False, n_parallel_heads=1, use_python_loop=False, use_reference_code=False)¶ Construct an LSH selfattention layer.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.

create_state_unbatched
(input_signature, rng)¶

hash_vectors
(vecs, rng, mask=None)¶

forward_unbatched
(qk, v, mask=None, *, state, rng, update_state)¶

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
Parameters: inputs – Layer inputs (subclasses may use different inputs) Returns: A tuple (output, new_state).

has_backward
¶ Returns True if this layer provides its own custom backward pass code.
A layer subclass that provides custom backward pass code (for custom gradients) must override this method to return True.

backward
(inputs, output, grad, weights, state, new_state, rng=None, **kwargs)¶ Custom backward pass, for efficiency (see forward_and_or_backward).

forward_and_or_backward
(inputs, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
See forward for a reference implementation of what this layer does. The reference implementation is not very efficient, however, and this method provides a more performant version.
Parameters:  inputs – inputs to the attention layer tuple (qk, v, mask)
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).
 output is not None iff compute_output is True
 new_state is not None iff update_state is True
 inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
MixedLSHSelfAttention
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, std_length=None, mode='train', output_dropout=0.0, attention_dropout=0.0, force_no_dropout=False, **pure_lsh_implementation_kwargs)¶ Bases:
trax.layers.base.Layer
LSH attention mixed with standard attention used until std_length.

__init__
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, std_length=None, mode='train', output_dropout=0.0, attention_dropout=0.0, force_no_dropout=False, **pure_lsh_implementation_kwargs)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

init_weights_and_state
(input_signature)¶ Initializes weights and state for inputs with the given signature.

forward
(xs)¶ Executes this layer as part of a forward pass through the model.

forward_and_or_backward
(inputs, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.


class
trax.layers.research.efficient_attention.
PureLSHSelfAttentionWrapper
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, output_dropout=0.0, attention_dropout=0.0, pure_lsh_implementation=None, bias=True, mode='train', num_weights=3, sparsity=16, weights_format='model', rotary_position_emb=False, **pure_lsh_implementation_kwargs)¶ Bases:
trax.layers.combinators.Serial
Pure LSH serial.

__init__
(n_heads=1, d_qk=64, d_v=64, causal=False, masked=False, output_dropout=0.0, attention_dropout=0.0, pure_lsh_implementation=None, bias=True, mode='train', num_weights=3, sparsity=16, weights_format='model', rotary_position_emb=False, **pure_lsh_implementation_kwargs)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward_and_or_backward
(inputs, weights, state, rng, output_grad=None, compute_output=True, update_state=True)¶ Performs batched forward and/or backward passes.
Parameters:  inputs – inputs to the attention layer
 weights – weights for the attention layer
 state – state of the attention layer
 rng – PRNG key for the layer (shared across all examples and heads)
 output_grad – gradient of the loss wrt the output of the layer, or None. This function performs the backward pass iff output_grad is not None.
 compute_output – bool: whether to return the output of the forward pass (for example, a pure backwards pass does not need to return the output).
 update_state – bool: whether to return an updated layer state.
Returns: A tuple (output, new_state, inputs_grad, weights_grad).  output is not None iff compute_output is True  new_state is not None iff update_state is True  inputs_grad & weights_grad are not None iff output_grad is not None


class
trax.layers.research.efficient_attention.
EncDecAttention
(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Bases:
trax.layers.research.efficient_attention.EfficientAttentionBase
Memoryefficient encoderdecoder attention.

__init__
(n_heads=2, d_qk=64, d_v=64, masked=True, mode='train', attention_dropout=0.0, output_dropout=0.0, n_parallel_heads=None, use_python_loop=False, use_reference_code=False)¶ Constructs an EfficientAttentionBase instance.
Parameters:  n_heads – Number of attention heads.
 n_in – Number of inputs to the layer (default 1).
 n_parallel_heads –
Number of attention heads to compute in parallel.
 If n_parallel_heads is None (default), the entire layer is computed with maximum parallelism. This mode is the fastest, but also uses the most memory. Start with this mode, but switch to one of the others if memory runs out.
 If n_parallel_heads is 1, attention is computed one head at a time, and one example at a time. This mode uses the least memory but is not as fast as batched attention. Use this mode when working with very long sequences, such that any amount of parallelism won’t fit in memory.
 If n_parallel_heads is a multiple of n_heads, attention is computed for subbatches of (n_parallel_heads // n_heads) examples at a time.
 If 1 < n_parallel_heads < n_heads, attention is computed for several heads at a time, but only within a single example. It must be the case that n_heads is a multiple of n_parallel_heads. Use this mode for long sequences, to strike a balance between parallelism and memory usage.
 incremental – If True, enable fast inference for selfattention types. Note that this flag should not be set when doing encoderdecoder attention, but only when doing selfattention.
 predict_mem_len – Number of input positions to remember in a cache when doing fast inference. Whenever the cache fills up, some input elements will be forgotten.
 predict_drop_len – Number of input elements to drop once the fast inference input cache fills up.
 use_python_loop – Set to True to use a Python loop when iterating over subbatches of examples/heads (as opposed to a JAX/XLA loop). This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging. In particular, note that enabling this option on TPU can decrease the maximum model size that will fit in memory.
 use_reference_code – Set to True to fall back to the reference implementation of batched attention. This option will increase compilation time and jitted code size, potentially drastically. Using it is not recommended except for testing/debugging.

create_weights_unbatched
(input_signature, rng)¶

forward_unbatched
(q_antecedent, kv_antecedent, mask=None, *, weights, state, rng, update_state)¶ Perform attention for a single batch element and head.
Subclasses should override this method.
Parameters:  *inputs – Inputs for a single example (subclasses may use different inputs)
 weights – Weights for a single attention head
 state – State for a single example & attention head pair.
Returns: A tuple (output, new_state) – output and new state for a single example and attention head.


class
trax.layers.research.efficient_attention.
LSHFF
(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
Feedforward block with LSH.
The original (nonLSH) feedforward block is a triple Dense(d_ff)ReluDense that takes an input, makes it of size d_ff (usually larger than it was) and then brings it back to the original size after Relu. It is commonly used in Transformer models where it often accounts for most of the trainable weights.
The original block can be slow in decoding due to the need to fetch a lot of weights from memory. The LSH block aims to exploit this sparsity. So in the first Dense(d_ff) layer, instead of making a full matrix multiplication, this block only multiplies by the parts of the weights matrix that have the highest chance to give non0 after Relu. This is determined by taking a number of localitysensitive hashes and masking to only include weights that have one hash identical to the multiplied element.

__init__
(d_ff, n_buckets, n_hashes=4, mode='train', kernel_initializer=<function ScaledInitializer.<locals>.Init>, bias_initializer=<function RandomNormalInitializer.<locals>.<lambda>>)¶ Returns a LSH feedforward block.

forward
(x)¶ Executes this layer as part of a forward pass through the model.
Parameters: x – Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input.

init_weights_and_state
(input_signature)¶ Randomly initializes this layer’s weights.

research.position_encodings¶
Experimenting with position encodings.

class
trax.layers.research.position_encodings.
AxialPositionalEncoding
(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')¶ Bases:
trax.layers.base.Layer
Axial positional encoding.

__init__
(shape=(64, 64, 3), d_embs=(384, 384, 256), kernel_initializer=<function RandomNormalInitializer.<locals>.<lambda>>, dropout=0.0, dropout_broadcast_dims=(), mode='train')¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.research.position_encodings.
SinCosPositionalEncoding
(add_offset=2048, dropout=0.0, dropout_broadcast_dims=(2, ), start_from_zero_one_in=2, mode='train')¶ Bases:
trax.layers.base.Layer
Implements the sincos positional encoding.

__init__
(add_offset=2048, dropout=0.0, dropout_broadcast_dims=(2, ), start_from_zero_one_in=2, mode='train')¶ Creates a SinCosPositionalEncoding instance.
Parameters:  add_offset – Maximumnumber to add to positions during training.
 dropout – Probability of not adding positional encoding to a sequence position.
 dropout_broadcast_dims – Axes along which dropout mask values are broadcast rather than individually set at random.
 start_from_zero_one_in – how often to start from 0 during training, every one in that many times (e.g., if 4, then it’s 25% of the time).
 mode – One of ‘train’, ‘eval’, or ‘predict’.

forward
(inputs)¶ Returns the input activations, with added positional information.

init_weights_and_state
(input_signature)¶ Randomly initializes the positional encoding vectors.
Parameters: input_signature – ShapeDtype instance characterizing the input this layer should compute on.


class
trax.layers.research.position_encodings.
FixedBasePositionalEncoding
(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=2, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)¶ Bases:
trax.layers.base.Layer
Implements fixedbase positional encoding.

__init__
(bases=[11, 13, 14, 15], n_digits=8, start_from_zero_one_in=2, base_dropout_one_in=100, mode='train', initializer=<function RandomUniformInitializer.<locals>.<lambda>>)¶ Creates a partially initialized, unconnected layer instance.
Parameters:  n_in – Number of inputs expected by this layer.
 n_out – Number of outputs promised by this layer.
 name – Classlike name for this layer; for use when printing this layer.
 sublayers_to_print – Sublayers to display when printing out this layer; if None (the default), display all sublayers.

forward
(x)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


trax.layers.research.position_encodings.
threefry_2x32_prf
(key, x: <sphinx.ext.autodoc.importer._MockObject object at 0x7f3458272f90>) → <sphinx.ext.autodoc.importer._MockObject object at 0x7f34582547d0>¶ Apply the threefry PRF to an array of inputs.
This function is vectorized over x. For threefry_2x32: K = X = uint32[2]
Parameters:  key – uint32[2] the key of the PRF
 x – uint32[…, 2] the inputs
Returns: uint32[…, 2] the outputs
Return type: y

trax.layers.research.position_encodings.
threefry_2x32_prange
(key, lo: int = 0, hi: int = 2)¶ Splits a key into a stream of random keys.
This uses the littleendian counter mode.
Parameters:  key – uint32[2] the key to split
 lo – the range to start extracting from
 hi – the range to stop extracting from
Returns: uint32[hi  lo, 2] the split keys
Return type: keys

class
trax.layers.research.position_encodings.
InfinitePositionalEncoding
(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')¶ Bases:
trax.layers.base.Layer
Infinite positional encoding.

__init__
(drift=0.03, affine=True, transform='any', time_bin_length=None, mode='train')¶ Initializes the encoding.
The encoding tries to roughly evenly traverse the latent space. The recurrence time is dependent on how many bits per dimension you use.
There are two parameters to control randomization:  randomizing the origin every 1/drift steps by letting it drift  randomizing the origin per call
Parameters:  drift – variance in position difference per unit of difference
 affine – whether to randomize the origin every call
 transform – learnable transform after encoding (any/diag/none)
 time_bin_length – Add features AxialPositionalEncoding learns if TimeBinCausalAttention is the first layer. bin_length should match TBCA.bin_length If you set transform=’diag’, this flag increases your model capacity to close to transform=’any’, though it will still train slower.
 mode – if ‘predict’, allow evaluating one token at a time

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.


class
trax.layers.research.position_encodings.
TimeBinPositionalEncoding
(time_bin_length, mode='train')¶ Bases:
trax.layers.base.Layer
Just the engineered features from InfinitePositionalEncoding.

num_features
= 3¶

__init__
(time_bin_length, mode='train')¶ Initializes the encoding.
Parameters:  time_bin_length – TimeBinCausalAttention.bin_length of the first layer.
 mode – if ‘predict’, allow evaluating one token at a time

forward
(inputs)¶ Computes this layer’s output as part of a forward pass through the model.
A layer subclass overrides this method to define how the layer computes outputs from inputs. If the layer depends on weights, state, or randomness as part of the computation, the needed information can be accessed as properties of the layer object: self.weights, self.state, and self.rng. (See numerous examples in trax.layers.core.)
Parameters: inputs – Zero or more input tensors, packaged as described in the Layer class docstring. Returns: Zero or more output tensors, packaged as described in the Layer class docstring.

init_weights_and_state
(input_signature)¶ Initializes weights and state, to handle input with the given signature.
A layer subclass must override this method if the layer uses weights or state. To initialize weights, set self.weights to desired (typically random) values. To initialize state (uncommon), set self.state to desired starting values.
Parameters: input_signature – A ShapeDtype instance (if this layer takes one input) or a list/tuple of ShapeDtype instances.
