# Joins & Reducers

Join nodes synchronize and aggregate data from parallel execution paths. They use **Reducers** to combine multiple inputs into a single output.

When you use [parallel execution](/docs/ai/graph/builder/parallel) (broadcasting or mapping), you often need to collect and combine the results. Join nodes serve this purpose by:

1.  Waiting for all parallel tasks to complete
2.  Aggregating their outputs using a `ReducerFunction`
3.  Passing the aggregated result to the next node

## Creating Joins

Create a join using `GraphBuilder.join` with a reducer function and initial value or factory:

basic\_join.py

```python
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, StepContext, reduce_list_append


@dataclass
class SimpleState:
    pass


g = GraphBuilder(state_type=SimpleState, output_type=list[int])

@g.step
async def generate_numbers(ctx: StepContext[SimpleState, None, None]) -> list[int]:
    return [1, 2, 3, 4, 5]

@g.step
async def square(ctx: StepContext[SimpleState, None, int]) -> int:
    return ctx.inputs * ctx.inputs

# Create a join to collect all squared values
collect = g.join(reduce_list_append, initial_factory=list[int])

g.add(
    g.edge_from(g.start_node).to(generate_numbers),
    g.edge_from(generate_numbers).map().to(square),
    g.edge_from(square).to(collect),
    g.edge_from(collect).to(g.end_node),
)

graph = g.build()

async def main():
    result = await graph.run(state=SimpleState())
    print(sorted(result))
    #> [1, 4, 9, 16, 25]
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

## Built-in Reducers

Pydantic Graph provides several common reducer types out of the box:

### `reduce_list_append`

`reduce_list_append` collects all inputs into a list:

list\_reducer.py

```python
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, StepContext, reduce_list_append


@dataclass
class SimpleState:
    pass


async def main():
    g = GraphBuilder(state_type=SimpleState, output_type=list[str])

    @g.step
    async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
        return [10, 20, 30]

    @g.step
    async def to_string(ctx: StepContext[SimpleState, None, int]) -> str:
        return f'value-{ctx.inputs}'

    collect = g.join(reduce_list_append, initial_factory=list[str])

    g.add(
        g.edge_from(g.start_node).to(generate),
        g.edge_from(generate).map().to(to_string),
        g.edge_from(to_string).to(collect),
        g.edge_from(collect).to(g.end_node),
    )

    graph = g.build()
    result = await graph.run(state=SimpleState())
    print(sorted(result))
    #> ['value-10', 'value-20', 'value-30']
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

### `reduce_list_extend`

`reduce_list_extend` extends a list with an iterable of items:

list\_extend\_reducer.py

```python
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, StepContext, reduce_list_extend


@dataclass
class SimpleState:
    pass


async def main():
    g = GraphBuilder(state_type=SimpleState, output_type=list[int])

    @g.step
    async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
        return [1, 2, 3]

    @g.step
    async def create_range(ctx: StepContext[SimpleState, None, int]) -> list[int]:
        """Create a range from 0 to the input value."""
        return list(range(ctx.inputs))

    collect = g.join(reduce_list_extend, initial_factory=list[int])

    g.add(
        g.edge_from(g.start_node).to(generate),
        g.edge_from(generate).map().to(create_range),
        g.edge_from(create_range).to(collect),
        g.edge_from(collect).to(g.end_node),
    )

    graph = g.build()
    result = await graph.run(state=SimpleState())
    print(sorted(result))
    #> [0, 0, 0, 1, 1, 2]
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

### `reduce_dict_update`

`reduce_dict_update` merges dictionaries together:

dict\_reducer.py

```python
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, StepContext, reduce_dict_update


@dataclass
class SimpleState:
    pass


async def main():
    g = GraphBuilder(state_type=SimpleState, output_type=dict[str, int])

    @g.step
    async def generate_keys(ctx: StepContext[SimpleState, None, None]) -> list[str]:
        return ['apple', 'banana', 'cherry']

    @g.step
    async def create_entry(ctx: StepContext[SimpleState, None, str]) -> dict[str, int]:
        return {ctx.inputs: len(ctx.inputs)}

    merge = g.join(reduce_dict_update, initial_factory=dict[str, int])

    g.add(
        g.edge_from(g.start_node).to(generate_keys),
        g.edge_from(generate_keys).map().to(create_entry),
        g.edge_from(create_entry).to(merge),
        g.edge_from(merge).to(g.end_node),
    )

    graph = g.build()
    result = await graph.run(state=SimpleState())
    result = {k: result[k] for k in sorted(result)}  # force deterministic ordering
    print(result)
    #> {'apple': 5, 'banana': 6, 'cherry': 6}
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

