Sharder Protocol¤
Core protocol for data sharding.
See Also¤
- Core Overview - All core protocols
- Sharding - Sharder implementations
- Distributed - Distributed training
- Sharding Quick Reference
datarax.core.sharder ¤
Base implementation for sharder modules in Datarax.
This module provides the SharderModule base class, which is the foundation for all NNX-based sharder implementations in Datarax. Sharders are responsible for distributing batches of data across JAX devices.
SharderModuleConfig
dataclass
¤
SharderModuleConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, sharding_rules: ShardingRules | None = None)
Bases: DataraxModuleConfig
Configuration for SharderModule.
Attributes:
| Name | Type | Description |
|---|---|---|
sharding_rules |
ShardingRules | None
|
Optional mapping from logical axis names to physical device mesh axis names. If provided, logical axis names can be used in sharding specifications. |
SharderModule ¤
SharderModule(config: SharderModuleConfig | None = None, *, rngs: Rngs | None = None, name: str | None = None)
Bases: DataraxModule
Base class for NNX-based sharder modules in Datarax.
SharderModule provides a foundation for implementing data sharders that can be integrated with NNX-based components. It handles the distribution of data batches across JAX devices based on a specified sharding configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
SharderModuleConfig | None
|
Optional configuration for the sharder. |
None
|
rngs
|
Rngs | None
|
Optional Rngs for random operations (typically not needed for sharders). |
None
|
name
|
str | None
|
Optional name for the module. |
None
|
get_partition_spec ¤
get_partition_spec(logical_spec: LogicalAxisSpec | PartitionSpec) -> PartitionSpec
Convert a logical axis specification to a physical PartitionSpec.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logical_spec
|
LogicalAxisSpec | PartitionSpec
|
A tuple of logical axis names or a PartitionSpec. |
required |
Returns:
| Type | Description |
|---|---|
PartitionSpec
|
A PartitionSpec using physical device mesh axis names. |
get_named_sharding ¤
get_named_sharding(mesh: Mesh, logical_spec: LogicalAxisSpec | PartitionSpec) -> NamedSharding
Create a NamedSharding from a logical axis specification and mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh
|
Mesh
|
The device mesh to use for sharding. |
required |
logical_spec
|
LogicalAxisSpec | PartitionSpec
|
A tuple of logical axis names or a PartitionSpec. |
required |
Returns:
| Type | Description |
|---|---|
NamedSharding
|
A NamedSharding using the provided mesh and converted partition spec. |
shard ¤
Distribute a batch of data across JAX devices.
This method should be implemented by subclasses to perform the specific sharding logic for different types of data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
A batch of data elements (PyTree). |
required |
sharding
|
Sharding
|
A JAX Sharding object specifying how to distribute arrays. |
required |
Returns:
| Type | Description |
|---|---|
Batch
|
A sharded batch (PyTree of jax.Array objects with the specified |
Batch
|
sharding). |
parallel_transform ¤
parallel_transform(batch: Batch, transform_fn: Callable[[Batch], Batch], mesh: Mesh, in_spec: LogicalAxisSpec | PartitionSpec, out_spec: LogicalAxisSpec | PartitionSpec | None = None) -> Batch
Apply a transformation to each shard of the batch in parallel.
This uses nnx.shard_map to efficiently process data in parallel across multiple devices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Batch
|
The batch of data to transform. |
required |
transform_fn
|
Callable[[Batch], Batch]
|
A function that takes a batch as input and returns a transformed batch. |
required |
mesh
|
Mesh
|
The device mesh to use for sharding. |
required |
in_spec
|
LogicalAxisSpec | PartitionSpec
|
The input sharding specification. |
required |
out_spec
|
LogicalAxisSpec | PartitionSpec | None
|
Optional output sharding specification. If not provided, the input spec is used. |
None
|
Returns:
| Type | Description |
|---|---|
Batch
|
The transformed batch with the specified sharding. |
get_state ¤
set_state ¤
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |