Skip to content

Sharder Protocol¤

Core protocol for data sharding.

See Also¤


datarax.core.sharder ¤

Base implementation for sharder modules in Datarax.

This module provides the SharderModule base class, which is the foundation for all NNX-based sharder implementations in Datarax. Sharders are responsible for distributing batches of data across JAX devices.

logger module-attribute ¤

logger = getLogger(__name__)

AxisName module-attribute ¤

AxisName = str

LogicalAxisSpec module-attribute ¤

LogicalAxisSpec = tuple[AxisName | None, ...]

ShardingRules module-attribute ¤

ShardingRules = list[tuple[AxisName, AxisName | None]]

SharderModuleConfig dataclass ¤

SharderModuleConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, sharding_rules: ShardingRules | None = None)

Bases: DataraxModuleConfig

Configuration for SharderModule.

Attributes:

Name Type Description
sharding_rules ShardingRules | None

Optional mapping from logical axis names to physical device mesh axis names. If provided, logical axis names can be used in sharding specifications.

sharding_rules class-attribute instance-attribute ¤

sharding_rules: ShardingRules | None = None

cacheable class-attribute instance-attribute ¤

cacheable: bool = False

batch_stats_fn class-attribute instance-attribute ¤

batch_stats_fn: Callable | Module | None = None

precomputed_stats class-attribute instance-attribute ¤

precomputed_stats: dict[str, Any] | None = None

SharderModule ¤

SharderModule(config: SharderModuleConfig | None = None, *, rngs: Rngs | None = None, name: str | None = None)

Bases: DataraxModule

Base class for NNX-based sharder modules in Datarax.

SharderModule provides a foundation for implementing data sharders that can be integrated with NNX-based components. It handles the distribution of data batches across JAX devices based on a specified sharding configuration.

Parameters:

Name Type Description Default
config SharderModuleConfig | None

Optional configuration for the sharder.

None
rngs Rngs | None

Optional Rngs for random operations (typically not needed for sharders).

None
name str | None

Optional name for the module.

None

config instance-attribute ¤

config: SharderModuleConfig = config

rngs instance-attribute ¤

rngs = rngs

name instance-attribute ¤

name = static(name)

get_partition_spec ¤

get_partition_spec(logical_spec: LogicalAxisSpec | PartitionSpec) -> PartitionSpec

Convert a logical axis specification to a physical PartitionSpec.

Parameters:

Name Type Description Default
logical_spec LogicalAxisSpec | PartitionSpec

A tuple of logical axis names or a PartitionSpec.

required

Returns:

Type Description
PartitionSpec

A PartitionSpec using physical device mesh axis names.

get_named_sharding ¤

get_named_sharding(mesh: Mesh, logical_spec: LogicalAxisSpec | PartitionSpec) -> NamedSharding

Create a NamedSharding from a logical axis specification and mesh.

Parameters:

Name Type Description Default
mesh Mesh

The device mesh to use for sharding.

required
logical_spec LogicalAxisSpec | PartitionSpec

A tuple of logical axis names or a PartitionSpec.

required

Returns:

Type Description
NamedSharding

A NamedSharding using the provided mesh and converted partition spec.

shard ¤

shard(batch: Batch, sharding: Sharding) -> Batch

Distribute a batch of data across JAX devices.

This method should be implemented by subclasses to perform the specific sharding logic for different types of data.

Parameters:

Name Type Description Default
batch Batch

A batch of data elements (PyTree).

required
sharding Sharding

A JAX Sharding object specifying how to distribute arrays.

required

Returns:

Type Description
Batch

A sharded batch (PyTree of jax.Array objects with the specified

Batch

sharding).

parallel_transform ¤

parallel_transform(batch: Batch, transform_fn: Callable[[Batch], Batch], mesh: Mesh, in_spec: LogicalAxisSpec | PartitionSpec, out_spec: LogicalAxisSpec | PartitionSpec | None = None) -> Batch

Apply a transformation to each shard of the batch in parallel.

This uses nnx.shard_map to efficiently process data in parallel across multiple devices.

Parameters:

Name Type Description Default
batch Batch

The batch of data to transform.

required
transform_fn Callable[[Batch], Batch]

A function that takes a batch as input and returns a transformed batch.

required
mesh Mesh

The device mesh to use for sharding.

required
in_spec LogicalAxisSpec | PartitionSpec

The input sharding specification.

required
out_spec LogicalAxisSpec | PartitionSpec | None

Optional output sharding specification. If not provided, the input spec is used.

None

Returns:

Type Description
Batch

The transformed batch with the specified sharding.

get_state ¤

get_state() -> dict[str, Any]

Get the current state of the SharderModule for checkpointing.

Returns:

Type Description
dict[str, Any]

A dictionary containing the internal state of the SharderModule.

set_state ¤

set_state(state: dict[str, Any]) -> None

Restore internal state from a checkpoint.

Parameters:

Name Type Description Default
state dict[str, Any]

A dictionary containing the internal state to restore.

required

get_operation_stats ¤

get_operation_stats() -> dict[str, int]

Get operation statistics.

Note: This method converts JAX arrays to Python ints for introspection. It is intended for use outside of JIT-compiled functions.

Returns:

Type Description
dict[str, int]

Dictionary with 'applied_count' and 'skipped_count'

reset_operation_stats ¤

reset_operation_stats() -> None

Reset operation statistics to zero.

Note: Creates new JAX arrays to reset the counters.

compute_statistics ¤

compute_statistics(data: Any) -> dict[str, Any] | None

Compute statistics from data using batch_stats_fn.

If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.

Parameters:

Name Type Description Default
data Any

Input data to compute statistics from

required

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no batch_stats_fn configured

get_statistics ¤

get_statistics() -> dict[str, Any] | None

Get current statistics.

Returns precomputed_stats if configured (unless reset was called), otherwise returns cached computed statistics, or None if no statistics available.

Returns:

Type Description
dict[str, Any] | None

Dictionary of statistics, or None if no statistics available

set_statistics ¤

set_statistics(stats: dict[str, Any]) -> None

Manually set statistics.

This overwrites any previously computed statistics and clears reset flag.

Parameters:

Name Type Description Default
stats dict[str, Any]

Dictionary of statistics to set

required

reset_statistics ¤

reset_statistics() -> None

Reset all statistics to None.

This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.

reset_cache ¤

reset_cache() -> None

Clear the cache.

Only has effect if cacheable=True in config.

copy ¤

copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule

Create a copy of this module with optional config/parameter changes.

This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.

Parameters:

Name Type Description Default
config DataraxModuleConfig | None

New config (if None, uses current config)

None
rngs Rngs | None

New RNG state (if None, uses current rngs)

None
name str | None

New name (if None, uses current name)

None

Returns:

Type Description
DataraxModule

New module instance with updated parameters

Examples:

Change configuration¤

new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)

Change name only¤

renamed = module.copy(name="new_name")

Note

Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().

clone ¤

clone() -> DataraxModule

Create a new instance with the same state as this module.

Uses NNX's clone function for proper deep cloning of all state.

Returns:

Type Description
DataraxModule

A new module instance with the same state.

requires_rng_streams ¤

requires_rng_streams() -> list[str] | None

Get the list of RNG streams required by this module.

Returns:

Type Description
list[str] | None

A list of required RNG stream names, or None if no RNG streams

list[str] | None

are required.

ensure_rng_streams ¤

ensure_rng_streams(stream_names: list[str]) -> None

Ensure that the required RNG streams are available.

Parameters:

Name Type Description Default
stream_names list[str]

A list of available RNG stream names.

required

Raises:

Type Description
ValueError

If a required RNG stream is not available.