trax.data

inputs

Data sources and input processing.

Trax authors recommend constructing input pipelines using layer-like functions and combinators. For example, following is an input pipeline for training sentiment analysis tasks on the IMDB dataset:

from trax import data

inputs = data.Serial(
  data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),
  data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
  data.Shuffle(),
  data.FilterByLength(max_length=2048, length_keys=[0]),
  data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                      batch_sizes=[128,  32,   8,    2, 1],
                      length_keys=[0]),
  data.AddLossWeights()
)

Each of these functions creates a Python generator of tuples of data arrays. For example:

data.TFDS('imdb_reviews', keys=('text', 'label'), train=True),

creates a generator of examples (tuples of NumPy ndarray objects) from the TFDS imdb_reviews dataset, see here: https://www.tensorflow.org/datasets/catalog/imdb_reviews

As you can see on the website above, this dataset has ‘text’ and ‘label’ fields and we create tuples containing the text and the label from the training split by specifying keys=(‘text’, ‘label’), train=True.

Other functions, like Tokenize and Shuffle, take a generator and output another generator, in this way converting tuples into other tuples or mixing the training stream. For example, Tokenize(..., keys=[0]) tokenizes the first element of a tuple – converting it from text to a NumPy integer array. And Shuffle randomizes the order of examples.

Note that all elements in the data pipeline are just functions on generators, so you can use Python’s map and filter and other native functions too. For example, you can create an input pipeline for a language model reading lines from my_file.txt as follows:

inputs = data.Serial(
  lambda _: open('my_file.txt'),
  lambda g: map(lambda line: line.strip(), g),
  data.Tokenize(vocab_file='en_8k.subword'),
  lambda g: filter(lambda x: x.shape[0] < 513, g),  # At most 512 tokens.
  data.Shuffle(),
  lambda g: map(lambda x: (x, x)),  # Language models have inputs = targets.
  data.BucketByLength(boundaries=[  32, 64, 128, 256, 512],
                      batch_sizes=[ 32, 16,  8,    4,   2, 1]),
  data.AddLossWeights(id_to_mask=0)
)
trax.data.inputs.Serial(*fns)

Combines generator functions into one that runs them serially.

trax.data.inputs.Parallel(fns=None, counters=None, reweight_by_minimum=False, gradually_reweight=False, use_remainders=False)

Combines generator functions into one that runs them in parallel.

Parameters:
  • fns – a sequence of datasets which are combined in parallel.
  • counters – a sequence of ints with same length as fns, please see comments on its use below.
  • reweight_by_minimum – if set to True, then we re-weight every counter by the minimal counter. E.g. counters (10000, 100000) are translated to (1, 10) and hence for every 10 examples from the second dataset we are getting 1 example from the first dataset. Without reweighting first we would see 20 examples from the first and second dataset and then 90 thousand eamples only from the first dataset.
  • gradually_reweight – if set to True, then we loop through the generators using a recursive rule defined in emit_examples. First we sort generators by the counters. If we have datasets with counters 1, 20, 40 (after sorting) then we yield examples (a(b c^2)^20)^*, where examples of type a come from the first dataset, of type b from the second and of type c from the third. The exponents are obtained through divisions of subsequent counters.
  • use_remainders – if set to True as weell as gradually_reweight is set to True and counters are 1, 20, 45 then after dealing with all examples in the format (a(b c^2)^20)^*, the generator yields the remaining 5 examples from the dataset with counter 45.
Returns:

the generator yields samples according to given; if counters are not given then samples are genereted uniformly.

Return type:

parallel_generator

Example 1:

gen = data.Parallel([dataset1, dataset2, dataset3], counters=(2, 1, 3))

defines a generator that yields 33% examples from dataset1, 16% examples from dataset2 and 50% examples from dataset3.

Example 2:

gen = data.Parallel([dataset1, dataset2, dataset3], counters=(20, 50, 30))

defines a generator that yields 20% examples from dataset1, 50% examples from dataset2 and 30% examples from dataset3.

trax.data.inputs.Shuffle(queue_size=1024)

Returns a shuffle function with the given queue size.

trax.data.inputs.Batch(batch_size)

Returns a batching function with given batch size.

trax.data.inputs.Dup()

Duplicates (copies) the top element (inputs).

The generator stream is augmented in the following way:

  • If the stream consists of a single element (inputs, ), the inputs simply get copied to (inputs, inputs).
  • If the stream consists of multiple elements, for example (inputs, weights), the rest of elements get moved toward the right side (inputs, inputs, weights).
Returns:the duplicating function.
trax.data.inputs.FilterEmptyExamples(axes=None, debug=False)

Filters empty examples.

Filters any example that has an array of size (0,) (if axes=None). Alternatively, checks only axes provided in `axes’ list. Contrary to FilterByLength used with several elements with length_axis, here the example would be filtered if ANY of the dimensions listed in `axes’ contains an empty array.

Parameters:
  • axes – list of indices to check, if None, all of them.
  • debug – If true, emits a log everytime we filter out an empty example.
Returns:

Function filtering empty examples.

trax.data.inputs.FilterByLength(max_length, min_length=0, length_keys=None, length_axis=0)

Returns a function that filters out examples by length.

Parameters:
  • max_length – int. If not None, indicates maximum length.
  • min_length – int. If not None, indicates minimum length.
  • length_keys – (list) which example keys to take into account.
  • length_axis – which shape axis to take into account.
Returns:

a function that filters out examples by length.

trax.data.inputs.TruncateToLength(len_map=None)

Returns a stream function that resizes items as specified by len_map.

Parameters:len_map – Dictionary that specifies maximum shapes for potentially multiple features per stream item. For example, given a stream of tokenized string pairs, one could enforce a maximum length of 256 tokens for each string by using len_map={0: (256,), 1: (256,)}.
trax.data.inputs.PadToLength(len_map=None, pad_value=0, multiple=False)

