# Trax Layers Intro¶

This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:

1. Layers: the basic building blocks and how to combine them
2. Inputs and Outputs: how data streams flow through layers
3. Defining New Layer Classes (if combining existing layers isn’t enough)
4. Testing and Debugging Layer Classes

General Setup

Execute the following few cells (once) before running any of the code samples in this notebook.

[ ]:

# Copyright 2018 Google LLC.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import numpy as np

[ ]:

# Import Trax

! pip install -q -U trax
! pip install -q tensorflow

from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp  # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature

/bin/sh: pip: command not found

[ ]:

# Settings and utilities for handling inputs, outputs, and object properties.

np.set_printoptions(precision=3)  # Reduce visual noise from extra digits.

def show_layer_properties(layer_obj, layer_name):
template = ('{}.n_in:  {}\n'
'{}.n_out: {}\n'
'{}.sublayers: {}\n'
'{}.weights:    {}\n')
print(template.format(layer_name, layer_obj.n_in,
layer_name, layer_obj.n_out,
layer_name, layer_obj.sublayers,
layer_name, layer_obj.weights))


## 1. Layers¶

The Layer class represents Trax’s basic building blocks:

class Layer:
"""Base class for composable layers in a deep learning network.

Layers are the basic building blocks for deep learning models. A Trax layer
computes a function from zero or more inputs to zero or more outputs,
optionally using trainable weights (common) and non-parameter state (not
common).  ...

...


### Layers compute functions.¶

A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.

The simplest layers, those with no weights or sublayers, can be used without initialization. You can think of them as (pure) mathematical functions that can be plugged into neural networks.

For ease of testing and interactive exploration, layer objects implement the __call__ method, so you can call them directly on input data:

y = my_layer(x)


Layers are also objects, so you can inspect their properties. For example:

print(f'Number of inputs expected by this layer: {my_layer.n_in}')


Example 1. tl.Relu $$[n_{in} = 1, n_{out} = 1]$$

[ ]:

relu = tl.Relu()

x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]])
y = relu(x)

# Show input, output, and two layer properties.
print(f'x:\n{x}\n\n'
f'relu(x):\n{y}\n\n'
f'Number of inputs expected by this layer: {relu.n_in}\n'
f'Number of outputs promised by this layer: {relu.n_out}')

x:
[[ -2  -1   0   1   2]
[-20 -10   0  10  20]]

relu(x):
[[ 0  0  0  1  2]
[ 0  0  0 10 20]]

Number of inputs expected by this layer: 1
Number of outputs promised by this layer: 1


Example 2. tl.Concatenate $$[n_{in} = 2, n_{out} = 1]$$

[ ]:

concat = tl.Concatenate()

x0 = np.array([[1, 2, 3],
[4, 5, 6]])
x1 = np.array([[10, 20, 30],
[40, 50, 60]])
y = concat([x0, x1])

print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'concat([x1, x2]):\n{y}\n\n'
f'Number of inputs expected by this layer: {concat.n_in}\n'
f'Number of outputs promised by this layer: {concat.n_out}')

x0:
[[1 2 3]
[4 5 6]]

x1:
[[10 20 30]
[40 50 60]]

concat([x1, x2]):
[[ 1  2  3 10 20 30]
[ 4  5  6 40 50 60]]

Number of inputs expected by this layer: 2
Number of outputs promised by this layer: 1


### Layers are configurable.¶

Many layer types have creation-time parameters for flexibility. The Concatenate layer type, for instance, has two optional parameters:

• axis: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.
• n_items: number of tensors to join into one by concatenation; default value is 2.

The following example shows Concatenate configured for 3 input tensors, and concatenation along the initial $$(0^{th})$$ axis.

Example 3. tl.Concatenate(n_items=3, axis=0)

[ ]:

concat3 = tl.Concatenate(n_items=3, axis=0)

x0 = np.array([[1, 2, 3],
[4, 5, 6]])
x1 = np.array([[10, 20, 30],
[40, 50, 60]])
x2 = np.array([[100, 200, 300],
[400, 500, 600]])

y = concat3([x0, x1, x2])

print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'x2:\n{x2}\n\n'
f'concat3([x0, x1, x2]):\n{y}')

x0:
[[1 2 3]
[4 5 6]]

x1:
[[10 20 30]
[40 50 60]]

x2:
[[100 200 300]
[400 500 600]]

concat3([x0, x1, x2]):
[[  1   2   3]
[  4   5   6]
[ 10  20  30]
[ 40  50  60]
[100 200 300]
[400 500 600]]


### Layers are trainable.¶

Many layer types include weights that affect the computation of outputs from inputs, and they use back-progagated gradients to update those weights.

🚧🚧 A very small subset of layer types, such as BatchNorm, also include modifiable weights (called state) that are updated based on forward-pass inputs/computation rather than back-propagated gradients.

Initialization

Trainable layers must be initialized before use. Trax can take care of this as part of the overall training process. In other settings (e.g., in tests or interactively in a Colab notebook), you need to initialize the outermost/topmost layer explicitly. For this, use init:

def init(self, input_signature, rng=None, use_cache=False):
"""Initializes weights/state of this layer and its sublayers recursively.

