Parallel Strategy¤
Execute multiple operators concurrently for throughput.
See Also¤
- Operators Overview - All operator types
- Composite Operator - Operator composition
- Sequential Strategy - Sequential execution
- Control Flow - DAG parallelism
- Performance Tools - Optimization
datarax.operators.strategies.parallel ¤
Parallel composition strategies.
ParallelStrategy ¤
ParallelStrategy(merge_strategy: str | None = None, merge_axis: int = 0, merge_fn: Callable | None = None)
Bases: CompositionStrategyImpl
Applies operators in parallel and merges outputs.
This strategy executes all child operators on the same input data and merges their results according to the specified strategy.
Attributes:
| Name | Type | Description |
|---|---|---|
merge_strategy |
How to merge outputs ('concat', 'stack', 'sum', 'mean'). |
|
merge_axis |
Axis along which to merge (for concat/stack). |
|
merge_fn |
Custom merging function. |
Examples:
Example usage:
strategy = ParallelStrategy(merge_strategy='concat', merge_axis=-1)
# op1 returns shape (B, 10), op2 returns shape (B, 5)
# result shape will be (B, 15)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
merge_strategy
|
str | None
|
String identifier for merge strategy. available:
|
None
|
merge_axis
|
int
|
Axis for concatenation or stacking. Defaults to 0. |
0
|
merge_fn
|
Callable | None
|
Optional custom callable to merge outputs. |
None
|
apply ¤
apply(operators: list[OperatorModule], context: StrategyContext) -> tuple[PyTree, PyTree, dict[str, Any]]
Apply all operators on the same input and merge their outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operators
|
list[OperatorModule]
|
Operators to execute in parallel on identical input. |
required |
context
|
StrategyContext
|
Execution context with input data, state, and RNG params. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any]]
|
Tuple of (merged_data, last_state, last_metadata). |
WeightedParallelStrategy ¤
Bases: CompositionStrategyImpl
Applies operators in parallel and merges with weights.
apply ¤
apply(operators: list[OperatorModule], context: StrategyContext) -> tuple[PyTree, PyTree, dict[str, Any]]
Apply operators in parallel and combine with learned weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operators
|
list[OperatorModule]
|
Operators to execute in parallel. |
required |
context
|
StrategyContext
|
Must include |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any]]
|
Tuple of (weighted_sum, last_state, last_metadata). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
ConditionalParallelStrategy ¤
ConditionalParallelStrategy(conditions: Sequence[Callable[[PyTree], bool | Array]], merge_strategy: str | None = None, merge_axis: int = 0, merge_fn: Callable | None = None)
Bases: CompositionStrategyImpl
Applies operators in parallel with conditions (vmap-compatible).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
conditions
|
Sequence[Callable[[PyTree], bool | Array]]
|
List of callables that determine whether each operator is applied. |
required |
merge_strategy
|
str | None
|
Strategy for merging active outputs (e.g. 'concat', 'stack'). |
None
|
merge_axis
|
int
|
Axis along which to merge outputs. |
0
|
merge_fn
|
Callable | None
|
Custom merge function, overrides merge_strategy if provided. |
None
|
apply ¤
apply(operators: list[OperatorModule], context: StrategyContext) -> tuple[PyTree, PyTree, dict[str, Any]]
Apply operators conditionally in parallel and merge active outputs.
Uses jax.lax.cond per operator for vmap/JIT compatibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operators
|
list[OperatorModule]
|
Operators to evaluate (must match length of conditions). |
required |
context
|
StrategyContext
|
Execution context with input data, state, and RNG params. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree, dict[str, Any]]
|
Tuple of (merged_data, last_state, last_metadata). |