Pads the values to lengths given in `len_map’.

len_map contains a dictionary of example keys to dimension sizes.

Parameters:
  • len_map – dict of int to int, we pad examples to lengths given by the values of the dict. If multiple is True, the dimensions are padded to multiple of this value.
  • pad_value – dict of int to int. The value gets applied to constant_values on numpy.pad per given dimension.
  • multiple – boolean. If False, pads to the value of len_map. If True, pads to closest multiple of value of len_map.
Returns:

Function to pad examples to given lengths.

trax.data.inputs.BucketByLength(boundaries, batch_sizes, length_keys=None, length_axis=0, strict_pad_on_len=False)

Returns a function for bucketing inputs, see bucket_by_length.

trax.data.inputs.MLM(vocab_size=None, max_length=None, noise_density=0.15, mean_noise_span_length=3.0)

Pipeline that just does MLM.

trax.data.inputs.PrefixLM(input_length=128, output_length=512)

Chunks examples so as to make inputs/outputs of specified lenghts.

trax.data.inputs.ConcatenateToLMInput(pad_to_length=None)

Prepares the input needed for training of Language Models.

Each example needs to contain two elements (input and target). Input is concatenated to target and, if pad_to_length is given, padded to length provided. The loss_weights indicates only the target, without input nor padding.

Parameters:pad_to_length – int, total length of padding of input and target arrays.
Returns:Function to return input for a LM.
trax.data.inputs.CastTo(dtype=<sphinx.ext.autodoc.importer._MockObject object>, indices=(0, 1), debug=False)

Casts the given indices to the given dtype.

trax.data.inputs.AppendValue(val=None)

Appends values provided in ‘val` to inputs.

val are keyed by example keys, its values contain appended tensors.

Parameters:val – dict of int to tensors. Specific keys get the tensors specified in values appended.
Returns:Funtion to append tensors to examples.
trax.data.inputs.AddLossWeights(id_to_mask=None)

Returns a function to add loss weights; see add_loss_weights.

trax.data.inputs.UnBatch()

Returns a function which unbatches.

trax.data.inputs.Prefetch(n_prefetch=2)

Pre-fetches a number of examples from generator in a separate process.

trax.data.inputs.UniformlySeek(name=None, host_id=None, n_hosts=None, dataset_size=None)

Sets each host at (dataset_size/n_hosts)-th of the dataset.

trax.data.inputs.CountAndSkip(name)

Returns a function that counts and skips examples (see above).

trax.data.inputs.Log(n_steps_per_example=1, only_shapes=True)

Creates a logging component of the input pipeline.

trax.data.inputs.shuffle(samples, queue_size)

Shuffles a sample stream using a random-out next-in queue of given size.

Parameters:
  • samples – Stream of samples for eventual use as training data or eval data.
  • queue_size – Minimum number of samples within which the streamed shuffling takes place.
Yields:

Shuffled stream of samples, ready for further processing, e.g., grouping into batches.

trax.data.inputs.batch(generator, batch_size)

Batch and pad generator as in tf.data.Dataset.padded_batch.

trax.data.inputs.pad_to_max_dims(tensors, boundary=None, strict_pad_on_len=False)

Pad a tuple of tensors to a joint dimension and return their batch.

For example, a pair of tensors of shape (2, 10) and (3, 9) will be padded to (3, 10) both and the returned tensor will have shape (2, 3, 10).

When boundary is specified, we try to pad all unknown dimensions to boundary if possible, which can help reduce the number of different shapes occurring in the tensors and speed up XLA compilation. So, for example, a pair of tensors of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12).

One special case occurs when boundary is much higher than the padding length that we’d use without boundary. For example, tensors (2, 10) and (3, 9) with boundary=12 could end up padded to (12, 12), but this is very wasteful in the first dimension. In that case, we will use the closest power-of-2 instead of the boundary, so the we will end up padding to (4, 12) instead of (12, 12).

Parameters:
  • tensors – a tuple or list of tensors to pad
  • boundary – int or None; if given, expand the padded dimensions to this size
  • strict_pad_on_len – bool; if true we pad on the length dimension, dim[0] strictly as a multiple of boundary.
Returns:

a tensor, the tensors padded together

trax.data.inputs.bucket_by_length(generator, length_fn, boundaries, batch_sizes, strict_pad_on_len=False)

Bucket by length, like tf.data.experimental.bucket_by_sequence_length.

This function draws examples from the provided generator and puts an example into a bucket depending on l = length_fn(example). Which bucket is used depends on between which boundaries is l. When a bucket reaches its batch size, as specified by batch_sizes, generates a batch of padded examples from this bucket.

Parameters:
  • generator – python generator to draw data from.
  • length_fn – a function taking the example and returning the length.
  • boundaries – a list of bucket boundaries.
  • batch_sizes – a list of batch sizes.
  • strict_pad_on_len – bool; if true we pad on the length dimension, dim[0] strictly as a multiple of boundary.
Yields:

An input batch, which comes from one of the buckets.

trax.data.inputs.add_loss_weights(generator, id_to_mask=None)

Add weights to inputs without weights and masks by id if requested.

The generator stream is augmented in the following way:

  • If the stream consists of pairs (inputs, targets), a loss mask is added that is creates as a tensor of ones of the same shape as targets.
  • If id_to_mask is not None, and the stream (after the previous point) has triples (inputs, targets, weights), the weights are multiplied by a 0/1 mask that is 0 iff targets is equal to id_to_mask (1 otherwise).
Parameters:
  • generator – Stream of tuples.
  • id_to_mask – If not None, int-valued id that represents padding, as opposed to true target IDs.
Yields:

Examples from the augmented stream.

trax.data.inputs.generate_random_noise_mask(noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None)

Returns a function that generates a random noise mask.

trax.data.inputs.consume_noise_mask(vocab_size=32100)

Consumes (tokens, noise mask) and returns (inputs, targets).

trax.data.inputs.generate_sequential_chunks(max_length=None)

Returns a function that generates chunks of atmost max_length length.

trax.data.inputs.addition_input_stream(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, min_length=<sphinx.ext.autodoc.importer._MockObject object>, max_length=<sphinx.ext.autodoc.importer._MockObject object>, pad_to_multiple=32, encdec=False)

Data stream for the add problem: <S>x+y<S>(x+y).

