trax.rl package

actor_critic

Classes for RL training in Trax.

class trax.rl.actor_critic.ActorCriticAgent(task, value_model=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate='logsumexp', q_value_temperature=1.0, q_value_n_samples=1, q_value_normalization=False, offline=False, **kwargs)

Bases: trax.rl.training.PolicyAgent

Trains policy and value models using actor-critic methods.

Attrs:
on_policy (bool): Whether the algorithm is on-policy. Used in the data
generators. Should be set in derived classes.
on_policy = None
__init__(task, value_model=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, n_shared_layers=0, added_policy_slice_length=0, n_replay_epochs=1, scale_value_targets=False, q_value=False, q_value_aggregate='logsumexp', q_value_temperature=1.0, q_value_n_samples=1, q_value_normalization=False, offline=False, **kwargs)

Configures the actor-critic trainer.

Parameters:
  • taskRLTask instance to use.
  • value_model – Model to use for the value function.
  • value_optimizer – Optimizer to train the value model.
  • value_lr_schedule – lr schedule for value model training.
  • value_batch_size – Batch size for value model training.
  • value_train_steps_per_epoch – Number of steps are we using to train the value model in each epoch.
  • value_evals_per_epoch – Number of value trainer evaluations per RL epoch. Every evaluation, we also synchronize the weights of the target network.
  • value_eval_steps – Number of value trainer steps per evaluation; only affects metric reporting.
  • n_shared_layers – Number of layers to share between value and policy models.
  • added_policy_slice_length – How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by max_slice_length in **kwargs.
  • n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms.
  • scale_value_targets – If True, scale value function targets by 1 / (1 - gamma).
  • q_value – If True, use Q-values as baselines.
  • q_value_aggregate – How to aggregate Q-values. Options: ‘mean’, ‘max’, ‘softmax’, ‘logsumexp’.
  • q_value_temperature – Temperature parameter for the ‘softmax’ and ‘logsumexp’ aggregation methods.
  • q_value_n_samples – Number of samples to average over when calculating baselines based on Q-values.
  • q_value_normalization – How to normalize Q-values before aggregation. Allowed values: ‘std’, ‘abs’, None. If None, don’t normalize.
  • offline – Whether to train in offline mode. This matters for some algorithms, e.g. QWR.
  • **kwargs – Arguments for PolicyAgent superclass.
value_mean

The mean value of the value function.

value_batches_stream()

Use the RLTask self._task to create inputs to the value model.

policy_inputs(trajectory, values)

Create inputs to policy model from a TimeStepBatch and values.

Parameters:
  • trajectory – a TimeStepBatch, the trajectory to create inputs from
  • values – a numpy array: value function computed on trajectory
Returns:

a tuple of numpy arrays of the form (inputs, x1, x2, …) that will be passed to the policy model; policy model will compute outputs from inputs and (outputs, x1, x2, …) will be passed to self.policy_loss which should be overridden accordingly.

policy_batches_stream()

Use the RLTask self._task to create inputs to the policy model.

train_epoch()

Trains RL for one epoch.

close()
class trax.rl.actor_critic.AdvantageBasedActorCriticAgent(task, advantage_estimator=<function td_lambda>, advantage_normalization=True, advantage_normalization_epsilon=1e-05, advantage_normalization_factor=1.0, added_policy_slice_length=0, **kwargs)

Bases: trax.rl.actor_critic.ActorCriticAgent

Base class for advantage-based actor-critic algorithms.

__init__(task, advantage_estimator=<function td_lambda>, advantage_normalization=True, advantage_normalization_epsilon=1e-05, advantage_normalization_factor=1.0, added_policy_slice_length=0, **kwargs)

Configures the actor-critic trainer.

Parameters:
  • taskRLTask instance to use.
  • value_model – Model to use for the value function.
  • value_optimizer – Optimizer to train the value model.
  • value_lr_schedule – lr schedule for value model training.
  • value_batch_size – Batch size for value model training.
  • value_train_steps_per_epoch – Number of steps are we using to train the value model in each epoch.
  • value_evals_per_epoch – Number of value trainer evaluations per RL epoch. Every evaluation, we also synchronize the weights of the target network.
  • value_eval_steps – Number of value trainer steps per evaluation; only affects metric reporting.
  • n_shared_layers – Number of layers to share between value and policy models.
  • added_policy_slice_length – How much longer should slices of trajectories be for policy than for value training; this is useful for TD calculations and only affect the length of elements produced for policy batches; value batches have maximum length set by max_slice_length in **kwargs.
  • n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms.
  • scale_value_targets – If True, scale value function targets by 1 / (1 - gamma).
  • q_value – If True, use Q-values as baselines.
  • q_value_aggregate – How to aggregate Q-values. Options: ‘mean’, ‘max’, ‘softmax’, ‘logsumexp’.
  • q_value_temperature – Temperature parameter for the ‘softmax’ and ‘logsumexp’ aggregation methods.
  • q_value_n_samples – Number of samples to average over when calculating baselines based on Q-values.
  • q_value_normalization – How to normalize Q-values before aggregation. Allowed values: ‘std’, ‘abs’, None. If None, don’t normalize.
  • offline – Whether to train in offline mode. This matters for some algorithms, e.g. QWR.
  • **kwargs – Arguments for PolicyAgent superclass.
policy_inputs(trajectory, values)

Create inputs to policy model from a TimeStepBatch and values.

policy_loss_given_log_probs

Policy loss given action log-probabilities.

policy_loss

Policy loss.

policy_metrics
advantage_mean
advantage_std
trax.rl.actor_critic.every(n_steps)

Returns True every n_steps, for use as *_at functions in various places.

class trax.rl.actor_critic.LoopActorCriticAgent(task, model_fn, optimizer=<class 'trax.optimizers.adam.Adam'>, policy_lr_schedule=<function multifactor>, policy_n_steps_per_epoch=1000, policy_weight_fn=<function LoopActorCriticAgent.<lambda>>, value_lr_schedule=<function multifactor>, value_n_steps_per_epoch=1000, value_sync_at=<function LoopActorCriticAgent.<lambda>>, advantage_estimator=<function monte_carlo>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, margin=0, n_replay_epochs=1, **kwargs)

Bases: trax.rl.training.Agent

Base class for actor-critic algorithms based on Loop.