Initialization creates layer weights and state, for layers that use them.
It derives the necessary array shapes and data types from the layer's input
signature, which is itself just shape and data type information.

For layers without weights or state, this method safely does nothing.

This method is designed to create weights/state only once for each layer
instance, even if the same layer instance occurs in multiple places in the
network. This enables weight sharing to be implemented as layer sharing.

Args:
input_signature: ShapeDtype instance (if this layer takes one input)
or list/tuple of ShapeDtype instances.
rng: Single-use random number generator (JAX PRNG key), or None;
if None, use a default computed from an integer 0 seed.
use_cache: If True, and if this layer instance has already been
initialized elsewhere in the network, then return special marker
values -- tuple (GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE).
Else return this layer's newly initialized weights and state.

Returns:
A (weights, state) tuple.
"""


Input signatures can be built from scratch using ShapeDType objects, or can be derived from data via the signature function (in module shapes):

def signature(obj):
"""Returns a ShapeDtype signature for the given obj.

A signature is either a ShapeDtype instance or a tuple of ShapeDtype
instances. Note that this function is permissive with respect to its inputs
(accepts lists or tuples or dicts, and underlying objects can be any type
as long as they have shape and dtype attributes) and returns the corresponding
nested structure of ShapeDtype.

Args:
obj: An object that has shape and dtype attributes, or a list/tuple/dict
of such objects.

Returns:
A corresponding nested structure of ShapeDtype instances.
"""


Example 4. tl.LayerNorm $$[n_{in} = 1, n_{out} = 1]$$

[ ]:

layer_norm = tl.LayerNorm()

x = np.array([[-2, -1, 0, 1, 2],
[1, 2, 3, 4, 5],
[10, 20, 30, 40, 50]]).astype(np.float32)
layer_norm.init(shapes.signature(x))

y = layer_norm(x)

print(f'x:\n{x}\n\n'
f'layer_norm(x):\n{y}\n')
print(f'layer_norm.weights:\n{layer_norm.weights}')

x:
[[-2. -1.  0.  1.  2.]
[ 1.  2.  3.  4.  5.]
[10. 20. 30. 40. 50.]]

layer_norm(x):
[[-1.414 -0.707  0.     0.707  1.414]
[-1.414 -0.707  0.     0.707  1.414]
[-1.414 -0.707  0.     0.707  1.414]]

layer_norm.weights:
(DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))


### Layers combine into layers.¶

The Trax library authors encourage users to build networks and network components as combinations of existing layers, by means of a small set of combinator layers. A combinator makes a list of layers behave as a single layer – by combining the sublayer computations yet looking from the outside like any other layer. The combined layer, like other layers, can:

• compute outputs from inputs,
• update parameters from gradients, and
• combine with yet more layers.

Combine with Serial

The most common way to combine layers is with the Serial combinator:

class Serial(base.Layer):
"""Combinator that applies layers serially (by function composition).

This combinator is commonly used to construct deep networks, e.g., like this::

mlp = tl.Serial(
tl.Dense(128),
tl.Relu(),
tl.Dense(10),
)

A Serial combinator uses stack semantics to manage data for its sublayers.
Each sublayer sees only the inputs it needs and returns only the outputs it
has generated. The sublayers interact via the data stack. For instance, a
sublayer k, following sublayer j, gets called with the data stack in the
state left after layer j has applied. The Serial combinator then:

- takes n_in items off the top of the stack (n_in = k.n_in) and calls
layer k, passing those items as arguments; and

- takes layer k's n_out return values (n_out = k.n_out) and pushes
them onto the data stack.