Parameters:
  • vocab_size – how many symbols to use.
  • batch_size – how large are the batches.
  • min_length – minimal length of w.
  • max_length – maximal length of w.
  • pad_to_multiple – int, pad length to be multiple of this number.
  • encdec – bool, if True return encoder-decoder style inputs (default: False)
Returns:

python generator of tuples of data examples

trax.data.inputs.random_spans_noise_mask(length, noise_density=0.15, mean_noise_span_length=3.0, seed1=None, seed2=None, example=None)

Computes span corruption masks given input parameters.

trax.data.inputs.lower_endian_to_number(l, base)

Helper function: convert a list of digits in the given base to a number.

trax.data.inputs.number_to_lower_endian(n, base)

Helper function: convert a number to a list of digits in the given base.

trax.data.inputs.random_number_lower_endian(length, base)

Helper function: generate a random number as a lower-endian digits list.

trax.data.inputs.count_and_skip(generator, name)

Count the number of items in the generator, skip already counted ones.

This function counts the number of processed examples and puts it into the global variable counters. This variable can be saved and restored, and if restored, this function will skip examples until the restored counter is reached. When the data generator is deterministic, this allows to restore the data reading process from a checkpoint.

Parameters:
  • generator – generator for examples in the dataset.
  • name – string, a unique id that we use to count the examples
Yields:

The examples from generator but first skip the number specified in the global variable counters[name] and next increment this variable every time a new example appears.

trax.data.inputs.save_data_counters(output_dir, host_id=None)

Checkpoint data counters.

trax.data.inputs.load_data_counters(output_dir, host_id=None)

Checkpoint data counters.

class trax.data.inputs.Inputs(train_stream, eval_stream=None, train_eval_stream=None)

Bases: object

Inputs bundle.

Inputs bundle holds input streams and shapes for a training run. It contains stream-creating functions that return python generators of (input_batch, target_batch) tuples.

  • train_stream: training data that will be used for training
    may include all the augmentation or selection the training wants the shape of examples is [batch_fn.batch_size, …]
  • train_eval_stream: training data used for evaluation
    examples from training data but usually without augmentation the shape of examples is [batch_fn.eval_batch_size, …]
  • eval_stream: evaluation data stream
    examples from evaluation data, usually without augmentation the shape of examples is [batch_fn.eval_batch_size, …]
  • input_shape: the shape of inputs
    the […] above, without batch size
  • input_dtype: the data type of inputs
  • target_shape: the shape of targets
    the […] above, without batch size
  • target_dtype: the data type of targets
__init__(train_stream, eval_stream=None, train_eval_stream=None)

Initialize a new set of inputs.

Parameters:
  • train_stream – a function taking n_devices (an int) and returning a python generator of training batches.
  • eval_stream – a function taking n_devices (an int) and returning a python generator of validation batches; if None, then the training generator will be used for evaluation.
  • train_eval_stream – a function taking n_devices (an int) and returning a python generator of batches from the training set used for evaluation (if None, use train_stream).
train_stream(n_devices)
eval_stream(n_devices)
train_eval_stream(n_devices)
input_shape

Example input shape, without batch dimension.

target_shape

Example target shape, without batch dimension.

input_dtype

Dtype of the input.

target_dtype

Dtype of the target.

example_shape_dtype

Shape and Dtype of an example batch.

trax.data.inputs.make_inputs(train_stream=<sphinx.ext.autodoc.importer._MockObject object>, eval_stream=None)

Create Inputs from two streams; mostly for use in gin configs.

trax.data.inputs.make_additional_stream(stream=<sphinx.ext.autodoc.importer._MockObject object>)

Create a stream mostly for use in gin configs for additional tasks.

trax.data.inputs.make_parallel_stream(streams=<sphinx.ext.autodoc.importer._MockObject object>, counters=None)

Create a parallel stream for use in gin configs for additional tasks.

trax.data.inputs.batcher(data_streams=<sphinx.ext.autodoc.importer._MockObject object>, variable_shapes=True, batch_size_per_device=32, batch_size=None, eval_batch_size=32, bucket_length=32, buckets=None, buckets_include_inputs_in_length=False, batch_shuffle_size=None, max_eval_length=None, id_to_mask=None, strict_pad_on_len=False)

Batcher: create trax Inputs from single-example data-streams.

trax.data.inputs.batch_fn(dataset, training, n_devices, variable_shapes, batch_size_per_device=32, batch_size=None, eval_batch_size=32, bucket_length=32, buckets=None, buckets_include_inputs_in_length=False, batch_shuffle_size=None, max_eval_length=None, id_to_mask=None, strict_pad_on_len=False)

Batching function.

trax.data.inputs.random_inputs(input_shape=<sphinx.ext.autodoc.importer._MockObject object>, input_dtype=<sphinx.ext.autodoc.importer._MockObject object>, input_range=(0, 255), output_shape=<sphinx.ext.autodoc.importer._MockObject object>, output_dtype=<sphinx.ext.autodoc.importer._MockObject object>, output_range=(0, 9))

Make random Inputs for debugging.

Parameters:
  • input_shape – the shape of inputs (including batch dimension).
  • input_dtype – the type of the inputs (int32 by default).
  • input_range – the range of inputs (defaults to (0, 255)).
  • output_shape – the shape of outputs (including batch dimension).
  • output_dtype – the type of the outputs (int32 by default).
  • output_range – the range of outputs (defaults to (0, 9)).
Returns:

trax.inputs.Inputs

trax.data.inputs.sequence_copy_inputs(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, train_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_min_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_max_length=<sphinx.ext.autodoc.importer._MockObject object>, reverse=False, pad_to_multiple=32)

Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*.

Parameters:
  • vocab_size – how many symbols to use.
  • batch_size – how large are the batches.
  • train_length – maximum length of w for training.
  • eval_min_length – minimum length of w for eval.
  • eval_max_length – maximum length of w for eval.
  • reverse – bool (optional, false by default): reverse the second sequence.
  • pad_to_multiple – int, pad length to be multiple of this number.
Returns:

trax.inputs.Inputs

trax.data.inputs.simple_sequence_copy_inputs(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, train_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_min_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_max_length=<sphinx.ext.autodoc.importer._MockObject object>, pad_to_multiple=32)

Inputs for the sequence copy problem: w for w in [1..vocab_size-1]*.

Parameters:
  • vocab_size – how many symbols to use.
  • batch_size – how large are the batches.
  • train_length – maximum length of w for training.
  • eval_min_length – minimum length of w for eval.
  • eval_max_length – maximum length of w for eval.
  • pad_to_multiple – int, pad length to be multiple of this number.
Returns:

trax.inputs.Inputs

trax.data.inputs.addition_inputs(vocab_size=<sphinx.ext.autodoc.importer._MockObject object>, batch_size=<sphinx.ext.autodoc.importer._MockObject object>, train_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_min_length=<sphinx.ext.autodoc.importer._MockObject object>, eval_max_length=<sphinx.ext.autodoc.importer._MockObject object>, pad_to_multiple=32, encdec=False)

Inputs for the add problem: <S>x+y<S>(x+y).

Parameters:
  • vocab_size – how many symbols to use.
  • batch_size – how large are the batches.
  • train_length – maximal length of w for training.
  • eval_min_length – minimal length of w for eval.
  • eval_max_length – maximal length of w for eval.
  • pad_to_multiple – int, pad length to be multiple of this number.
  • encdec – bool, if True return encoder-decoder style inputs (default: False)
Returns:

trax.inputs.Inputs

trax.data.inputs.sine_inputs(batch_size=<sphinx.ext.autodoc.importer._MockObject object>, length=<sphinx.ext.autodoc.importer._MockObject object>, max_phase=6.283185307179586, min_period=0.1, max_period=10.0)

Sinusoids of random period and phase.

Parameters:
  • batch_size (int) – Number of examples in a batch.
  • length (int) – Length of each sequence.
  • max_phase (float) – Maximum phase of the sinusoids.
  • min_period (float) – Minimum period of the sinusoids.
  • max_period (float) – Maximum period of the sinusoids.
Returns:

trax.inputs.Inputs

tf_inputs

TensorFlow data sources and associated prepocessing functions.

trax.data.tf_inputs.t5_data()

Get the T5 data module if available.

trax.data.tf_inputs.no_preprocess(dataset, training)
trax.data.tf_inputs.t2t_problems()
trax.data.tf_inputs.data_streams(dataset_name, data_dir=None, preprocess_fn=<function no_preprocess>, bare_preprocess_fn=None, shuffle_buffer_size=1024, eval_holdout_size=0, input_name=None, target_name=None)

Creates (train, eval) data sources from dataset_name.

Parameters:
  • dataset_name – Name of dataset belonging to TFDS or T2T. T2T dataset names must start with 't2t_'.
  • data_dir – Directory where the data is located.
  • preprocess_fn – Function to use for pre-processing after appending targets to inputs.
  • bare_preprocess_fn – Function to use for pre-processing before appending targets to inputs.
  • shuffle_buffer_size – Size of the shuffle buffer.
  • eval_holdout_size – If greater than 0, specifies a fraction of training data to siphon off and use as eval data, in place of an separate eval split.
  • input_name – Name of the inputs from the dictionary.
  • target_name – Name of the outputs either from the dictionary or as a result of post-processing.
Returns:

A pair of functions, (f, g) for use as data sources; call f() to get an iterator of training data samples, and call g() to get an iterator of eval data samples.

trax.data.tf_inputs.dataset_to_stream(dataset, input_name)

Takes a tf.Dataset and creates a numpy stream of ready batches.

trax.data.tf_inputs.TFDS(dataset_name, data_dir=None, tfds_preprocess_fn=None, keys=None, train=True, use_alt_eval=False, shuffle_train=True, host_id=None, n_hosts=None, eval_holdout_size=0)

Creates a data source from TensorFlow dataset dataset_name.

Parameters:
  • dataset_name – Name of the dataset, as registered in TensorFlow datasets (e.g., 'glue/mnli').
  • data_dir – Directory where the data is located.
  • tfds_preprocess_fn – If specified, function that applies to items in raw dataset (before selecting specific features).
  • keys – Tuple of dataset-specific strings that select features from the dataset.
  • train – If True, select the training split from the dataset; else select an eval split.
  • use_alt_eval – If True, and if train is False, select the dataset’s alternate eval split if it has one (or fall back to the dataset’s only eval split). This currently affects only the glue/mnli dataset.
  • shuffle_train – If True, have TensorFlow pre-shuffle the training data; else receive training data in deterministic sequence.
  • host_id – Integer id used for tracking data subsplits, in cases where n_hosts > 1.
  • n_hosts – If greater than 1, prepare data subsplits for the given number of hosts.
  • eval_holdout_size – If greater than 0, specifies a fraction of training data to siphon off and use as eval data, in place of an separate eval split.
Returns:

A function f for use as a training or eval data source; call f() to get an iterator of data samples.

trax.data.tf_inputs.tokenize(stream, keys=None, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)

Tokenize examples from the stream.

This function assumes that stream generates either strings or tuples/dicts containing strings at some keys. This function maps these strings to numpy arrays of integers – the tokenized version of each string.

Parameters:
  • stream – A python generator yielding strings, tuples or dicts.
  • keys – which keys of the tuple/dict to tokenize (by default: all)
  • vocab_type – Type of vocabulary, one of: ‘subword’, ‘sentencepiece’, ‘char’.
  • vocab_file – Name of the vocabulary file.
  • vocab_dir – Directory which contains the vocabulary file.
  • n_reserved_ids – An int, offset added so 0, …, n_reserved_ids-1 are unused; This is common for example when reserving the 0 for padding and 1 for EOS, but it’s only needed if these symbols are not already included (and thus reserved) in the vocab_file.
Yields:

Examples from stream with strings at keys replaced by np.arrays of integers – the tokenized version of these strings.

trax.data.tf_inputs.Tokenize(keys=None, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)

Returns a function that maps text to integer arrays; see tokenize.

trax.data.tf_inputs.detokenize(x, vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)

Maps integer arrays to text; the opposite of tokenize.

In many cases (all char- and subword-type vocabularies and most sentencepiece ones) the tokenization is invertible, so detokenize(tokenize(x)) = x. In some more rare cases this can remove some spacing, but it is still often useful to run detokenize to get a readable version for a tokenized string.

