Trax Layers Intro

This notebook introduces the core concepts of the Trax library through a series of code samples and explanations. The topics covered in following sections are:

  1. Layers: the basic building blocks and how to combine them
  2. Inputs and Outputs: how data streams flow through layers
  3. Defining New Layer Classes (if combining existing layers isn’t enough)
  4. Testing and Debugging Layer Classes

General Setup

Execute the following few cells (once) before running any of the code samples in this notebook.

[ ]:
# Copyright 2018 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np



[ ]:
# Import Trax

! pip install -q -U trax
! pip install -q tensorflow

from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp  # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature
/bin/sh: pip: command not found
/bin/sh: pip: command not found
[ ]:
# Settings and utilities for handling inputs, outputs, and object properties.

np.set_printoptions(precision=3)  # Reduce visual noise from extra digits.

def show_layer_properties(layer_obj, layer_name):
  template = ('{}.n_in:  {}\n'
              '{}.n_out: {}\n'
              '{}.sublayers: {}\n'
              '{}.weights:    {}\n')
  print(template.format(layer_name, layer_obj.n_in,
                        layer_name, layer_obj.n_out,
                        layer_name, layer_obj.sublayers,
                        layer_name, layer_obj.weights))

1. Layers

The Layer class represents Trax’s basic building blocks:

