trax.fastmath¶
ops¶
Trax accelerated math operations for fast computing on GPUs and TPUs.
Import these operations directly from fastmath and import fastmath.numpy as np:
from trax import fastmath
from trax.fastmath import numpy as np
x = np.array([1.0, 2.0]) # Use like numpy.
y = np.exp(x) # Common numpy ops are available and accelerated.
z = fastmath.logsumexp(y) # Special operations available from fastmath.
Trax uses either TensorFlow 2 or JAX as backend for accelerating operations. You can select which one to use (e.g., for debugging) with use_backend.
-
class
trax.fastmath.ops.Backend¶ Bases:
enum.EnumAn enumeration.
-
JAX= 'jax'¶
-
TFNP= 'tensorflow-numpy'¶
-
NUMPY= 'numpy'¶
-
-
class
trax.fastmath.ops.NumpyBackend¶ Bases:
objectNumpy functions accelerated to run on GPUs and TPUs. Use like numpy.
-
trax.fastmath.ops.numpy¶ Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.
-
class
trax.fastmath.ops.RandomBackend¶ Bases:
objectBackend providing random functions.
-
get_prng(seed)¶
-
split(prng, num=2)¶
-
fold_in(rng, data)¶
-
uniform(*args, **kwargs)¶
-
randint(*args, **kwargs)¶
-
normal(*args, **kwargs)¶
-
bernoulli(*args, **kwargs)¶
-
-
trax.fastmath.ops.logsumexp(*args, **kwargs)¶ Computes the log of the sum of exponentials of input elements.
-
trax.fastmath.ops.expit(*args, **kwargs)¶ Computes the expit (sigmoid) function.
-
trax.fastmath.ops.sigmoid(*args, **kwargs)¶ Computes the sigmoid (expit) function.
-
trax.fastmath.ops.erf(*args, **kwargs)¶ Computes the erf function.
-
trax.fastmath.ops.conv(*args, **kwargs)¶ Computes a generalized convolution.
-
trax.fastmath.ops.avg_pool(*args, **kwargs)¶ Average pooling.
-
trax.fastmath.ops.max_pool(*args, **kwargs)¶ Max pooling.
-
trax.fastmath.ops.sum_pool(*args, **kwargs)¶ Sum pooling.
-
trax.fastmath.ops.top_k(*args, **kwargs)¶ Top k.
-
trax.fastmath.ops.sort_key_val(*args, **kwargs)¶ Sorts keys along dimension and applies same permutation to values.
-
trax.fastmath.ops.scan(*args, **kwargs)¶ Scan to make recurrent functions run faster on accelerators.
-
trax.fastmath.ops.map(*args, **kwargs)¶ Map a function over leading array axes.
-
trax.fastmath.ops.fori_loop(lower, upper, body_fn, init_val)¶ Loop from lower to upper running body_fn starting from init_val.
The semantics of fori_loop is as follows:
def fori_loop(lower, upper, body_fn, init_val): val = init_val for i in range(lower, upper): val = body_fn(i, val) return val
Parameters: - lower – an integer representing the loop index lower bound (inclusive)
- upper – an integer representing the loop index upper bound (exclusive)
- body_fn – function of type (int, a) -> a.
- init_val – initial loop carry value of type a.
Returns: Loop value from the final iteration.
-
trax.fastmath.ops.remat(*args, **kwargs)¶ Recompute everything in the backward pass to same memory.
-
trax.fastmath.ops.cond(*args, **kwargs)¶ Conditional computation to run on accelerators.
-
trax.fastmath.ops.lt(*args, **kwargs)¶ Less-than function for backends that do not override <.
-
trax.fastmath.ops.index_update(*args, **kwargs)¶
-
trax.fastmath.ops.index_add(*args, **kwargs)¶
-
trax.fastmath.ops.index_min(*args, **kwargs)¶
-
trax.fastmath.ops.index_max(*args, **kwargs)¶
-
trax.fastmath.ops.dynamic_slice(*args, **kwargs)¶
-
trax.fastmath.ops.dynamic_slice_in_dim(*args, **kwargs)¶
-
trax.fastmath.ops.dynamic_update_slice(*args, **kwargs)¶
-
trax.fastmath.ops.dynamic_update_slice_in_dim(*args, **kwargs)¶
-
trax.fastmath.ops.stop_gradient(*args, **kwargs)¶ Identity on the forward pass but 0 (no gradient) on the backward pass.
-
trax.fastmath.ops.jit(*args, **kwargs)¶ Just-In-Time compiles the given function for use on accelerators.
-
trax.fastmath.ops.disable_jit()¶ Disables JIT-compilation; helpful for debugging.
-
trax.fastmath.ops.vmap(*args, **kwargs)¶ Vectorizes the specified function (returns a function).
-
trax.fastmath.ops.grad(*args, **kwargs)¶ Computes the gradient of the specified function (returns a function).
-
trax.fastmath.ops.value_and_grad(*args, **kwargs)¶ Computes the gradient of the specified function together with the value.
-
trax.fastmath.ops.vjp(*args, **kwargs)¶ Computes the vector-Jacobian product for the specified function.
-
trax.fastmath.ops.custom_grad(*args, **kwargs)¶ Set a custom gradient computation (override the default) for a function.
-
trax.fastmath.ops.custom_vjp(f, f_fwd, f_bwd, nondiff_argnums=())¶ Set a custom vjp computation (override the default) for a function.
-
trax.fastmath.ops.pmap(*args, **kwargs)¶ Parallel-map to apply a function on multiple accelerators in parallel.
-
trax.fastmath.ops.psum(*args, **kwargs)¶ Parallel-sum to use within a pmap’d function for aggregation.
-
trax.fastmath.ops.abstract_eval(*args, **kwargs)¶ Evaluates function just on signatures of parameters, return signatures.
-
trax.fastmath.ops.dataset_as_numpy(*args, **kwargs)¶ Convert a tf.data.Dataset to a stream of numpy arrays.
-
trax.fastmath.ops.global_device_count(*args, **kwargs)¶ Return the number of accelerators (GPUs or TPUs) in all hosts.
-
trax.fastmath.ops.local_device_count(*args, **kwargs)¶ Return the number of accelerators (GPUs or TPUs) available on this host.
-
trax.fastmath.ops.set_backend(name)¶ Sets the default backend to use in Trax.
-
trax.fastmath.ops.backend(name='jax')¶ Returns the backend used to provide fastmath ops (‘tf’ or ‘jax’).
-
trax.fastmath.ops.use_backend(name)¶ Call fastmath functions with a specified backend.
-
trax.fastmath.ops.backend_name()¶ Returns the name of the backend currently in use (‘tf’ or ‘jax’).
-
trax.fastmath.ops.is_backend(backend_)¶