Skip to content

Branching Strategy¤

Split data flow into multiple processing branches.

See Also¤


datarax.operators.strategies.branching ¤

Branching composition strategy.

logger module-attribute ¤

logger = getLogger(__name__)

BranchingStrategy ¤

BranchingStrategy(router: Callable[[PyTree], int | Array])

Bases: CompositionStrategyImpl

Applies branching strategy with vmap-compatible integer routing.

Parameters:

Name Type Description Default
router Callable[[PyTree], int | Array]

Function that returns integer index (0, 1, 2, ...)

required

router instance-attribute ¤

router = router

apply ¤

apply(operators: list[OperatorModule], context: StrategyContext) -> tuple[PyTree, PyTree, dict[str, Any]]

Route input to exactly one operator via jax.lax.switch.

The router function returns an integer index selecting which operator to execute. Only the selected branch runs (JIT-efficient).

Parameters:

Name Type Description Default
operators list[OperatorModule]

Candidate operators (indexed by router output).

required
context StrategyContext

Execution context with input data, state, and RNG params.

required

Returns:

Type Description
tuple[PyTree, PyTree, dict[str, Any]]

Tuple of (data, state, metadata) from the selected branch.

describe ¤

describe() -> dict[str, Any]

Return a serializable description of this strategy.