class Layer:
  """Base class for composable layers in a deep learning network.

  Layers are the basic building blocks for deep learning models. A Trax layer
  computes a function from zero or more inputs to zero or more outputs,
  optionally using trainable weights (common) and non-parameter state (not
  common).  ...

  ...

Layers compute functions.

A layer computes a function from zero or more inputs to zero or more outputs. The inputs and outputs are NumPy arrays or JAX objects behaving as NumPy arrays.

The simplest layers, those with no weights or sublayers, can be used without initialization. You can think of them as (pure) mathematical functions that can be plugged into neural networks.

For ease of testing and interactive exploration, layer objects implement the __call__ method, so you can call them directly on input data:

y = my_layer(x)

Layers are also objects, so you can inspect their properties. For example:

print(f'Number of inputs expected by this layer: {my_layer.n_in}')

Example 1. tl.Relu \([n_{in} = 1, n_{out} = 1]\)

[ ]:
relu = tl.Relu()

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]])
y = relu(x)

# Show input, output, and two layer properties.
print(f'x:\n{x}\n\n'
      f'relu(x):\n{y}\n\n'
      f'Number of inputs expected by this layer: {relu.n_in}\n'
      f'Number of outputs promised by this layer: {relu.n_out}')
x:
[[ -2  -1   0   1   2]
 [-20 -10   0  10  20]]

relu(x):
[[ 0  0  0  1  2]
 [ 0  0  0 10 20]]

Number of inputs expected by this layer: 1
Number of outputs promised by this layer: 1

Example 2. tl.Concatenate \([n_{in} = 2, n_{out} = 1]\)

[ ]:
concat = tl.Concatenate()

x0 = np.array([[1, 2, 3],
               [4, 5, 6]])
x1 = np.array([[10, 20, 30],
               [40, 50, 60]])
y = concat([x0, x1])

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'concat([x1, x2]):\n{y}\n\n'
      f'Number of inputs expected by this layer: {concat.n_in}\n'
      f'Number of outputs promised by this layer: {concat.n_out}')
x0:
[[1 2 3]
 [4 5 6]]

x1:
[[10 20 30]
 [40 50 60]]

concat([x1, x2]):
[[ 1  2  3 10 20 30]
 [ 4  5  6 40 50 60]]

Number of inputs expected by this layer: 2
Number of outputs promised by this layer: 1

Layers are configurable.

Many layer types have creation-time parameters for flexibility. The Concatenate layer type, for instance, has two optional parameters:

  • axis: index of axis along which to concatenate the tensors; default value of -1 means to use the last axis.
  • n_items: number of tensors to join into one by concatenation; default value is 2.

The following example shows Concatenate configured for 3 input tensors, and concatenation along the initial \((0^{th})\) axis.

Example 3. tl.Concatenate(n_items=3, axis=0)

[ ]:
concat3 = tl.Concatenate(n_items=3, axis=0)

x0 = np.array([[1, 2, 3],
               [4, 5, 6]])
x1 = np.array([[10, 20, 30],
               [40, 50, 60]])
x2 = np.array([[100, 200, 300],
               [400, 500, 600]])

y = concat3([x0, x1, x2])

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'x2:\n{x2}\n\n'
      f'concat3([x0, x1, x2]):\n{y}')
x0:
[[1 2 3]
 [4 5 6]]

x1:
[[10 20 30]
 [40 50 60]]

x2:
[[100 200 300]
 [400 500 600]]

concat3([x0, x1, x2]):
[[  1   2   3]
 [  4   5   6]
 [ 10  20  30]
 [ 40  50  60]
 [100 200 300]
 [400 500 600]]

Layers are trainable.

Many layer types include weights that affect the computation of outputs from inputs, and they use back-progagated gradients to update those weights.

🚧🚧 A very small subset of layer types, such as ``BatchNorm``, also include modifiable weights (called ``state``) that are updated based on forward-pass inputs/computation rather than back-propagated gradients.

Initialization

Trainable layers must be initialized before use. Trax can take care of this as part of the overall training process. In other settings (e.g., in tests or interactively in a Colab notebook), you need to initialize the outermost/topmost layer explicitly. For this, use init:

def init(self, input_signature, rng=None, use_cache=False):
  """Initializes weights/state of this layer and its sublayers recursively.

  Initialization creates layer weights and state, for layers that use them.
  It derives the necessary array shapes and data types from the layer's input
  signature, which is itself just shape and data type information.

  For layers without weights or state, this method safely does nothing.

  This method is designed to create weights/state only once for each layer
  instance, even if the same layer instance occurs in multiple places in the
  network. This enables weight sharing to be implemented as layer sharing.

  Args:
    input_signature: `ShapeDtype` instance (if this layer takes one input)
        or list/tuple of `ShapeDtype` instances.
    rng: Single-use random number generator (JAX PRNG key), or `None`;
        if `None`, use a default computed from an integer 0 seed.
    use_cache: If `True`, and if this layer instance has already been
        initialized elsewhere in the network, then return special marker
        values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
        Else return this layer's newly initialized weights and state.

  Returns:
    A `(weights, state)` tuple.
  """

Input signatures can be built from scratch using ShapeDType objects, or can be derived from data via the signature function (in module shapes):

def signature(obj):
  """Returns a `ShapeDtype` signature for the given `obj`.

  A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
  instances. Note that this function is permissive with respect to its inputs
  (accepts lists or tuples or dicts, and underlying objects can be any type
  as long as they have shape and dtype attributes) and returns the corresponding
  nested structure of `ShapeDtype`.

  Args:
    obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
        of such objects.

  Returns:
    A corresponding nested structure of `ShapeDtype` instances.
  """

Example 4. tl.LayerNorm \([n_{in} = 1, n_{out} = 1]\)

[ ]:
layer_norm = tl.LayerNorm()

x = np.array([[-2, -1, 0, 1, 2],
              [1, 2, 3, 4, 5],
              [10, 20, 30, 40, 50]]).astype(np.float32)
layer_norm.init(shapes.signature(x))

y = layer_norm(x)

print(f'x:\n{x}\n\n'
      f'layer_norm(x):\n{y}\n')
print(f'layer_norm.weights:\n{layer_norm.weights}')
x:
[[-2. -1.  0.  1.  2.]
 [ 1.  2.  3.  4.  5.]
 [10. 20. 30. 40. 50.]]

layer_norm(x):
[[-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]
 [-1.414 -0.707  0.     0.707  1.414]]

layer_norm.weights:
(DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32))

Layers combine into layers.

The Trax library authors encourage users to build networks and network components as combinations of existing layers, by means of a small set of combinator layers. A combinator makes a list of layers behave as a single layer – by combining the sublayer computations yet looking from the outside like any other layer. The combined layer, like other layers, can:

  • compute outputs from inputs,
  • update parameters from gradients, and
  • combine with yet more layers.

Combine with ``Serial``

The most common way to combine layers is with the Serial combinator:

class Serial(base.Layer):
  """Combinator that applies layers serially (by function composition).

  This combinator is commonly used to construct deep networks, e.g., like this::

      mlp = tl.Serial(
        tl.Dense(128),
        tl.Relu(),
        tl.Dense(10),
      )

  A Serial combinator uses stack semantics to manage data for its sublayers.
  Each sublayer sees only the inputs it needs and returns only the outputs it
  has generated. The sublayers interact via the data stack. For instance, a
  sublayer k, following sublayer j, gets called with the data stack in the
  state left after layer j has applied. The Serial combinator then:

    - takes n_in items off the top of the stack (n_in = k.n_in) and calls
      layer k, passing those items as arguments; and

    - takes layer k's n_out return values (n_out = k.n_out) and pushes
      them onto the data stack.

  A Serial instance with no sublayers acts as a special-case (but useful)
  1-input 1-output no-op.
  """

If one layer has the same number of outputs as the next layer has inputs (which is the usual case), the successive layers behave like function composition:

#  h(.) = g(f(.))
layer_h = Serial(
    layer_f,
    layer_g,
)

Note how, inside Serial, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed.

Example 5. y = layer_norm(relu(x)) \([n_{in} = 1, n_{out} = 1]\)

[ ]:
layer_block = tl.Serial(
    tl.Relu(),
    tl.LayerNorm(),
)

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)

print(f'x:\n{x}\n\n'
      f'layer_block(x):\n{y}')
x:
[[ -2.  -1.   0.   1.   2.]
 [-20. -10.   0.  10.  20.]]

layer_block(x):
[[-0.75 -0.75 -0.75  0.5   1.75]
 [-0.75 -0.75 -0.75  0.5   1.75]]

And we can inspect the block as a whole, as if it were just another layer:

Example 5’. Inspecting a Serial layer.

[ ]:
print(f'layer_block: {layer_block}\n\n'
      f'layer_block.weights: {layer_block.weights}')
layer_block: Serial[
  Relu
  LayerNorm
]

layer_block.weights: ((), (DeviceArray([1., 1., 1., 1., 1.], dtype=float32), DeviceArray([0., 0., 0., 0., 0.], dtype=float32)))

Combine with ``Branch``

The Branch combinator arranges layers into parallel computational channels:

def Branch(*layers, name='Branch'):
  """Combinator that applies a list of layers in parallel to copies of inputs.

  Each layer in the input list is applied to as many inputs from the stack
  as it needs, and their outputs are successively combined on stack.

  For example, suppose one has three layers:

    - F: 1 input, 1 output
    - G: 3 inputs, 1 output
    - H: 2 inputs, 2 outputs (h1, h2)

  Then Branch(F, G, H) will take 3 inputs and give 4 outputs:

    - inputs: a, b, c
    - outputs: F(a), G(a, b, c), h1, h2    where h1, h2 = H(a, b)

  As an important special case, a None argument to Branch acts as if it takes
  one argument, which it leaves unchanged. (It acts as a one-arg no-op.)

  Args:
    *layers: List of layers.
    name: Descriptive name for this layer.

  Returns:
    A branch layer built from the given sublayers.
  """

Residual blocks, for example, are implemented using Branch:

def Residual(*layers, shortcut=None):
  """Wraps a series of layers with a residual connection.

  Args:
    *layers: One or more layers, to be applied in series.
    shortcut: If None (the usual case), the Residual layer computes the
        element-wise sum of the stack-top input with the output of the layer
        series. If specified, the `shortcut` layer applies to a copy of the
        inputs and (elementwise) adds its output to the output from the main
        layer series.

  Returns:
      A layer representing a residual connection paired with a layer series.
  """
  layers = _ensure_flat(layers)
  layer = layers[0] if len(layers) == 1 else Serial(layers)
  return Serial(
      Branch(shortcut, layer),
      Add(),
  )

Here’s a simple code example to highlight the mechanics.

Example 6. Branch

[ ]:
relu = tl.Relu()
times_100 = tl.Fn("Times100", lambda x: x * 100.0)
branch_relu_t100 = tl.Branch(relu, times_100)

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]])
branch_relu_t100.init(shapes.signature(x))

y0, y1 = branch_relu_t100(x)

print(f'x:\n{x}\n\n'
      f'y0:\n{y0}\n\n'
      f'y1:\n{y1}')
x:
[[ -2  -1   0   1   2]
 [-20 -10   0  10  20]]

y0:
[[ 0  0  0  1  2]
 [ 0  0  0 10 20]]

y1:
[[ -200.  -100.     0.   100.   200.]
 [-2000. -1000.     0.  1000.  2000.]]

2. Inputs and Outputs

Trax allows layers to have multiple input streams and output streams. When designing a network, you have the flexibility to use layers that:

  • process a single data stream (\(n_{in} = n_{out} = 1\)),
  • process multiple parallel data streams ($n_{in} = n_{out} = 2, 3, … $),
  • split or inject data streams (\(n_{in} < n_{out}\)), or
  • merge or remove data streams (\(n_{in} > n_{out}\)).

We saw in section 1 the example of Residual, which involves both a split and a merge:

...
return Serial(
    Branch(shortcut, layer),
    Add(),
)

In other words, layer by layer:

  • Branch(shortcut, layers): makes two copies of the single incoming data stream, passes one copy via the shortcut (typically a no-op), and processes the other copy via the given layers (applied in series). [\(n_{in} = 1\), \(n_{out} = 2\)]
  • Add(): combines the two streams back into one by adding two tensors elementwise. [\(n_{in} = 2\), \(n_{out} = 1\)]

Data Stack

Trax supports flexible data flows through a network via a data stack, which is managed by the Serial combinator:

class Serial(base.Layer):
  """Combinator that applies layers serially (by function composition).

  ...

  A Serial combinator uses stack semantics to manage data for its sublayers.
  Each sublayer sees only the inputs it needs and returns only the outputs it
  has generated. The sublayers interact via the data stack. For instance, a
  sublayer k, following sublayer j, gets called with the data stack in the
  state left after layer j has applied. The Serial combinator then:

    - takes n_in items off the top of the stack (n_in = k.n_in) and calls
      layer k, passing those items as arguments; and

    - takes layer k's n_out return values (n_out = k.n_out) and pushes
      them onto the data stack.

  ...

  """

Simple Case 1 – Each layer takes one input and has one output.

This is in effect a single data stream pipeline, and the successive layers behave like function composition:

#  s(.) = h(g(f(.)))
layer_s = Serial(
    layer_f,
    layer_g,
    layer_h,
)

Note how, inside Serial, function composition is expressed naturally as a succession of operations, so that no nested parentheses are needed and the order of operations matches the textual order of layers.

Simple Case 2 – Each layer consumes all outputs of the preceding layer.

This is still a single pipeline, but data streams internal to it can split and merge. The Residual example above illustrates this kind.

General Case – Successive layers interact via the data stack.

As described in the Serial class docstring, each layer gets its inputs from the data stack after the preceding layer has put its outputs onto the stack. This covers the simple cases above, but also allows for more flexible data interactions between non-adjacent layers. The following example is schematic:

x, y_target = get_batch_of_labeled_data()

model_plus_eval = Serial(
    my_fancy_deep_model(),  # Takes one arg (x) and has one output (y_hat)
    my_eval(),  # Takes two args (y_hat, y_target) and has one output (score)
)

eval_score = model_plus_eval((x, y_target))

Here is the corresponding progression of stack states:

  1. At start: –empty–
  2. After get_batch_of_labeled_data(): x, y_target
  3. After my_fancy_deep_model(): y_hat, y_target
  4. After my_eval(): score

Note in particular how the application of the model (between stack states 1 and 2) only uses and affects the top element on the stack: x –> y_hat. The rest of the data stack (y_target) comes in use only later, for the eval function.

3. Defining New Layer Classes

If you need a layer type that is not easily defined as a combination of existing layer types, you can define your own layer classes in a couple different ways.

With the Fn layer-creating function.

Many layer types needed in deep learning compute pure functions from inputs to outputs, using neither weights nor randomness. You can use Trax’s Fn function to define your own pure layer types:

def Fn(name, f, n_out=1):  # pylint: disable=invalid-name
  """Returns a layer with no weights that applies the function `f`.

  `f` can take and return any number of arguments, and takes only positional
  arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
  The following, for example, would create a layer that takes two inputs and
  returns two outputs -- element-wise sums and maxima:

      `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`

  The layer's number of inputs (`n_in`) is automatically set to number of
  positional arguments in `f`, but you must explicitly set the number of
  outputs (`n_out`) whenever it's not the default value 1.

  Args:
    name: Class-like name for the resulting layer; for use in debugging.
    f: Pure function from input tensors to output tensors, where each input
        tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
        Output tensors must be packaged as specified in the `Layer` class
        docstring.
    n_out: Number of outputs promised by the layer; default value 1.

  Returns:
    Layer executing the function `f`.
  """

Example 7. Use Fn to define a new layer type:

[ ]:
# Define new layer type.
def Gcd():
  """Returns a layer to compute the greatest common divisor, elementwise."""
  return tl.Fn('Gcd', lambda x0, x1: jnp.gcd(x0, x1))

# Use it.
gcd = Gcd()

x0 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x1 = np.array([11, 12, 13, 14, 15, 16, 17, 18, 19, 20])

y = gcd((x0, x1))

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'gcd((x0, x1)):\n{y}')
x0:
[ 1  2  3  4  5  6  7  8  9 10]

x1:
[11 12 13 14 15 16 17 18 19 20]

gcd((x0, x1)):
[ 1  2  1  2  5  2  1  2  1 10]

The Fn function infers n_in (number of inputs) as the length of f’s arg list. Fn does not infer n_out (number out outputs) though. If your f has more than one output, you need to give an explicit value using the n_out keyword arg.

Example 8. Fn with multiple outputs:

[ ]:
# Define new layer type.
def SumAndMax():
  """Returns a layer to compute sums and maxima of two input tensors."""
  return tl.Fn('SumAndMax',
               lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
               n_out=2)

# Use it.
sum_and_max = SumAndMax()

x0 = np.array([1, 2, 3, 4, 5])
x1 = np.array([10, -20, 30, -40, 50])

y0, y1 = sum_and_max([x0, x1])

print(f'x0:\n{x0}\n\n'
      f'x1:\n{x1}\n\n'
      f'y0:\n{y0}\n\n'
      f'y1:\n{y1}')
x0:
[1 2 3 4 5]

x1:
[ 10 -20  30 -40  50]

y0:
[ 11 -18  33 -36  55]

y1:
[10  2 30  4 50]

Example 9. Use Fn to define a configurable layer:

[ ]:
# Function defined in trax/layers/core.py:
def Flatten(n_axes_to_keep=1):
  """Returns a layer that combines one or more trailing axes of a tensor.

  Flattening keeps all the values of the input tensor, but reshapes it by
  collapsing one or more trailing axes into a single axis. For example, a
  `Flatten(n_axes_to_keep=2)` layer would map a tensor with shape
  `(2, 3, 5, 7, 11)` to the same values with shape `(2, 3, 385)`.

  Args:
    n_axes_to_keep: Number of leading axes to leave unchanged when reshaping;
        collapse only the axes after these.
  """
  layer_name = f'Flatten_keep{n_axes_to_keep}'
  def f(x):
    in_rank = len(x.shape)
    if in_rank <= n_axes_to_keep:
      raise ValueError(f'Input rank ({in_rank}) must exceed the number of '
                       f'axes to keep ({n_axes_to_keep}) after flattening.')
    return jnp.reshape(x, (x.shape[:n_axes_to_keep] + (-1,)))
  return tl.Fn(layer_name, f)

flatten_keep_1_axis = Flatten(n_axes_to_keep=1)
flatten_keep_2_axes = Flatten(n_axes_to_keep=2)

x = np.array([[[1, 2, 3],
               [10, 20, 30],
               [100, 200, 300]],
              [[4, 5, 6],
               [40, 50, 60],
               [400, 500, 600]]])

y1 = flatten_keep_1_axis(x)
y2 = flatten_keep_2_axes(x)

print(f'x:\n{x}\n\n'
      f'flatten_keep_1_axis(x):\n{y1}\n\n'
      f'flatten_keep_2_axes(x):\n{y2}')


x:
[[[  1   2   3]
  [ 10  20  30]
  [100 200 300]]

 [[  4   5   6]
  [ 40  50  60]
  [400 500 600]]]

flatten_keep_1_axis(x):
[[  1   2   3  10  20  30 100 200 300]
 [  4   5   6  40  50  60 400 500 600]]

flatten_keep_2_axes(x):
[[[  1   2   3]
  [ 10  20  30]
  [100 200 300]]

 [[  4   5   6]
  [ 40  50  60]
  [400 500 600]]]

By defining a Layer subclass

If you need a layer type that uses trainable weights (or state), you can extend the base Layer class:

class Layer:
  """Base class for composable layers in a deep learning network.

  ...

  Authors of new layer subclasses typically override at most two methods of
  the base `Layer` class:

    `forward(inputs)`:
      Computes this layer's output as part of a forward pass through the model.

    `init_weights_and_state(self, input_signature)`:
      Initializes weights and state for inputs with the given signature.

  ...

