trax.rl package¶
actor_critic¶
Classes for RL training in Trax.

class
trax.rl.actor_critic.
ActorCriticAgent
(task, value_model=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate='logsumexp', q_value_temperature=1.0, q_value_n_samples=1, q_value_normalization=False, offline=False, **kwargs)¶ Bases:
trax.rl.training.PolicyAgent
Trains policy and value models using actorcritic methods.
 Attrs:
 on_policy (bool): Whether the algorithm is onpolicy. Used in the data
 generators. Should be set in derived classes.

on_policy
= None¶

__init__
(task, value_model=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate='logsumexp', q_value_temperature=1.0, q_value_n_samples=1, q_value_normalization=False, offline=False, **kwargs)¶ Configures the actorcritic trainer.
Parameters:  task – RLTask instance to use.
 value_model – Model to use for the value function.
 value_optimizer – Optimizer to train the value model.
 value_lr_schedule – lr schedule for value model training.
 value_batch_size – Batch size for value model training.
 value_train_steps_per_epoch – Number of steps are we using to train the value model in each epoch.
 value_evals_per_epoch – Number of value trainer evaluations per RL epoch. Every evaluation, we also synchronize the weights of the target network.
 value_eval_steps – Number of value trainer steps per evaluation; only affects metric reporting.
 n_shared_layers – Number of layers to share between value and policy models.
 added_policy_slice_length – How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by max_slice_length in **kwargs.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma).
 q_value – If True, use Qvalues as baselines.
 q_value_aggregate – How to aggregate Qvalues. Options: ‘mean’, ‘max’, ‘softmax’, ‘logsumexp’.
 q_value_temperature – Temperature parameter for the ‘softmax’ and ‘logsumexp’ aggregation methods.
 q_value_n_samples – Number of samples to average over when calculating baselines based on Qvalues.
 q_value_normalization – How to normalize Qvalues before aggregation. Allowed values: ‘std’, ‘abs’, None. If None, don’t normalize.
 offline – Whether to train in offline mode. This matters for some algorithms, e.g. QWR.
 **kwargs – Arguments for PolicyAgent superclass.

value_mean
¶ The mean value of the value function.

value_batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

policy_inputs
(trajectory, values)¶ Create inputs to policy model from a TimeStepBatch and values.
Parameters:  trajectory – a TimeStepBatch, the trajectory to create inputs from
 values – a numpy array: value function computed on trajectory
Returns: a tuple of numpy arrays of the form (inputs, x1, x2, …) that will be passed to the policy model; policy model will compute outputs from inputs and (outputs, x1, x2, …) will be passed to self.policy_loss which should be overridden accordingly.

policy_batches_stream
()¶ Use the RLTask self._task to create inputs to the policy model.

train_epoch
()¶ Trains RL for one epoch.

close
()¶

class
trax.rl.actor_critic.
AdvantageBasedActorCriticAgent
(task, advantage_estimator=<function td_lambda>, advantage_normalization=True, advantage_normalization_epsilon=1e05, advantage_normalization_factor=1.0, added_policy_slice_length=0, **kwargs)¶ Bases:
trax.rl.actor_critic.ActorCriticAgent
Base class for advantagebased actorcritic algorithms.

__init__
(task, advantage_estimator=<function td_lambda>, advantage_normalization=True, advantage_normalization_epsilon=1e05, advantage_normalization_factor=1.0, added_policy_slice_length=0, **kwargs)¶ Configures the actorcritic trainer.
Parameters:  task – RLTask instance to use.
 value_model – Model to use for the value function.
 value_optimizer – Optimizer to train the value model.
 value_lr_schedule – lr schedule for value model training.
 value_batch_size – Batch size for value model training.
 value_train_steps_per_epoch – Number of steps are we using to train the value model in each epoch.
 value_evals_per_epoch – Number of value trainer evaluations per RL epoch. Every evaluation, we also synchronize the weights of the target network.
 value_eval_steps – Number of value trainer steps per evaluation; only affects metric reporting.
 n_shared_layers – Number of layers to share between value and policy models.
 added_policy_slice_length – How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by max_slice_length in **kwargs.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma).
 q_value – If True, use Qvalues as baselines.
 q_value_aggregate – How to aggregate Qvalues. Options: ‘mean’, ‘max’, ‘softmax’, ‘logsumexp’.
 q_value_temperature – Temperature parameter for the ‘softmax’ and ‘logsumexp’ aggregation methods.
 q_value_n_samples – Number of samples to average over when calculating baselines based on Qvalues.
 q_value_normalization – How to normalize Qvalues before aggregation. Allowed values: ‘std’, ‘abs’, None. If None, don’t normalize.
 offline – Whether to train in offline mode. This matters for some algorithms, e.g. QWR.
 **kwargs – Arguments for PolicyAgent superclass.

policy_inputs
(trajectory, values)¶ Create inputs to policy model from a TimeStepBatch and values.

policy_loss_given_log_probs
¶ Policy loss given action logprobabilities.

policy_loss
¶ Policy loss.

policy_metrics
¶

advantage_mean
¶

advantage_std
¶


trax.rl.actor_critic.
every
(n_steps)¶ Returns True every n_steps, for use as *_at functions in various places.

class
trax.rl.actor_critic.
LoopActorCriticAgent
(task, model_fn, optimizer=<class 'trax.optimizers.adam.Adam'>, policy_lr_schedule=<function multifactor>, policy_n_steps_per_epoch=1000, policy_weight_fn=<function LoopActorCriticAgent.<lambda>>, value_lr_schedule=<function multifactor>, value_n_steps_per_epoch=1000, value_sync_at=<function LoopActorCriticAgent.<lambda>>, advantage_estimator=<function monte_carlo>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, margin=0, n_replay_epochs=1, **kwargs)¶ Bases:
trax.rl.training.Agent
Base class for actorcritic algorithms based on Loop.

on_policy
= None¶

