grain.samplers.IndexSampler

grain.samplers.IndexSampler#

class grain.samplers.IndexSampler(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), shuffle=False, num_epochs=None, seed=None)#

Base index sampler for training on a single datasource.

This index sampler supports the following operations: - Sharding of the dataset. - Global shuffle of the dataset.

Parameters:
  • num_records (int)

  • shard_options (ShardOptions)

  • shuffle (bool)

  • num_epochs (int | None)

  • seed (int | None)

__init__(num_records, shard_options=NoSharding(shard_index=0, shard_count=1, drop_remainder=False), shuffle=False, num_epochs=None, seed=None)#
Parameters:
  • num_records (int)

  • shard_options (ShardOptions)

  • shuffle (bool)

  • num_epochs (int | None)

  • seed (int | None)

Methods

__init__(num_records[, shard_options, ...])