grain.experimental.batch_and_pad#
- grain.experimental.batch_and_pad(values, *, batch_size, pad_value=0)#
Batches the given values and, if needed, pads the batch to the given size.
Can be passed to ds.batch as batch_fn to avoid the need to drop the remainder data and pad it instead.
Example usage:
ds = grain.MapDataset.range(1, 5) batch_size = 3 batch_fn = functools.partial( grain.experimental.batch_and_pad, batch_size=batch_size) ds = ds.batch(batch_size, batch_fn=batch_fn) list(ds) == [np.ndarray([1, 2, 3]), np.ndarray([4, 0, 0])]
- Parameters:
values (Sequence[T]) – The values to batch.
batch_size (int) – Target batch size. If the number of values is smaller than this, the batch is padded with pad_value to the given size.
pad_value (Any) – The value to use for padding.
- Returns:
A batch of values with a new batch dimension at the front.
- Return type:
T