__init__
(task, model_fn, optimizer=<class 'trax.optimizers.adam.Adam'>, policy_lr_schedule=<function multifactor>, policy_n_steps_per_epoch=1000, policy_weight_fn=<function LoopActorCriticAgent.<lambda>>, value_lr_schedule=<function multifactor>, value_n_steps_per_epoch=1000, value_sync_at=<function LoopActorCriticAgent.<lambda>>, advantage_estimator=<function monte_carlo>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, margin=0, n_replay_epochs=1, **kwargs)¶ Initializes LoopActorCriticAgent.
Parameters:  task – RLTask instance to use.
 model_fn – Function mode > Trax model, building a joint policy and value network.
 optimizer – Optimizer for the policy and value networks.
 policy_lr_schedule – Learning rate schedule for the policy network.
 policy_n_steps_per_epoch – Number of steps to train the policy network for in each epoch.
 policy_weight_fn – Function advantages > weights for calculating the log probability weights in policy training.
 value_lr_schedule – Learning rate schedule for the value network.
 value_n_steps_per_epoch – Number of steps to train the value network for in each epoch.
 value_sync_at – Function step > bool indicating when to synchronize the target network with the trained network in value training.
 advantage_estimator – Advantage estimator to use in policy and value training.
 batch_size – Batch size for training the networks.
 network_eval_at – Function step > bool indicating in when to evaluate the networks.
 n_eval_batches – Number of batches to compute the network evaluation metrics on.
 max_slice_length – Maximum length of a trajectory slice to train on.
 margin – Number of timesteps to add at the end of each trajectory slice for better advantage estimation.
 n_replay_epochs – Number of epochs of trajectories to store in the replay buffer.
 **kwargs – Keyword arguments forwarded to Agent.

loop
¶ Loop exposed for testing.

policy
(trajectory, temperature=1.0)¶ Policy function that allows to play using this agent.

train_epoch
()¶ Trains RL for one epoch.


