Skip to content

Parallel Strategy¤

Execute multiple operators concurrently for throughput.

See Also¤


datarax.operators.strategies.parallel ¤

Parallel composition strategies.

logger module-attribute ¤

logger = getLogger(__name__)

ParallelStrategy ¤

ParallelStrategy(merge_strategy: str | None = None, merge_axis: int = 0, merge_fn: Callable | None = None)

Bases: CompositionStrategyImpl

Applies operators in parallel and merges outputs.

This strategy executes all child operators on the same input data and merges their results according to the specified strategy.

Attributes:

Name Type Description
merge_strategy

How to merge outputs ('concat', 'stack', 'sum', 'mean').

merge_axis

Axis along which to merge (for concat/stack).

merge_fn

Custom merging function.

Examples:

Example usage:

strategy = ParallelStrategy(merge_strategy='concat', merge_axis=-1)
# op1 returns shape (B, 10), op2 returns shape (B, 5)
# result shape will be (B, 15)

Parameters:

Name Type Description Default
merge_strategy str | None

String identifier for merge strategy. available:

  • 'concat': Concatenate along axis.
  • 'stack': Stack along (new) axis.
  • 'sum': Sum outputs (element-wise).
  • 'mean': Average outputs (element-wise).
None
merge_axis int

Axis for concatenation or stacking. Defaults to 0.

0
merge_fn Callable | None

Optional custom callable to merge outputs.

None

merge_strategy instance-attribute ¤

merge_strategy = merge_strategy

merge_axis instance-attribute ¤

merge_axis = merge_axis

merge_fn instance-attribute ¤

merge_fn = merge_fn

apply ¤

apply(operators: list[OperatorModule], context: StrategyContext) -> tuple[PyTree, PyTree, dict[str, Any]]

Apply all operators on the same input and merge their outputs.

Parameters:

Name Type Description Default
operators list[OperatorModule]

Operators to execute in parallel on identical input.

required
context StrategyContext

Execution context with input data, state, and RNG params.

required

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any]]

Tuple of (merged_data, last_state, last_metadata).

describe ¤

describe() -> dict[str, Any]

Return a serializable description of this strategy.

WeightedParallelStrategy ¤

Bases: CompositionStrategyImpl

Applies operators in parallel and merges with weights.

apply ¤

apply(operators: list[OperatorModule], context: StrategyContext) -> tuple[PyTree, PyTree, dict[str, Any]]

Apply operators in parallel and combine with learned weights.

Parameters:

Name Type Description Default
operators list[OperatorModule]

Operators to execute in parallel.

required
context StrategyContext

Must include extra_params["weights"] JAX array.

required

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any]]

Tuple of (weighted_sum, last_state, last_metadata).

Raises:

Type Description
ValueError

If weights not found in context.extra_params.

describe ¤

describe() -> dict[str, Any]

Return a serializable description of this strategy.

ConditionalParallelStrategy ¤

ConditionalParallelStrategy(conditions: Sequence[Callable[[PyTree], bool | Array]], merge_strategy: str | None = None, merge_axis: int = 0, merge_fn: Callable | None = None)

Bases: CompositionStrategyImpl

Applies operators in parallel with conditions (vmap-compatible).

Parameters:

Name Type Description Default
conditions Sequence[Callable[[PyTree], bool | Array]]

List of callables that determine whether each operator is applied.

required
merge_strategy str | None

Strategy for merging active outputs (e.g. 'concat', 'stack').

None
merge_axis int

Axis along which to merge outputs.

0
merge_fn Callable | None

Custom merge function, overrides merge_strategy if provided.

None

conditions instance-attribute ¤

conditions = conditions

merge_strategy instance-attribute ¤

merge_strategy = merge_strategy

merge_axis instance-attribute ¤

merge_axis = merge_axis

merge_fn instance-attribute ¤

merge_fn = merge_fn

apply ¤

apply(operators: list[OperatorModule], context: StrategyContext) -> tuple[PyTree, PyTree, dict[str, Any]]

Apply operators conditionally in parallel and merge active outputs.

Uses jax.lax.cond per operator for vmap/JIT compatibility.

Parameters:

Name Type Description Default
operators list[OperatorModule]

Operators to evaluate (must match length of conditions).

required
context StrategyContext

Execution context with input data, state, and RNG params.

required

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any]]

Tuple of (merged_data, last_state, last_metadata).

describe ¤

describe() -> dict[str, Any]

Return a serializable description of this strategy.