Skip to content

Minimum Spanning Tree

Problem

Given an undirected weighted graph, find a subset of edges that connects all vertices with minimum total weight and no cycles.

Approach

Kruskal: Sort edges by weight, greedily add edges that don't form a cycle (checked via Union-Find). Prim: Grow the MST from an arbitrary node using a min-heap.

When to Use

Minimum cost to connect all nodes — cable/road/pipeline routing, network backbone design. Kruskal for sparse graphs, Prim for dense. Aviation: minimum-cost ground infrastructure linking airports.

Implementation

minimum_spanning_tree

Minimum spanning tree: Kruskal's and Prim's algorithms.

Problem

Given an undirected weighted graph, find a subset of edges that connects all vertices with minimum total weight and no cycles.

Approach

Kruskal: Sort edges by weight, greedily add edges that don't form a cycle (checked via Union-Find). Prim: Grow the MST from an arbitrary node using a min-heap.

When to use

Minimum cost to connect all nodes — cable/road/pipeline routing, network backbone design. Kruskal for sparse graphs, Prim for dense. Aviation: minimum-cost ground infrastructure linking airports.

Complexity

Kruskal: O(E log E) (sort-dominated) Prim: O(E log V) (heap-dominated)

UnionFind

Disjoint-set / Union-Find with path compression and union by rank.

Source code in src/algo/graphs/minimum_spanning_tree.py
class UnionFind:
    """Disjoint-set / Union-Find with path compression and union by rank."""

    def __init__(self, n: int) -> None:
        self.parent = list(range(n))
        self.rank = [0] * n
        self.components = n

    def find(self, x: int) -> int:
        """Find the root of *x* with path compression."""
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x: int, y: int) -> bool:
        """Merge sets containing *x* and *y*. Return False if already same set."""
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
        if self.rank[px] < self.rank[py]:
            px, py = py, px
        self.parent[py] = px
        if self.rank[px] == self.rank[py]:
            self.rank[px] += 1
        self.components -= 1
        return True

find

find(x: int) -> int

Find the root of x with path compression.

Source code in src/algo/graphs/minimum_spanning_tree.py
def find(self, x: int) -> int:
    """Find the root of *x* with path compression."""
    while self.parent[x] != x:
        self.parent[x] = self.parent[self.parent[x]]
        x = self.parent[x]
    return x

union

union(x: int, y: int) -> bool

Merge sets containing x and y. Return False if already same set.

Source code in src/algo/graphs/minimum_spanning_tree.py
def union(self, x: int, y: int) -> bool:
    """Merge sets containing *x* and *y*. Return False if already same set."""
    px, py = self.find(x), self.find(y)
    if px == py:
        return False
    if self.rank[px] < self.rank[py]:
        px, py = py, px
    self.parent[py] = px
    if self.rank[px] == self.rank[py]:
        self.rank[px] += 1
    self.components -= 1
    return True

kruskal

kruskal(
    num_nodes: int, edges: Sequence[tuple[int, int, float]]
) -> list[tuple[int, int, float]]

Return MST edges using Kruskal's algorithm.

edges is a list of (u, v, weight) for an undirected graph. Returns the MST as a list of (u, v, weight) edges.

kruskal(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]) [(0, 1, 1), (1, 2, 2), (2, 3, 4)]

Source code in src/algo/graphs/minimum_spanning_tree.py
def kruskal(
    num_nodes: int,
    edges: Sequence[tuple[int, int, float]],
) -> list[tuple[int, int, float]]:
    """Return MST edges using Kruskal's algorithm.

    *edges* is a list of (u, v, weight) for an undirected graph.
    Returns the MST as a list of (u, v, weight) edges.

    >>> kruskal(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)])
    [(0, 1, 1), (1, 2, 2), (2, 3, 4)]
    """
    sorted_edges = sorted(edges, key=lambda e: e[2])
    uf = UnionFind(num_nodes)
    mst: list[tuple[int, int, float]] = []

    for u, v, w in sorted_edges:
        if uf.union(u, v):
            mst.append((u, v, w))
            if len(mst) == num_nodes - 1:
                break

    return mst

prim

prim(
    num_nodes: int, edges: Sequence[tuple[int, int, float]]
) -> list[tuple[int, int, float]]

Return MST edges using Prim's algorithm.

edges is a list of (u, v, weight) for an undirected graph.

sorted( ... prim(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]), key=lambda e: e[2] ... ) [(0, 1, 1), (1, 2, 2), (2, 3, 4)]

Source code in src/algo/graphs/minimum_spanning_tree.py
def prim(
    num_nodes: int,
    edges: Sequence[tuple[int, int, float]],
) -> list[tuple[int, int, float]]:
    """Return MST edges using Prim's algorithm.

    *edges* is a list of (u, v, weight) for an undirected graph.

    >>> sorted(
    ...     prim(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]), key=lambda e: e[2]
    ... )
    [(0, 1, 1), (1, 2, 2), (2, 3, 4)]
    """
    adj: list[list[tuple[int, float]]] = [[] for _ in range(num_nodes)]
    for u, v, w in edges:
        adj[u].append((v, w))
        adj[v].append((u, w))

    in_mst = [False] * num_nodes
    mst: list[tuple[int, int, float]] = []
    heap: list[tuple[float, int, int]] = [(0, -1, 0)]

    while heap and len(mst) < num_nodes - 1:
        w, frm, to = heapq.heappop(heap)
        if in_mst[to]:
            continue
        in_mst[to] = True
        if frm != -1:
            mst.append((frm, to, w))
        for neighbor, weight in adj[to]:
            if not in_mst[neighbor]:
                heapq.heappush(heap, (weight, to, neighbor))

    return mst
tests/graphs/test_minimum_spanning_tree.py
"""Tests for minimum spanning tree algorithms."""

from algo.graphs.minimum_spanning_tree import UnionFind, kruskal, prim


class TestUnionFind:
    def test_initial_components(self) -> None:
        uf = UnionFind(5)
        assert uf.components == 5

    def test_union_and_find(self) -> None:
        uf = UnionFind(4)
        assert uf.union(0, 1) is True
        assert uf.find(0) == uf.find(1)
        assert uf.components == 3

    def test_redundant_union(self) -> None:
        uf = UnionFind(3)
        uf.union(0, 1)
        assert uf.union(0, 1) is False

    def test_transitive(self) -> None:
        uf = UnionFind(3)
        uf.union(0, 1)
        uf.union(1, 2)
        assert uf.find(0) == uf.find(2)
        assert uf.components == 1


def _mst_weight(mst: list[tuple[int, int, float]]) -> float:
    return sum(w for _, _, w in mst)


class TestKruskal:
    def test_simple_triangle(self) -> None:
        edges = [(0, 1, 1), (1, 2, 2), (0, 2, 3)]
        mst = kruskal(3, edges)
        assert _mst_weight(mst) == 3

    def test_four_nodes(self) -> None:
        edges = [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]
        mst = kruskal(4, edges)
        assert len(mst) == 3
        assert _mst_weight(mst) == 7

    def test_single_node(self) -> None:
        assert kruskal(1, []) == []

    def test_two_nodes(self) -> None:
        mst = kruskal(2, [(0, 1, 5)])
        assert _mst_weight(mst) == 5


class TestPrim:
    def test_simple_triangle(self) -> None:
        edges = [(0, 1, 1), (1, 2, 2), (0, 2, 3)]
        mst = prim(3, edges)
        assert _mst_weight(mst) == 3

    def test_four_nodes(self) -> None:
        edges = [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]
        mst = prim(4, edges)
        assert len(mst) == 3
        assert _mst_weight(mst) == 7

    def test_single_node(self) -> None:
        assert prim(1, []) == []

    def test_two_nodes(self) -> None:
        mst = prim(2, [(0, 1, 5)])
        assert _mst_weight(mst) == 5

Implement it yourself

Run: just challenge graphs minimum_spanning_tree

Then implement the functions to make all tests pass. Use just study graphs for watch mode.

Reveal Solution

minimum_spanning_tree

Minimum spanning tree: Kruskal's and Prim's algorithms.

Problem

Given an undirected weighted graph, find a subset of edges that connects all vertices with minimum total weight and no cycles.

Approach

Kruskal: Sort edges by weight, greedily add edges that don't form a cycle (checked via Union-Find). Prim: Grow the MST from an arbitrary node using a min-heap.

When to use

Minimum cost to connect all nodes — cable/road/pipeline routing, network backbone design. Kruskal for sparse graphs, Prim for dense. Aviation: minimum-cost ground infrastructure linking airports.

Complexity

Kruskal: O(E log E) (sort-dominated) Prim: O(E log V) (heap-dominated)

UnionFind

Disjoint-set / Union-Find with path compression and union by rank.

