trax.fastmath

ops

Trax accelerated math operations for fast computing on GPUs and TPUs.

Import these operations directly from fastmath and import fastmath.numpy as np:

from trax import fastmath
from trax.fastmath import numpy as np

x = np.array([1.0, 2.0])  # Use like numpy.
y = np.exp(x)  # Common numpy ops are available and accelerated.
z = fastmath.logsumexp(y)  # Special operations available from fastmath.

Trax uses either TensorFlow 2 or JAX as backend for accelerating operations. You can select which one to use (e.g., for debugging) with use_backend.

class trax.fastmath.ops.Backend

Bases: enum.Enum

An enumeration.

JAX = 'jax'
TFNP = 'tensorflow-numpy'
NUMPY = 'numpy'
class trax.fastmath.ops.NumpyBackend

Bases: object

Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.

trax.fastmath.ops.numpy

Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.

class trax.fastmath.ops.RandomBackend

Bases: object

Backend providing random functions.

get_prng(seed)
split(prng, num=2)
fold_in(rng, data)
uniform(*args, **kwargs)
randint(*args, **kwargs)
normal(*args, **kwargs)
bernoulli(*args, **kwargs)
trax.fastmath.ops.logsumexp(*args, **kwargs)

Computes the log of the sum of exponentials of input elements.

trax.fastmath.ops.expit(*args, **kwargs)

Computes the expit (sigmoid) function.

trax.fastmath.ops.sigmoid(*args, **kwargs)

Computes the sigmoid (expit) function.

trax.fastmath.ops.erf(*args, **kwargs)

Computes the erf function.

trax.fastmath.ops.conv(*args, **kwargs)

Computes a generalized convolution.

trax.fastmath.ops.avg_pool(*args, **kwargs)

Average pooling.

trax.fastmath.ops.max_pool(*args, **kwargs)

Max pooling.

trax.fastmath.ops.sum_pool(*args, **kwargs)

Sum pooling.

trax.fastmath.ops.top_k(*args, **kwargs)

Top k.

trax.fastmath.ops.sort_key_val(*args, **kwargs)

Sorts keys along dimension and applies same permutation to values.

trax.fastmath.ops.scan(*args, **kwargs)

Scan to make recurrent functions run faster on accelerators.

trax.fastmath.ops.map(*args, **kwargs)

Map a function over leading array axes.

trax.fastmath.ops.fori_loop(lower, upper, body_fn, init_val)

Loop from lower to upper running body_fn starting from init_val.

The semantics of fori_loop is as follows:

def fori_loop(lower, upper, body_fn, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fn(i, val)
  return val
Parameters:
  • lower – an integer representing the loop index lower bound (inclusive)
  • upper – an integer representing the loop index upper bound (exclusive)
  • body_fn – function of type (int, a) -> a.
  • init_val – initial loop carry value of type a.
Returns:

Loop value from the final iteration.

trax.fastmath.ops.remat(*args, **kwargs)

Recompute everything in the backward pass to same memory.

trax.fastmath.ops.cond(*args, **kwargs)

Conditional computation to run on accelerators.

trax.fastmath.ops.lt(*args, **kwargs)

Less-than function for backends that do not override <.

trax.fastmath.ops.index_update(*args, **kwargs)
trax.fastmath.ops.index_add(*args, **kwargs)
trax.fastmath.ops.index_min(*args, **kwargs)
trax.fastmath.ops.index_max(*args, **kwargs)
trax.fastmath.ops.dynamic_slice(*args, **kwargs)
trax.fastmath.ops.dynamic_slice_in_dim(*args, **kwargs)
trax.fastmath.ops.dynamic_update_slice(*args, **kwargs)
trax.fastmath.ops.dynamic_update_slice_in_dim(*args, **kwargs)
trax.fastmath.ops.stop_gradient(*args, **kwargs)

Identity on the forward pass but 0 (no gradient) on the backward pass.

trax.fastmath.ops.jit(*args, **kwargs)

Just-In-Time compiles the given function for use on accelerators.

trax.fastmath.ops.disable_jit()

Disables JIT-compilation; helpful for debugging.

trax.fastmath.ops.vmap(*args, **kwargs)

Vectorizes the specified function (returns a function).

trax.fastmath.ops.grad(*args, **kwargs)

Computes the gradient of the specified function (returns a function).

trax.fastmath.ops.value_and_grad(*args, **kwargs)

Computes the gradient of the specified function together with the value.

trax.fastmath.ops.vjp(*args, **kwargs)

Computes the vector-Jacobian product for the specified function.

trax.fastmath.ops.custom_grad(*args, **kwargs)

Set a custom gradient computation (override the default) for a function.

trax.fastmath.ops.custom_vjp(f, f_fwd, f_bwd, nondiff_argnums=())

Set a custom vjp computation (override the default) for a function.

trax.fastmath.ops.pmap(*args, **kwargs)

Parallel-map to apply a function on multiple accelerators in parallel.

trax.fastmath.ops.psum(*args, **kwargs)

Parallel-sum to use within a pmap’d function for aggregation.

trax.fastmath.ops.abstract_eval(*args, **kwargs)

Evaluates function just on signatures of parameters, return signatures.

trax.fastmath.ops.dataset_as_numpy(*args, **kwargs)

Convert a tf.data.Dataset to a stream of numpy arrays.

trax.fastmath.ops.global_device_count(*args, **kwargs)

Return the number of accelerators (GPUs or TPUs) in all hosts.

trax.fastmath.ops.local_device_count(*args, **kwargs)

Return the number of accelerators (GPUs or TPUs) available on this host.

trax.fastmath.ops.set_backend(name)

Sets the default backend to use in Trax.

trax.fastmath.ops.backend(name='jax')

Returns the backend used to provide fastmath ops (‘tf’ or ‘jax’).

trax.fastmath.ops.use_backend(name)

Call fastmath functions with a specified backend.

trax.fastmath.ops.backend_name()

Returns the name of the backend currently in use (‘tf’ or ‘jax’).

trax.fastmath.ops.is_backend(backend_)