grain.experimental.ConcatThenSplitIterDataset

grain.experimental.ConcatThenSplitIterDataset#

class grain.experimental.ConcatThenSplitIterDataset(parent, *, length_struct, meta_features=(), split_full_length_features=True, bos_handling=BOSHandling.DO_NOTHING, bos_features=(), bos_token_id=None)#

Implements concat-then-split packing for sequence features.

This assumes that elements of the parent dataset are unnested dictionaries and entries are either scalars or NumPy arrays. The first dimension is considered the sequence dimension and its size may vary between elements. All other dimensions must be the same size for all elements. Scalars are treated as 1-dimensional arrays of size 1.

On a high level this concatenates the underlying dataset and then splits it at target sequence lengths intervals. This is well defined for the case of a single feature. For multiple features we start with an empty buffer and concatenate elements until at least one feature is fully packed. As an optimization, elements from the parent dataset that are already fully packed are passed through in priority. When the buffer contains enough elements to fill at least one feature to its target sequence length, we pack the buffer. The last element might not fully fit and will be split. The remainder of the split stays in the buffer.

When packing features we also create {feature_name}_positions and {feature_name}_segment_ids features. They are 1D arrays of size sequence_length. Segment IDs start at 1 and enumerate the elements of the packed element. Positions indicate the position within the unpacked sequence.

Features can be “meta features” in which case they are never split and we do not create *_positions and *_segment_ids features for them.

Parameters:
  • parent (dataset.IterDataset)

  • length_struct (Mapping[str, int])

  • meta_features (Collection[str])

  • split_full_length_features (bool)

  • bos_handling (BOSHandling)

  • bos_features (Collection[str])

  • bos_token_id (int | None)

__init__(parent, *, length_struct, meta_features=(), split_full_length_features=True, bos_handling=BOSHandling.DO_NOTHING, bos_features=(), bos_token_id=None)#

Creates a dataset that concat-then-splits sequences from the parent.

Parameters:
  • parent (IterDataset) – The parent dataset.

  • length_struct (Mapping[str, int]) – Mapping from feature name to target sequence length.

  • meta_features (Collection[str]) – Set of feature names that are considered meta features. Meta features are never split and will be duplicated when other features of the same element are split. Otherwise, meta features are packed normally (they have their own sequence length). No *_positions and *_segment_ids features are created for meta features.

  • split_full_length_features (bool) – Whether full-length features are split, or they are considered packed and passed through in priority. Setting split_full_length_features=False is an optimization when some sequences already have the target length, and you don’t want them to be split. This optimization is not used by default.

  • bos_handling (BOSHandling) – The instructions for handling BOS tokens (by default, no BOS token is added).

  • bos_features (Collection[str]) – The features to which BOS handling is applied in case BOS is used.

  • bos_token_id (int | None) – The token indicating BOS in case BOS is used.

Methods

__init__(parent, *, length_struct[, ...])

Creates a dataset that concat-then-splits sequences from the parent.

apply(transformations)

Returns a dataset with the given transformation(s) applied.

batch(batch_size, *[, drop_remainder, batch_fn])

Returns a dataset of elements batched along a new first dimension.

filter(transform)

Returns a dataset containing only the elements that match the filter.

map(transform)

Returns a dataset containing the elements transformed by transform.

map_with_index(transform)

Returns a dataset of the elements transformed by the transform.

mp_prefetch([options, worker_init_fn])

Returns a dataset prefetching elements in multiple processes.

pipe(func, /, *args, **kwargs)

Syntactic sugar for applying a callable to this dataset.

prefetch(multiprocessing_options)

Deprecated, use mp_prefetch instead.

random_map(transform, *[, seed])

Returns a dataset containing the elements transformed by transform.

seed(seed)

Returns a dataset that uses the seed for default seed generation.

Attributes

parents