grain.experimental.device_put#
- grain.experimental.device_put(ds, device, *, cpu_buffer_size=4, device_buffer_size=2)#
Moves the data to the given devices with prefetching.
Stage 1: A CPU-side prefetch buffer. Stage 2: Per-device buffers for elements already transferred to the device.
- Parameters:
ds (IterDataset) – Dataset to prefetch.
device – same arguments as in jax.device_put.
cpu_buffer_size (int) – Number of elements to prefetch on CPU.
device_buffer_size (int) – Number of elements to prefetch per device.
- Returns:
Dataset with the elements prefetched to the devices.
- Return type: