grain.experimental.PackAndBatchOperation#
- class grain.experimental.PackAndBatchOperation(length_struct, batch_size, _cur_batch=None)#
PyGrain pack-and-batch operation - see module docstring.
WARNING: This class is deprecated. Please use lazy_dataset.FirstFitPackIterDataset instead.
- Parameters:
length_struct (Any)
batch_size (int)
_cur_batch (_PackedBatch | None)
- batch_size#
int, the batch size.
- Type:
int
- length_struct#
A pytree, with the same structure as input_iterator elements, but where leaves are ints, representing the packed length of the corresponding feature.
- Type:
Any
__call__() takes an input iterator, where elements are `Record`s containing:
- input_data: Pytrees of arrays. For more info about PyTrees, please refer to:
https://jax.readthedocs.io/en/latest/pytrees.html. Packed leaves should be n-dimensional arrays, with sequence length as the leading dimension, i.e. shape (T_in, …), where T_in < T_packed. Note that leaves can and will often have ragged length dimensions across different elements of the input iterator.
The output of __call__() will be an iterator over `Record`s containing a 3-tuple of Pytrees. These are:
- data: The batched and packed data. This is a Pytree with parallel structure
to elements of input_iterator. Leaves have shape (B, T_packed, …).
- segmentations: Pytree with the same structure as data, and leaves of shape
(B, T). Represents which example each entry comes from. This may be used for Transformer attention masks, for example.
- positions: Pytree with the same structure as data, and leaves of shape
(B, T). Represents the position of each entry within their original example. This may be used e.g. in Transformer absolute position embeddings.
- __init__(length_struct, batch_size, _cur_batch=None)#
- Parameters:
length_struct (Any)
batch_size (int)
_cur_batch (_PackedBatch | None)
- Return type:
None
Methods
__init__
(length_struct, batch_size[, _cur_batch])Attributes