A Serial instance with no sublayers acts as a special-case (but useful)
1-input 1-output no-op.
"""


If one layer has the same number of outputs as the next layer has inputs (which is the usual case), the successive layers behave like function composition:

#  h(.) = g(f(.))
layer_h = Serial(
layer_f,
layer_g,
)


Note how, inside Serial, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed.

Example 5. y = layer_norm(relu(x)) $$[n_{in} = 1, n_{out} = 1]$$

[ ]:

layer_block = tl.Serial(
tl.Relu(),
tl.LayerNorm(),
)

x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)

print(f'x:\n{x}\n\n'
f'layer_block(x):\n{y}')

x:
[[ -2.  -1.   0.   1.   2.]
[-20. -10.   0.  10.  20.]]

layer_block(x):
[[-0.75 -0.75 -0.75  0.5   1.75]
[-0.75 -0.75 -0.75  0.5   1.75]]


And we can inspect the block as a whole, as if it were just another layer:

Example 5’. Inspecting a Serial layer.

[ ]:

print(f'layer_block: {layer_block}\n\n'
f'layer_block.weights: {layer_block.weights}')

layer_block: Serial[
Relu
LayerNorm
]

layer_block.weights: ((), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32)))


Combine with Branch

The Branch combinator arranges layers into parallel computational channels:

def Branch(*layers, name='Branch'):
"""Combinator that applies a list of layers in parallel to copies of inputs.

Each layer in the input list is applied to as many inputs from the stack
as it needs, and their outputs are successively combined on stack.

For example, suppose one has three layers:

- F: 1 input, 1 output
- G: 3 inputs, 1 output
- H: 2 inputs, 2 outputs (h1, h2)

Then Branch(F, G, H) will take 3 inputs and give 4 outputs:

- inputs: a, b, c
- outputs: F(a), G(a, b, c), h1, h2    where h1, h2 = H(a, b)

As an important special case, a None argument to Branch acts as if it takes
one argument, which it leaves unchanged. (It acts as a one-arg no-op.)

Args:
*layers: List of layers.
name: Descriptive name for this layer.

Returns:
A branch layer built from the given sublayers.
"""


Residual blocks, for example, are implemented using Branch:

def Residual(*layers, shortcut=None):
"""Wraps a series of layers with a residual connection.

Args:
*layers: One or more layers, to be applied in series.
shortcut: If None (the usual case), the Residual layer computes the
element-wise sum of the stack-top input with the output of the layer
series. If specified, the shortcut layer applies to a copy of the
inputs and (elementwise) adds its output to the output from the main
layer series.

Returns:
A layer representing a residual connection paired with a layer series.
"""
layers = _ensure_flat(layers)
layer = layers[0] if len(layers) == 1 else Serial(layers)
return Serial(
Branch(shortcut, layer),
)


Here’s a simple code example to highlight the mechanics.

Example 6. Branch

[ ]:

relu = tl.Relu()
times_100 = tl.Fn("Times100", lambda x: x * 100.0)
branch_relu_t100 = tl.Branch(relu, times_100)

x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]])
branch_relu_t100.init(shapes.signature(x))

y0, y1 = branch_relu_t100(x)

print(f'x:\n{x}\n\n'
f'y0:\n{y0}\n\n'
f'y1:\n{y1}')

x:
[[ -2  -1   0   1   2]
[-20 -10   0  10  20]]

y0:
[[ 0  0  0  1  2]
[ 0  0  0 10 20]]

y1:
[[ -200.  -100.     0.   100.   200.]
[-2000. -1000.     0.  1000.  2000.]]


## 2. Inputs and Outputs¶

Trax allows layers to have multiple input streams and output streams. When designing a network, you have the flexibility to use layers that:

• process a single data stream ($$n_{in} = n_{out} = 1$$),
• process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, …$),
• split or inject data streams ($$n_{in} < n_{out}$$), or
• merge or remove data streams ($$n_{in} > n_{out}$$).

We saw in section 1 the example of Residual, which involves both a split and a merge:

...
return Serial(
Branch(shortcut, layer),
)


In other words, layer by layer:

• Branch(shortcut, layers): makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [$$n_{in} = 1$$, $$n_{out} = 2$$]
• Add(): combines the two streams back into one by adding two tensors elementwise. [$$n_{in} = 2$$, $$n_{out} = 1$$]

### Data Stack¶

Trax supports flexible data flows through a network via a data stack, which is managed by the Serial combinator:

class Serial(base.Layer):
"""Combinator that applies layers serially (by function composition).

