grain.DataLoaderIterator

grain.DataLoaderIterator#

class grain.DataLoaderIterator(data_loader, state)#

DataLoader iterator providing get/set state functionality.

This is the only iterator we expose to users. It wraps underlying MultipleProcessIterator. In order to set state, it recreates the underlying iterator fresh with a new state.

Checkpointing for DataLoaderIterator: DataLoaderIterator uses GrainPool, which distributes RecordMetadata from produced records among worker processes in a round robin fashion. Generally, some workers can process more elements than others at a given training step. Checkpointing logic goes as follows: 1) With each output batch produced, GrainPool emits the worker_index of The

worker that processed the batch.

  1. DataLoaderIterator keeps track of the last_seen_index at each worker.

  2. When restoring from a state, DataLoaderIterator checks what is the minimum last_seen_index (among the last seen indices for all workers.) and which worker processed that index. GrainPool is instructed to start distributing indices to the next worker.

Parameters:
__init__(data_loader, state)#
Parameters:

Methods

__init__(data_loader, state)

get_state()

load(directory)

Loads the iterator state from a directory.

save(directory)

Saves the iterator state to a directory.

set_state(state)

Sets the state for the underlying iterator.