on_policy = None
__init__(task, model_fn, optimizer=<class 'trax.optimizers.adam.Adam'>, policy_lr_schedule=<function multifactor>, policy_n_steps_per_epoch=1000, policy_weight_fn=<function LoopActorCriticAgent.<lambda>>, value_lr_schedule=<function multifactor>, value_n_steps_per_epoch=1000, value_sync_at=<function LoopActorCriticAgent.<lambda>>, advantage_estimator=<function monte_carlo>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, margin=0, n_replay_epochs=1, **kwargs)

Initializes LoopActorCriticAgent.

Parameters:
  • taskRLTask instance to use.
  • model_fn – Function mode -> Trax model, building a joint policy and value network.
  • optimizer – Optimizer for the policy and value networks.
  • policy_lr_schedule – Learning rate schedule for the policy network.
  • policy_n_steps_per_epoch – Number of steps to train the policy network for in each epoch.
  • policy_weight_fn – Function advantages -> weights for calculating the log probability weights in policy training.
  • value_lr_schedule – Learning rate schedule for the value network.
  • value_n_steps_per_epoch – Number of steps to train the value network for in each epoch.
  • value_sync_at – Function step -> bool indicating when to synchronize the target network with the trained network in value training.
  • advantage_estimator – Advantage estimator to use in policy and value training.
  • batch_size – Batch size for training the networks.
  • network_eval_at – Function step -> bool indicating in when to evaluate the networks.
  • n_eval_batches – Number of batches to compute the network evaluation metrics on.
  • max_slice_length – Maximum length of a trajectory slice to train on.
  • margin – Number of timesteps to add at the end of each trajectory slice for better advantage estimation.
  • n_replay_epochs – Number of epochs of trajectories to store in the replay buffer.
  • **kwargs – Keyword arguments forwarded to Agent.
loop

Loop exposed for testing.

policy(trajectory, temperature=1.0)

Policy function that allows to play using this agent.

train_epoch()

Trains RL for one epoch.

class trax.rl.actor_critic.A2C(task, entropy_coeff=0.01, **kwargs)

Bases: trax.rl.actor_critic.AdvantageBasedActorCriticAgent

Trains policy and value models using the A2C algorithm.

on_policy = True
__init__(task, entropy_coeff=0.01, **kwargs)

Configures the A2C Trainer.

policy_loss_given_log_probs

Definition of the Advantage Actor Critic (A2C) loss.

class trax.rl.actor_critic.PPO(task, epsilon=0.2, entropy_coeff=0.01, **kwargs)

Bases: trax.rl.actor_critic.AdvantageBasedActorCriticAgent

The Proximal Policy Optimization Algorithm aka PPO.

Trains policy and value models using the PPO algorithm.

on_policy = True
__init__(task, epsilon=0.2, entropy_coeff=0.01, **kwargs)

Configures the PPO Trainer.

policy_loss_given_log_probs

Definition of the Proximal Policy Optimization loss.

trax.rl.actor_critic.awr_weights(advantages, beta, thresholds)
trax.rl.actor_critic.awr_metrics(beta, thresholds, preprocess_layer=None)
trax.rl.actor_critic.awr_weight_stat(stat_name, stat_fn, beta, thresholds, preprocess_layer)
trax.rl.actor_critic.AWRLoss(beta, w_max, thresholds)

Definition of the Advantage Weighted Regression (AWR) loss.

class trax.rl.actor_critic.AWR(task, beta=1.0, w_max=20.0, thresholds=None, **kwargs)

Bases: trax.rl.actor_critic.AdvantageBasedActorCriticAgent

Trains policy and value models using AWR.

on_policy = False
__init__(task, beta=1.0, w_max=20.0, thresholds=None, **kwargs)

Configures the AWR Trainer.

policy_loss_given_log_probs

Policy loss.

class trax.rl.actor_critic.LoopAWR(task, model_fn, beta=1.0, w_max=20, **kwargs)

Bases: trax.rl.actor_critic.LoopActorCriticAgent

Advantage Weighted Regression.

on_policy = False
__init__(task, model_fn, beta=1.0, w_max=20, **kwargs)

Initializes LoopActorCriticAgent.

Parameters:
  • taskRLTask instance to use.
  • model_fn – Function mode -> Trax model, building a joint policy and value network.
  • optimizer – Optimizer for the policy and value networks.
  • policy_lr_schedule – Learning rate schedule for the policy network.
  • policy_n_steps_per_epoch – Number of steps to train the policy network for in each epoch.
  • policy_weight_fn – Function advantages -> weights for calculating the log probability weights in policy training.
  • value_lr_schedule – Learning rate schedule for the value network.
  • value_n_steps_per_epoch – Number of steps to train the value network for in each epoch.
  • value_sync_at – Function step -> bool indicating when to synchronize the target network with the trained network in value training.
  • advantage_estimator – Advantage estimator to use in policy and value training.
  • batch_size – Batch size for training the networks.
  • network_eval_at – Function step -> bool indicating in when to evaluate the networks.
  • n_eval_batches – Number of batches to compute the network evaluation metrics on.
  • max_slice_length – Maximum length of a trajectory slice to train on.
  • margin – Number of timesteps to add at the end of each trajectory slice for better advantage estimation.
  • n_replay_epochs – Number of epochs of trajectories to store in the replay buffer.
  • **kwargs – Keyword arguments forwarded to Agent.
trax.rl.actor_critic.SamplingAWRLoss(beta, w_max, thresholds, reweight=False, sampled_all_discrete=False)

Definition of the Advantage Weighted Regression (AWR) loss.

class trax.rl.actor_critic.SamplingAWR(task, beta=1.0, w_max=20.0, thresholds=None, reweight=False, **kwargs)

Bases: trax.rl.actor_critic.AdvantageBasedActorCriticAgent

Trains policy and value models using Sampling AWR.

on_policy = False
__init__(task, beta=1.0, w_max=20.0, thresholds=None, reweight=False, **kwargs)

Configures the AWR Trainer.

policy_metrics
policy_loss

Policy loss.

policy_batches_stream()

Use the RLTask self._task to create inputs to the policy model.

actor_critic_joint

Classes for RL training in Trax.