...

A Serial combinator uses stack semantics to manage data for its sublayers.
Each sublayer sees only the inputs it needs and returns only the outputs it
has generated. The sublayers interact via the data stack. For instance, a
sublayer k, following sublayer j, gets called with the data stack in the
state left after layer j has applied. The Serial combinator then:

- takes n_in items off the top of the stack (n_in = k.n_in) and calls
layer k, passing those items as arguments; and

- takes layer k's n_out return values (n_out = k.n_out) and pushes
them onto the data stack.

...

"""


Simple Case 1 – Each layer takes one input and has one output.

This is in effect a single data stream pipeline, and the successive layers behave like function composition:

#  s(.) = h(g(f(.)))
layer_s = Serial(
layer_f,
layer_g,
layer_h,
)


Note how, inside Serial, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed and the order of operations matches the textual order of layers.

Simple Case 2 – Each layer consumes all outputs of the preceding layer.

This is still a single pipeline, but data streams internal to it can split and merge. The Residual example above illustrates this kind.

General Case – Successive layers interact via the data stack.

As described in the Serial class docstring, each layer gets its inputs from the data stack after the preceding layer has put its outputs onto the stack. This covers the simple cases above, but also allows for more flexible data interactions between non-adjacent layers. The following example is schematic:

x, y_target = get_batch_of_labeled_data()

model_plus_eval = Serial(
my_fancy_deep_model(),  # Takes one arg (x) and has one output (y_hat)
my_eval(),  # Takes two args (y_hat, y_target) and has one output (score)
)

eval_score = model_plus_eval((x, y_target))


Here is the corresponding progression of stack states:

1. At start: –empty–
2. After get_batch_of_labeled_data(): x, y_target
3. After my_fancy_deep_model(): y_hat, y_target
4. After my_eval(): score

Note in particular how the application of the model (between stack states 1 and 2) only uses and affects the top element on the stack: x –> y_hat. The rest of the data stack (y_target) comes in use only later, for the eval function.

## 3. Defining New Layer Classes¶

If you need a layer type that is not easily defined as a combination of existing layer types, you can define your own layer classes in a couple different ways.

### With the Fn layer-creating function.¶

Many layer types needed in deep learning compute pure functions from inputs to outputs, using neither weights nor randomness. You can use Trax’s Fn function to define your own pure layer types:

def Fn(name, f, n_out=1):  # pylint: disable=invalid-name
"""Returns a layer with no weights that applies the function f.

f can take and return any number of arguments, and takes only positional
arguments -- no default or keyword arguments. It often uses JAX-numpy (jnp).
The following, for example, would create a layer that takes two inputs and
returns two outputs -- element-wise sums and maxima:

Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)

The layer's number of inputs (n_in) is automatically set to number of
positional arguments in f, but you must explicitly set the number of
outputs (n_out) whenever it's not the default value 1.

Args:
name: Class-like name for the resulting layer; for use in debugging.
f: Pure function from input tensors to output tensors, where each input
tensor is a separate positional arg, e.g., f(x0, x1) --> x0 + x1.
Output tensors must be packaged as specified in the Layer class
docstring.
n_out: Number of outputs promised by the layer; default value 1.

Returns:
Layer executing the function f.
"""


Example 7. Use Fn to define a new layer type:

[ ]:

# Define new layer type.
def Gcd():
"""Returns a layer to compute the greatest common divisor, elementwise."""
return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))

# Use it.
gcd = Gcd()

x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])

y = gcd((x0, x1))

print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'gcd((x0, x1)):\n{y}')

x0:
[ 1  2  3  4  5  6  7  8  9 10]

x1:
[11 12 13 14 15 16 17 18 19 20]

gcd((x0, x1)):
[ 1  2  1  2  5  2  1  2  1 10]


The Fn function infers n_in (number of inputs) as the length of f’s arg list. Fn does not infer n_out (number out outputs) though. If your f has more than one output, you need to give an explicit value using the n_out keyword arg.

Example 8. Fn with multiple outputs:

[ ]:

# Define new layer type.
def SumAndMax():
"""Returns a layer to compute sums and maxima of two input tensors."""
return tl.Fn('SumAndMax',
lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
n_out=2)

# Use it.
sum_and_max = SumAndMax()

x0 = np.array([1, 2, 3, 4, 5])
x1 = np.array([10, -20, 30, -40, 50])

y0, y1 = sum_and_max([x0, x1])

print(f'x0:\n{x0}\n\n'
f'x1:\n{x1}\n\n'
f'y0:\n{y0}\n\n'
f'y1:\n{y1}')

x0:
[1 2 3 4 5]

x1:
[ 10 -20  30 -40  50]

y0:
[ 11 -18  33 -36  55]

y1:
[10  2 30  4 50]


Example 9. Use Fn to define a configurable layer:

[ ]:

# Function defined in trax/layers/core.py:
def Flatten(n_axes_to_keep=1):
"""Returns a layer that combines one or more trailing axes of a tensor.