### `reduce_null`

`reduce_null` discards all inputs and returns `None`. Useful when you only care about side effects:

null\_reducer.py

```python
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, StepContext, reduce_null


@dataclass
class CounterState:
    total: int = 0


async def main():
    g = GraphBuilder(state_type=CounterState, output_type=int)

    @g.step
    async def generate(ctx: StepContext[CounterState, None, None]) -> list[int]:
        return [1, 2, 3, 4, 5]

    @g.step
    async def accumulate(ctx: StepContext[CounterState, None, int]) -> int:
        ctx.state.total += ctx.inputs
        return ctx.inputs

    # We don't care about the outputs, only the side effect on state
    ignore = g.join(reduce_null, initial=None)

    @g.step
    async def get_total(ctx: StepContext[CounterState, None, None]) -> int:
        return ctx.state.total

    g.add(
        g.edge_from(g.start_node).to(generate),
        g.edge_from(generate).map().to(accumulate),
        g.edge_from(accumulate).to(ignore),
        g.edge_from(ignore).to(get_total),
        g.edge_from(get_total).to(g.end_node),
    )

    graph = g.build()
    state = CounterState()
    result = await graph.run(state=state)
    print(result)
    #> 15
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

### `reduce_sum`

`reduce_sum` sums numeric values:

sum\_reducer.py

```python
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, StepContext, reduce_sum


@dataclass
class SimpleState:
    pass


async def main():
    g = GraphBuilder(state_type=SimpleState, output_type=int)

    @g.step
    async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
        return [10, 20, 30, 40]

    @g.step
    async def identity(ctx: StepContext[SimpleState, None, int]) -> int:
        return ctx.inputs

    sum_join = g.join(reduce_sum, initial=0)

    g.add(
        g.edge_from(g.start_node).to(generate),
        g.edge_from(generate).map().to(identity),
        g.edge_from(identity).to(sum_join),
        g.edge_from(sum_join).to(g.end_node),
    )

    graph = g.build()
    result = await graph.run(state=SimpleState())
    print(result)
    #> 100
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

### `ReduceFirstValue`

`ReduceFirstValue` returns the first value it receives and cancels all other parallel tasks. This is useful for "race" scenarios where you want the first successful result:

first\_value\_reducer.py

```python
import asyncio
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, ReduceFirstValue, StepContext


@dataclass
class SimpleState:
    tasks_completed: int = 0


async def main():
    g = GraphBuilder(state_type=SimpleState, output_type=str)

    @g.step
    async def generate(ctx: StepContext[SimpleState, None, None]) -> list[int]:
        return [1, 12, 13, 14, 15]

    @g.step
    async def slow_process(ctx: StepContext[SimpleState, None, int]) -> str:
        """Simulate variable processing times."""
        # Simulate different delays
        await asyncio.sleep(ctx.inputs * 0.1)
        ctx.state.tasks_completed += 1
        return f'Result from task {ctx.inputs}'

    # Use ReduceFirstValue to get the first result and cancel the rest
    first_result = g.join(ReduceFirstValue[str](), initial=None, node_id='first_result')

    g.add(
        g.edge_from(g.start_node).to(generate),
        g.edge_from(generate).map().to(slow_process),
        g.edge_from(slow_process).to(first_result),
        g.edge_from(first_result).to(g.end_node),
    )

    graph = g.build()
    state = SimpleState()
    result = await graph.run(state=state)

    print(result)
    #> Result from task 1
    print(f'Tasks completed: {state.tasks_completed}')
    #> Tasks completed: 1
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

## Custom Reducers

Create custom reducers by defining a `ReducerFunction`:

custom\_reducer.py

```python

from pydantic_graph import GraphBuilder, StepContext


def reduce_sum(current: int, inputs: int) -> int:
    """A reducer that sums numbers."""
    return current + inputs