class
trax.rl.actor_critic.
A2C
(task, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
Trains policy and value models using the A2C algorithm.

on_policy
= True¶

__init__
(task, entropy_coeff=0.01, **kwargs)¶ Configures the A2C Trainer.

policy_loss_given_log_probs
¶ Definition of the Advantage Actor Critic (A2C) loss.


class
trax.rl.actor_critic.
PPO
(task, epsilon=0.2, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
The Proximal Policy Optimization Algorithm aka PPO.
Trains policy and value models using the PPO algorithm.

on_policy
= True¶

__init__
(task, epsilon=0.2, entropy_coeff=0.01, **kwargs)¶ Configures the PPO Trainer.

policy_loss_given_log_probs
¶ Definition of the Proximal Policy Optimization loss.


trax.rl.actor_critic.
awr_weights
(advantages, beta, thresholds)¶

trax.rl.actor_critic.
awr_metrics
(beta, thresholds, preprocess_layer=None)¶

trax.rl.actor_critic.
awr_weight_stat
(stat_name, stat_fn, beta, thresholds, preprocess_layer)¶

trax.rl.actor_critic.
AWRLoss
(beta, w_max, thresholds)¶ Definition of the Advantage Weighted Regression (AWR) loss.

class
trax.rl.actor_critic.
AWR
(task, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
Trains policy and value models using AWR.

on_policy
= False¶

__init__
(task, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Configures the AWR Trainer.

policy_loss_given_log_probs
¶ Policy loss.


class
trax.rl.actor_critic.
LoopAWR
(task, model_fn, beta=1.0, w_max=20, **kwargs)¶ Bases:
trax.rl.actor_critic.LoopActorCriticAgent
Advantage Weighted Regression.

on_policy
= False¶

__init__
(task, model_fn, beta=1.0, w_max=20, **kwargs)¶ Initializes LoopActorCriticAgent.
Parameters:  task – RLTask instance to use.
 model_fn – Function mode > Trax model, building a joint policy and value network.
 optimizer – Optimizer for the policy and value networks.
 policy_lr_schedule – Learning rate schedule for the policy network.
 policy_n_steps_per_epoch – Number of steps to train the policy network for in each epoch.
 policy_weight_fn – Function advantages > weights for calculating the log probability weights in policy training.
 value_lr_schedule – Learning rate schedule for the value network.
 value_n_steps_per_epoch – Number of steps to train the value network for in each epoch.
 value_sync_at – Function step > bool indicating when to synchronize the target network with the trained network in value training.
 advantage_estimator – Advantage estimator to use in policy and value training.
 batch_size – Batch size for training the networks.
 network_eval_at – Function step > bool indicating in when to evaluate the networks.
 n_eval_batches – Number of batches to compute the network evaluation metrics on.
 max_slice_length – Maximum length of a trajectory slice to train on.
 margin – Number of timesteps to add at the end of each trajectory slice for better advantage estimation.
 n_replay_epochs – Number of epochs of trajectories to store in the replay buffer.
 **kwargs – Keyword arguments forwarded to Agent.


trax.rl.actor_critic.
SamplingAWRLoss
(beta, w_max, thresholds, reweight=False, sampled_all_discrete=False)¶ Definition of the Advantage Weighted Regression (AWR) loss.

class
trax.rl.actor_critic.
SamplingAWR
(task, beta=1.0, w_max=20.0, thresholds=None, reweight=False, **kwargs)¶ Bases:
trax.rl.actor_critic.AdvantageBasedActorCriticAgent
Trains policy and value models using Sampling AWR.

on_policy
= False¶

__init__
(task, beta=1.0, w_max=20.0, thresholds=None, reweight=False, **kwargs)¶ Configures the AWR Trainer.

policy_metrics
¶

policy_loss
¶ Policy loss.

policy_batches_stream
()¶ Use the RLTask self._task to create inputs to the policy model.

actor_critic_joint¶
Classes for RL training in Trax.

class
trax.rl.actor_critic_joint.
ActorCriticJointAgent
(task, joint_model=None, optimizer=None, lr_schedule=<function multifactor>, batch_size=64, train_steps_per_epoch=500, supervised_evals_per_epoch=1, supervised_eval_steps=1, n_trajectories_per_epoch=50, max_slice_length=1, normalize_advantages=True, output_dir=None, n_replay_epochs=1)¶ Bases:
trax.rl.training.Agent
Trains a joint policyandvalue model using actorcritic methods.

__init__
(task, joint_model=None, optimizer=None, lr_schedule=<function multifactor>, batch_size=64, train_steps_per_epoch=500, supervised_evals_per_epoch=1, supervised_eval_steps=1, n_trajectories_per_epoch=50, max_slice_length=1, normalize_advantages=True, output_dir=None, n_replay_epochs=1)¶ Configures the joint trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 joint_model – Trax layer, representing the joint policy and value model.
 optimizer – the optimizer to use to train the joint model.
 lr_schedule – learning rate schedule to use to train the joint model/.
 batch_size – batch size used to train the joint model.
 train_steps_per_epoch – how long to train the joint model in each RL epoch.
 supervised_evals_per_epoch – number of value trainer evaluations per RL epoch  only affects metric reporting.
 supervised_eval_steps – number of value trainer steps per evaluation  only affects metric reporting.
 n_trajectories_per_epoch – how many trajectories to collect per epoch.
 max_slice_length – the maximum length of trajectory slices to use.
 normalize_advantages – if True, then normalize advantages  currently implemented only in PPO.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 n_replay_epochs – how many last epochs to take into the replay buffer; > 1 only makes sense for offpolicy algorithms.

close
()¶

batches_stream
()¶ Use self.task to create inputs to the policy model.

joint_loss
¶ Joint policy and value loss layer.

advantage_mean
¶ Mean of advantages.

advantage_norm
¶ Norm of advantages.

value_loss
¶ Value loss  so far generic for all A2C.

explained_variance
¶ Explained variance metric.

log_probs_mean
¶ Mean of log_probs aka dist_inputs.

preferred_move
¶ Preferred move  the mean of selected moves.

policy
(trajectory, temperature=1.0)¶ Chooses an action to play after a trajectory.

train_epoch
()¶ Trains RL for one epoch.


class
trax.rl.actor_critic_joint.
PPOJoint
(task, epsilon=0.2, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic_joint.ActorCriticJointAgent
The Proximal Policy Optimization Algorithm aka PPO.
Trains policy and value models using the PPO algortithm.

on_policy
= True¶

__init__
(task, epsilon=0.2, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Configures the PPO Trainer.

batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

joint_loss
¶ Joint policy and value loss.

probs_ratio_mean
¶ Joint policy and value loss layer.

clip_fraction
¶ Joint policy and value loss layer.

entropy_loss
¶ Entropy layer.

approximate_kl_divergence
¶ Approximate KL divergence.

unclipped_objective_mean
¶

clipped_objective_mean
¶

ppo_objective
¶ PPO objective with local parameters.

ppo_objective_mean
¶ PPO objective mean.


class
trax.rl.actor_critic_joint.
A2CJoint
(task, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Bases:
trax.rl.actor_critic_joint.ActorCriticJointAgent
The A2C algorithm.
Trains policy and value models using the A2C algortithm.

on_policy
= True¶

__init__
(task, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)¶ Configures the A2C Trainer.

batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

joint_loss
¶ Joint policy and value loss.

entropy_loss
¶ Entropy layer.

approximate_kl_divergence
¶ Approximate KL divergence.

a2c_objective
¶ A2C objective with local parameters.

a2c_objective_mean
¶ A2C objective mean.


class
trax.rl.actor_critic_joint.
AWRJoint
(task, value_loss_coeff=0.1, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Bases:
trax.rl.actor_critic_joint.ActorCriticJointAgent
Trains a joint policyandvalue model using AWR.

__init__
(task, value_loss_coeff=0.1, beta=1.0, w_max=20.0, thresholds=None, **kwargs)¶ Configures the joint AWR Trainer.

batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

joint_loss
¶ Joint policy and value loss.

advantages¶
RL advantage estimators.

trax.rl.advantages.
mask_discount
(discount, discount_mask)¶ Computes a discount to apply at a given timestep, based on the mask.

trax.rl.advantages.
discounted_returns
(rewards, gammas)¶ Computes discounted returns for a trajectory or a batch of them.

trax.rl.advantages.
monte_carlo
(gamma, margin)¶ Calculate Monte Carlo advantage.
We assume the values are a tensor of shape [batch_size, length] and this is the same shape as rewards and returns.
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].

trax.rl.advantages.
td_k
(gamma, margin)¶ Calculate TDk advantage.
The k parameter is assumed to be the same as margin.
We calculate advantage(s_i) as:
gamma^n_steps * value(s_{i + n_steps})  value(s_i) + discounted_rewardswhere discounted_rewards is the sum of rewards in these steps with discounting by powers of gamma.
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].

trax.rl.advantages.
td_lambda
(gamma, margin, lambda_=0.95)¶ Calculate TDlambda advantage.
The estimated return is an exponentiallyweighted average of different TDk returns.
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
 lambda – float, the lambda parameter of TDlambda
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].

trax.rl.advantages.
gae
(gamma, margin, lambda_=0.95)¶ Calculate Generalized Advantage Estimation.
Calculate state values bootstrapping off the following state values  Generalized Advantage Estimation https://arxiv.org/abs/1506.02438
Parameters:  gamma – float, gamma parameter for TD from the underlying task
 margin – number of extra steps in the sequence
 lambda – float, the lambda parameter of GAE
Returns: Function (rewards, returns, values, dones) > advantages, where advantages advantages is an array of shape [batch_size, length  margin].
distributions¶
Probability distributions for RL training in Trax.

class
trax.rl.distributions.
Distribution
¶ Bases:
object
Abstract class for parametrized probability distributions.

n_inputs
¶ Returns the number of inputs to the distribution (i.e. parameters).

sample
(inputs, temperature=1.0)¶ Samples a point from the distribution.
Parameters:  inputs (jnp.ndarray) – Distribution inputs. Shape is subclassspecific. Broadcasts along the first dimensions. For example, in the categorical distribution parameter shape is (C,), where C is the number of categories. If (B, C) is passed, the object will represent a batch of B categorical distributions with different parameters.
 temperature – sampling temperature; 1.0 is default, at 0.0 chooses the most probable (preferred) action.
Returns: Sampled point of shape dependent on the subclass and on the shape of inputs.

log_prob
(inputs, point)¶ Retrieves log probability (or log probability density) of a point.
Parameters:  inputs (jnp.ndarray) – Distribution parameters.
 point (jnp.ndarray) – Point from the distribution. Shape should be consistent with inputs.
Returns: Array of log probabilities of points in the distribution.

LogProb
()¶ Builds a log probability layer for this distribution.


trax.rl.distributions.
create_distribution
(space)¶ Creates a Distribution for the given Gym space.

trax.rl.distributions.
LogLoss
(distribution, **unused_kwargs)¶ Builds a log loss layer for a Distribution.
normalization¶
Normalization helpers.

trax.rl.normalization.
running_mean_init
(shape, fill_value=0)¶

trax.rl.normalization.
running_mean_update
(x, state)¶

trax.rl.normalization.
running_mean_get_mean
(state)¶

trax.rl.normalization.
running_mean_get_count
(state)¶

trax.rl.normalization.
running_mean_and_variance_init
(shape)¶

trax.rl.normalization.
running_mean_and_variance_update
(x, state)¶

trax.rl.normalization.
running_mean_and_variance_get_mean
(state)¶

trax.rl.normalization.
running_mean_and_variance_get_count
(state)¶

trax.rl.normalization.
running_mean_and_variance_get_variance
(state)¶

trax.rl.normalization.
LayerNormSquash
(mode, width=128)¶ DenseLayerNormTanh normalizer inspired by ACME.
rl_layers¶
A number of RL functions intended to be later wrapped as Trax layers.
Wrapping happens with help of the function tl.Fn.

trax.rl.rl_layers.
ValueLoss
(values, returns, value_loss_coeff)¶ Definition of the loss of the value function.

trax.rl.rl_layers.
ExplainedVariance
(values, returns)¶ Definition of explained variance  an approach from OpenAI baselines.

trax.rl.rl_layers.
PreferredMove
(dist_inputs, sample)¶ Definition of the preferred move.

trax.rl.rl_layers.
NewLogProbs
(dist_inputs, actions, log_prob_fun)¶ Given distribution and actions calculate log probs.

trax.rl.rl_layers.
EntropyLoss
(dist_inputs, distribution, coeff)¶ Definition of the Entropy Layer.

trax.rl.rl_layers.
ProbsRatio
(dist_inputs, actions, old_log_probs, log_prob_fun)¶ Probability Ratio from the PPO algorithm.

trax.rl.rl_layers.
ApproximateKLDivergence
(dist_inputs, actions, old_log_probs, log_prob_fun)¶ Probability Ratio from the PPO algorithm.

trax.rl.rl_layers.
UnclippedObjective
(probs_ratio, advantages)¶ Unclipped Objective from the PPO algorithm.

trax.rl.rl_layers.
ClippedObjective
(probs_ratio, advantages, epsilon)¶ Clipped Objective from the PPO algorithm.

trax.rl.rl_layers.
PPOObjective
(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages)¶ PPO Objective.

trax.rl.rl_layers.
A2CObjective
(dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun, normalize_advantages)¶ Definition of the Advantage Actor Critic (A2C) loss.
serialization_utils¶
Utilities for serializing trajectories into discrete sequences.

trax.rl.serialization_utils.
Serialize
(serializer)¶ Layer that serializes a given array.

trax.rl.serialization_utils.
Interleave
()¶ Layer that interleaves and flattens two serialized sequences.
The first sequence can be longer by 1 than the second one. This is so we can interleave sequences of observations and actions, when there’s 1 extra observation at the end.
For serialized sequences [[x_1_1, …, x_1_R1], …, [x_L1_1, …, x_L1_R1]] and [[y_1_1, …, y_1_R2], …, [y_L2_1, …, y_L2_R2]], where L1 = L2 + 1, the result is [x_1_1, …, x_1_R1, y_1_1, …, y_1_R2, …, x_L2_1, …, x_L2_R1, y_L2_1, …, y_L2_R2, x_L1_1, …, x_L1_R1] (batch dimension omitted for clarity).
The layer inputs are a sequence pair of shapes (B, L1, R1) and (B, L2, R2), where B is batch size, L* is the length of the sequence and R* is the representation length of each element in the sequence.
Returns: Layer that interleaves sequence of shape (B, L1 * R1 + L2 * R2).

trax.rl.serialization_utils.
Deinterleave
(x_size, y_size)¶ Layer that does the inverse of Interleave.

trax.rl.serialization_utils.
RepresentationMask
(serializer)¶ Upsamples a mask to cover the serialized representation.

trax.rl.serialization_utils.
SignificanceWeights
(serializer, decay)¶ Multiplies a binary mask with a symbol significance mask.

class
trax.rl.serialization_utils.
SerializedModel
(seq_model, observation_serializer, action_serializer, significance_decay, mode='train')¶ Bases:
trax.layers.combinators.Serial
Wraps a world model in serialization machinery for training.
The resulting model takes as input the observation and action sequences, serializes them and interleaves into one sequence, which is fed into a given autoregressive model. The resulting logit sequence is deinterleaved into observations and actions, and the observation logits are returned together with computed symbol significance weights.
The model has a signature (obs, act, obs, mask) > (obs_logits, obs_repr, weights), where obs are observations (the second occurrence is the target), act are actions, mask is the observation mask, obs_logits are logits of the output observation representation, obs_repr is the target observation representation and weights are the target weights.

__init__
(seq_model, observation_serializer, action_serializer, significance_decay, mode='train')¶ Initializes SerializedModel.
Parameters:  seq_model – Trax autoregressive model taking as input a sequence of symbols and outputting a sequence of symbol logits.
 observation_serializer – Serializer to use for observations.
 action_serializer – Serializer to use for actions.
 significance_decay – Float from (0, 1) for exponential weighting of symbols in the representation.
 mode – ‘train’ or ‘eval’.

observation_serializer
¶

action_serializer
¶

make_predict_model
()¶ Returns a predictmode model of the same architecture.

seq_model_weights
¶ Extracts the weights of the underlying sequence model.

seq_model_state
¶ Extracts the state of the underlying sequence model.


trax.rl.serialization_utils.
TimeSeriesModel
(seq_model, low=0.0, high=1.0, precision=2, vocab_size=64, significance_decay=0.7, mode='train')¶ Simplified constructor for SerializedModel, for time series prediction.

trax.rl.serialization_utils.
RawPolicy
(seq_model, n_controls, n_actions)¶ Wraps a sequence model in a policy interface.
The resulting model takes as input observation anc action sequences, but only uses the observations. Adds output heads for action logits and value predictions.
Parameters:  seq_model – Trax sequence model taking as input and outputting a sequence of continuous vectors.
 n_controls – Number of controls.
 n_actions – Number of action categories in each control.
Returns: obs: (batch_size, length + 1, obs_depth) act: (batch_size, length, n_controls) act_logits: (batch_size, length, n_controls, n_actions) values: (batch_size, length)
Return type: A model of signature (obs, act) > (act_logits, values), with shapes

trax.rl.serialization_utils.
substitute_inner_policy_raw
(raw_policy, inner_policy)¶ Substitutes the weights/state of the inner model in a RawPolicy.

trax.rl.serialization_utils.
SerializedPolicy
(seq_model, n_controls, n_actions, observation_serializer, action_serializer)¶ Wraps a policy in serialization machinery for training.
The resulting model takes as input observation and action sequences, and serializes them into one sequence similar to SerializedModel, before passing to the given sequence model. Adds output heads for action logits and value predictions.
Parameters:  seq_model – Trax sequence model taking as input a sequence of symbols and outputting a sequence of continuous vectors.
 n_controls – Number of controls.
 n_actions – Number of action categories in each control.
 observation_serializer – Serializer to use for observations.
 action_serializer – Serializer to use for actions.
Returns: A model of signature (obs, act) > (act_logits, values), same as in RawPolicy.

trax.rl.serialization_utils.
substitute_inner_policy_serialized
(serialized_policy, inner_policy)¶ Substitutes the weights/state of the inner model in a SerializedPolicy.

trax.rl.serialization_utils.
analyze_action_space
(action_space)¶ Returns the number of controls and actions for an action space.

trax.rl.serialization_utils.
wrap_policy
(seq_model, observation_space, action_space, vocab_size)¶ Wraps a sequence model in either RawPolicy or SerializedPolicy.
Parameters:  seq_model – Trax sequence model.
 observation_space – Gym observation space.
 action_space – Gym action space.
 vocab_size – Either the number of symbols for a serialized policy, or None.
Returns: RawPolicy if vocab_size is None, else SerializedPolicy.

trax.rl.serialization_utils.
substitute_inner_policy
(wrapped_policy, inner_policy, vocab_size)¶ Substitutes the inner weights/state in a {Raw,Serialized}Policy.
Parameters:  wrapped_policy (pytree) – Weights or state of a wrapped policy.
 inner_policy (pytree) – Weights or state of an inner policy.
 vocab_size (int or None) – Vocabulary size of a serialized policy, or None in case of a raw policy.
Returns:  New weights or state of wrapped_policy, with the inner weights/state
copied from inner_policy.
space_serializer¶
Serialization of elements of Gym spaces into discrete sequences.

class
trax.rl.space_serializer.
SpaceSerializer
(space, vocab_size)¶ Bases:
object
Base class for Gym space serializers.
 Attrs:
 space_type: (type) Gym space class that this SpaceSerializer corresponds
 to. Should be defined in subclasses.
 representation_length: (int) Number of symbols in the representation of
 every element of the space.
 significance_map: (np.ndarray) Integer array of the same size as the
 discrete representation, where elements describe the significance of symbols, e.g. in fixedprecision encoding. 0 is the most significant symbol, 1 the second most significant etc.

space_type
= None¶

representation_length
= None¶

significance_map
= None¶

__init__
(space, vocab_size)¶ Creates a SpaceSerializer.
Subclasses should retain the signature.
Parameters:  space – (gym.Space) Gym space of type self.space_type.
 vocab_size – (int) Number of symbols in the vocabulary.

vocab_size
¶

serialize
(data)¶ Serializes a batch of space elements into discrete sequences.
Should be defined in subclasses.
Parameters: data – A batch of batch_size elements of the Gym space to be serialized. Returns: int32 array of shape (batch_size, self.representation_length).

deserialize
(representation)¶ Deserializes a batch of discrete sequences into space elements.
Should be defined in subclasses.
Parameters: representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized. Returns: A batch of batch_size deserialized elements of the Gym space.

trax.rl.space_serializer.
create
(space, vocab_size)¶ Creates a SpaceSerializer for the given Gym space.

class
trax.rl.space_serializer.
DiscreteSpaceSerializer
(space, vocab_size)¶ Bases:
trax.rl.space_serializer.SpaceSerializer
Serializer for gym.spaces.Discrete.
Assumes that the size of the space fits in the number of symbols.

space_type
¶ Used by autodoc_mock_imports.

representation_length
= 1¶

__init__
(space, vocab_size)¶ Creates a SpaceSerializer.
Subclasses should retain the signature.
Parameters:  space – (gym.Space) Gym space of type self.space_type.
 vocab_size – (int) Number of symbols in the vocabulary.

serialize
(data)¶ Serializes a batch of space elements into discrete sequences.
Should be defined in subclasses.
Parameters: data – A batch of batch_size elements of the Gym space to be serialized. Returns: int32 array of shape (batch_size, self.representation_length).

deserialize
(representation)¶ Deserializes a batch of discrete sequences into space elements.
Should be defined in subclasses.
Parameters: representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized. Returns: A batch of batch_size deserialized elements of the Gym space.

significance_map
¶


class
trax.rl.space_serializer.
MultiDiscreteSpaceSerializer
(space, vocab_size)¶ Bases:
trax.rl.space_serializer.SpaceSerializer
Serializer for gym.spaces.MultiDiscrete.
Assumes that the number of categories in each dimension fits in the number of symbols.

space_type
¶ Used by autodoc_mock_imports.

__init__
(space, vocab_size)¶ Creates a SpaceSerializer.
Subclasses should retain the signature.
Parameters:  space – (gym.Space) Gym space of type self.space_type.
 vocab_size – (int) Number of symbols in the vocabulary.

serialize
(data)¶ Serializes a batch of space elements into discrete sequences.
Should be defined in subclasses.
Parameters: data – A batch of batch_size elements of the Gym space to be serialized. Returns: int32 array of shape (batch_size, self.representation_length).

deserialize
(representation)¶ Deserializes a batch of discrete sequences into space elements.
Should be defined in subclasses.
Parameters: representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized. Returns: A batch of batch_size deserialized elements of the Gym space.

representation_length
¶

significance_map
¶

task¶
Classes for defining RL tasks in Trax.

class
trax.rl.task.
TimeStepBatch
(observation, action, reward, done, mask, dist_inputs, env_info, return_)¶ Bases:
tuple

action
¶ Alias for field number 1

dist_inputs
¶ Alias for field number 5

done
¶ Alias for field number 3

env_info
¶ Alias for field number 6

mask
¶ Alias for field number 4

observation
¶ Alias for field number 0

return_
¶ Alias for field number 7

reward
¶ Alias for field number 2


class
trax.rl.task.
EnvInfo
(control_mask, discount_mask)¶ Bases:
tuple

control_mask
¶ Alias for field number 0

discount_mask
¶ Alias for field number 1


class
trax.rl.task.
Trajectory
(observation)¶ Bases:
object
A trajectory of interactions with a RL environment.
Trajectories are created when interacting with an RL environment. They can be prolonged and sliced and when completed, allow to recalculate returns.

__init__
(observation)¶ Initialize self. See help(type(self)) for accurate signature.

suffix
(length)¶ Returns a Trajectory with the last length observations.

timesteps
¶

total_return
¶ Sum of all rewards in this trajectory.

last_observation
¶ Return the last observation in this trajectory.

done
¶ Returns whether the trajectory is finished.

extend
(new_observation, mask=1, **kwargs)¶ Take action in the last state, getting reward and going to new state.

calculate_returns
(gamma)¶ Calculate discounted returns.

to_np
(margin=1, timestep_to_np=None)¶ Create a tuple of numpy arrays from a given trajectory.
Parameters:  margin (int) – Number of dummy timesteps past the trajectory end to include. By default we include 1, which contains the last observation.
 timestep_to_np (callable or None) – Optional function TimeStepBatch[Any] > TimeStepBatch[np.array], converting the timestep data into numpy arrays.
Returns: TimeStepBatch, where all fields have shape (len(self) + margin  1, …).


trax.rl.task.
play
(env, policy, dm_suite=False, max_steps=None, last_observation=None)¶ Play an episode in env taking actions according to the given policy.
Environment is first reset and an from then on, a game proceeds. At each step, the policy is asked to choose an action and the environment moves forward. A Trajectory is created in that way and returns when the episode finished, which is either when env returns done or max_steps is reached.
Parameters:  env – the environment to play in, conforming to gym.Env or DeepMind suite interfaces.
 policy – a function taking a Trajectory and returning a pair consisting of an action (int or float) and the confidence in that action (float, defined as the log of the probability of taking that action).
 dm_suite – whether we are using the DeepMind suite or the gym interface
 max_steps – for how many steps to play.
 last_observation – last observation from a previous trajectory slice, used to begin a new one. Controls whether we reset the environment at the beginning  if None, resets the env and starts the slice from the observation got from reset().
Returns: a completed trajectory slice that was just played.
training¶
Classes for RL training in Trax.

class
trax.rl.training.
Agent
(task: <sphinx.ext.autodoc.importer._MockObject object at 0x7f036bb37510>, n_trajectories_per_epoch=None, n_interactions_per_epoch=None, n_eval_episodes=0, eval_steps=None, eval_temperatures=(0.0, ), only_eval=False, output_dir=None, timestep_to_np=None)¶ Bases:
object
Abstract class for RL agents, presenting the required API.

__init__
(task: <sphinx.ext.autodoc.importer._MockObject object at 0x7f036bb37510>, n_trajectories_per_epoch=None, n_interactions_per_epoch=None, n_eval_episodes=0, eval_steps=None, eval_temperatures=(0.0, ), only_eval=False, output_dir=None, timestep_to_np=None)¶ Configures the Agent.
Note that subclasses can have many more arguments, which will be configured using defaults and gin. But task and output_dir are passed explicitly.
Parameters:  task – RLTask instance, which defines the environment to train on.
 n_trajectories_per_epoch – How many new trajectories to collect in each epoch.
 n_interactions_per_epoch – How many interactions to collect in each epoch.
 n_eval_episodes – Number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only.
 eval_steps – an optional list of max_steps to use for evaluation (defaults to task.max_steps).
 eval_temperatures – we always train with temperature 1 and evaluate with temperature specified in the eval_temperatures list (defaults to [0.0, 0.5])
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 output_dir – Path telling where to save outputs such as checkpoints.
 timestep_to_np – Timesteptonumpy function to override in the task.

current_epoch
¶ Returns current step number in this training session.

task
¶ Returns the task.

avg_returns
¶

save_gin
(summary_writer=None)¶

save_to_file
(file_name='rl.pkl', task_file_name='trajectories.pkl')¶ Save current epoch number and average returns to file.

init_from_file
(file_name='rl.pkl', task_file_name='trajectories.pkl')¶ Initialize epoch number and average returns from file.

policy
(trajectory, temperature=1.0)¶ Policy function that allows to play using this trainer.
Parameters:  trajectory – an instance of trax.rl.task.Trajectory
 temperature – temperature used to sample from the policy (default=1.0)
Returns: a pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

train_epoch
()¶ Trains this Agent for one epoch – main RL logic goes here.

run
(n_epochs=1, n_epochs_is_total_epochs=False)¶ Runs this loop for n epochs.
Parameters:  n_epochs – Stop training after completing n steps.
 n_epochs_is_total_epochs – if True, consider n_epochs as the total number of epochs to train, including previously trained ones

close
()¶


class
trax.rl.training.
PolicyAgent
(task, policy_model=None, policy_optimizer=None, policy_lr_schedule=<function multifactor>, policy_batch_size=64, policy_train_steps_per_epoch=500, policy_evals_per_epoch=1, policy_eval_steps=1, n_eval_episodes=0, only_eval=False, max_slice_length=1, output_dir=None, **kwargs)¶ Bases:
trax.rl.training.Agent
Agent that uses a deep learning model for policy.
Many deep RL methods, such as policy gradient (REINFORCE) or actorcritic fall into this category, so a lot of classes will be subclasses of this one. But some methods only have a value or Q function, these are different.

__init__
(task, policy_model=None, policy_optimizer=None, policy_lr_schedule=<function multifactor>, policy_batch_size=64, policy_train_steps_per_epoch=500, policy_evals_per_epoch=1, policy_eval_steps=1, n_eval_episodes=0, only_eval=False, max_slice_length=1, output_dir=None, **kwargs)¶ Configures the policy trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 policy_model – Trax layer, representing the policy model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 policy_optimizer – the optimizer to use to train the policy model.
 policy_lr_schedule – learning rate schedule to use to train the policy.
 policy_batch_size – batch size used to train the policy model.
 policy_train_steps_per_epoch – how long to train policy in each RL epoch.
 policy_evals_per_epoch – number of policy trainer evaluations per RL epoch  only affects metric reporting.
 policy_eval_steps – number of policy trainer steps per evaluation  only affects metric reporting.
 n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 max_slice_length – the maximum length of trajectory slices to use.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 **kwargs – arguments for the superclass Agent.

policy_loss
¶ Policy loss.

policy_metrics
¶

policy_batches_stream
()¶ Use self.task to create inputs to the policy model.

policy
(trajectory, temperature=1.0)¶ Chooses an action to play after a trajectory.

train_epoch
()¶ Trains RL for one epoch.

close
()¶


trax.rl.training.
remaining_evals
(cur_step, epoch, train_steps_per_epoch, evals_per_epoch)¶ Helper function to calculate remaining evaluations for a trainer.
Parameters:  cur_step – current step of the supervised trainer
 epoch – current epoch of the RL trainer
 train_steps_per_epoch – supervised trainer steps per RL epoch
 evals_per_epoch – supervised trainer evals per RL epoch
Returns: number of remaining evals to do this epoch
Raises: ValueError if the provided numbers indicate a step mismatch

class
trax.rl.training.
LoopPolicyAgent
(task, model_fn, value_fn, weight_fn, n_replay_epochs, n_train_steps_per_epoch, advantage_normalization, optimizer=<class 'trax.optimizers.adam.Adam'>, lr_schedule=<function multifactor>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, trajectory_stream_preprocessing_fn=None, **kwargs)¶ Bases:
trax.rl.training.Agent
Base class for policyonly Agents based on Loop.

__init__
(task, model_fn, value_fn, weight_fn, n_replay_epochs, n_train_steps_per_epoch, advantage_normalization, optimizer=<class 'trax.optimizers.adam.Adam'>, lr_schedule=<function multifactor>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, trajectory_stream_preprocessing_fn=None, **kwargs)¶ Initializes LoopPolicyAgent.
Parameters:  task – Instance of trax.rl.task.RLTask.
 model_fn – Function (policy_distribution, mode) > policy_model.
 value_fn – Function TimeStepBatch > array (batch_size, seq_len) calculating the baseline for advantage calculation.
 weight_fn – Function float > float to apply to advantages when calculating policy loss.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 n_train_steps_per_epoch – Number of steps to train the policy network for in each epoch.
 advantage_normalization – Whether to normalize the advantages before passing them to weight_fn.
 optimizer – Optimizer for network training.
 lr_schedule – Learning rate schedule for network training.
 batch_size – Batch size for network training.
 network_eval_at – Function step > bool indicating the training steps, when network evaluation should be performed.
 n_eval_batches – Number of batches to run during network evaluation.
 max_slice_length – The length of trajectory slices to run the network on.
 trajectory_stream_preprocessing_fn – Function to apply to the trajectory stream before batching. Can be used e.g. to filter trajectories.
 **kwargs – Keyword arguments passed to the superclass.

loop
¶ Loop exposed for testing.

train_epoch
()¶ Trains RL for one epoch.


class
trax.rl.training.
PolicyGradient
(task, model_fn, **kwargs)¶ Bases:
trax.rl.training.LoopPolicyAgent
Trains a policy model using policy gradient on the given RLTask.

__init__
(task, model_fn, **kwargs)¶ Initializes PolicyGradient.
Parameters:  task – Instance of trax.rl.task.RLTask.
 model_fn – Function (policy_distribution, mode) > policy_model.
 **kwargs – Keyword arguments passed to the superclass.

policy
(trajectory, temperature=1.0)¶ Policy function that samples from the trained network.


trax.rl.training.
sharpened_network_policy
(temperature, temperature_multiplier=1.0, **kwargs)¶ Expert function that runs a policy network with lower temperature.
Parameters:  temperature – Temperature passed from the Agent.
 temperature_multiplier – Multiplier to apply to the temperature to “sharpen” the policy distribution. Should be <= 1, but this is not a requirement.
 **kwargs – Keyword arguments passed to network_policy.
Returns: Pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

class
trax.rl.training.
ExpertIteration
(task, model_fn, expert_policy_fn=<function sharpened_network_policy>, quantile=0.9, n_replay_epochs=10, n_train_steps_per_epoch=1000, filter_buffer_size=256, **kwargs)¶ Bases:
trax.rl.training.LoopPolicyAgent
Trains a policy model using expert iteration with a given expert.

__init__
(task, model_fn, expert_policy_fn=<function sharpened_network_policy>, quantile=0.9, n_replay_epochs=10, n_train_steps_per_epoch=1000, filter_buffer_size=256, **kwargs)¶ Initializes ExpertIteration.
Parameters:  task – Instance of trax.rl.task.RLTask.
 model_fn – Function (policy_distribution, mode) > policy_model.
 expert_policy_fn – Function of the same signature as network_policy, to be used as an expert. The policy will be trained to mimic the expert on the “solved” trajectories.
 quantile – Quantile of best trajectories to be marked as “solved”. They will be used to train the policy.
 n_replay_epochs – Number of last epochs to include in the replay buffer.
 n_train_steps_per_epoch – Number of policy training steps to run in each epoch.
 filter_buffer_size – Number of trajectories in the trajectory filter buffer, used to select the best trajectories based on the quantile.
 **kwargs – Keyword arguments passed to the superclass.

policy
(trajectory, temperature=1.0)¶ Policy function that runs the expert.


trax.rl.training.
network_policy
(collect_model, policy_distribution, loop, trajectory_np, head_index=0, temperature=1.0)¶ Policy function powered by a neural network.
Used to implement Agent.policy() in policybased agents.
Parameters:  collect_model – the model used for collecting trajectories
 policy_distribution – an instance of trax.rl.distributions.Distribution
 loop – trax.supervised.training.Loop used to train the policy network
 trajectory_np – an instance of trax.rl.task.TimeStepBatch
 head_index – index of the policy head a multihead model.
 temperature – temperature used to sample from the policy (default=1.0)
Returns: a pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

class
trax.rl.training.
ValueAgent
(task, value_body=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial(<function multifactor>, factors='constant * decay_every', constant=1.0, decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs)¶ Bases:
trax.rl.training.Agent
Trainer that uses a deep learning model for value function.
Compute the loss using variants of the Bellman equation.

__init__
(task, value_body=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial(<function multifactor>, factors='constant * decay_every', constant=1.0, decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs)¶ Configures the value trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 value_body – Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 value_optimizer – the optimizer to use to train the policy model.
 value_lr_schedule – learning rate schedule to use to train the policy.
 value_batch_size – batch size used to train the policy model.
 value_train_steps_per_epoch – how long to train policy in each RL epoch.
 value_evals_per_epoch – number of policy trainer evaluations per RL epoch  only affects metric reporting.
 value_eval_steps – number of policy trainer steps per evaluation  only affects metric reporting.
 exploration_rate – exploration rate schedule  used in the policy method.
 n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 max_slice_length – the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future.
 sync_freq – frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using nstep returns.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma). We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 **kwargs – arguments for the superclass RLTrainer.

value_batches_stream
()¶ Use self.task to create inputs to the policy model.

policy
(trajectory, temperature=1)¶ Chooses an action to play after a trajectory.

train_epoch
()¶ Trains RL for one epoch.

close
()¶

value_mean
¶ The mean value of actions selected by the behavioral policy.

returns_mean
¶ The mean value of actions selected by the behavioral policy.


class
trax.rl.training.
DQN
(task, advantage_estimator=<function monte_carlo>, max_slice_length=1, smoothl1loss=True, double_dqn=False, **kwargs)¶ Bases:
trax.rl.training.ValueAgent
Trains a value model using DQN on the given RLTask.
Notice that the algorithm and the parameters signficantly diverge from the original DQN paper. In particular we have separated learning and data collection.
The Bellman loss is computed in the value_loss method. The formula takes the stateaction values tensors Q and nstep returns R:
\[L(s,a) = Q(s,a)  R(s,a)\]where R is computed in value_batches_stream. In the simplest case of the 1step returns we are getting
\[L(s,a) = Q(s,a)  r(s,a)  gamma * \max_{a'} Q'(s',a')\]where s’ is the state reached after taking action a in state s, Q’ is the target network, gamma is the discount factor and the maximum is taken with respect to all actions avaliable in the state s’. The tensor Q’ is updated using the sync_freq parameter.
In code the maximum is visible in the policy method where we take sample = jnp.argmax(values). The epsilongreedy policy is taking a random move with probability epsilon and oterhwise in state s it takes the action argmax_a Q(s,a).

__init__
(task, advantage_estimator=<function monte_carlo>, max_slice_length=1, smoothl1loss=True, double_dqn=False, **kwargs)¶ Configures the value trainer.
Parameters:  task – RLTask instance, which defines the environment to train on.
 value_body – Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
 value_optimizer – the optimizer to use to train the policy model.
 value_lr_schedule – learning rate schedule to use to train the policy.
 value_batch_size – batch size used to train the policy model.
 value_train_steps_per_epoch – how long to train policy in each RL epoch.
 value_evals_per_epoch – number of policy trainer evaluations per RL epoch  only affects metric reporting.
 value_eval_steps – number of policy trainer steps per evaluation  only affects metric reporting.
 exploration_rate – exploration rate schedule  used in the policy method.
 n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
 only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
 n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for offpolicy algorithms.
 max_slice_length – the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future.
 sync_freq – frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using nstep returns.
 scale_value_targets – If True, scale value function targets by 1 / (1  gamma). We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters.
 output_dir – Path telling where to save outputs (evals and checkpoints).
 **kwargs – arguments for the superclass RLTrainer.

value_loss
¶ Value loss computed using smooth L1 loss or L2 loss.

value_batches_stream
()¶ Use the RLTask self._task to create inputs to the value model.

policy
(trajectory, temperature=1)¶ Chooses an action to play after a trajectory.

value_mean
¶ The mean value of actions selected by the behavioral policy.
