Writing a binary search tree

Tim discusses using Swift enumeration indirect keywords to build binary tree nodes. Read on to learn more about how reference types and value semantics combine…

Today we’re writing a simple binary search tree in Swift. While binary search trees seem to be very powerful in theory, their performance is rather disappointing in practice. Other, more advanced tree-shaped data structures such as red-black trees and B-trees solve some of the practical problems that binary search trees have, and are subsequently much more useful in performance-critical code. Nevertheless, learning about binary search trees is the first step to getting more familiar with this class of data structures.

Probably the simplest way to represent a binary tree in Swift is by using an indirect enum, like so:

enum BinarySearchTree<Element: Comparable> {
    case empty
    indirect case node(left: BinarySearchTree, value: Element, right: BinarySearchTree)
}

While indirect enums use reference types under the hood, they have value semantics by default. So that’s nice.

From all the standard library’s protocols, SetAlgebra and BidirectionalCollection both seem a good fit for our BinarySearchTree type: SetAlgebra for inserting and removing elements, and BidirectionalCollection for traversing the tree (both forwards and backwards, hence the name). However, for the purpose of this post, we’ll stick to only a couple basic methods.

Let’s start with insertion. Because of the way enums with associated values work, it’s easiest to implement the insert method using an under-the-hood inserting method:

extension BinarySearchTree {
    mutating func insert(_ element: Element) {
        self = inserting(element)
    }
    
    private func inserting(_ element: Element) -> BinarySearchTree {
        switch self {
        case .empty:
            // the tree is empty, so inserting an element results in a tree containing only that element
            return .node(.empty, element, .empty)
            
        case .node(_, element, _):
            // the element is already present in the tree
            return self
            
        case let .node(left, value, right) where element < value:
            // the element should be inserted into the left subtree
            return .node(left.inserting(element), value, right)
            
        case let .node(left, value, right):
            // the element should be inserted into the right subtree
            return .node(left, value, right.inserting(element))
        }
    }
}

Now we can insert values into a tree, but we can’t read them in any way. So let’s add a contains method as well:

extension BinarySearchTree {
    func contains(_ element: Element) -> Bool {
        switch self {
        case .empty:
            // an empty tree obviously doesn't contain any elements!
            return false
            
        case .node(_, element, _):
            // the element is equal to this node's value
            return true
            
        case let .node(left, value, _) where element < value:
            // if the element is present in the tree, it must be in the left subtree
            return left.contains(element)
            
        case let .node(_, _, right):
            // if the element is present in the tree, it must be in the right subtree
            return right.contains(element)
        }
    }
}

Let’s try it out!

var tree = BinarySearchTree.empty

tree.contains(5) // => false
tree.insert(5)
tree.contains(5) // => true
tree.insert(3)
tree.contains(3) // => true
tree.contains(5) // => true

Looks good. We can’t actually remove elements, though, but that is a real pain to implement and it’s out of scope for this post.

Finally, to iterate over a binary search tree, we need to implement the IteratorProtocol protocol. We’ll use an in-order traversal algorithm:

extension BinarySearchTree: Sequence {
    func makeIterator() -> BinarySearchTreeIterator<Element> {
        return BinarySearchTreeIterator(self)
    }
}

struct BinarySearchTreeIterator<Element: Comparable>: IteratorProtocol {
    var node: BinarySearchTree
    var stack: [(Element, BinarySearchTree)] = []
    
    init(_ node: BinarySearchTree) {
        self.node = node
    }
    
    public mutating func next() -> Element? {
        while case let .node(left, value, right) = node {
            stack.append((value, right))
            node = left
        }
        
        guard let (element, node) = stack.popLast() else { return nil }
        
        self.node = node
        return element
    }
}

Now we can do all kinds of sequence-y stuff with trees, like so:

var tree = BinarySearchTree.empty
[5, 2, 4, 8, 3, 2].forEach { tree.insert($0) }

for element in tree {
    print(element, terminator: " ") // => 2 3 4 5 8
}

tree.reduce(0, +)                                  // => 22
tree.lazy.map(String.init).joined(separator: ", ") // => 2, 3, 4, 5, 8

And the list goes on.

That’s it for now! If you enjoy this kind of stuff, make sure to check out Károly Lőrentey’s brand new book Optimizing Collections in which he in much detail goes through implementing several data structures in Swift, focusing on performance. As of writing this post, it’s 25% off.

One Comment

  • Thank you for this post! This leads to a question whose answer I would find very useful. How would you write a class MarkovChain or TransitionTable in idiomatic Swift that takes a transition table like

    A B C
    A 0.1 0.6 0.3
    B 0.2 0.0 0.8
    C 0.2 0.6 0.2

    (except the actual transition table will be quite a lot larger (16 x 16?), like two pre-calculated transition tables of “all the intervals in a McCartney song” and “all the intervals in a Lennon song”, with the app modulating along a linear curve from McCartney-esque to Lennon-esque and back)

    (probably using the sums to make it easier):

    A B C
    A 0.1 0.7 1.0
    B 0.2 0.2 1.0
    C 0.2 0.8 1.0

    then, given a current row and a new random number between 0.0 and 1.0
    func transition(String currentRow, double rand) -> String

    returns the next row (“A”, “B”, “C”), where “A” might point to “major tenth down”, “B” to “minor tenth down”, etc. The problem is making it as fast as possible, using binary search and anything else available.

    Thanks!