async def main():
    g = GraphBuilder(output_type=int)

    @g.step
    async def generate(ctx: StepContext[None, None, None]) -> list[int]:
        return [5, 10, 15, 20]

    @g.step
    async def identity(ctx: StepContext[None, None, int]) -> int:
        return ctx.inputs

    sum_join = g.join(reduce_sum, initial=0)

    g.add(
        g.edge_from(g.start_node).to(generate),
        g.edge_from(generate).map().to(identity),
        g.edge_from(identity).to(sum_join),
        g.edge_from(sum_join).to(g.end_node),
    )

    graph = g.build()
    result = await graph.run()
    print(result)
    #> 50
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

## Reducers with State Access

Reducers can access and modify the graph state:

stateful\_reducer.py

```python
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, ReducerContext, StepContext


@dataclass
class MetricsState:
    total_count: int = 0
    total_sum: int = 0


@dataclass
class ReducedMetrics:
    count: int = 0
    sum: int = 0


def reduce_metrics_sum(ctx: ReducerContext[MetricsState, None], current: ReducedMetrics, inputs: int) -> ReducedMetrics:
    ctx.state.total_count += 1
    ctx.state.total_sum += inputs
    return ReducedMetrics(count=current.count + 1, sum=current.sum + inputs)

def reduce_metrics_max(current: ReducedMetrics, inputs: ReducedMetrics) -> ReducedMetrics:
    return ReducedMetrics(count=max(current.count, inputs.count), sum=max(current.sum, inputs.sum))


async def main():
    g = GraphBuilder(state_type=MetricsState, output_type=dict[str, int])

    @g.step
    async def generate(ctx: StepContext[object, None, None]) -> list[int]:
        return [1, 3, 5, 7, 9, 10, 20, 30, 40]

    @g.step
    async def process_even(ctx: StepContext[MetricsState, None, int]) -> int:
        return ctx.inputs * 2

    @g.step
    async def process_odd(ctx: StepContext[MetricsState, None, int]) -> int:
        return ctx.inputs * 3

    metrics_even = g.join(reduce_metrics_sum, initial_factory=ReducedMetrics, node_id='metrics_even')
    metrics_odd = g.join(reduce_metrics_sum, initial_factory=ReducedMetrics, node_id='metrics_odd')
    metrics_max = g.join(reduce_metrics_max, initial_factory=ReducedMetrics, node_id='metrics_max')

    g.add(
        g.edge_from(g.start_node).to(generate),
        # Send even and odd numbers to their respective `process` steps
        g.edge_from(generate).map().to(
            g.decision()
            .branch(g.match(int, matches=lambda x: x % 2 == 0).label('even').to(process_even))
            .branch(g.match(int, matches=lambda x: x % 2 == 1).label('odd').to(process_odd))
        ),
        # Reduce metrics for even and odd numbers separately
        g.edge_from(process_even).to(metrics_even),
        g.edge_from(process_odd).to(metrics_odd),
        # Aggregate the max values for each field
        g.edge_from(metrics_even).to(metrics_max),
        g.edge_from(metrics_odd).to(metrics_max),
        # Finish the graph run with the final reduced value
        g.edge_from(metrics_max).to(g.end_node),
    )

    graph = g.build()
    state = MetricsState()
    result = await graph.run(state=state)

    print(f'Result: {result}')
    #> Result: ReducedMetrics(count=5, sum=200)
    print(f'State total_count: {state.total_count}')
    #> State total_count: 9
    print(f'State total_sum: {state.total_sum}')
    #> State total_sum: 275
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

### Canceling Sibling Tasks

Reducers with access to `ReducerContext` can call `ctx.cancel_sibling_tasks()` to cancel all other parallel tasks in the same fork. This is useful for early termination when you've found what you need:

cancel\_siblings.py

```python
import asyncio
from dataclasses import dataclass

from pydantic_graph import GraphBuilder, ReducerContext, StepContext


@dataclass
class SearchState:
    searches_completed: int = 0


def reduce_find_match(ctx: ReducerContext[SearchState, None], current: str | None, inputs: str) -> str | None:
    """Return the first input that contains 'target' and cancel remaining tasks."""
    if current is not None:
        # We already found a match, ignore subsequent inputs
        return current
    if 'target' in inputs:
        # Found a match! Cancel all other parallel tasks
        ctx.cancel_sibling_tasks()
        return inputs
    return None