Parameters:
  • x – a list or numpy array of integers.
  • vocab_type – Type of vocabulary, one of: ‘subword’, ‘sentencepiece’, ‘char’.
  • vocab_file – Name of the vocabulary file.
  • vocab_dir – Directory which contains the vocabulary file.
  • n_reserved_ids – An int, offset added so 0, …, n_reserved_ids-1 are unused; This is common for example when reserving the 0 for padding and 1 for EOS, but it’s only needed if these symbols are not already included (and thus reserved) in the vocab_file.
Returns:

A string corresponding to the de-tokenized version of x.

trax.data.tf_inputs.ConvertToUnicode(keys=None)

Converts to Unicode UTF-8 elements of an example.

Useful for when TFDS outputs byte arrays. All of the errors of the conversion are ignored.

Parameters:keys – tuple/list of example dimensions to convert.
Returns:Function converting chosen elements of an example to UTF-8.
trax.data.tf_inputs.vocab_size(vocab_type='subword', vocab_file=None, vocab_dir=None, n_reserved_ids=0)

Returns the size of the vocabulary (number of symbols used).

This function can be used to set the size of the final layers of a model that needs to predict symbols from a given vocabulary. More precisely, if this function returns N then the last layer size should be set to at least N (it can be more). Note that this function does take reserved IDs into account.

Parameters:
  • vocab_type – Type of vocabulary, one of: ‘subword’, ‘sentencepiece’, ‘char’.
  • vocab_file – Name of the vocabulary file.
  • vocab_dir – Directory which contains the vocabulary file.
  • n_reserved_ids – An int, offset added so 0, …, n_reserved_ids-1 are unused.
Returns:

An integer, the number of symbols used (including reserved IDs).

trax.data.tf_inputs.cifar10_no_augmentation_preprocess(dataset, training)
trax.data.tf_inputs.cifar10_augmentation_preprocess(dataset, training)

Preprocessing for cifar10 with augmentation (see below).

trax.data.tf_inputs.cifar10_augmentation_flatten_preprocess(dataset, training, predict_image_train_weight=0.01)

Preprocessing for cifar10 that flattens it and appends targets.

trax.data.tf_inputs.downsampled_imagenet_flatten_bare_preprocess(dataset, training)

Preprocessing for downsampled_imagenet.

Parameters:
  • dataset – the dataset.
  • training – unused option.
Returns:

Flattened dataset.

Preprocessing for downsampled_imagenet 32x32 and 64x64 generation from http://arxiv.org/abs/1601.06759 (page 8).

trax.data.tf_inputs.concat_preprocess(dataset, training, pad_symbol=0)

Pre-processing function that concatenates input and target for LM.

trax.data.tf_inputs.squeeze_targets_preprocess(dataset, training)

Pre-processing function that squeezes last axis of targets.

trax.data.tf_inputs.lm1b_preprocess(dataset, training, max_target_length=-1, max_eval_target_length=-1)

Preprocessing for LM1B: filter out targets exceeding maximum length.

trax.data.tf_inputs.wmt_preprocess(dataset, training, max_length=-1, max_eval_length=-1)

Preprocessing for LM1B: filter out targets exceeding maximum length.

trax.data.tf_inputs.wmt_concat_preprocess(dataset, training, max_length=-1, max_eval_length=-1)

Preprocessing for WMT: filter exceeding maximum length and concatenate.

trax.data.tf_inputs.lm_token_preprocessing(dataset, training)

Concatenates inputs, 0, targets, with masking only for targets.

trax.data.tf_inputs.bair_robot_pushing_hparams(hparams=None, video_num_input_frames=1, video_num_target_frames=15)
trax.data.tf_inputs.bair_robot_pushing_preprocess(dataset, training)

Pre-processing function that concatenates input and target frames.

trax.data.tf_inputs.sentencepiece_tokenize(stream, spm_path=None, extra_ids=0)

Sentencepiece tokenization.

trax.data.tf_inputs.SentencePieceTokenize(spm_path=None, extra_ids=0)

Returns a function that maps text to integer arrays.

trax.data.tf_inputs.c4_preprocess(dataset, training, max_target_length=-1, tokenization=None, spm_path=None)

Pre-processing function for C4 dataset.

trax.data.tf_inputs.c4_bare_preprocess_fn(dataset, training=True, spm_path=None, copy_pretokenized=True, sequence_length=None)

Returns a dataset that contains ‘inputs’ and ‘targets’ from C4.

trax.data.tf_inputs.filter_dataset_on_len(dataset, training, len_map=None, filter_on_eval=False)

Filters a dataset of lengths given in len_map.

Parameters:
  • datasettf.data.Dataset the dataset to filter.
  • training – bool, true if we are in training mode.
  • len_map

    optional dict of str to (int, int). We filter examples where a feature’s size is beyond the specified bounds. Ex: {‘inputs’: (1, 512), ‘targets’: (64, 128)} will keep only those examples

    where 1 <= len(inputs) <= 512 and 64 <= len(targets) <= 128.
  • filter_on_eval – bool if true, we will filter in eval mode also.
Returns:

a filtered tf.data.Dataset.

trax.data.tf_inputs.truncate_dataset_on_len(dataset, training, len_map=None, truncate_on_eval=False)

Truncates features in an example to lengths given in len_map.

Parameters:
  • datasettf.data.Dataset the dataset to filter.
  • training – bool, true if we are in training mode.
  • len_map

    optional dict of str to int, we truncate examples where a feature’s size is beyond the max. Ex: {‘inputs’: 512, ‘targets’: 64} will truncate

    examples to be within those bounds.
  • truncate_on_eval – bool if true, we will truncate in eval mode also.
Returns:

a filtered tf.data.Dataset.

trax.data.tf_inputs.pad_dataset_to_length(dataset, training, len_map=None)

Pad features less than specified length to specified length.

trax.data.tf_inputs.add_eos_to_output_features(dataset, training, output_features='targets', eos=1)

Adds EOS to all features in output_features.

