grain.sharding.ShardByJaxProcess

grain.sharding.ShardByJaxProcess#

class grain.sharding.ShardByJaxProcess(drop_remainder=False)#

Shards the data across JAX processes.

Parameters:

drop_remainder (bool)

__init__(drop_remainder=False)#
Parameters:

drop_remainder (bool)

Methods

__init__([drop_remainder])

Attributes

drop_remainder

shard_index

shard_count