Flattening keeps all the values of the input tensor, but reshapes it by
collapsing one or more trailing axes into a single axis. For example, a
Flatten(n_axes_to_keep=2) layer would map a tensor with shape
(2, 3, 5, 7, 11) to the same values with shape (2, 3, 385).

Args:
n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;
collapse only the axes after these.
"""
layer_name = f'Flatten_keep{n_axes_to_keep}'
def f(x):
in_rank = len(x.shape)
if in_rank <= n_axes_to_keep:
raise ValueError(f'Input rank ({in_rank}) must exceed the number of '
f'axes to keep ({n_axes_to_keep}) after flattening.')
return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
return tl.Fn(layer_name, f)

flatten_keep_1_axis = Flatten(n_axes_to_keep=1)
flatten_keep_2_axes = Flatten(n_axes_to_keep=2)

x = np.array([[[1, 2, 3],
[10, 20, 30],
[100, 200, 300]],
[[4, 5, 6],
[40, 50, 60],
[400, 500, 600]]])

y1 = flatten_keep_1_axis(x)
y2 = flatten_keep_2_axes(x)

print(f'x:\n{x}\n\n'
f'flatten_keep_1_axis(x):\n{y1}\n\n'
f'flatten_keep_2_axes(x):\n{y2}')

x:
[[[  1   2   3]
[ 10  20  30]
[100 200 300]]

[[  4   5   6]
[ 40  50  60]
[400 500 600]]]

flatten_keep_1_axis(x):
[[  1   2   3  10  20  30 100 200 300]
[  4   5   6  40  50  60 400 500 600]]

flatten_keep_2_axes(x):
[[[  1   2   3]
[ 10  20  30]
[100 200 300]]

[[  4   5   6]
[ 40  50  60]
[400 500 600]]]


### By defining a Layer subclass¶

If you need a layer type that uses trainable weights (or state), you can extend the base Layer class:

class Layer:
"""Base class for composable layers in a deep learning network.

...

Authors of new layer subclasses typically override at most two methods of
the base Layer class:

forward(inputs):
Computes this layer's output as part of a forward pass through the model.

init_weights_and_state(self, input_signature):
Initializes weights and state for inputs with the given signature.

...


The forward method uses weights stored in the layer object (self.weights) to compute outputs from inputs. For example, here is the definition of forward for Trax’s Dense layer:

def forward(self, x):
"""Executes this layer as part of a forward pass through the model.

Args:
x: Tensor of same shape and dtype as the input signature used to
initialize this layer.

Returns:
Tensor of same shape and dtype as the input, except the final dimension
is the layer's n_units value.
"""
if self._use_bias:
if not isinstance(self.weights, (tuple, list)):
raise ValueError(f'Weights should be a (w, b) tuple or list; '
w, b = self.weights
return jnp.dot(x, w) + b  # Affine map.
else:
w = self.weights
return jnp.dot(x, w)  # Linear map.


Layer weights must be initialized before the layer can be used; the init_weights_and_state method specifies how. Continuing the Dense example, here is the corresponding initialization code:

def init_weights_and_state(self, input_signature):
"""Randomly initializes this layer's weights.

Weights are a (w, b) tuple for layers created with use_bias=True (the
default case), or a w tensor for layers created with use_bias=False.

Args:
input_signature: ShapeDtype instance characterizing the input this layer
should compute on.
"""
shape_w = (input_signature.shape[-1], self._n_units)
shape_b = (self._n_units,)
rng_w, rng_b = fastmath.random.split(self.rng, 2)
w = self._kernel_initializer(shape_w, rng_w)

if self._use_bias:
b = self._bias_initializer(shape_b, rng_b)
self.weights = (w, b)
else:
self.weights = w


TBD

TBD