trax.data.tf_inputs.generic_text_dataset_preprocess_fn(dataset, training=True, text_preprocess_fns=None, token_preprocess_fns=None, spm_path=None, copy_pretokenized=False, debug_print_examples=False, debug_print_examples_rate=0.01)

Pre-processes, tokenizes and post-processes a tf.data.Dataset.

Parameters:
  • datasettf.data.Dataset to process.
  • training – boolean, set to True if training, False otherwise.
  • text_preprocess_fns – None or list of callables: tf.data.Dataset, bool -> tf.data.Dataset this operates before tokenization. Typically used to select which fields we want to learn over or change something into “text to text” form.
  • token_preprocess_fns – None or list of callables: tf.data.Dataset, bool -> tf.data.Dataset, this operates after tokenization. Since this can view the tokenized fields, this can be used to filter on length etc.
  • spm_path – None or str, path to a sentencepiece model to use for tokenization by default uses the 32k vocabulary from T5.
  • copy_pretokenized – bool, if True retains the original fields after tokenization.
  • debug_print_examples – bool, if True this prints examples to the logging stream for inspection, both before and after tokenization.
  • debug_print_examples_rate – float, [0, 1.0], on average this fraction of dataset examples will be printed out in each phase i.e. pre and post tokenization.
Returns:

a tf.data.Dataset with all the preprocessing and tokenization performed.

trax.data.tf_inputs.get_t5_preprocessor_by_name(name=None, fn_kwargs=None)

Returns a closure of any T5 preprocessor function with its arguments.

The main use-case is to use this (with gin scopes) to make any preprocessor function available in a gin file to configure and use.

See: TFInputs.test_gin_configurable_preprocessors

Parameters:
  • name – str, name of the preprocessor function to configure.
  • fn_kwargs – optional dictionary, the arguments to configure, these will be partially applied to the function given by name.
Returns:

a closure of the preprocessor function along with its arguments, this function takes two arguments only, dataset and boolean training and ignores the training and calls the t5 processor with the dataset (and closed over arguments only).

trax.data.tf_inputs.download_and_prepare(dataset_name, data_dir)

Downloads and prepares T2T or TFDS dataset.

Parameters:
  • dataset_name – tfds dataset or t2t problem name prefixed by ‘t2t_’.
  • data_dir – location of existing dataset or None.
Returns:

path string of downloaded data.

Return type:

data_dir

trax.data.tf_inputs.BertSingleSentenceInputs(batch, labeled=True, cls_id=101, sep_id=102)

Prepares inputs for BERT: add [SEP], [CLS] and create embeddings.

trax.data.tf_inputs.BertDoubleSentenceInputs(batch, labeled=True, cls_id=101, sep_id=102)

Prepares inputs for BERT models by adding [SEP] and [CLS] tokens and creating segment embeddings.

trax.data.tf_inputs.CreateBertInputs(double_sentence=True, labeled=True, cls_id=101, sep_id=102)
trax.data.tf_inputs.mask_random_tokens(batch, explicit_vocab_size=30522, masking_prob=0.15, cls_id=101, sep_id=102, mask_id=103, vocab_start_id=999)

Prepares input for the masking task.

Preparation consist in masking masking_prob percentage of non-special tokens at each input row; round(masking_prob * num_nonspecial_tokens) random tokens are selected out of which each token is either - replaced with [MASK] token with 80% probability, - replaced with random token with 10% probability, - or unchanged with 10%. The implentation is based on https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L342

Examples: - batch is a stream with each row having tuple (token_ids,). Function yields rows of form (modified_token_ids, original_tokens, token_weights), where modified_token_ids have [MASK] tokens or random tokens according to the procedure described above. - batch is a stream with each row having tuple (token_ids, segment_embeddings, nsp_label, nsp_weight).Function yields rows of form (modified_token_ids, segment_embeddings, nsp_label, nsp_weight, original_tokens, token_weights).

Parameters:
  • batch – stream of inputs. Each row in the stream is a tuple which first element is an array of tokens
  • explicit_vocab_size – the total size of the vocabulary.
  • masking_prob – Determines percent of non-special tokens to be selected for masking.
  • cls_id – id of the special CLS token.
  • sep_id – id of the special SEP token.
  • mask_id – id of the special MASK token.
  • vocab_start_id – id of first non-special token in the vocabulary.
Yields:

a stream with tokens masked for MLM training and 2 appended arrays – - original tokens: a copy of original tokens used as a label for mlm training - token_weights: weights distributed uniformly over selected tokens (sum is 1). Other tokens have 0 weight.

trax.data.tf_inputs.BertNextSentencePredictionInputs(dataset_name, data_dir=None, text_key='text', train=True, shuffle_size=50000)

Defines a stream for the next sentence prediction task.

trax.data.tf_inputs.CorpusToRandomChunks(dataset_name, num_tokens=512, train=True)
trax.data.tf_inputs.BertGlueTrainStream(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)

Returns a Bert-preprocessed training stream for benchmark.

Parameters:benchmark – Simple lower-case name of a GLUE benchmark, e.g., 'cola', 'mnli', 'rte'.
trax.data.tf_inputs.BertGlueEvalStream(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)

Returns a Bert-preprocessed eval data stream for benchmark.

Parameters:benchmark – Simple lower-case name of a GLUE benchmark, e.g., 'cola', 'mnli', 'rte'. If the benchmark includes an alternate eval (e.g., MNLI’s “mismatched” eval/validation split), you can specify it with an '_e2' suffix, e.g., 'mnli_e2'.
trax.data.tf_inputs.T5GlueTrainStream(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)

Returns a T5-preprocessed training data stream for benchmark.

Parameters:benchmark – Simple lower-case name of a GLUE benchmark, e.g., 'cola', 'mnli', 'rte'.
trax.data.tf_inputs.T5GlueTrainStreamsParallel(benchmark_list=<sphinx.ext.autodoc.importer._MockObject object>, counters=None, reweight_by_minimum=False, gradually_reweight=False)

Returns a parallel set of training streams, based on benchmark_list.

