grain.experimental.device_put

Contents

grain.experimental.device_put#

grain.experimental.device_put(ds, device, *, cpu_buffer_size=4, device_buffer_size=2)#

Moves the data to the given devices with prefetching.

Stage 1: A CPU-side prefetch buffer. Stage 2: Per-device buffers for elements already transferred to the device.

Parameters:
  • ds (IterDataset) – Dataset to prefetch.

  • device – same arguments as in jax.device_put.

  • cpu_buffer_size (int) – Number of elements to prefetch on CPU.

  • device_buffer_size (int) – Number of elements to prefetch per device.

Returns:

Dataset with the elements prefetched to the devices.

Return type:

IterDataset