class trax.rl.actor_critic_joint.ActorCriticJointAgent(task, joint_model=None, optimizer=None, lr_schedule=<function multifactor>, batch_size=64, train_steps_per_epoch=500, supervised_evals_per_epoch=1, supervised_eval_steps=1, n_trajectories_per_epoch=50, max_slice_length=1, normalize_advantages=True, output_dir=None, n_replay_epochs=1)

Bases: trax.rl.training.Agent

Trains a joint policy-and-value model using actor-critic methods.

__init__(task, joint_model=None, optimizer=None, lr_schedule=<function multifactor>, batch_size=64, train_steps_per_epoch=500, supervised_evals_per_epoch=1, supervised_eval_steps=1, n_trajectories_per_epoch=50, max_slice_length=1, normalize_advantages=True, output_dir=None, n_replay_epochs=1)

Configures the joint trainer.

Parameters:
  • task – RLTask instance, which defines the environment to train on.
  • joint_model – Trax layer, representing the joint policy and value model.
  • optimizer – the optimizer to use to train the joint model.
  • lr_schedule – learning rate schedule to use to train the joint model/.
  • batch_size – batch size used to train the joint model.
  • train_steps_per_epoch – how long to train the joint model in each RL epoch.
  • supervised_evals_per_epoch – number of value trainer evaluations per RL epoch - only affects metric reporting.
  • supervised_eval_steps – number of value trainer steps per evaluation - only affects metric reporting.
  • n_trajectories_per_epoch – how many trajectories to collect per epoch.
  • max_slice_length – the maximum length of trajectory slices to use.
  • normalize_advantages – if True, then normalize advantages - currently implemented only in PPO.
  • output_dir – Path telling where to save outputs (evals and checkpoints).
  • n_replay_epochs – how many last epochs to take into the replay buffer; > 1 only makes sense for off-policy algorithms.
close()
batches_stream()

Use self.task to create inputs to the policy model.

joint_loss

Joint policy and value loss layer.

advantage_mean

Mean of advantages.

advantage_norm

Norm of advantages.

value_loss

Value loss - so far generic for all A2C.

explained_variance

Explained variance metric.

log_probs_mean

Mean of log_probs aka dist_inputs.

preferred_move

Preferred move - the mean of selected moves.

policy(trajectory, temperature=1.0)

Chooses an action to play after a trajectory.

train_epoch()

Trains RL for one epoch.

class trax.rl.actor_critic_joint.PPOJoint(task, epsilon=0.2, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)

Bases: trax.rl.actor_critic_joint.ActorCriticJointAgent

The Proximal Policy Optimization Algorithm aka PPO.

Trains policy and value models using the PPO algortithm.

on_policy = True
__init__(task, epsilon=0.2, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)

Configures the PPO Trainer.

batches_stream()

Use the RLTask self._task to create inputs to the value model.

joint_loss

Joint policy and value loss.

probs_ratio_mean

Joint policy and value loss layer.

clip_fraction

Joint policy and value loss layer.

entropy_loss

Entropy layer.

approximate_kl_divergence

Approximate KL divergence.

unclipped_objective_mean
clipped_objective_mean
ppo_objective

PPO objective with local parameters.

ppo_objective_mean

PPO objective mean.

class trax.rl.actor_critic_joint.A2CJoint(task, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)

Bases: trax.rl.actor_critic_joint.ActorCriticJointAgent

The A2C algorithm.

Trains policy and value models using the A2C algortithm.

on_policy = True
__init__(task, value_loss_coeff=0.1, entropy_coeff=0.01, **kwargs)

Configures the A2C Trainer.

batches_stream()

Use the RLTask self._task to create inputs to the value model.

joint_loss

Joint policy and value loss.

entropy_loss

Entropy layer.

approximate_kl_divergence

Approximate KL divergence.

a2c_objective

A2C objective with local parameters.

a2c_objective_mean

A2C objective mean.

class trax.rl.actor_critic_joint.AWRJoint(task, value_loss_coeff=0.1, beta=1.0, w_max=20.0, thresholds=None, **kwargs)

Bases: trax.rl.actor_critic_joint.ActorCriticJointAgent

Trains a joint policy-and-value model using AWR.

__init__(task, value_loss_coeff=0.1, beta=1.0, w_max=20.0, thresholds=None, **kwargs)

Configures the joint AWR Trainer.

batches_stream()

Use the RLTask self._task to create inputs to the value model.

joint_loss

Joint policy and value loss.

advantages

RL advantage estimators.

trax.rl.advantages.mask_discount(discount, discount_mask)

Computes a discount to apply at a given timestep, based on the mask.

trax.rl.advantages.discounted_returns(rewards, gammas)

Computes discounted returns for a trajectory or a batch of them.

trax.rl.advantages.monte_carlo(gamma, margin)

Calculate Monte Carlo advantage.

We assume the values are a tensor of shape [batch_size, length] and this is the same shape as rewards and returns.

Parameters:
  • gamma – float, gamma parameter for TD from the underlying task
  • margin – number of extra steps in the sequence
Returns:

Function (rewards, returns, values, dones) -> advantages, where advantages advantages is an array of shape [batch_size, length - margin].

trax.rl.advantages.td_k(gamma, margin)

Calculate TD-k advantage.

The k parameter is assumed to be the same as margin.

We calculate advantage(s_i) as:

gamma^n_steps * value(s_{i + n_steps}) - value(s_i) + discounted_rewards

where discounted_rewards is the sum of rewards in these steps with discounting by powers of gamma.

Parameters:
  • gamma – float, gamma parameter for TD from the underlying task
  • margin – number of extra steps in the sequence
Returns:

Function (rewards, returns, values, dones) -> advantages, where advantages advantages is an array of shape [batch_size, length - margin].

trax.rl.advantages.td_lambda(gamma, margin, lambda_=0.95)

Calculate TD-lambda advantage.

The estimated return is an exponentially-weighted average of different TD-k returns.

Parameters:
  • gamma – float, gamma parameter for TD from the underlying task
  • margin – number of extra steps in the sequence
  • lambda – float, the lambda parameter of TD-lambda
Returns:

Function (rewards, returns, values, dones) -> advantages, where advantages advantages is an array of shape [batch_size, length - margin].

trax.rl.advantages.gae(gamma, margin, lambda_=0.95)

Calculate Generalized Advantage Estimation.

Calculate state values bootstrapping off the following state values - Generalized Advantage Estimation https://arxiv.org/abs/1506.02438

