grain.experimental.batch_and_pad

Contents

grain.experimental.batch_and_pad#

grain.experimental.batch_and_pad(values, *, batch_size, pad_value=0)#

Batches the given values and, if needed, pads the batch to the given size.

Can be passed to ds.batch as batch_fn to avoid the need to drop the remainder data and pad it instead.

Example usage:

ds = grain.MapDataset.range(1, 5)
batch_size = 3
batch_fn = functools.partial(
    grain.experimental.batch_and_pad, batch_size=batch_size)
ds = ds.batch(batch_size, batch_fn=batch_fn)
list(ds) == [np.ndarray([1, 2, 3]), np.ndarray([4, 0, 0])]
Parameters:
  • values (Sequence[T]) – The values to batch.

  • batch_size (int) – Target batch size. If the number of values is smaller than this, the batch is padded with pad_value to the given size.

  • pad_value (Any) – The value to use for padding.

Returns:

A batch of values with a new batch dimension at the front.

Return type:

T