Trax accelerated math operations for fast computing on GPUs and TPUs.

Import these operations directly from fastmath and import fastmath.numpy as np:

from trax import fastmath
from trax.fastmath import numpy as np

x = np.array([1.0, 2.0])  # Use like numpy.
y = np.exp(x)  # Common numpy ops are available and accelerated.
z = fastmath.logsumexp(y)  # Special operations available from fastmath.

Trax uses either TensorFlow 2 or JAX as backend for accelerating operations. You can select which one to use (e.g., for debugging) with use_backend.

class trax.fastmath.ops.Backend

Bases: enum.Enum

An enumeration.

JAX = 'jax'
TFNP = 'tensorflow-numpy'
NUMPY = 'numpy'
class trax.fastmath.ops.NumpyBackend

Bases: object

Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.


Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.

class trax.fastmath.ops.RandomBackend

Bases: object

Backend providing random functions.

split(prng, num=2)
fold_in(rng, data)
uniform(*args, **kwargs)
randint(*args, **kwargs)
normal(*args, **kwargs)
bernoulli(*args, **kwargs)
trax.fastmath.ops.logsumexp(*args, **kwargs)

Computes the log of the sum of exponentials of input elements.

trax.fastmath.ops.expit(*args, **kwargs)

Computes the expit (sigmoid) function.

trax.fastmath.ops.sigmoid(*args, **kwargs)

Computes the sigmoid (expit) function.

trax.fastmath.ops.erf(*args, **kwargs)

Computes the erf function.

trax.fastmath.ops.conv(*args, **kwargs)

Computes a generalized convolution.

trax.fastmath.ops.avg_pool(*args, **kwargs)

Average pooling.

trax.fastmath.ops.max_pool(*args, **kwargs)

Max pooling.

trax.fastmath.ops.sum_pool(*args, **kwargs)

Sum pooling.

trax.fastmath.ops.top_k(*args, **kwargs)

Top k.

trax.fastmath.ops.sort_key_val(*args, **kwargs)

Sorts keys along dimension and applies same permutation to values.

trax.fastmath.ops.scan(*args, **kwargs)

Scan to make recurrent functions run faster on accelerators.*args, **kwargs)

Map a function over leading array axes.

trax.fastmath.ops.fori_loop(lower, upper, body_fn, init_val)

Loop from lower to upper running body_fn starting from init_val.

The semantics of fori_loop is as follows:

def fori_loop(lower, upper, body_fn, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fn(i, val)
  return val
  • lower – an integer representing the loop index lower bound (inclusive)
  • upper – an integer representing the loop index upper bound (exclusive)
  • body_fn – function of type (int, a) -> a.
  • init_val – initial loop carry value of type a.

Loop value from the final iteration.

trax.fastmath.ops.remat(*args, **kwargs)

Recompute everything in the backward pass to same memory.

trax.fastmath.ops.cond(*args, **kwargs)

Conditional computation to run on accelerators.*args, **kwargs)

Less-than function for backends that do not override <.

trax.fastmath.ops.index_update(*args, **kwargs)
trax.fastmath.ops.index_add(*args, **kwargs)
trax.fastmath.ops.index_min(*args, **kwargs)
trax.fastmath.ops.index_max(*args, **kwargs)
trax.fastmath.ops.dynamic_slice(*args, **kwargs)
trax.fastmath.ops.dynamic_slice_in_dim(*args, **kwargs)
trax.fastmath.ops.dynamic_update_slice(*args, **kwargs)
trax.fastmath.ops.dynamic_update_slice_in_dim(*args, **kwargs)
trax.fastmath.ops.stop_gradient(*args, **kwargs)

Identity on the forward pass but 0 (no gradient) on the backward pass.

trax.fastmath.ops.jit(*args, **kwargs)

Just-In-Time compiles the given function for use on accelerators.


Disables JIT-compilation; helpful for debugging.

trax.fastmath.ops.vmap(*args, **kwargs)

Vectorizes the specified function (returns a function).

trax.fastmath.ops.grad(*args, **kwargs)

Computes the gradient of the specified function (returns a function).

trax.fastmath.ops.value_and_grad(*args, **kwargs)

Computes the gradient of the specified function together with the value.

trax.fastmath.ops.vjp(*args, **kwargs)

Computes the vector-Jacobian product for the specified function.

trax.fastmath.ops.custom_grad(*args, **kwargs)

Set a custom gradient computation (override the default) for a function.

trax.fastmath.ops.custom_vjp(f, f_fwd, f_bwd, nondiff_argnums=())

Set a custom vjp computation (override the default) for a function.

trax.fastmath.ops.pmap(*args, **kwargs)

Parallel-map to apply a function on multiple accelerators in parallel.

trax.fastmath.ops.psum(*args, **kwargs)

Parallel-sum to use within a pmap’d function for aggregation.

trax.fastmath.ops.abstract_eval(*args, **kwargs)

Evaluates function just on signatures of parameters, return signatures.

trax.fastmath.ops.dataset_as_numpy(*args, **kwargs)

Convert a to a stream of numpy arrays.

trax.fastmath.ops.global_device_count(*args, **kwargs)

Return the number of accelerators (GPUs or TPUs) in all hosts.

trax.fastmath.ops.local_device_count(*args, **kwargs)

Return the number of accelerators (GPUs or TPUs) available on this host.


Sets the default backend to use in Trax.


Returns the backend used to provide fastmath ops (‘tf’ or ‘jax’).


Call fastmath functions with a specified backend.


Returns the name of the backend currently in use (‘tf’ or ‘jax’).