Parameters:
  • benchmark_list – List of simple lower-case names of GLUE benchmarks, e.g., 'cola', 'mnli', 'rte'.
  • counters – a list of counters to be passed to data.Parallel, e.g.,
  • 392702, 2490] would be a reasonable counterpart to ([8551,) –
  • = ["cola", "mnli", "rte"], see (benchmark_list) –
  • https – //github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/glue_utils.py#L42
  • more details on counters. (for) –
  • reweight_by_minimum – divide by the minimal counter.
  • gradually_reweight – a more refined reweighting policy, see inputs.py for more details.
trax.data.tf_inputs.T5GlueEvalStream(benchmark=<sphinx.ext.autodoc.importer._MockObject object>)

Returns a T5-preprocessed eval data stream for benchmark.

Parameters:benchmark – Simple lower-case name of a GLUE benchmark, e.g., 'cola', 'mnli', 'rte'. If the benchmark includes an alternate eval (e.g., MNLI’s “mismatched” eval/validation split), you can specify it with an '_e2' suffix, e.g., 'mnli_e2'.
trax.data.tf_inputs.T5GlueEvalStreamsParallel(benchmark_list=<sphinx.ext.autodoc.importer._MockObject object>)

Returns a parallel set of T5 eval streams, based on benchmark_list.

Parameters:benchmark_list – List of strings, each of which is a simple lower-case name of a GLUE benchmark, e.g., 'cola', 'mnli', 'rte'. If a benchmark includes an alternate eval (e.g., MNLI’s “mismatched” eval/validation split), you can specify it with an '_e2' suffix, e.g., 'mnli_e2'.
trax.data.tf_inputs.T5GlueEvalTasks(benchmark_list=<sphinx.ext.autodoc.importer._MockObject object>)

Returns a list of T5 GLUE eval tasks, based on benchmark_list.

Parameters:benchmark_list – List of strings, each of which indicates the name and data split of a GLUE benchmark. Data splits are indicated as underscore suffixes, e.g., 'cola_t' (Cola benchmark, training split), 'rte_e' (RTE benchmark, eval/validation split), and 'mnli_e2' (MNLI alternate “mismatched” eval/validation split).
trax.data.tf_inputs.compute_single_result(op_name, num_args)

An implementation of the most popular ops from the MathQA dataset.

trax.data.tf_inputs.compute_result(list_op, list_num)

Python execution of MathQA ops.

trax.data.tf_inputs.single_op_to_python_command(op_name, num_args)

An implementation of the most popular ops from the MathQA dataset.

trax.data.tf_inputs.compute_program(list_op)

Python execution of MathQA ops.

trax.data.tf_inputs.compute_nums(question)

Finds numbers in a string and convert them to floats.

trax.data.tf_inputs.compute_ops(linear_formula)
trax.data.tf_inputs.process_single_mathqa_example(example)

Execute a single example and verify coherence of a MathQA problem.

Parameters:example – a dictionary with the following fields: Problem - a natural language formulation of the problem Rationale - a natural language solution of the problem options - five possible answers ( a) b) c) d) and e) ) correct - the letter representing the correct answer annotated_formula - formula representing the full solution linear_formula - a string of operations separated by the | character, e.g. multiply(n2,const_100)|multiply(n0,n1)|divide(#0,#1)| multiply(#2,const_100)|divide(#3,#1)| category - a natural language description of the category to which a given problem belongs.
Returns:numerical answer contained in the example python_result: numerical answers computed in Python, including intermediate
results. The answer_num should be close python_result[-1]

list_op: list of arithmetic operations list_num: list of identified numbers in the text

Return type:answer_num
trax.data.tf_inputs.convert_float_to_mathqa(number)
trax.data.tf_inputs.convert_to_subtract(const_string)
trax.data.tf_inputs.execute_mathqa_dsl_program(problem, dsl_code)

Executes the DSL code for a given problem.

Parameters:
  • problem – problem formulation (needed to get parameters).
  • dsl_code – DSL code.
Returns:

the result of executing of the DSL code.

trax.data.tf_inputs.is_number(s)
trax.data.tf_inputs.execute_mathqa_program(problem, program)

Executes the DSL code for a given problem.

Parameters:
  • problem – problem formulation (not needed, but we want the same API as in the DSL case).
  • program – Python code.
Returns:

the result of executing of the Python code.

trax.data.tf_inputs.CreateMathQAInputs(dataset_path=None, train=True, test=False, challenge=False, tolerance=0.01, cumulative=True, python_code=False, full_dict=False, partial_results=True, nlp_rationale=False, correct_answer=False, answer_in_mathqa_format=True, correct_answer_given_reasoning=False, category=False, order_prediction=False, reduced_operation_name=True, qed=False)

Prepares MathQA inputs.

The generation procedure leaves a lot parameters to be set by the user. Currently we support only correct examples in the following sense: python execution agrees with the declared answer up to 1%.

According to this criterion wrong examples such as problem: calculate 85184 ÷ ? = 352 operations [‘multiply(n0,n1)’] are ignored (this should be divide(n0,n1) in this case).

Parameters:
  • dataset_path – a path with the MathQA dataset.
  • train – if True, then generate training examples; if train, test and challenge are set to False generate validation examples.
  • test – if train is set to False and test is set to True, then generate test examples.
  • challenge – if train and test are set to False and challenge is set to True, then generate challenge examples.
  • tolerance – if for a given example relative difference between Python result and the result declared in the dataset exceeds the level, then the example is dropped; tolerances ranging from 0.1 to 0.001 yield from 18K to 21K examples.
  • cumulative – if set to True, then generate examples in the format input - problem + numbers + op1 + op2 + op3 target - op4 If set to False, then examples are in the format input - problem + numbers target - all operations.
  • python_code – if set to True, then generates python code instead of MathQA commands.
  • full_dict – if set to True, then Python examples are returned together with the DSL code and the NLP rationale.
  • partial_results – if set to True, then partial results will be reported as part of the input, e.g. input - problem + numbers + op1 + #1 + op2 + #2 + op3 + #3, target - op4, where #k is the partial results from operation opk. Activated only in cumulative set to True.
  • nlp_rationale – if set to True, then input is the problem and the target is the nlp rationale.
  • correct_answer – if set to True, then input is the problem plus all possible answers and the target is the correct answer.
  • answer_in_mathqa_format – if set to True, then convert numerical answer to the MathQA format and wrap it in the subtract operation. E.g. “3.13” is converted to “subtract(const_3_13,const_0)”.
  • correct_answer_given_reasoning – if set to True, then input is the problem plus linear formula plus all possible answers and the target is the correct answer.
  • category – if set to True, then input is the problem and the target is its category.
  • order_prediction – if set to True, then input is the problem and a list of all operations; with probability 0.5 two operations are swapped; the task consists in detecting whether the operations were swapped. See the order prediction task in CreateAquaInputs in this file.
  • reduced_operation_name – If set to True, then in order prediction consider only the operation token without parameterers.
  • qed – if set to True, then the reasoning is finished with an additional operation qed.
Returns:

a generator of MathQA examples; the generator yields non-tokenized examples - they can be further processed using for example the tokenize function from this module

Return type:

mathqa_yield_examples

trax.data.tf_inputs.CreateAquaInputs(dataset_path=None, train=True, cumulative=False, rationale=False, correct_answer=False, correct_answer_given_reasoning=False, partial_reasoning=True, order_prediction=False)

Prepares Aqua inputs.

Parameters:
  • dataset_path – a path with the Aqua dataset.
  • train – if True, then generate training examples, otherwhise generate validation examples (the dataset has also a test set).
  • cumulative – if set to True, then generate examples in the format input - problem + step1 + step3 + step3 target - step4 If set to False, then examples are in the format input - problem, target - all operations.
  • rationale – if set to True, then input is the problem and the target is the rationale.
  • correct_answer – if set to True, then input is the problem plus all possible answers and the target is the correct answer.
  • correct_answer_given_reasoning – if set to True, then input is the problem plus reasoning (aka rationale) plus all possible answers and the target is the correct answer.
  • partial_reasoning – an additional option related to correct_answer_given_reasoning; if set to True, then we take a random prefix of the reasoning.
  • order_prediction

    if set to True, then input is the problem and a list of all operations; with probability 0.5 two operations are swapped; the task consists in detecting whether the operations were swapped. A similar additional task was considered in https://arxiv.org/pdf/1909.11942.pdf and

    in a recent work of Piotr Piękos, henrykm@ and mateuszm@.
Returns:

a generator of Aqua examples; the generator yields non-tokenized examples - they can be further processed using for example the tokenize function from this module

Return type:

aqua_yield_examples

trax.data.tf_inputs.CreateDropInputs(train=True, mathqa_format=False)

Prepares Drop inputs.

Parameters:
  • train – if True, then generate training examples, otherwhise generate validation examples (the dataset has also a test set).
  • mathqa_format – if True, then floats in targets are converted to the the MathQA convention and wrapped in the subtract operation. E.g. “3.13” is converted to “subtract(const_3_13,const_0)”.
Returns:

a generator of Drop examples; the generator yields non-tokenized examples - they can be further processed using for example the tokenize function from this module

Return type:

drop_yield_examples

trax.data.tf_inputs.CreateAnnotatedDropInputs(dataset_path=None, train=True, single_file=True, unique=False, total_number_of_samples=None, percentile=1.0)

Prepares annotated Drop inputs.

Example of an annotated input which can be used with this interface:

{
‘passage’: ‘The Armenian Prelature of Cyprus was established in 973 by Catholicos Khatchig I. Historically, the Prelature has been under the jurisdiction of the Catholicosate of the Great House of Cilicia, while today it is the oldest theme that falls under its jurisdiction. Since 2014 the Prelate, a Catholicosal Vicar General, has been Archbishop Nareg Alemezian. The parish priest in Nicosia is Fr. Momik Habeshian, while the parish priest in Larnaca and Limassol is Fr. Mashdots Ashkarian. For centuries, the Prelature building was located within the Armenian compound in Victoria street in walled Nicosia; when that area was taken over by Turkish-Cypriot extremists in 1963-1964, the Prelature was temporarily housed in Aram Ouzounian street and, later on, in Kyriakos Matsis street in Ayios Dhometios. Thanks to the efforts of Bishop Zareh Aznavorian and with financial aid from the Evangelical Church of Westphalia, the new Prelature building was erected in 1983, next to the Virgin Mary church and the Nareg school in Nicosia, by architects Athos Dikaios & Alkis Dikaios; it was officially inaugurated on 4 March 1984, during the pastoral visit of Catholicos Karekin II. By initiative of Archbishop Varoujan Hergelian, in 1998 the basement of the building was renovated and the “Vahram Utidjian” Hall was formed; previously a store room, it became a reality from the proceeds of the auction in 1994 of the art collection that Vahram Utidjian had donated to the Prelature in 1954. It was inaugurated on 3 February 1999 by Catholicos Aram I; numerous charity, communal and cultural events take place there. The Prelature’s consistory houses a collection of ecclesiastical relics, some of which were previously in the old Virgin Mary church or the Magaravank.’, ‘question’: ‘How many years after the Vahram Utidjian was donated to the Prelature was it sold at an auction?’, ‘answer’: 40, ‘calculation’: ‘subtract(n8,n9)’

}

In this example the calculation is formulated using the notation from the MathQA dataset, but this is not required. subtract(n8,n9) means that the answer 40 can be obtained through the substraction of the 9th and and the 10th number in the input. The input consists of the passage concatened with the question. The annotations can be generated using, for example, a method from the paper https://arxiv.org/abs/1909.00109.

Parameters:
  • dataset_path – a path with the Aqua dataset.
  • train – if True, then generate training examples, otherwhise generate validation examples (the dataset has also a test set).
  • single_file – if True, then look just for one file. If False, read all json files in a given directory and assume that each file contains one example. Applied only to training data.
  • unique – if set to True, then the generator will provide at most one question per passage.
  • total_number_of_samples – if set to a positive integer, then the total number of unique samples will be bounded total_number_of_samples.
  • percentile – the percentile of the train dataset used for training; default set to 1., though setting to a lower value can be interesting when combined train is combined with another source of data.
Returns:

a generator of annotated Drop examples; the generator yields non-tokenized examples - they can be further processed using for example the tokenize function from this module.

Return type:

drop_annotated_yield_examples