Parameters:
  • gamma – float, gamma parameter for TD from the underlying task
  • margin – number of extra steps in the sequence
  • lambda – float, the lambda parameter of GAE
Returns:

Function (rewards, returns, values, dones) -> advantages, where advantages advantages is an array of shape [batch_size, length - margin].

distributions

Probability distributions for RL training in Trax.

class trax.rl.distributions.Distribution

Bases: object

Abstract class for parametrized probability distributions.

n_inputs

Returns the number of inputs to the distribution (i.e. parameters).

sample(inputs, temperature=1.0)

Samples a point from the distribution.

Parameters:
  • inputs (jnp.ndarray) – Distribution inputs. Shape is subclass-specific. Broadcasts along the first dimensions. For example, in the categorical distribution parameter shape is (C,), where C is the number of categories. If (B, C) is passed, the object will represent a batch of B categorical distributions with different parameters.
  • temperature – sampling temperature; 1.0 is default, at 0.0 chooses the most probable (preferred) action.
Returns:

Sampled point of shape dependent on the subclass and on the shape of inputs.

log_prob(inputs, point)

Retrieves log probability (or log probability density) of a point.

Parameters:
  • inputs (jnp.ndarray) – Distribution parameters.
  • point (jnp.ndarray) – Point from the distribution. Shape should be consistent with inputs.
Returns:

Array of log probabilities of points in the distribution.

LogProb()

Builds a log probability layer for this distribution.

trax.rl.distributions.create_distribution(space)

Creates a Distribution for the given Gym space.

trax.rl.distributions.LogLoss(distribution, **unused_kwargs)

Builds a log loss layer for a Distribution.

normalization

Normalization helpers.

trax.rl.normalization.running_mean_init(shape, fill_value=0)
trax.rl.normalization.running_mean_update(x, state)
trax.rl.normalization.running_mean_get_mean(state)
trax.rl.normalization.running_mean_get_count(state)
trax.rl.normalization.running_mean_and_variance_init(shape)
trax.rl.normalization.running_mean_and_variance_update(x, state)
trax.rl.normalization.running_mean_and_variance_get_mean(state)
trax.rl.normalization.running_mean_and_variance_get_count(state)
trax.rl.normalization.running_mean_and_variance_get_variance(state)
trax.rl.normalization.LayerNormSquash(mode, width=128)

Dense-LayerNorm-Tanh normalizer inspired by ACME.

rl_layers

A number of RL functions intended to be later wrapped as Trax layers.

Wrapping happens with help of the function tl.Fn.

trax.rl.rl_layers.ValueLoss(values, returns, value_loss_coeff)

Definition of the loss of the value function.

trax.rl.rl_layers.ExplainedVariance(values, returns)

Definition of explained variance - an approach from OpenAI baselines.

trax.rl.rl_layers.PreferredMove(dist_inputs, sample)

Definition of the preferred move.

trax.rl.rl_layers.NewLogProbs(dist_inputs, actions, log_prob_fun)

Given distribution and actions calculate log probs.

trax.rl.rl_layers.EntropyLoss(dist_inputs, distribution, coeff)

Definition of the Entropy Layer.

trax.rl.rl_layers.ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun)

Probability Ratio from the PPO algorithm.

trax.rl.rl_layers.ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun)

Probability Ratio from the PPO algorithm.

trax.rl.rl_layers.UnclippedObjective(probs_ratio, advantages)

Unclipped Objective from the PPO algorithm.

trax.rl.rl_layers.ClippedObjective(probs_ratio, advantages, epsilon)

Clipped Objective from the PPO algorithm.

trax.rl.rl_layers.PPOObjective(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun, epsilon, normalize_advantages)

PPO Objective.

trax.rl.rl_layers.A2CObjective(dist_inputs, values, returns, dones, rewards, actions, mask, log_prob_fun, normalize_advantages)

Definition of the Advantage Actor Critic (A2C) loss.

serialization_utils

Utilities for serializing trajectories into discrete sequences.

trax.rl.serialization_utils.Serialize(serializer)

Layer that serializes a given array.

trax.rl.serialization_utils.Interleave()

Layer that interleaves and flattens two serialized sequences.

The first sequence can be longer by 1 than the second one. This is so we can interleave sequences of observations and actions, when there’s 1 extra observation at the end.

For serialized sequences [[x_1_1, …, x_1_R1], …, [x_L1_1, …, x_L1_R1]] and [[y_1_1, …, y_1_R2], …, [y_L2_1, …, y_L2_R2]], where L1 = L2 + 1, the result is [x_1_1, …, x_1_R1, y_1_1, …, y_1_R2, …, x_L2_1, …, x_L2_R1, y_L2_1, …, y_L2_R2, x_L1_1, …, x_L1_R1] (batch dimension omitted for clarity).

The layer inputs are a sequence pair of shapes (B, L1, R1) and (B, L2, R2), where B is batch size, L* is the length of the sequence and R* is the representation length of each element in the sequence.

Returns:Layer that interleaves sequence of shape (B, L1 * R1 + L2 * R2).
trax.rl.serialization_utils.Deinterleave(x_size, y_size)

Layer that does the inverse of Interleave.

trax.rl.serialization_utils.RepresentationMask(serializer)

Upsamples a mask to cover the serialized representation.

trax.rl.serialization_utils.SignificanceWeights(serializer, decay)

Multiplies a binary mask with a symbol significance mask.

class trax.rl.serialization_utils.SerializedModel(seq_model, observation_serializer, action_serializer, significance_decay, mode='train')

Bases: trax.layers.combinators.Serial

Wraps a world model in serialization machinery for training.

The resulting model takes as input the observation and action sequences, serializes them and interleaves into one sequence, which is fed into a given autoregressive model. The resulting logit sequence is deinterleaved into observations and actions, and the observation logits are returned together with computed symbol significance weights.

The model has a signature (obs, act, obs, mask) -> (obs_logits, obs_repr, weights), where obs are observations (the second occurrence is the target), act are actions, mask is the observation mask, obs_logits are logits of the output observation representation, obs_repr is the target observation representation and weights are the target weights.

__init__(seq_model, observation_serializer, action_serializer, significance_decay, mode='train')

Initializes SerializedModel.