The forward method uses weights stored in the layer object (self.weights) to compute outputs from inputs. For example, here is the definition of forward for Trax’s Dense layer:

def forward(self, x):
  """Executes this layer as part of a forward pass through the model.

  Args:
    x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

  Returns:
    Tensor of same shape and dtype as the input, except the final dimension
    is the layer's `n_units` value.
  """
  if self._use_bias:
    if not isinstance(self.weights, (tuple, list)):
      raise ValueError(f'Weights should be a (w, b) tuple or list; '
                       f'instead got: {self.weights}')
    w, b = self.weights
    return jnp.dot(x, w) + b  # Affine map.
  else:
    w = self.weights
    return jnp.dot(x, w)  # Linear map.

Layer weights must be initialized before the layer can be used; the init_weights_and_state method specifies how. Continuing the Dense example, here is the corresponding initialization code:

def init_weights_and_state(self, input_signature):
  """Randomly initializes this layer's weights.

  Weights are a `(w, b)` tuple for layers created with `use_bias=True` (the
  default case), or a `w` tensor for layers created with `use_bias=False`.

  Args:
    input_signature: `ShapeDtype` instance characterizing the input this layer
        should compute on.
  """
  shape_w = (input_signature.shape[-1], self._n_units)
  shape_b = (self._n_units,)
  rng_w, rng_b = fastmath.random.split(self.rng, 2)
  w = self._kernel_initializer(shape_w, rng_w)

  if self._use_bias:
    b = self._bias_initializer(shape_b, rng_b)
    self.weights = (w, b)
  else:
    self.weights = w

By defining a Combinator subclass

TBD

4. Testing and Debugging Layer Classes

TBD