Trax
stable
Introductory Notebooks
Trax Quick Intro
Trax Layers Intro
Using Trax with TensorFlow NumPy and Keras
Packages/modules
trax.*
Trax
Docs
»
Index
Edit on GitHub
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
|
W
_
__call__() (trax.layers.base.Layer method)
__init__() (trax.data.inputs.Inputs method)
(trax.layers.acceleration.Accelerate method)
(trax.layers.attention.DotProductAttention method)
(trax.layers.attention.DotProductCausalAttention method)
(trax.layers.attention.PositionalEncoding method)
(trax.layers.attention.PureAttention method)
(trax.layers.base.Layer method)
(trax.layers.base.LayerError method)
(trax.layers.base.PureLayer method)
(trax.layers.combinators.BatchLeadingAxes method)
(trax.layers.combinators.Cache method)
(trax.layers.combinators.Concatenate method)
(trax.layers.combinators.Cond method)
(trax.layers.combinators.Parallel method)
(trax.layers.combinators.Scan method)
(trax.layers.combinators.Serial method)
(trax.layers.combinators.Split method)
(trax.layers.convolution.CausalConv method)
(trax.layers.convolution.Conv method)
(trax.layers.core.Dense method)
(trax.layers.core.Dropout method)
(trax.layers.core.Embedding method)
(trax.layers.core.LocallyConnected1d method)
(trax.layers.core.RandomUniform method)
(trax.layers.core.SummaryImage method)
(trax.layers.core.SummaryScalar method)
(trax.layers.core.Weights method)
(trax.layers.normalization.BatchNorm method)
(trax.layers.normalization.FilterResponseNorm method)
(trax.layers.normalization.LayerNorm method)
(trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.EncDecAttention method)
(trax.layers.research.efficient_attention.LSHFF method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.MixedLSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttentionWrapper method)
(trax.layers.research.efficient_attention.SelfAttention method)
(trax.layers.research.position_encodings.AxialPositionalEncoding method)
(trax.layers.research.position_encodings.FixedBasePositionalEncoding method)
(trax.layers.research.position_encodings.InfinitePositionalEncoding method)
(trax.layers.research.position_encodings.SinCosPositionalEncoding method)
(trax.layers.research.position_encodings.TimeBinPositionalEncoding method)
(trax.layers.reversible.ReversibleConcatenatePair method)
(trax.layers.reversible.ReversibleHalfResidual method)
(trax.layers.reversible.ReversiblePrintShape method)
(trax.layers.reversible.ReversibleReshape method)
(trax.layers.reversible.ReversibleSelect method)
(trax.layers.reversible.ReversibleSerial method)
(trax.layers.rnn.GRUCell method)
(trax.layers.rnn.LSTMCell method)
(trax.models.research.bert.PretrainedBERT method)
(trax.optimizers.adafactor.Adafactor method)
(trax.optimizers.adam.Adam method)
(trax.optimizers.base.Optimizer method)
(trax.optimizers.momentum.Momentum method)
(trax.optimizers.rms_prop.RMSProp method)
(trax.optimizers.sm3.SM3 method)
(trax.rl.actor_critic.A2C method)
(trax.rl.actor_critic.AWR method)
(trax.rl.actor_critic.ActorCriticAgent method)
(trax.rl.actor_critic.AdvantageBasedActorCriticAgent method)
(trax.rl.actor_critic.LoopAWR method)
(trax.rl.actor_critic.LoopActorCriticAgent method)
(trax.rl.actor_critic.PPO method)
(trax.rl.actor_critic.SamplingAWR method)
(trax.rl.actor_critic_joint.A2CJoint method)
(trax.rl.actor_critic_joint.AWRJoint method)
(trax.rl.actor_critic_joint.ActorCriticJointAgent method)
(trax.rl.actor_critic_joint.PPOJoint method)
(trax.rl.serialization_utils.SerializedModel method)
(trax.rl.space_serializer.DiscreteSpaceSerializer method)
(trax.rl.space_serializer.MultiDiscreteSpaceSerializer method)
(trax.rl.space_serializer.SpaceSerializer method)
(trax.rl.task.Trajectory method)
(trax.rl.training.Agent method)
(trax.rl.training.DQN method)
(trax.rl.training.ExpertIteration method)
(trax.rl.training.LoopPolicyAgent method)
(trax.rl.training.PolicyAgent method)
(trax.rl.training.PolicyGradient method)
(trax.rl.training.ValueAgent method)
(trax.shapes.ShapeDtype method)
(trax.supervised.training.Loop method)
(trax.trax2keras.AsKeras method)
A
A2C (class in trax.rl.actor_critic)
a2c_objective (trax.rl.actor_critic_joint.A2CJoint attribute)
a2c_objective_mean (trax.rl.actor_critic_joint.A2CJoint attribute)
A2CJoint (class in trax.rl.actor_critic_joint)
A2CObjective() (in module trax.rl.rl_layers)
abstract_eval() (in module trax.fastmath.ops)
Accelerate (class in trax.layers.acceleration)
Accuracy() (in module trax.layers.metrics)
action (trax.rl.task.TimeStepBatch attribute)
action_serializer (trax.rl.serialization_utils.SerializedModel attribute)
ActorCriticAgent (class in trax.rl.actor_critic)
ActorCriticJointAgent (class in trax.rl.actor_critic_joint)
Adafactor (class in trax.optimizers.adafactor)
Adam (class in trax.optimizers.adam)
Add() (in module trax.layers.combinators)
add_eos_to_output_features() (in module trax.data.tf_inputs)
add_loss_weights() (in module trax.data.inputs)
AddBias (class in trax.models.research.bert)
addition_input_stream() (in module trax.data.inputs)
addition_inputs() (in module trax.data.inputs)
AddLossWeights() (in module trax.data.inputs)
advantage_mean (trax.rl.actor_critic.AdvantageBasedActorCriticAgent attribute)
(trax.rl.actor_critic_joint.ActorCriticJointAgent attribute)
advantage_norm (trax.rl.actor_critic_joint.ActorCriticJointAgent attribute)
advantage_std (trax.rl.actor_critic.AdvantageBasedActorCriticAgent attribute)
AdvantageBasedActorCriticAgent (class in trax.rl.actor_critic)
Agent (class in trax.rl.training)
analyze_action_space() (in module trax.rl.serialization_utils)
AppendValue() (in module trax.data.inputs)
apply_broadcasted_dropout() (in module trax.layers.research.efficient_attention)
approximate_kl_divergence (trax.rl.actor_critic_joint.A2CJoint attribute)
(trax.rl.actor_critic_joint.PPOJoint attribute)
ApproximateKLDivergence() (in module trax.rl.rl_layers)
ArgMax() (in module trax.layers.core)
as_tuple() (trax.shapes.ShapeDtype method)
AsKeras (class in trax.trax2keras)
assert_same_shape() (in module trax.shapes)
assert_shape_equals() (in module trax.shapes)
AtariCnn() (in module trax.models.atari_cnn)
AtariCnnBody() (in module trax.models.atari_cnn)
AtariConvInit() (in module trax.layers.initializers)
attend() (in module trax.layers.research.efficient_attention)
Attention() (in module trax.layers.attention)
AttentionQKV() (in module trax.layers.attention)
autoregressive_sample() (in module trax.supervised.decoding)
autoregressive_sample_stream() (in module trax.supervised.decoding)
avg_pool() (in module trax.fastmath.ops)
avg_returns (trax.rl.training.Agent attribute)
AvgPool() (in module trax.layers.pooling)
AWR (class in trax.rl.actor_critic)
awr_metrics() (in module trax.rl.actor_critic)
awr_weight_stat() (in module trax.rl.actor_critic)
awr_weights() (in module trax.rl.actor_critic)
AWRJoint (class in trax.rl.actor_critic_joint)
AWRLoss() (in module trax.rl.actor_critic)
AxialPositionalEncoding (class in trax.layers.research.position_encodings)
B
Backend (class in trax.fastmath.ops)
backend() (in module trax.fastmath.ops)
backend_name() (in module trax.fastmath.ops)
backward() (trax.layers.base.Layer method)
(trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
(trax.layers.research.efficient_attention.SelfAttention method)
(trax.layers.reversible.ReversibleLayer method)
bair_robot_pushing_hparams() (in module trax.data.tf_inputs)
bair_robot_pushing_preprocess() (in module trax.data.tf_inputs)
Batch() (in module trax.data.inputs)
batch() (in module trax.data.inputs)
batch_fn() (in module trax.data.inputs)
batcher() (in module trax.data.inputs)
batches_stream() (trax.rl.actor_critic_joint.A2CJoint method)
(trax.rl.actor_critic_joint.AWRJoint method)
(trax.rl.actor_critic_joint.ActorCriticJointAgent method)
(trax.rl.actor_critic_joint.PPOJoint method)
BatchLeadingAxes (class in trax.layers.combinators)
BatchNorm (class in trax.layers.normalization)
beam_search() (in module trax.supervised.decoding)
bernoulli() (trax.fastmath.ops.RandomBackend method)
BERT() (in module trax.models.research.bert)
BERTClassifierHead() (in module trax.models.research.bert)
BertDoubleSentenceInputs() (in module trax.data.tf_inputs)
BertGlueEvalStream() (in module trax.data.tf_inputs)
BertGlueTrainStream() (in module trax.data.tf_inputs)
BERTMLMHead() (in module trax.models.research.bert)
BertNextSentencePredictionInputs() (in module trax.data.tf_inputs)
BERTPretrainingHead() (in module trax.models.research.bert)
BERTPretrainingLoss() (in module trax.models.research.bert)
BERTRegressionHead() (in module trax.models.research.bert)
BertSingleSentenceInputs() (in module trax.data.tf_inputs)
BinaryCrossEntropy() (in module trax.layers.metrics)
BinaryCrossEntropyLoss() (in module trax.layers.metrics)
BinaryCrossEntropySum() (in module trax.layers.metrics)
Branch() (in module trax.layers.combinators)
bucket_by_length() (in module trax.data.inputs)
BucketByLength() (in module trax.data.inputs)
build() (trax.trax2keras.AsKeras method)
C
c4_bare_preprocess_fn() (in module trax.data.tf_inputs)
c4_preprocess() (in module trax.data.tf_inputs)
Cache (class in trax.layers.combinators)
calculate_returns() (trax.rl.task.Trajectory method)
call() (trax.trax2keras.AsKeras method)
CastTo() (in module trax.data.inputs)
CategoryAccuracy() (in module trax.layers.metrics)
CategoryCrossEntropy() (in module trax.layers.metrics)
CausalAttention() (in module trax.layers.attention)
CausalConv (class in trax.layers.convolution)
Chunk() (in module trax.layers.combinators)
cifar10_augmentation_flatten_preprocess() (in module trax.data.tf_inputs)
cifar10_augmentation_preprocess() (in module trax.data.tf_inputs)
cifar10_no_augmentation_preprocess() (in module trax.data.tf_inputs)
clip_fraction (trax.rl.actor_critic_joint.PPOJoint attribute)
clip_grads() (in module trax.optimizers.base)
clipped_objective_mean (trax.rl.actor_critic_joint.PPOJoint attribute)
ClippedObjective() (in module trax.rl.rl_layers)
close() (trax.rl.actor_critic.ActorCriticAgent method)
(trax.rl.actor_critic_joint.ActorCriticJointAgent method)
(trax.rl.training.Agent method)
(trax.rl.training.PolicyAgent method)
(trax.rl.training.ValueAgent method)
compute_nums() (in module trax.data.tf_inputs)
compute_ops() (in module trax.data.tf_inputs)
compute_program() (in module trax.data.tf_inputs)
compute_result() (in module trax.data.tf_inputs)
compute_single_result() (in module trax.data.tf_inputs)
concat_preprocess() (in module trax.data.tf_inputs)
Concatenate (class in trax.layers.combinators)
ConcatenateToLMInput() (in module trax.data.inputs)
Cond (class in trax.layers.combinators)
cond() (in module trax.fastmath.ops)
ConfigurableAttention() (in module trax.layers.attention)
constant() (in module trax.supervised.lr_schedules)
consume_noise_mask() (in module trax.data.inputs)
control_mask (trax.rl.task.EnvInfo attribute)
Conv (class in trax.layers.convolution)
conv() (in module trax.fastmath.ops)
Conv1d() (in module trax.layers.convolution)
ConvBlock() (in module trax.models.resnet)
ConvDiagonalGRU() (in module trax.models.neural_gpu)
convert_float_to_mathqa() (in module trax.data.tf_inputs)
convert_to_subtract() (in module trax.data.tf_inputs)
ConvertToUnicode() (in module trax.data.tf_inputs)
ConvGRUCell() (in module trax.layers.rnn)
CorpusToRandomChunks() (in module trax.data.tf_inputs)
count_and_skip() (in module trax.data.inputs)
CountAndSkip() (in module trax.data.inputs)
create() (in module trax.rl.space_serializer)
create_distribution() (in module trax.rl.distributions)
create_state_unbatched() (trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
(trax.layers.research.efficient_attention.SelfAttention method)
create_weights_unbatched() (trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.EncDecAttention method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.SelfAttention method)
CreateAnnotatedDropInputs() (in module trax.data.tf_inputs)
CreateAquaInputs() (in module trax.data.tf_inputs)
CreateBertInputs() (in module trax.data.tf_inputs)
CreateDropInputs() (in module trax.data.tf_inputs)
CreateMathQAInputs() (in module trax.data.tf_inputs)
CrossEntropyLoss() (in module trax.layers.metrics)
CrossEntropyLossWithLogSoftmax() (in module trax.layers.metrics)
CrossEntropySum() (in module trax.layers.metrics)
current_epoch (trax.rl.training.Agent attribute)
custom_grad() (in module trax.fastmath.ops)
custom_vjp() (in module trax.fastmath.ops)
D
data_streams() (in module trax.data.tf_inputs)
dataset_as_numpy() (in module trax.fastmath.ops)
dataset_to_stream() (in module trax.data.tf_inputs)
DecoderBlock() (in module trax.models.reformer.reformer)
Deinterleave() (in module trax.rl.serialization_utils)
Dense (class in trax.layers.core)
deserialize() (trax.rl.space_serializer.DiscreteSpaceSerializer method)
(trax.rl.space_serializer.MultiDiscreteSpaceSerializer method)
(trax.rl.space_serializer.SpaceSerializer method)
detokenize() (in module trax.data.tf_inputs)
DiagonalGate() (in module trax.models.neural_gpu)
disable_jit() (in module trax.fastmath.ops)
discount_mask (trax.rl.task.EnvInfo attribute)
discounted_returns() (in module trax.rl.advantages)
DiscreteSpaceSerializer (class in trax.rl.space_serializer)
dist_inputs (trax.rl.task.TimeStepBatch attribute)
Distribution (class in trax.rl.distributions)
done (trax.rl.task.TimeStepBatch attribute)
(trax.rl.task.Trajectory attribute)
DotProductAttention (class in trax.layers.attention)
DotProductCausalAttention (class in trax.layers.attention)
download_and_prepare() (in module trax.data.tf_inputs)
download_model() (trax.models.research.bert.PretrainedBERT class method)
downsampled_imagenet_flatten_bare_preprocess() (in module trax.data.tf_inputs)
DQN (class in trax.rl.training)
Drop() (in module trax.layers.combinators)
Dropout (class in trax.layers.core)
dtype (trax.shapes.ShapeDtype attribute)
Dup() (in module trax.layers.combinators)
dynamic_slice() (in module trax.fastmath.ops)
dynamic_slice_in_dim() (in module trax.fastmath.ops)
dynamic_update_slice() (in module trax.fastmath.ops)
dynamic_update_slice_in_dim() (in module trax.fastmath.ops)
E
EfficientAttentionBase (class in trax.layers.research.efficient_attention)
Elu() (in module trax.layers.activation_fns)
EMA (trax.optimizers.sm3.MomentumType attribute)
Embedding (class in trax.layers.core)
EncDecAttention (class in trax.layers.research.efficient_attention)
EncoderBlock() (in module trax.models.reformer.reformer)
EncoderDecoderBlock() (in module trax.models.reformer.reformer)
EncoderDecoderMask() (in module trax.layers.attention)
entropy_loss (trax.rl.actor_critic_joint.A2CJoint attribute)
(trax.rl.actor_critic_joint.PPOJoint attribute)
EntropyLoss() (in module trax.rl.rl_layers)
env_info (trax.rl.task.TimeStepBatch attribute)
EnvInfo (class in trax.rl.task)
erf() (in module trax.fastmath.ops)
eval_model (trax.supervised.training.Loop attribute)
eval_stream() (trax.data.inputs.Inputs method)
eval_tasks (trax.supervised.training.Loop attribute)
every() (in module trax.rl.actor_critic)
example_shape_dtype (trax.data.inputs.Inputs attribute)
Exp() (in module trax.layers.activation_fns)
ExpertIteration (class in trax.rl.training)
expit() (in module trax.fastmath.ops)
explained_variance (trax.rl.actor_critic_joint.ActorCriticJointAgent attribute)
ExplainedVariance() (in module trax.rl.rl_layers)
extend() (trax.rl.task.Trajectory method)
F
FastGelu() (in module trax.layers.activation_fns)
filter_dataset_on_len() (in module trax.data.tf_inputs)
FilterByLength() (in module trax.data.inputs)
FilterEmptyExamples() (in module trax.data.inputs)
FilterResponseNorm (class in trax.layers.normalization)
FixedBasePositionalEncoding (class in trax.layers.research.position_encodings)
Flatten() (in module trax.layers.core)
flatten_weights_and_state() (in module trax.layers.base)
FlattenList() (in module trax.layers.combinators)
Fn() (in module trax.layers.base)
fold_in() (trax.fastmath.ops.RandomBackend method)
for_n_devices() (in module trax.layers.acceleration)
fori_loop() (in module trax.fastmath.ops)
forward() (trax.layers.activation_fns.ThresholdedLinearUnit method)
(trax.layers.attention.DotProductAttention method)
(trax.layers.attention.DotProductCausalAttention method)
(trax.layers.attention.PositionalEncoding method)
(trax.layers.attention.PureAttention method)
(trax.layers.base.Layer method)
(trax.layers.base.PureLayer method)
(trax.layers.combinators.BatchLeadingAxes method)
(trax.layers.combinators.Cache method)
(trax.layers.combinators.Concatenate method)
(trax.layers.combinators.Cond method)
(trax.layers.combinators.Parallel method)
(trax.layers.combinators.Scan method)
(trax.layers.combinators.Serial method)
(trax.layers.combinators.Split method)
(trax.layers.convolution.CausalConv method)
(trax.layers.convolution.Conv method)
(trax.layers.core.Dense method)
(trax.layers.core.Dropout method)
(trax.layers.core.Embedding method)
(trax.layers.core.LocallyConnected1d method)
(trax.layers.core.RandomUniform method)
(trax.layers.core.SummaryImage method)
(trax.layers.core.SummaryScalar method)
(trax.layers.core.Weights method)
(trax.layers.normalization.BatchNorm method)
(trax.layers.normalization.FilterResponseNorm method)
(trax.layers.normalization.LayerNorm method)
(trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.LSHFF method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.MixedLSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
(trax.layers.research.efficient_attention.SelfAttention method)
(trax.layers.research.position_encodings.AxialPositionalEncoding method)
(trax.layers.research.position_encodings.FixedBasePositionalEncoding method)
(trax.layers.research.position_encodings.InfinitePositionalEncoding method)
(trax.layers.research.position_encodings.SinCosPositionalEncoding method)
(trax.layers.research.position_encodings.TimeBinPositionalEncoding method)
(trax.layers.reversible.ReversibleConcatenatePair method)
(trax.layers.reversible.ReversibleHalfResidual method)
(trax.layers.reversible.ReversiblePrintShape method)
(trax.layers.reversible.ReversibleReshape method)
(trax.layers.reversible.ReversibleSelect method)
(trax.layers.rnn.GRUCell method)
(trax.layers.rnn.LSTMCell method)
(trax.models.research.bert.AddBias method)
forward_and_or_backward() (trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.MixedLSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttentionWrapper method)
(trax.layers.research.efficient_attention.SelfAttention method)
forward_unbatched() (trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.EncDecAttention method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
(trax.layers.research.efficient_attention.SelfAttention method)
FrameStackMLP() (in module trax.models.atari_cnn)
G
gae() (in module trax.rl.advantages)
Gate() (in module trax.layers.combinators)
Gelu() (in module trax.layers.activation_fns)
GeneralGRUCell() (in module trax.layers.rnn)
generate_random_noise_mask() (in module trax.data.inputs)
generate_sequential_chunks() (in module trax.data.inputs)
generic_text_dataset_preprocess_fn() (in module trax.data.tf_inputs)
get_prng() (trax.fastmath.ops.RandomBackend method)
get_t5_preprocessor_by_name() (in module trax.data.tf_inputs)
global_device_count() (in module trax.fastmath.ops)
GlorotNormalInitializer() (in module trax.layers.initializers)
GlorotUniformInitializer() (in module trax.layers.initializers)
Glu() (in module trax.layers.activation_fns)
grad() (in module trax.fastmath.ops)
GRU() (in module trax.layers.rnn)
GRUCell (class in trax.layers.rnn)
GRULM() (in module trax.models.rnn)
H
HardSigmoid() (in module trax.layers.activation_fns)
HardTanh() (in module trax.layers.activation_fns)
has_backward (trax.layers.base.Layer attribute)
(trax.layers.research.efficient_attention.EfficientAttentionBase attribute)
(trax.layers.research.efficient_attention.LSHSelfAttention attribute)
(trax.layers.research.efficient_attention.PureLSHSelfAttention attribute)
(trax.layers.research.efficient_attention.SelfAttention attribute)
(trax.layers.reversible.ReversibleLayer attribute)
hash_vecs() (in module trax.layers.research.efficient_attention)
hash_vectors() (trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
HEAVY_BALL (trax.optimizers.sm3.MomentumType attribute)
history (trax.supervised.training.Loop attribute)
I
IdentityBlock() (in module trax.models.resnet)
index_add() (in module trax.fastmath.ops)
index_max() (in module trax.fastmath.ops)
index_min() (in module trax.fastmath.ops)
index_update() (in module trax.fastmath.ops)
InfinitePositionalEncoding (class in trax.layers.research.position_encodings)
init() (trax.layers.acceleration.Accelerate method)
(trax.layers.base.Layer method)
(trax.optimizers.adafactor.Adafactor method)
(trax.optimizers.adam.Adam method)
(trax.optimizers.base.Optimizer method)
(trax.optimizers.base.SGD method)
(trax.optimizers.momentum.Momentum method)
(trax.optimizers.rms_prop.RMSProp method)
(trax.optimizers.sm3.SM3 method)
init_from_file() (trax.layers.base.Layer method)
(trax.rl.training.Agent method)
init_host_and_devices() (in module trax.supervised.training)
init_weights_and_state() (trax.layers.activation_fns.ThresholdedLinearUnit method)
(trax.layers.attention.DotProductCausalAttention method)
(trax.layers.attention.PositionalEncoding method)
(trax.layers.base.Layer method)
(trax.layers.combinators.BatchLeadingAxes method)
(trax.layers.combinators.Cache method)
(trax.layers.combinators.Cond method)
(trax.layers.combinators.Parallel method)
(trax.layers.combinators.Scan method)
(trax.layers.combinators.Serial method)
(trax.layers.convolution.Conv method)
(trax.layers.core.Dense method)
(trax.layers.core.Dropout method)
(trax.layers.core.Embedding method)
(trax.layers.core.LocallyConnected1d method)
(trax.layers.core.SummaryImage method)
(trax.layers.core.SummaryScalar method)
(trax.layers.core.Weights method)
(trax.layers.normalization.BatchNorm method)
(trax.layers.normalization.FilterResponseNorm method)
(trax.layers.normalization.LayerNorm method)
(trax.layers.research.efficient_attention.EfficientAttentionBase method)
(trax.layers.research.efficient_attention.LSHFF method)
(trax.layers.research.efficient_attention.LSHSelfAttention method)
(trax.layers.research.efficient_attention.MixedLSHSelfAttention method)
(trax.layers.research.efficient_attention.PureLSHSelfAttention method)
(trax.layers.research.efficient_attention.SelfAttention method)
(trax.layers.research.position_encodings.AxialPositionalEncoding method)
(trax.layers.research.position_encodings.FixedBasePositionalEncoding method)
(trax.layers.research.position_encodings.InfinitePositionalEncoding method)
(trax.layers.research.position_encodings.SinCosPositionalEncoding method)
(trax.layers.research.position_encodings.TimeBinPositionalEncoding method)
(trax.layers.reversible.ReversibleHalfResidual method)
(trax.layers.rnn.GRUCell method)
(trax.layers.rnn.LSTMCell method)
(trax.models.research.bert.AddBias method)
(trax.models.research.bert.PretrainedBERT method)
InitializerFromFile() (in module trax.layers.initializers)
InnerSRUCell() (in module trax.layers.rnn)
input_dtype (trax.data.inputs.Inputs attribute)
input_shape (trax.data.inputs.Inputs attribute)
Inputs (class in trax.data.inputs)
inputs_from_stack() (in module trax.layers.combinators)
Interleave() (in module trax.rl.serialization_utils)
is_backend() (in module trax.fastmath.ops)
is_chief (trax.supervised.training.Loop attribute)
J
JAX (trax.fastmath.ops.Backend attribute)
jit() (in module trax.fastmath.ops)
jit_forward() (in module trax.layers.acceleration)
joint_loss (trax.rl.actor_critic_joint.A2CJoint attribute)
(trax.rl.actor_critic_joint.AWRJoint attribute)
(trax.rl.actor_critic_joint.ActorCriticJointAgent attribute)
(trax.rl.actor_critic_joint.PPOJoint attribute)
K
KaimingNormalInitializer() (in module trax.layers.initializers)
KaimingUniformInitializer() (in module trax.layers.initializers)
L
l2_norm() (in module trax.optimizers.base)
L2Loss() (in module trax.layers.metrics)
last_observation (trax.rl.task.Trajectory attribute)
Layer (class in trax.layers.base)
LayerError
LayerNorm (class in trax.layers.normalization)
LayerNormSquash() (in module trax.rl.normalization)
LeakyRelu() (in module trax.layers.activation_fns)
LeCunNormalInitializer() (in module trax.layers.initializers)
LeCunUniformInitializer() (in module trax.layers.initializers)
length_normalized() (in module trax.layers.research.efficient_attention)
lm1b_preprocess() (in module trax.data.tf_inputs)
lm_token_preprocessing() (in module trax.data.tf_inputs)
load_checkpoint() (trax.supervised.training.Loop method)
load_data_counters() (in module trax.data.inputs)
local_device_count() (in module trax.fastmath.ops)
LocallyConnected1d (class in trax.layers.core)
Log() (in module trax.data.inputs)
(in module trax.layers.activation_fns)
log_gaussian_diag_pdf() (in module trax.layers.core)
log_gaussian_pdf() (in module trax.layers.core)
log_prob() (trax.rl.distributions.Distribution method)
log_probs_mean (trax.rl.actor_critic_joint.ActorCriticJointAgent attribute)
log_softmax() (in module trax.layers.core)
log_summary() (trax.supervised.training.Loop method)
LogLoss() (in module trax.rl.distributions)
LogProb() (trax.rl.distributions.Distribution method)
LogSoftmax() (in module trax.layers.core)
logsoftmax_sample() (in module trax.layers.core)
logsumexp() (in module trax.fastmath.ops)
LogSumExp() (in module trax.layers.core)
look_adjacent() (in module trax.layers.research.efficient_attention)
Loop (class in trax.supervised.training)
loop (trax.rl.actor_critic.LoopActorCriticAgent attribute)
(trax.rl.training.LoopPolicyAgent attribute)
LoopActorCriticAgent (class in trax.rl.actor_critic)
LoopAWR (class in trax.rl.actor_critic)
LoopPolicyAgent (class in trax.rl.training)
lower_endian_to_number() (in module trax.data.inputs)
LSHFF (class in trax.layers.research.efficient_attention)
LSHSelfAttention (class in trax.layers.research.efficient_attention)
LSTM() (in module trax.layers.rnn)
LSTMCell (class in trax.layers.rnn)
LSTMSeq2SeqAttn() (in module trax.models.rnn)
lt() (in module trax.fastmath.ops)
M
MacroAveragedFScore() (in module trax.layers.metrics)
main() (in module trax.rl_trainer)
(in module trax.trainer)
make_additional_stream() (in module trax.data.inputs)
make_inputs() (in module trax.data.inputs)
make_parallel_stream() (in module trax.data.inputs)
make_predict_model() (trax.rl.serialization_utils.SerializedModel method)
MakeZeroState() (in module trax.layers.rnn)
map() (in module trax.fastmath.ops)
mask (trax.rl.task.TimeStepBatch attribute)
mask_discount() (in module trax.rl.advantages)
mask_random_tokens() (in module trax.data.tf_inputs)
mask_self_attention() (in module trax.layers.research.efficient_attention)
MaskedSequenceAccuracy() (in module trax.layers.metrics)
Max() (in module trax.layers.core)
max_pool() (in module trax.fastmath.ops)
MaxPool() (in module trax.layers.pooling)
Mean() (in module trax.layers.core)
mean_or_pmean() (in module trax.layers.acceleration)
MergeHeads() (in module trax.layers.attention)
message (trax.layers.base.LayerError attribute)
Min() (in module trax.layers.core)
MixedLSHSelfAttention (class in trax.layers.research.efficient_attention)
MLM() (in module trax.data.inputs)
MLP() (in module trax.models.mlp)
model (trax.supervised.training.Loop attribute)
Momentum (class in trax.optimizers.momentum)
MomentumType (class in trax.optimizers.sm3)
monte_carlo() (in module trax.rl.advantages)
MultiDiscreteSpaceSerializer (class in trax.rl.space_serializer)
multifactor() (in module trax.supervised.lr_schedules)
multigaussian_loss() (in module trax.layers.core)
Multiply() (in module trax.layers.combinators)
N
n_devices (trax.supervised.training.Loop attribute)
n_in (trax.layers.base.Layer attribute)
n_inputs (trax.rl.distributions.Distribution attribute)
n_out (trax.layers.base.Layer attribute)
name (trax.layers.base.Layer attribute)
Negate() (in module trax.layers.core)
NESTEROV (trax.optimizers.sm3.MomentumType attribute)
network_policy() (in module trax.rl.training)
NeuralGPU() (in module trax.models.neural_gpu)
new_rng() (trax.supervised.training.Loop method)
NewLogProbs() (in module trax.rl.rl_layers)
no_preprocess() (in module trax.data.tf_inputs)
normal() (trax.fastmath.ops.RandomBackend method)
np_from_file() (in module trax.layers.base)
np_to_file() (in module trax.layers.base)
num_features (trax.layers.research.position_encodings.TimeBinPositionalEncoding attribute)
number_to_lower_endian() (in module trax.data.inputs)
numpy (in module trax.fastmath.ops)
NUMPY (trax.fastmath.ops.Backend attribute)
NumpyBackend (class in trax.fastmath.ops)
O
observation (trax.rl.task.TimeStepBatch attribute)
observation_serializer (trax.rl.serialization_utils.SerializedModel attribute)
on_accelerator() (in module trax.layers.acceleration)
on_cpu() (in module trax.layers.acceleration)
on_policy (trax.rl.actor_critic.A2C attribute)
(trax.rl.actor_critic.AWR attribute)
(trax.rl.actor_critic.ActorCriticAgent attribute)
(trax.rl.actor_critic.LoopAWR attribute)
(trax.rl.actor_critic.LoopActorCriticAgent attribute)
(trax.rl.actor_critic.PPO attribute)
(trax.rl.actor_critic.SamplingAWR attribute)
(trax.rl.actor_critic_joint.A2CJoint attribute)
(trax.rl.actor_critic_joint.PPOJoint attribute)
one_hot() (in module trax.layers.core)
opt_params (trax.optimizers.base.Optimizer attribute)
Optimizer (class in trax.optimizers.base)
OrthogonalInitializer() (in module trax.layers.initializers)
output_dir (trax.supervised.training.Loop attribute)
output_signature() (trax.layers.base.Layer method)
outputs_onto_stack() (in module trax.layers.combinators)
P
pad_dataset_to_length() (in module trax.data.tf_inputs)
pad_to_max_dims() (in module trax.data.inputs)
PaddingMask() (in module trax.layers.attention)
PadToLength() (in module trax.data.inputs)
Parallel (class in trax.layers.combinators)
Parallel() (in module trax.data.inputs)
ParametricRelu() (in module trax.layers.activation_fns)
permute_via_gather() (in module trax.layers.research.efficient_attention)
permute_via_sort() (in module trax.layers.research.efficient_attention)
pickle_to_file() (in module trax.supervised.training)
play() (in module trax.rl.task)
pmap() (in module trax.fastmath.ops)
Policy() (in module trax.models.rl)
policy() (trax.rl.actor_critic.LoopActorCriticAgent method)
(trax.rl.actor_critic_joint.ActorCriticJointAgent method)
(trax.rl.training.Agent method)
(trax.rl.training.DQN method)
(trax.rl.training.ExpertIteration method)
(trax.rl.training.PolicyAgent method)
(trax.rl.training.PolicyGradient method)
(trax.rl.training.ValueAgent method)
policy_batches_stream() (trax.rl.actor_critic.ActorCriticAgent method)
(trax.rl.actor_critic.SamplingAWR method)
(trax.rl.training.PolicyAgent method)
policy_inputs() (trax.rl.actor_critic.ActorCriticAgent method)
(trax.rl.actor_critic.AdvantageBasedActorCriticAgent method)
policy_loss (trax.rl.actor_critic.AdvantageBasedActorCriticAgent attribute)
(trax.rl.actor_critic.SamplingAWR attribute)
(trax.rl.training.PolicyAgent attribute)
policy_loss_given_log_probs (trax.rl.actor_critic.A2C attribute)
(trax.rl.actor_critic.AWR attribute)
(trax.rl.actor_critic.AdvantageBasedActorCriticAgent attribute)
(trax.rl.actor_critic.PPO attribute)
policy_metrics (trax.rl.actor_critic.AdvantageBasedActorCriticAgent attribute)
(trax.rl.actor_critic.SamplingAWR attribute)
(trax.rl.training.PolicyAgent attribute)
PolicyAgent (class in trax.rl.training)
PolicyAndValue() (in module trax.models.rl)
PolicyGradient (class in trax.rl.training)
PositionalEncoding (class in trax.layers.attention)
PPO (class in trax.rl.actor_critic)
ppo_objective (trax.rl.actor_critic_joint.PPOJoint attribute)
ppo_objective_mean (trax.rl.actor_critic_joint.PPOJoint attribute)
PPOJoint (class in trax.rl.actor_critic_joint)
PPOObjective() (in module trax.rl.rl_layers)
preferred_move (trax.rl.actor_critic_joint.ActorCriticJointAgent attribute)
PreferredMove() (in module trax.rl.rl_layers)
Prefetch() (in module trax.data.inputs)
PrefixLM() (in module trax.data.inputs)
PretrainedBERT (class in trax.models.research.bert)
PrintShape() (in module trax.layers.core)
probs_ratio_mean (trax.rl.actor_critic_joint.PPOJoint attribute)
ProbsRatio() (in module trax.rl.rl_layers)
process_single_mathqa_example() (in module trax.data.tf_inputs)
psum() (in module trax.fastmath.ops)
pure_fn() (trax.layers.acceleration.Accelerate method)
(trax.layers.base.Layer method)
PureAttention (class in trax.layers.attention)
PureLayer (class in trax.layers.base)
PureLSHSelfAttention (class in trax.layers.research.efficient_attention)
PureLSHSelfAttentionWrapper (class in trax.layers.research.efficient_attention)
Q
Quality() (in module trax.models.rl)
R
randint() (trax.fastmath.ops.RandomBackend method)
random_inputs() (in module trax.data.inputs)
random_number_lower_endian() (in module trax.data.inputs)
random_spans_noise_mask() (in module trax.data.inputs)
RandomBackend (class in trax.fastmath.ops)
RandomNormalInitializer() (in module trax.layers.initializers)
RandomUniform (class in trax.layers.core)
RandomUniformInitializer() (in module trax.layers.initializers)
RawPolicy() (in module trax.rl.serialization_utils)
read_values() (in module trax.trax2keras)
Reformer() (in module trax.models.reformer.reformer)
ReformerLM() (in module trax.models.reformer.reformer)
ReformerShortenLM() (in module trax.models.reformer.reformer)
Relu() (in module trax.layers.activation_fns)
remaining_evals() (in module trax.rl.training)
remat() (in module trax.fastmath.ops)
replace() (trax.shapes.ShapeDtype method)
replicate_state() (trax.layers.acceleration.Accelerate method)
replicate_weights() (trax.layers.acceleration.Accelerate method)
representation_length (trax.rl.space_serializer.DiscreteSpaceSerializer attribute)
(trax.rl.space_serializer.MultiDiscreteSpaceSerializer attribute)
(trax.rl.space_serializer.SpaceSerializer attribute)
RepresentationMask() (in module trax.rl.serialization_utils)
reshape_by_device() (in module trax.layers.acceleration)
Residual() (in module trax.layers.combinators)
Resnet50() (in module trax.models.resnet)
return_ (trax.rl.task.TimeStepBatch attribute)
returns_mean (trax.rl.training.ValueAgent attribute)
reverse() (trax.layers.reversible.ReversibleConcatenatePair method)
(trax.layers.reversible.ReversibleHalfResidual method)
(trax.layers.reversible.ReversibleLayer method)
(trax.layers.reversible.ReversiblePrintShape method)
(trax.layers.reversible.ReversibleReshape method)
(trax.layers.reversible.ReversibleSelect method)
(trax.layers.reversible.ReversibleSerial method)
reverse_and_grad() (trax.layers.reversible.ReversibleHalfResidual method)
(trax.layers.reversible.ReversibleLayer method)
(trax.layers.reversible.ReversibleSerial method)
ReversibleConcatenatePair (class in trax.layers.reversible)
ReversibleHalfResidual (class in trax.layers.reversible)
ReversibleLayer (class in trax.layers.reversible)
ReversiblePrintShape (class in trax.layers.reversible)
ReversibleReshape (class in trax.layers.reversible)
ReversibleSelect (class in trax.layers.reversible)
ReversibleSerial (class in trax.layers.reversible)
ReversibleSwap() (in module trax.layers.reversible)
reward (trax.rl.task.TimeStepBatch attribute)
RMSProp (class in trax.optimizers.rms_prop)
rng (trax.layers.base.Layer attribute)
RNNLM() (in module trax.models.rnn)
run() (trax.rl.training.Agent method)
(trax.supervised.training.Loop method)
run_evals() (trax.supervised.training.Loop method)
running_mean_and_variance_get_count() (in module trax.rl.normalization)
running_mean_and_variance_get_mean() (in module trax.rl.normalization)
running_mean_and_variance_get_variance() (in module trax.rl.normalization)
running_mean_and_variance_init() (in module trax.rl.normalization)
running_mean_and_variance_update() (in module trax.rl.normalization)
running_mean_get_count() (in module trax.rl.normalization)
running_mean_get_mean() (in module trax.rl.normalization)
running_mean_init() (in module trax.rl.normalization)
running_mean_update() (in module trax.rl.normalization)
S
sample() (trax.rl.distributions.Distribution method)
SamplingAWR (class in trax.rl.actor_critic)
SamplingAWRLoss() (in module trax.rl.actor_critic)
SaturationCost() (in module trax.models.neural_gpu)
save_checkpoint() (trax.supervised.training.Loop method)
save_data_counters() (in module trax.data.inputs)
save_gin() (trax.rl.training.Agent method)
save_to_file() (trax.layers.base.Layer method)
(trax.rl.training.Agent method)
ScaledInitializer() (in module trax.layers.initializers)
Scan (class in trax.layers.combinators)
scan() (in module trax.fastmath.ops)
Select() (in module trax.layers.combinators)
SelfAttention (class in trax.layers.research.efficient_attention)
Selu() (in module trax.layers.activation_fns)
sentencepiece_tokenize() (in module trax.data.tf_inputs)
SentencePieceTokenize() (in module trax.data.tf_inputs)
seq_model_state (trax.rl.serialization_utils.SerializedModel attribute)
seq_model_weights (trax.rl.serialization_utils.SerializedModel attribute)
sequence_copy_inputs() (in module trax.data.inputs)
SequenceAccuracy() (in module trax.layers.metrics)
Serial (class in trax.layers.combinators)
Serial() (in module trax.data.inputs)
Serialize() (in module trax.rl.serialization_utils)
serialize() (trax.rl.space_serializer.DiscreteSpaceSerializer method)
(trax.rl.space_serializer.MultiDiscreteSpaceSerializer method)
(trax.rl.space_serializer.SpaceSerializer method)
SerializedModel (class in trax.rl.serialization_utils)
SerializedPolicy() (in module trax.rl.serialization_utils)
SerialWithSideOutputs() (in module trax.layers.combinators)
set_backend() (in module trax.fastmath.ops)
SGD (class in trax.optimizers.base)
shape (trax.shapes.ShapeDtype attribute)
ShapeDtype (class in trax.shapes)
shard() (in module trax.layers.base)
sharpened_network_policy() (in module trax.rl.training)
ShiftRight() (in module trax.layers.attention)
Shuffle() (in module trax.data.inputs)
shuffle() (in module trax.data.inputs)
sigmoid() (in module trax.fastmath.ops)
Sigmoid() (in module trax.layers.activation_fns)
signature() (in module trax.shapes)
significance_map (trax.rl.space_serializer.DiscreteSpaceSerializer attribute)
(trax.rl.space_serializer.MultiDiscreteSpaceSerializer attribute)
(trax.rl.space_serializer.SpaceSerializer attribute)
SignificanceWeights() (in module trax.rl.serialization_utils)
simple_sequence_copy_inputs() (in module trax.data.inputs)
SinCosPositionalEncoding (class in trax.layers.research.position_encodings)
sine_inputs() (in module trax.data.inputs)
single_op_to_python_command() (in module trax.data.tf_inputs)
slots (trax.optimizers.base.Optimizer attribute)
SM3 (class in trax.optimizers.sm3)
SmoothL1Loss() (in module trax.layers.metrics)
Softmax() (in module trax.layers.core)
Softplus() (in module trax.layers.activation_fns)
sort_key_val() (in module trax.fastmath.ops)
space_type (trax.rl.space_serializer.DiscreteSpaceSerializer attribute)
(trax.rl.space_serializer.MultiDiscreteSpaceSerializer attribute)
(trax.rl.space_serializer.SpaceSerializer attribute)
SpaceSerializer (class in trax.rl.space_serializer)
splice_signatures() (in module trax.shapes)
Split (class in trax.layers.combinators)
split() (trax.fastmath.ops.RandomBackend method)
SplitIntoHeads() (in module trax.layers.attention)
squeeze_targets_preprocess() (in module trax.data.tf_inputs)
SRU() (in module trax.layers.rnn)
state (trax.layers.acceleration.Accelerate attribute)
(trax.layers.base.Layer attribute)
(trax.layers.combinators.Cache attribute)
(trax.layers.combinators.Scan attribute)
step (trax.supervised.training.Loop attribute)
stop_gradient() (in module trax.fastmath.ops)
StopGradient() (in module trax.layers.core)
sublayer (trax.layers.acceleration.Accelerate attribute)
(trax.layers.combinators.BatchLeadingAxes attribute)
(trax.layers.combinators.Cache attribute)
(trax.layers.combinators.Scan attribute)
sublayers (trax.layers.base.Layer attribute)
substitute_inner_policy() (in module trax.rl.serialization_utils)
substitute_inner_policy_raw() (in module trax.rl.serialization_utils)
substitute_inner_policy_serialized() (in module trax.rl.serialization_utils)
SubtractTop() (in module trax.layers.combinators)
suffix() (trax.rl.task.Trajectory method)
Sum() (in module trax.layers.core)
sum_pool() (in module trax.fastmath.ops)
SummaryImage (class in trax.layers.core)
SummaryScalar (class in trax.layers.core)
SumPool() (in module trax.layers.pooling)
Swap() (in module trax.layers.combinators)
Swish() (in module trax.layers.activation_fns)
T
t2t_problems() (in module trax.data.tf_inputs)
T5GlueEvalStream() (in module trax.data.tf_inputs)
T5GlueEvalStreamsParallel() (in module trax.data.tf_inputs)
T5GlueEvalTasks() (in module trax.data.tf_inputs)
T5GlueTrainStream() (in module trax.data.tf_inputs)
T5GlueTrainStreamsParallel() (in module trax.data.tf_inputs)
Tanh() (in module trax.layers.activation_fns)
target_dtype (trax.data.inputs.Inputs attribute)
target_shape (trax.data.inputs.Inputs attribute)
task (trax.rl.training.Agent attribute)
tasks (trax.supervised.training.Loop attribute)
td_k() (in module trax.rl.advantages)
td_lambda() (in module trax.rl.advantages)
tensor_shapes_to_shape_dtypes() (in module trax.trax2keras)
tf_init_tpu() (in module trax.trainer)
TFDS() (in module trax.data.tf_inputs)
TFNP (trax.fastmath.ops.Backend attribute)
threefry_2x32_prange() (in module trax.layers.research.position_encodings)
threefry_2x32_prf() (in module trax.layers.research.position_encodings)
ThresholdedLinearUnit (class in trax.layers.activation_fns)
ThresholdToBinary() (in module trax.layers.core)
TimeBinPositionalEncoding (class in trax.layers.research.position_encodings)
TimeSeriesModel() (in module trax.rl.serialization_utils)
TimeStepBatch (class in trax.rl.task)
timesteps (trax.rl.task.Trajectory attribute)
to_arrays() (in module trax.trax2keras)
to_list() (in module trax.layers.base)
to_np() (trax.rl.task.Trajectory method)
to_tensors() (in module trax.trax2keras)
ToFloat() (in module trax.layers.core)
Tokenize() (in module trax.data.tf_inputs)
tokenize() (in module trax.data.tf_inputs)
top_k() (in module trax.fastmath.ops)
total_return (trax.rl.task.Trajectory attribute)
train_epoch() (trax.rl.actor_critic.ActorCriticAgent method)
(trax.rl.actor_critic.LoopActorCriticAgent method)
(trax.rl.actor_critic_joint.ActorCriticJointAgent method)
(trax.rl.training.Agent method)
(trax.rl.training.LoopPolicyAgent method)
(trax.rl.training.PolicyAgent method)
(trax.rl.training.ValueAgent method)
train_eval_stream() (trax.data.inputs.Inputs method)
train_rl() (in module trax.rl_trainer)
train_stream() (trax.data.inputs.Inputs method)
Trajectory (class in trax.rl.task)
Transformer() (in module trax.models.transformer)
TransformerDecoder() (in module trax.models.transformer)
TransformerEncoder() (in module trax.models.transformer)
TransformerLM() (in module trax.models.transformer)
trax.data.inputs (module)
trax.data.tf_inputs (module)
trax.fastmath.ops (module)
trax.layers.acceleration (module)
trax.layers.activation_fns (module)
trax.layers.attention (module)
trax.layers.base (module)
trax.layers.combinators (module)
trax.layers.convolution (module)
trax.layers.core (module)
trax.layers.initializers (module)
trax.layers.metrics (module)
trax.layers.normalization (module)
trax.layers.pooling (module)
trax.layers.research.efficient_attention (module)
trax.layers.research.position_encodings (module)
trax.layers.reversible (module)
trax.layers.rnn (module)
trax.models.atari_cnn (module)
trax.models.mlp (module)
trax.models.neural_gpu (module)
trax.models.reformer.reformer (module)
trax.models.research.bert (module)
trax.models.resnet (module)
trax.models.rl (module)
trax.models.rnn (module)
trax.models.transformer (module)
trax.optimizers.adafactor (module)
trax.optimizers.adam (module)
trax.optimizers.base (module)
trax.optimizers.momentum (module)
trax.optimizers.rms_prop (module)
trax.optimizers.sm3 (module)
trax.rl.actor_critic (module)
trax.rl.actor_critic_joint (module)
trax.rl.advantages (module)
trax.rl.distributions (module)
trax.rl.normalization (module)
trax.rl.rl_layers (module)
trax.rl.serialization_utils (module)
trax.rl.space_serializer (module)
trax.rl.task (module)
trax.rl.training (module)
trax.rl_trainer (module)
trax.shapes (module)
trax.supervised.decoding (module)
trax.supervised.lr_schedules (module)
trax.supervised.training (module)
trax.trainer (module)
trax.trax2keras (module)
tree_init() (trax.optimizers.base.Optimizer method)
tree_update() (trax.optimizers.base.Optimizer method)
truncate_dataset_on_len() (in module trax.data.tf_inputs)
TruncateToLength() (in module trax.data.inputs)
U
UnBatch() (in module trax.data.inputs)
unclipped_objective_mean (trax.rl.actor_critic_joint.PPOJoint attribute)
UnclippedObjective() (in module trax.rl.rl_layers)
unflatten_weights_and_state() (in module trax.layers.base)
uniform() (trax.fastmath.ops.RandomBackend method)
UniformlySeek() (in module trax.data.inputs)
unpickle_from_file() (in module trax.supervised.training)
unshard() (in module trax.layers.base)
unshard_in_pmap() (in module trax.layers.base)
update() (trax.optimizers.adafactor.Adafactor method)
(trax.optimizers.adam.Adam method)
(trax.optimizers.base.Optimizer method)
(trax.optimizers.base.SGD method)
(trax.optimizers.momentum.Momentum method)
(trax.optimizers.rms_prop.RMSProp method)
(trax.optimizers.sm3.SM3 method)
update_weights_and_state() (trax.supervised.training.Loop method)
use_backend() (in module trax.fastmath.ops)
V
Value() (in module trax.models.rl)
value_and_grad() (in module trax.fastmath.ops)
value_batches_stream() (trax.rl.actor_critic.ActorCriticAgent method)
(trax.rl.training.DQN method)
(trax.rl.training.ValueAgent method)
value_loss (trax.rl.actor_critic_joint.ActorCriticJointAgent attribute)
(trax.rl.training.DQN attribute)
value_mean (trax.rl.actor_critic.ActorCriticAgent attribute)
(trax.rl.training.DQN attribute)
(trax.rl.training.ValueAgent attribute)
ValueAgent (class in trax.rl.training)
ValueLoss() (in module trax.rl.rl_layers)
vjp() (in module trax.fastmath.ops)
vmap() (in module trax.fastmath.ops)
vocab_size (trax.rl.space_serializer.SpaceSerializer attribute)
vocab_size() (in module trax.data.tf_inputs)
W
warmup() (in module trax.supervised.lr_schedules)
warmup_and_rsqrt_decay() (in module trax.supervised.lr_schedules)
WeightedCategoryAccuracy() (in module trax.layers.metrics)
WeightedCategoryCrossEntropy() (in module trax.layers.metrics)
WeightedFScore() (in module trax.layers.metrics)
WeightedSum() (in module trax.layers.metrics)
Weights (class in trax.layers.core)
weights (trax.layers.acceleration.Accelerate attribute)
(trax.layers.base.Layer attribute)
weights_and_state_signature() (trax.layers.base.Layer method)
WideResnet() (in module trax.models.resnet)
WideResnetBlock() (in module trax.models.resnet)
WideResnetGroup() (in module trax.models.resnet)
wmt_concat_preprocess() (in module trax.data.tf_inputs)
wmt_preprocess() (in module trax.data.tf_inputs)
wrap_policy() (in module trax.rl.serialization_utils)
Read the Docs
v: stable
Versions
latest
stable
Downloads
html
epub
On Read the Docs
Project Home
Builds
Free document hosting provided by
Read the Docs
.