Parameters:
  • seq_model – Trax autoregressive model taking as input a sequence of symbols and outputting a sequence of symbol logits.
  • observation_serializer – Serializer to use for observations.
  • action_serializer – Serializer to use for actions.
  • significance_decay – Float from (0, 1) for exponential weighting of symbols in the representation.
  • mode – ‘train’ or ‘eval’.
observation_serializer
action_serializer
make_predict_model()

Returns a predict-mode model of the same architecture.

seq_model_weights

Extracts the weights of the underlying sequence model.

seq_model_state

Extracts the state of the underlying sequence model.

trax.rl.serialization_utils.TimeSeriesModel(seq_model, low=0.0, high=1.0, precision=2, vocab_size=64, significance_decay=0.7, mode='train')

Simplified constructor for SerializedModel, for time series prediction.

trax.rl.serialization_utils.RawPolicy(seq_model, n_controls, n_actions)

Wraps a sequence model in a policy interface.

The resulting model takes as input observation anc action sequences, but only uses the observations. Adds output heads for action logits and value predictions.

Parameters:
  • seq_model – Trax sequence model taking as input and outputting a sequence of continuous vectors.
  • n_controls – Number of controls.
  • n_actions – Number of action categories in each control.
Returns:

obs: (batch_size, length + 1, obs_depth) act: (batch_size, length, n_controls) act_logits: (batch_size, length, n_controls, n_actions) values: (batch_size, length)

Return type:

A model of signature (obs, act) -> (act_logits, values), with shapes

trax.rl.serialization_utils.substitute_inner_policy_raw(raw_policy, inner_policy)

Substitutes the weights/state of the inner model in a RawPolicy.

trax.rl.serialization_utils.SerializedPolicy(seq_model, n_controls, n_actions, observation_serializer, action_serializer)

Wraps a policy in serialization machinery for training.

The resulting model takes as input observation and action sequences, and serializes them into one sequence similar to SerializedModel, before passing to the given sequence model. Adds output heads for action logits and value predictions.

Parameters:
  • seq_model – Trax sequence model taking as input a sequence of symbols and outputting a sequence of continuous vectors.
  • n_controls – Number of controls.
  • n_actions – Number of action categories in each control.
  • observation_serializer – Serializer to use for observations.
  • action_serializer – Serializer to use for actions.
Returns:

A model of signature (obs, act) -> (act_logits, values), same as in RawPolicy.

trax.rl.serialization_utils.substitute_inner_policy_serialized(serialized_policy, inner_policy)

Substitutes the weights/state of the inner model in a SerializedPolicy.

trax.rl.serialization_utils.analyze_action_space(action_space)

Returns the number of controls and actions for an action space.

trax.rl.serialization_utils.wrap_policy(seq_model, observation_space, action_space, vocab_size)

Wraps a sequence model in either RawPolicy or SerializedPolicy.

Parameters:
  • seq_model – Trax sequence model.
  • observation_space – Gym observation space.
  • action_space – Gym action space.
  • vocab_size – Either the number of symbols for a serialized policy, or None.
Returns:

RawPolicy if vocab_size is None, else SerializedPolicy.

trax.rl.serialization_utils.substitute_inner_policy(wrapped_policy, inner_policy, vocab_size)

Substitutes the inner weights/state in a {Raw,Serialized}Policy.

Parameters:
  • wrapped_policy (pytree) – Weights or state of a wrapped policy.
  • inner_policy (pytree) – Weights or state of an inner policy.
  • vocab_size (int or None) – Vocabulary size of a serialized policy, or None in case of a raw policy.
Returns:

New weights or state of wrapped_policy, with the inner weights/state

copied from inner_policy.

space_serializer

Serialization of elements of Gym spaces into discrete sequences.

class trax.rl.space_serializer.SpaceSerializer(space, vocab_size)

Bases: object

Base class for Gym space serializers.

Attrs:
space_type: (type) Gym space class that this SpaceSerializer corresponds
to. Should be defined in subclasses.
representation_length: (int) Number of symbols in the representation of
every element of the space.
significance_map: (np.ndarray) Integer array of the same size as the
discrete representation, where elements describe the significance of symbols, e.g. in fixed-precision encoding. 0 is the most significant symbol, 1 the second most significant etc.
space_type = None
representation_length = None
significance_map = None
__init__(space, vocab_size)

Creates a SpaceSerializer.

Subclasses should retain the signature.

Parameters:
  • space – (gym.Space) Gym space of type self.space_type.
  • vocab_size – (int) Number of symbols in the vocabulary.
vocab_size
serialize(data)

Serializes a batch of space elements into discrete sequences.

Should be defined in subclasses.

Parameters:data – A batch of batch_size elements of the Gym space to be serialized.
Returns:int32 array of shape (batch_size, self.representation_length).
deserialize(representation)

Deserializes a batch of discrete sequences into space elements.

Should be defined in subclasses.

Parameters:representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized.
Returns:A batch of batch_size deserialized elements of the Gym space.
trax.rl.space_serializer.create(space, vocab_size)

Creates a SpaceSerializer for the given Gym space.

class trax.rl.space_serializer.DiscreteSpaceSerializer(space, vocab_size)

Bases: trax.rl.space_serializer.SpaceSerializer

Serializer for gym.spaces.Discrete.

Assumes that the size of the space fits in the number of symbols.

space_type

Used by autodoc_mock_imports.

representation_length = 1
__init__(space, vocab_size)

Creates a SpaceSerializer.

Subclasses should retain the signature.

Parameters:
  • space – (gym.Space) Gym space of type self.space_type.
  • vocab_size – (int) Number of symbols in the vocabulary.
serialize(data)

Serializes a batch of space elements into discrete sequences.

Should be defined in subclasses.

Parameters:data – A batch of batch_size elements of the Gym space to be serialized.
Returns:int32 array of shape (batch_size, self.representation_length).
deserialize(representation)

Deserializes a batch of discrete sequences into space elements.

Should be defined in subclasses.

Parameters:representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized.
Returns:A batch of batch_size deserialized elements of the Gym space.
significance_map
class trax.rl.space_serializer.MultiDiscreteSpaceSerializer(space, vocab_size)

Bases: trax.rl.space_serializer.SpaceSerializer

Serializer for gym.spaces.MultiDiscrete.

Assumes that the number of categories in each dimension fits in the number of symbols.

space_type

Used by autodoc_mock_imports.

__init__(space, vocab_size)

Creates a SpaceSerializer.