Source code in src/algo/graphs/minimum_spanning_tree.py
class UnionFind:
    """Disjoint-set / Union-Find with path compression and union by rank."""

    def __init__(self, n: int) -> None:
        self.parent = list(range(n))
        self.rank = [0] * n
        self.components = n

    def find(self, x: int) -> int:
        """Find the root of *x* with path compression."""
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x: int, y: int) -> bool:
        """Merge sets containing *x* and *y*. Return False if already same set."""
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
        if self.rank[px] < self.rank[py]:
            px, py = py, px
        self.parent[py] = px
        if self.rank[px] == self.rank[py]:
            self.rank[px] += 1
        self.components -= 1
        return True

find

find(x: int) -> int

Find the root of x with path compression.

Source code in src/algo/graphs/minimum_spanning_tree.py
def find(self, x: int) -> int:
    """Find the root of *x* with path compression."""
    while self.parent[x] != x:
        self.parent[x] = self.parent[self.parent[x]]
        x = self.parent[x]
    return x

union

union(x: int, y: int) -> bool

Merge sets containing x and y. Return False if already same set.

Source code in src/algo/graphs/minimum_spanning_tree.py
def union(self, x: int, y: int) -> bool:
    """Merge sets containing *x* and *y*. Return False if already same set."""
    px, py = self.find(x), self.find(y)
    if px == py:
        return False
    if self.rank[px] < self.rank[py]:
        px, py = py, px
    self.parent[py] = px
    if self.rank[px] == self.rank[py]:
        self.rank[px] += 1
    self.components -= 1
    return True

kruskal

kruskal(
    num_nodes: int, edges: Sequence[tuple[int, int, float]]
) -> list[tuple[int, int, float]]

Return MST edges using Kruskal's algorithm.

edges is a list of (u, v, weight) for an undirected graph. Returns the MST as a list of (u, v, weight) edges.

kruskal(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]) [(0, 1, 1), (1, 2, 2), (2, 3, 4)]

Source code in src/algo/graphs/minimum_spanning_tree.py
def kruskal(
    num_nodes: int,
    edges: Sequence[tuple[int, int, float]],
) -> list[tuple[int, int, float]]:
    """Return MST edges using Kruskal's algorithm.

    *edges* is a list of (u, v, weight) for an undirected graph.
    Returns the MST as a list of (u, v, weight) edges.

    >>> kruskal(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)])
    [(0, 1, 1), (1, 2, 2), (2, 3, 4)]
    """
    sorted_edges = sorted(edges, key=lambda e: e[2])
    uf = UnionFind(num_nodes)
    mst: list[tuple[int, int, float]] = []

    for u, v, w in sorted_edges:
        if uf.union(u, v):
            mst.append((u, v, w))
            if len(mst) == num_nodes - 1:
                break

    return mst

prim

prim(
    num_nodes: int, edges: Sequence[tuple[int, int, float]]
) -> list[tuple[int, int, float]]

Return MST edges using Prim's algorithm.

edges is a list of (u, v, weight) for an undirected graph.

sorted( ... prim(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]), key=lambda e: e[2] ... ) [(0, 1, 1), (1, 2, 2), (2, 3, 4)]

Source code in src/algo/graphs/minimum_spanning_tree.py
def prim(
    num_nodes: int,
    edges: Sequence[tuple[int, int, float]],
) -> list[tuple[int, int, float]]:
    """Return MST edges using Prim's algorithm.

    *edges* is a list of (u, v, weight) for an undirected graph.

    >>> sorted(
    ...     prim(4, [(0, 1, 1), (1, 2, 2), (0, 2, 3), (2, 3, 4)]), key=lambda e: e[2]
    ... )
    [(0, 1, 1), (1, 2, 2), (2, 3, 4)]
    """
    adj: list[list[tuple[int, float]]] = [[] for _ in range(num_nodes)]
    for u, v, w in edges:
        adj[u].append((v, w))
        adj[v].append((u, w))

    in_mst = [False] * num_nodes
    mst: list[tuple[int, int, float]] = []
    heap: list[tuple[float, int, int]] = [(0, -1, 0)]

    while heap and len(mst) < num_nodes - 1:
        w, frm, to = heapq.heappop(heap)
        if in_mst[to]:
            continue
        in_mst[to] = True
        if frm != -1:
            mst.append((frm, to, w))
        for neighbor, weight in adj[to]:
            if not in_mst[neighbor]:
                heapq.heappush(heap, (weight, to, neighbor))

    return mst