trax.fastmath¶
ops¶
Trax accelerated math operations for fast computing on GPUs and TPUs.
Import these operations directly from fastmath and import fastmath.numpy as np:
from trax import fastmath
from trax.fastmath import numpy as np
x = np.array([1.0, 2.0]) # Use like numpy.
y = np.exp(x) # Common numpy ops are available and accelerated.
z = fastmath.logsumexp(y) # Special operations available from fastmath.
Trax uses either TensorFlow 2 or JAX as backend for accelerating operations. You can select which one to use (e.g., for debugging) with use_backend.
-
class
trax.fastmath.ops.
Backend
¶ Bases:
enum.Enum
An enumeration.
-
JAX
= 'jax'¶
-
TFNP
= 'tensorflow-numpy'¶
-
NUMPY
= 'numpy'¶
-
-
class
trax.fastmath.ops.
NumpyBackend
¶ Bases:
object
Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.
-
trax.fastmath.ops.
numpy
¶ Numpy functions accelerated to run on GPUs and TPUs. Use like numpy.
-
class
trax.fastmath.ops.
RandomBackend
¶ Bases:
object
Backend providing random functions.
-
get_prng
(seed)¶
-
split
(prng, num=2)¶
-
fold_in
(rng, data)¶
-
uniform
(*args, **kwargs)¶
-
randint
(*args, **kwargs)¶
-
normal
(*args, **kwargs)¶
-
bernoulli
(*args, **kwargs)¶
-
-
trax.fastmath.ops.
logsumexp
(*args, **kwargs)¶ Computes the log of the sum of exponentials of input elements.
-
trax.fastmath.ops.
expit
(*args, **kwargs)¶ Computes the expit (sigmoid) function.
-
trax.fastmath.ops.
sigmoid
(*args, **kwargs)¶ Computes the sigmoid (expit) function.
-
trax.fastmath.ops.
erf
(*args, **kwargs)¶ Computes the erf function.
-
trax.fastmath.ops.
conv
(*args, **kwargs)¶ Computes a generalized convolution.
-
trax.fastmath.ops.
avg_pool
(*args, **kwargs)¶ Average pooling.
-
trax.fastmath.ops.
max_pool
(*args, **kwargs)¶ Max pooling.
-
trax.fastmath.ops.
sum_pool
(*args, **kwargs)¶ Sum pooling.
-
trax.fastmath.ops.
top_k
(*args, **kwargs)¶ Top k.
-
trax.fastmath.ops.
sort_key_val
(*args, **kwargs)¶ Sorts keys along dimension and applies same permutation to values.
-
trax.fastmath.ops.
scan
(*args, **kwargs)¶ Scan to make recurrent functions run faster on accelerators.
-
trax.fastmath.ops.
map
(*args, **kwargs)¶ Map a function over leading array axes.
-
trax.fastmath.ops.
fori_loop
(lower, upper, body_fn, init_val)¶ Loop from lower to upper running body_fn starting from init_val.
The semantics of fori_loop is as follows:
def fori_loop(lower, upper, body_fn, init_val): val = init_val for i in range(lower, upper): val = body_fn(i, val) return val
Parameters: - lower – an integer representing the loop index lower bound (inclusive)
- upper – an integer representing the loop index upper bound (exclusive)
- body_fn – function of type (int, a) -> a.
- init_val – initial loop carry value of type a.
Returns: Loop value from the final iteration.
-
trax.fastmath.ops.
remat
(*args, **kwargs)¶ Recompute everything in the backward pass to same memory.
-
trax.fastmath.ops.
cond
(*args, **kwargs)¶ Conditional computation to run on accelerators.
-
trax.fastmath.ops.
lt
(*args, **kwargs)¶ Less-than function for backends that do not override <.
-
trax.fastmath.ops.
index_update
(*args, **kwargs)¶
-
trax.fastmath.ops.
index_add
(*args, **kwargs)¶
-
trax.fastmath.ops.
index_min
(*args, **kwargs)¶
-
trax.fastmath.ops.
index_max
(*args, **kwargs)¶
-
trax.fastmath.ops.
dynamic_slice
(*args, **kwargs)¶
-
trax.fastmath.ops.
dynamic_slice_in_dim
(*args, **kwargs)¶
-
trax.fastmath.ops.
dynamic_update_slice
(*args, **kwargs)¶
-
trax.fastmath.ops.
dynamic_update_slice_in_dim
(*args, **kwargs)¶
-
trax.fastmath.ops.
stop_gradient
(*args, **kwargs)¶ Identity on the forward pass but 0 (no gradient) on the backward pass.
-
trax.fastmath.ops.
jit
(*args, **kwargs)¶ Just-In-Time compiles the given function for use on accelerators.
-
trax.fastmath.ops.
disable_jit
()¶ Disables JIT-compilation; helpful for debugging.
-
trax.fastmath.ops.
vmap
(*args, **kwargs)¶ Vectorizes the specified function (returns a function).
-
trax.fastmath.ops.
grad
(*args, **kwargs)¶ Computes the gradient of the specified function (returns a function).
-
trax.fastmath.ops.
value_and_grad
(*args, **kwargs)¶ Computes the gradient of the specified function together with the value.
-
trax.fastmath.ops.
vjp
(*args, **kwargs)¶ Computes the vector-Jacobian product for the specified function.
-
trax.fastmath.ops.
custom_grad
(*args, **kwargs)¶ Set a custom gradient computation (override the default) for a function.
-
trax.fastmath.ops.
custom_vjp
(f, f_fwd, f_bwd, nondiff_argnums=())¶ Set a custom vjp computation (override the default) for a function.
-
trax.fastmath.ops.
pmap
(*args, **kwargs)¶ Parallel-map to apply a function on multiple accelerators in parallel.
-
trax.fastmath.ops.
psum
(*args, **kwargs)¶ Parallel-sum to use within a pmap’d function for aggregation.
-
trax.fastmath.ops.
abstract_eval
(*args, **kwargs)¶ Evaluates function just on signatures of parameters, return signatures.
-
trax.fastmath.ops.
dataset_as_numpy
(*args, **kwargs)¶ Convert a tf.data.Dataset to a stream of numpy arrays.
-
trax.fastmath.ops.
global_device_count
(*args, **kwargs)¶ Return the number of accelerators (GPUs or TPUs) in all hosts.
-
trax.fastmath.ops.
local_device_count
(*args, **kwargs)¶ Return the number of accelerators (GPUs or TPUs) available on this host.
-
trax.fastmath.ops.
set_backend
(name)¶ Sets the default backend to use in Trax.
-
trax.fastmath.ops.
backend
(name='jax')¶ Returns the backend used to provide fastmath ops (‘tf’ or ‘jax’).
-
trax.fastmath.ops.
use_backend
(name)¶ Call fastmath functions with a specified backend.
-
trax.fastmath.ops.
backend_name
()¶ Returns the name of the backend currently in use (‘tf’ or ‘jax’).
-
trax.fastmath.ops.
is_backend
(backend_)¶