Subclasses should retain the signature.

Parameters:
  • space – (gym.Space) Gym space of type self.space_type.
  • vocab_size – (int) Number of symbols in the vocabulary.
serialize(data)

Serializes a batch of space elements into discrete sequences.

Should be defined in subclasses.

Parameters:data – A batch of batch_size elements of the Gym space to be serialized.
Returns:int32 array of shape (batch_size, self.representation_length).
deserialize(representation)

Deserializes a batch of discrete sequences into space elements.

Should be defined in subclasses.

Parameters:representation – int32 Numpy array of shape (batch_size, self.representation_length) to be deserialized.
Returns:A batch of batch_size deserialized elements of the Gym space.
representation_length
significance_map

task

Classes for defining RL tasks in Trax.

class trax.rl.task.TimeStepBatch(observation, action, reward, done, mask, dist_inputs, env_info, return_)

Bases: tuple

action

Alias for field number 1

dist_inputs

Alias for field number 5

done

Alias for field number 3

env_info

Alias for field number 6

mask

Alias for field number 4

observation

Alias for field number 0

return_

Alias for field number 7

reward

Alias for field number 2

class trax.rl.task.EnvInfo(control_mask, discount_mask)

Bases: tuple

control_mask

Alias for field number 0

discount_mask

Alias for field number 1

class trax.rl.task.Trajectory(observation)

Bases: object

A trajectory of interactions with a RL environment.

Trajectories are created when interacting with an RL environment. They can be prolonged and sliced and when completed, allow to re-calculate returns.

__init__(observation)

Initialize self. See help(type(self)) for accurate signature.

suffix(length)

Returns a Trajectory with the last length observations.

timesteps
total_return

Sum of all rewards in this trajectory.

last_observation

Return the last observation in this trajectory.

done

Returns whether the trajectory is finished.

extend(new_observation, mask=1, **kwargs)

Take action in the last state, getting reward and going to new state.

calculate_returns(gamma)

Calculate discounted returns.

to_np(margin=1, timestep_to_np=None)

Create a tuple of numpy arrays from a given trajectory.

Parameters:
  • margin (int) – Number of dummy timesteps past the trajectory end to include. By default we include 1, which contains the last observation.
  • timestep_to_np (callable or None) – Optional function TimeStepBatch[Any] -> TimeStepBatch[np.array], converting the timestep data into numpy arrays.
Returns:

TimeStepBatch, where all fields have shape (len(self) + margin - 1, …).

trax.rl.task.play(env, policy, dm_suite=False, max_steps=None, last_observation=None)

Play an episode in env taking actions according to the given policy.

Environment is first reset and an from then on, a game proceeds. At each step, the policy is asked to choose an action and the environment moves forward. A Trajectory is created in that way and returns when the episode finished, which is either when env returns done or max_steps is reached.

Parameters:
  • env – the environment to play in, conforming to gym.Env or DeepMind suite interfaces.
  • policy – a function taking a Trajectory and returning a pair consisting of an action (int or float) and the confidence in that action (float, defined as the log of the probability of taking that action).
  • dm_suite – whether we are using the DeepMind suite or the gym interface
  • max_steps – for how many steps to play.
  • last_observation – last observation from a previous trajectory slice, used to begin a new one. Controls whether we reset the environment at the beginning - if None, resets the env and starts the slice from the observation got from reset().
Returns:

a completed trajectory slice that was just played.

training

Classes for RL training in Trax.

class trax.rl.training.Agent(task: <sphinx.ext.autodoc.importer._MockObject object at 0x7f036bb37510>, n_trajectories_per_epoch=None, n_interactions_per_epoch=None, n_eval_episodes=0, eval_steps=None, eval_temperatures=(0.0, ), only_eval=False, output_dir=None, timestep_to_np=None)

Bases: object

Abstract class for RL agents, presenting the required API.

__init__(task: <sphinx.ext.autodoc.importer._MockObject object at 0x7f036bb37510>, n_trajectories_per_epoch=None, n_interactions_per_epoch=None, n_eval_episodes=0, eval_steps=None, eval_temperatures=(0.0, ), only_eval=False, output_dir=None, timestep_to_np=None)

Configures the Agent.

Note that subclasses can have many more arguments, which will be configured using defaults and gin. But task and output_dir are passed explicitly.

Parameters:
  • task – RLTask instance, which defines the environment to train on.
  • n_trajectories_per_epoch – How many new trajectories to collect in each epoch.
  • n_interactions_per_epoch – How many interactions to collect in each epoch.
  • n_eval_episodes – Number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only.
  • eval_steps – an optional list of max_steps to use for evaluation (defaults to task.max_steps).
  • eval_temperatures – we always train with temperature 1 and evaluate with temperature specified in the eval_temperatures list (defaults to [0.0, 0.5])
  • only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
  • output_dir – Path telling where to save outputs such as checkpoints.
  • timestep_to_np – Timestep-to-numpy function to override in the task.
current_epoch

Returns current step number in this training session.

task

Returns the task.

avg_returns
save_gin(summary_writer=None)
save_to_file(file_name='rl.pkl', task_file_name='trajectories.pkl')

Save current epoch number and average returns to file.

init_from_file(file_name='rl.pkl', task_file_name='trajectories.pkl')

Initialize epoch number and average returns from file.

policy(trajectory, temperature=1.0)

Policy function that allows to play using this trainer.

Parameters:
  • trajectory – an instance of trax.rl.task.Trajectory
  • temperature – temperature used to sample from the policy (default=1.0)
Returns:

a pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

train_epoch()

Trains this Agent for one epoch – main RL logic goes here.

run(n_epochs=1, n_epochs_is_total_epochs=False)

Runs this loop for n epochs.

Parameters:
  • n_epochs – Stop training after completing n steps.
  • n_epochs_is_total_epochs – if True, consider n_epochs as the total number of epochs to train, including previously trained ones
close()
class trax.rl.training.PolicyAgent(task, policy_model=None, policy_optimizer=None, policy_lr_schedule=<function multifactor>, policy_batch_size=64, policy_train_steps_per_epoch=500, policy_evals_per_epoch=1, policy_eval_steps=1, n_eval_episodes=0, only_eval=False, max_slice_length=1, output_dir=None, **kwargs)

Bases: trax.rl.training.Agent

Agent that uses a deep learning model for policy.