async def main():
    g = GraphBuilder(state_type=SearchState, output_type=str | None)

    @g.step
    async def generate_searches(ctx: StepContext[SearchState, None, None]) -> list[str]:
        return ['item1', 'item2', 'target_item', 'item4', 'item5']

    @g.step
    async def search(ctx: StepContext[SearchState, None, str]) -> str:
        """Simulate a slow search operation."""
        # make the search artificially slower for 'item4' and 'item5'
        search_duration = 0.1 if ctx.inputs not in {'item4', 'item5'} else 1.0
        await asyncio.sleep(search_duration)
        ctx.state.searches_completed += 1
        return ctx.inputs

    find_match = g.join(reduce_find_match, initial=None)

    g.add(
        g.edge_from(g.start_node).to(generate_searches),
        g.edge_from(generate_searches).map().to(search),
        g.edge_from(search).to(find_match),
        g.edge_from(find_match).to(g.end_node),
    )

    graph = g.build()
    state = SearchState()
    result = await graph.run(state=state)

    print(f'Found: {result}')
    #> Found: target_item
    print(f'Searches completed: {state.searches_completed}')
    #> Searches completed: 3
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

Note that only 3 searches completed instead of all 5, because the reducer canceled the remaining tasks after finding a match.

## Multiple Joins

A graph can have multiple independent joins:

multiple\_joins.py

```python
from dataclasses import dataclass, field

from pydantic_graph import GraphBuilder, StepContext, reduce_list_append


@dataclass
class MultiState:
    results: dict[str, list[int]] = field(default_factory=dict)


async def main():
    g = GraphBuilder(state_type=MultiState, output_type=dict[str, list[int]])

    @g.step
    async def source_a(ctx: StepContext[MultiState, None, None]) -> list[int]:
        return [1, 2, 3]

    @g.step
    async def source_b(ctx: StepContext[MultiState, None, None]) -> list[int]:
        return [10, 20]

    @g.step
    async def process_a(ctx: StepContext[MultiState, None, int]) -> int:
        return ctx.inputs * 2

    @g.step
    async def process_b(ctx: StepContext[MultiState, None, int]) -> int:
        return ctx.inputs * 3

    join_a = g.join(reduce_list_append, initial_factory=list[int], node_id='join_a')
    join_b = g.join(reduce_list_append, initial_factory=list[int], node_id='join_b')

    @g.step
    async def store_a(ctx: StepContext[MultiState, None, list[int]]) -> None:
        ctx.state.results['a'] = ctx.inputs

    @g.step
    async def store_b(ctx: StepContext[MultiState, None, list[int]]) -> None:
        ctx.state.results['b'] = ctx.inputs

    @g.step
    async def combine(ctx: StepContext[MultiState, None, None]) -> dict[str, list[int]]:
        return ctx.state.results

    g.add(
        g.edge_from(g.start_node).to(source_a, source_b),
        g.edge_from(source_a).map().to(process_a),
        g.edge_from(source_b).map().to(process_b),
        g.edge_from(process_a).to(join_a),
        g.edge_from(process_b).to(join_b),
        g.edge_from(join_a).to(store_a),
        g.edge_from(join_b).to(store_b),
        g.edge_from(store_a, store_b).to(combine),
        g.edge_from(combine).to(g.end_node),
    )

    graph = g.build()
    state = MultiState()
    result = await graph.run(state=state)

    print(f"Group A: {sorted(result['a'])}")
    #> Group A: [2, 4, 6]
    print(f"Group B: {sorted(result['b'])}")
    #> Group B: [30, 60]
```

_(This example is complete, it can be run "as is" -- you'll need to add `import asyncio; asyncio.run(main())` to run `main`)_

## Customizing Join Nodes

### Custom Node IDs

Like steps, joins can have custom IDs:

join\_custom\_id.py

```python
from pydantic_graph import reduce_list_append

from basic_join import g

my_join = g.join(reduce_list_append, initial_factory=list[int], node_id='my_custom_join_id')
```

## How Joins Work

Internally, the graph tracks which "fork" each parallel task belongs to. A join:

1.  Identifies its parent fork (the fork that created the parallel paths)
2.  Waits for all tasks from that fork to reach the join
3.  Calls `reduce()` for each incoming value
4.  Calls `finalize()` once all values are received
5.  Passes the finalized result to downstream nodes

This ensures proper synchronization even with nested parallel operations.

## Next Steps

-   Learn about [parallel execution](/docs/ai/graph/builder/parallel) with broadcasting and mapping
-   Explore [conditional branching](/docs/ai/graph/builder/decisions) with decision nodes
-   See the API reference for complete reducer documentation