Range Sampler¤
Sample from a specific range of indices.
See Also¤
- Samplers Overview - All sampling strategies
- Sequential Sampler - Sequential order
- Sharding - Distributed sampling
- Data Sources - Data loading
datarax.samplers.range_sampler ¤
Range sampler for Datarax.
This module provides a unified range sampler that generates a sequence of integers. Supports both static method usage and NNX module instantiation.
RangeSamplerConfig
dataclass
¤
RangeSamplerConfig(cacheable: bool = False, batch_stats_fn: Callable | Module | None = None, precomputed_stats: dict[str, Any] | None = None, stochastic: bool = False, stream_name: str | None = None, start: int = 0, stop: int | None = None, step: int = 1)
Bases: StructuralConfig
Configuration for RangeSampler.
Attributes:
| Name | Type | Description |
|---|---|---|
start |
int
|
The start of the range (inclusive, default: 0) |
stop |
int | None
|
The end of the range (exclusive, if None uses start as stop and start=0) |
step |
int
|
The step size between consecutive elements (default: 1) |
precomputed_stats
class-attribute
instance-attribute
¤
RangeSampler ¤
RangeSampler(config: RangeSamplerConfig, *, rngs: Rngs | None = None, name: str | None = None)
Bases: SamplerModule
Unified range sampler implementation for Datarax.
This class provides methods for generating a sequence of integers, with support for both static method usage and NNX module instantiation. Similar to Python's built-in range(), but implements the SamplerModule interface for use with Datarax pipelines.
Attributes:
| Name | Type | Description |
|---|---|---|
start |
The start of the range (inclusive). |
|
stop |
The end of the range (exclusive). |
|
step |
The step size between consecutive elements. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
RangeSamplerConfig
|
Configuration for the sampler. |
required |
rngs
|
Rngs | None
|
Optional Rngs object (not used for this deterministic sampler). |
None
|
name
|
str | None
|
Optional name for the module. |
None
|
create_static
staticmethod
¤
Static method to create a range iterator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int
|
The start of the range (inclusive). If stop is None, this becomes the stop value, and start is set to 0. |
0
|
stop
|
int | None
|
The end of the range (exclusive). |
None
|
step
|
int
|
The step size between consecutive elements. |
1
|
Returns:
| Type | Description |
|---|---|
Iterator[Element]
|
An iterator that yields integers in the specified range. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If step is 0, or if the range parameters would result in an empty range. |
get_length_static
staticmethod
¤
Static method to get the length of a range.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
start
|
int
|
The start of the range (inclusive). If stop is None, this becomes the stop value, and start is set to 0. |
0
|
stop
|
int | None
|
The end of the range (exclusive). |
None
|
step
|
int
|
The step size between consecutive elements. |
1
|
Returns:
| Type | Description |
|---|---|
int
|
The number of elements in the range. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If step is 0. |
set_state ¤
reset ¤
reset(seed: int | None = None) -> None
Reset the sampler to the beginning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int | None
|
Optional seed (unused for range sampler but kept for API consistency). |
None
|
set_current_position ¤
set_current_position(position: int) -> None
Set the current position in the range.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
position
|
int
|
The position to set. |
required |
get_operation_stats ¤
reset_operation_stats ¤
Reset operation statistics to zero.
Note: Creates new JAX arrays to reset the counters.
compute_statistics ¤
Compute statistics from data using batch_stats_fn.
If batch_stats_fn is not configured, returns None. Computed statistics are cached in _computed_stats.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
Input data to compute statistics from |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any] | None
|
Dictionary of statistics, or None if no batch_stats_fn configured |
get_statistics ¤
set_statistics ¤
reset_statistics ¤
Reset all statistics to None.
This clears both computed statistics and marks that precomputed_stats should be ignored (via internal flag). After reset, get_statistics() will return None until new statistics are set or computed.
copy ¤
copy(*, config: DataraxModuleConfig | None = None, rngs: Rngs | None = None, name: str | None = None) -> DataraxModule
Create a copy of this module with optional config/parameter changes.
This allows creating a new module instance with modified configuration while preserving other attributes. Useful for hyperparameter tuning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
DataraxModuleConfig | None
|
New config (if None, uses current config) |
None
|
rngs
|
Rngs | None
|
New RNG state (if None, uses current rngs) |
None
|
name
|
str | None
|
New name (if None, uses current name) |
None
|
Returns:
| Type | Description |
|---|---|
DataraxModule
|
New module instance with updated parameters |
Examples:
Change configuration¤
new_config = DataraxModuleConfig(cacheable=True) new_module = module.copy(config=new_config)
Change name only¤
renamed = module.copy(name="new_name")
Note
Subclasses can override this method to provide more fine-grained control over copying, such as allowing individual config field updates without requiring dataclass replace().
clone ¤
clone() -> DataraxModule
Create a new instance with the same state as this module.
Uses NNX's clone function for proper deep cloning of all state.
Returns:
| Type | Description |
|---|---|
DataraxModule
|
A new module instance with the same state. |
requires_rng_streams ¤
ensure_rng_streams ¤
Ensure that the required RNG streams are available.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stream_names
|
list[str]
|
A list of available RNG stream names. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If a required RNG stream is not available. |
process ¤
Process input structure.
This method transforms the structure/organization of input data without modifying the data values themselves.
Subclasses MUST implement this method.
The input/output types depend on the specific structural processor:
- Batcher: list[Element] -> list[Batch]
- Sampler: int -> list[int]
- Sharder: Batch -> Sharded[Batch]
- Splitter: Dataset -> tuple[Dataset, Dataset]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Any
|
Input to process (type varies by processor) |
required |
*args
|
Any
|
Additional positional arguments (processor-specific) |
()
|
**kwargs
|
Any
|
Additional keyword arguments (processor-specific) |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Processed output (type varies by processor) |
Examples:
Batcher implementation:
def process(self, elements: list[Element]) -> list[Batch]:
batches = []
for i in range(0, len(elements), self.config.batch_size):
batch_elements = elements[i:i + self.config.batch_size]
batches.append(Batch.from_elements(batch_elements))
return batches
Sampler implementation (deterministic):
def process(self, dataset_size: int) -> list[int]:
return list(range(min(self.config.num_samples, dataset_size)))
Sampler implementation (stochastic):
sample ¤
Return a list of sampled indices.
This method returns all indices that would be yielded by the iterator, collected into a list. This is useful when you need all indices upfront rather than iterating through them one by one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n
|
int
|
The number of indices to sample (typically the dataset size). |
required |
Returns:
| Type | Description |
|---|---|
list[int]
|
A list of sampled indices. |
Note
The default implementation simply collects all indices from the iterator. Subclasses may override this for more efficient implementations.
index_spec ¤
index_spec() -> Any
Return a jax.ShapeDtypeStruct (or PyTree thereof) describing emitted indices.
The default implementation returns a scalar int32 spec, matching the
common case of one-index-per-call samplers (sequential, shuffle, range).
Specialized samplers (SlidingWindowSampler, BufferSampler)
override this to declare windowed or vectorized index shapes.
Returns:
| Type | Description |
|---|---|
Any
|
A |
Any
|
shape and dtype of one emitted index. |