Many deep RL methods, such as policy gradient (REINFORCE) or actor-critic fall into this category, so a lot of classes will be subclasses of this one. But some methods only have a value or Q function, these are different.

__init__(task, policy_model=None, policy_optimizer=None, policy_lr_schedule=<function multifactor>, policy_batch_size=64, policy_train_steps_per_epoch=500, policy_evals_per_epoch=1, policy_eval_steps=1, n_eval_episodes=0, only_eval=False, max_slice_length=1, output_dir=None, **kwargs)

Configures the policy trainer.

Parameters:
  • task – RLTask instance, which defines the environment to train on.
  • policy_model – Trax layer, representing the policy model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
  • policy_optimizer – the optimizer to use to train the policy model.
  • policy_lr_schedule – learning rate schedule to use to train the policy.
  • policy_batch_size – batch size used to train the policy model.
  • policy_train_steps_per_epoch – how long to train policy in each RL epoch.
  • policy_evals_per_epoch – number of policy trainer evaluations per RL epoch - only affects metric reporting.
  • policy_eval_steps – number of policy trainer steps per evaluation - only affects metric reporting.
  • n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
  • only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
  • max_slice_length – the maximum length of trajectory slices to use.
  • output_dir – Path telling where to save outputs (evals and checkpoints).
  • **kwargs – arguments for the superclass Agent.
policy_loss

Policy loss.

policy_metrics
policy_batches_stream()

Use self.task to create inputs to the policy model.

policy(trajectory, temperature=1.0)

Chooses an action to play after a trajectory.

train_epoch()

Trains RL for one epoch.

close()
trax.rl.training.remaining_evals(cur_step, epoch, train_steps_per_epoch, evals_per_epoch)

Helper function to calculate remaining evaluations for a trainer.

Parameters:
  • cur_step – current step of the supervised trainer
  • epoch – current epoch of the RL trainer
  • train_steps_per_epoch – supervised trainer steps per RL epoch
  • evals_per_epoch – supervised trainer evals per RL epoch
Returns:

number of remaining evals to do this epoch

Raises:

ValueError if the provided numbers indicate a step mismatch

class trax.rl.training.LoopPolicyAgent(task, model_fn, value_fn, weight_fn, n_replay_epochs, n_train_steps_per_epoch, advantage_normalization, optimizer=<class 'trax.optimizers.adam.Adam'>, lr_schedule=<function multifactor>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, trajectory_stream_preprocessing_fn=None, **kwargs)

Bases: trax.rl.training.Agent

Base class for policy-only Agents based on Loop.

__init__(task, model_fn, value_fn, weight_fn, n_replay_epochs, n_train_steps_per_epoch, advantage_normalization, optimizer=<class 'trax.optimizers.adam.Adam'>, lr_schedule=<function multifactor>, batch_size=64, network_eval_at=None, n_eval_batches=1, max_slice_length=1, trajectory_stream_preprocessing_fn=None, **kwargs)

Initializes LoopPolicyAgent.

Parameters:
  • task – Instance of trax.rl.task.RLTask.
  • model_fn – Function (policy_distribution, mode) -> policy_model.
  • value_fn – Function TimeStepBatch -> array (batch_size, seq_len) calculating the baseline for advantage calculation.
  • weight_fn – Function float -> float to apply to advantages when calculating policy loss.
  • n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms.
  • n_train_steps_per_epoch – Number of steps to train the policy network for in each epoch.
  • advantage_normalization – Whether to normalize the advantages before passing them to weight_fn.
  • optimizer – Optimizer for network training.
  • lr_schedule – Learning rate schedule for network training.
  • batch_size – Batch size for network training.
  • network_eval_at – Function step -> bool indicating the training steps, when network evaluation should be performed.
  • n_eval_batches – Number of batches to run during network evaluation.
  • max_slice_length – The length of trajectory slices to run the network on.
  • trajectory_stream_preprocessing_fn – Function to apply to the trajectory stream before batching. Can be used e.g. to filter trajectories.
  • **kwargs – Keyword arguments passed to the superclass.
loop

Loop exposed for testing.

train_epoch()

Trains RL for one epoch.

class trax.rl.training.PolicyGradient(task, model_fn, **kwargs)

Bases: trax.rl.training.LoopPolicyAgent

Trains a policy model using policy gradient on the given RLTask.

__init__(task, model_fn, **kwargs)

Initializes PolicyGradient.

Parameters:
  • task – Instance of trax.rl.task.RLTask.
  • model_fn – Function (policy_distribution, mode) -> policy_model.
  • **kwargs – Keyword arguments passed to the superclass.
policy(trajectory, temperature=1.0)

Policy function that samples from the trained network.

trax.rl.training.sharpened_network_policy(temperature, temperature_multiplier=1.0, **kwargs)

Expert function that runs a policy network with lower temperature.

Parameters:
  • temperature – Temperature passed from the Agent.
  • temperature_multiplier – Multiplier to apply to the temperature to “sharpen” the policy distribution. Should be <= 1, but this is not a requirement.
  • **kwargs – Keyword arguments passed to network_policy.
Returns:

Pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

class trax.rl.training.ExpertIteration(task, model_fn, expert_policy_fn=<function sharpened_network_policy>, quantile=0.9, n_replay_epochs=10, n_train_steps_per_epoch=1000, filter_buffer_size=256, **kwargs)

Bases: trax.rl.training.LoopPolicyAgent

Trains a policy model using expert iteration with a given expert.

__init__(task, model_fn, expert_policy_fn=<function sharpened_network_policy>, quantile=0.9, n_replay_epochs=10, n_train_steps_per_epoch=1000, filter_buffer_size=256, **kwargs)

Initializes ExpertIteration.

Parameters:
  • task – Instance of trax.rl.task.RLTask.
  • model_fn – Function (policy_distribution, mode) -> policy_model.
  • expert_policy_fn – Function of the same signature as network_policy, to be used as an expert. The policy will be trained to mimic the expert on the “solved” trajectories.
  • quantile – Quantile of best trajectories to be marked as “solved”. They will be used to train the policy.
  • n_replay_epochs – Number of last epochs to include in the replay buffer.
  • n_train_steps_per_epoch – Number of policy training steps to run in each epoch.
  • filter_buffer_size – Number of trajectories in the trajectory filter buffer, used to select the best trajectories based on the quantile.
  • **kwargs – Keyword arguments passed to the superclass.
policy(trajectory, temperature=1.0)

Policy function that runs the expert.

trax.rl.training.network_policy(collect_model, policy_distribution, loop, trajectory_np, head_index=0, temperature=1.0)

Policy function powered by a neural network.

Used to implement Agent.policy() in policy-based agents.

Parameters:
  • collect_model – the model used for collecting trajectories
  • policy_distribution – an instance of trax.rl.distributions.Distribution
  • loop – trax.supervised.training.Loop used to train the policy network
  • trajectory_np – an instance of trax.rl.task.TimeStepBatch
  • head_index – index of the policy head a multihead model.
  • temperature – temperature used to sample from the policy (default=1.0)
Returns:

a pair (action, dist_inputs) where action is the action taken and dist_inputs is the parameters of the policy distribution, that will later be used for training.

class trax.rl.training.ValueAgent(task, value_body=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial(<function multifactor>, factors='constant * decay_every', constant=1.0, decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs)

Bases: trax.rl.training.Agent

Trainer that uses a deep learning model for value function.

Compute the loss using variants of the Bellman equation.

__init__(task, value_body=None, value_optimizer=None, value_lr_schedule=<function multifactor>, value_batch_size=64, value_train_steps_per_epoch=500, value_evals_per_epoch=1, value_eval_steps=1, exploration_rate=functools.partial(<function multifactor>, factors='constant * decay_every', constant=1.0, decay_factor=0.99, steps_per_decay=1, minimum=0.1), n_eval_episodes=0, only_eval=False, n_replay_epochs=1, max_slice_length=1, sync_freq=1000, scale_value_targets=True, output_dir=None, **kwargs)

Configures the value trainer.

Parameters:
  • task – RLTask instance, which defines the environment to train on.
  • value_body – Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
  • value_optimizer – the optimizer to use to train the policy model.
  • value_lr_schedule – learning rate schedule to use to train the policy.
  • value_batch_size – batch size used to train the policy model.
  • value_train_steps_per_epoch – how long to train policy in each RL epoch.
  • value_evals_per_epoch – number of policy trainer evaluations per RL epoch - only affects metric reporting.
  • value_eval_steps – number of policy trainer steps per evaluation - only affects metric reporting.
  • exploration_rate – exploration rate schedule - used in the policy method.
  • n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
  • only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
  • n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms.
  • max_slice_length – the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future.
  • sync_freq – frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using n-step returns.
  • scale_value_targets – If True, scale value function targets by 1 / (1 - gamma). We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters.
  • output_dir – Path telling where to save outputs (evals and checkpoints).
  • **kwargs – arguments for the superclass RLTrainer.
value_batches_stream()

Use self.task to create inputs to the policy model.

policy(trajectory, temperature=1)

Chooses an action to play after a trajectory.

train_epoch()

Trains RL for one epoch.

close()
value_mean

The mean value of actions selected by the behavioral policy.

returns_mean

The mean value of actions selected by the behavioral policy.

class trax.rl.training.DQN(task, advantage_estimator=<function monte_carlo>, max_slice_length=1, smoothl1loss=True, double_dqn=False, **kwargs)

Bases: trax.rl.training.ValueAgent

Trains a value model using DQN on the given RLTask.

Notice that the algorithm and the parameters signficantly diverge from the original DQN paper. In particular we have separated learning and data collection.

The Bellman loss is computed in the value_loss method. The formula takes the state-action values tensors Q and n-step returns R:

\[L(s,a) = Q(s,a) - R(s,a)\]

where R is computed in value_batches_stream. In the simplest case of the 1-step returns we are getting

\[L(s,a) = Q(s,a) - r(s,a) - gamma * \max_{a'} Q'(s',a')\]

where s’ is the state reached after taking action a in state s, Q’ is the target network, gamma is the discount factor and the maximum is taken with respect to all actions avaliable in the state s’. The tensor Q’ is updated using the sync_freq parameter.

In code the maximum is visible in the policy method where we take sample = jnp.argmax(values). The epsilon-greedy policy is taking a random move with probability epsilon and oterhwise in state s it takes the action argmax_a Q(s,a).

__init__(task, advantage_estimator=<function monte_carlo>, max_slice_length=1, smoothl1loss=True, double_dqn=False, **kwargs)

Configures the value trainer.

Parameters:
  • task – RLTask instance, which defines the environment to train on.
  • value_body – Trax layer, representing the body of the value model. functions and eval functions (a.k.a. metrics) are considered to be outside the core model, taking core model output and data labels as their two inputs.
  • value_optimizer – the optimizer to use to train the policy model.
  • value_lr_schedule – learning rate schedule to use to train the policy.
  • value_batch_size – batch size used to train the policy model.
  • value_train_steps_per_epoch – how long to train policy in each RL epoch.
  • value_evals_per_epoch – number of policy trainer evaluations per RL epoch - only affects metric reporting.
  • value_eval_steps – number of policy trainer steps per evaluation - only affects metric reporting.
  • exploration_rate – exploration rate schedule - used in the policy method.
  • n_eval_episodes – number of episodes to play with policy at temperature 0 in each epoch – used for evaluation only
  • only_eval – If set to True, then trajectories are collected only for for evaluation purposes, but they are not recorded.
  • n_replay_epochs – Number of last epochs to take into the replay buffer; only makes sense for off-policy algorithms.
  • max_slice_length – the maximum length of trajectory slices to use; it is the second dimenions of the value network output: (batch, max_slice_length, number of actions) Higher max_slice_length implies that the network has to predict more values into the future.
  • sync_freq – frequency when to synchronize the target network with the trained network. This is necessary for training the network on bootstrapped targets, e.g. using n-step returns.
  • scale_value_targets – If True, scale value function targets by 1 / (1 - gamma). We are trying to fix the problem with very large returns in some games in a way which does not introduce an additional hyperparameters.
  • output_dir – Path telling where to save outputs (evals and checkpoints).
  • **kwargs – arguments for the superclass RLTrainer.
value_loss

Value loss computed using smooth L1 loss or L2 loss.

value_batches_stream()

Use the RLTask self._task to create inputs to the value model.

policy(trajectory, temperature=1)

Chooses an action to play after a trajectory.

value_mean

The mean value of actions selected by the behavioral policy.