Tim discusses using Swift enumeration indirect
keywords to build binary tree nodes. Read on to learn more about how reference types and value semantics combine…
Today we’re writing a simple binary search tree in Swift. While binary search trees seem to be very powerful in theory, their performance is rather disappointing in practice. Other, more advanced tree-shaped data structures such as red-black trees and B-trees solve some of the practical problems that binary search trees have, and are subsequently much more useful in performance-critical code. Nevertheless, learning about binary search trees is the first step to getting more familiar with this class of data structures.
Probably the simplest way to represent a binary tree in Swift is by using an indirect enum, like so:
enum BinarySearchTree<Element: Comparable> { case empty indirect case node(left: BinarySearchTree, value: Element, right: BinarySearchTree) }
While indirect enums use reference types under the hood, they have value semantics by default. So that’s nice.
From all the standard library’s protocols, SetAlgebra
and BidirectionalCollection
both seem a good fit for our BinarySearchTree
type: SetAlgebra
for inserting and removing elements, and BidirectionalCollection
for traversing the tree (both forwards and backwards, hence the name). However, for the purpose of this post, we’ll stick to only a couple basic methods.
Let’s start with insertion. Because of the way enums with associated values work, it’s easiest to implement the insert
method using an under-the-hood inserting
method:
extension BinarySearchTree { mutating func insert(_ element: Element) { self = inserting(element) } private func inserting(_ element: Element) -> BinarySearchTree { switch self { case .empty: // the tree is empty, so inserting an element results in a tree containing only that element return .node(.empty, element, .empty) case .node(_, element, _): // the element is already present in the tree return self case let .node(left, value, right) where element < value: // the element should be inserted into the left subtree return .node(left.inserting(element), value, right) case let .node(left, value, right): // the element should be inserted into the right subtree return .node(left, value, right.inserting(element)) } } }
Now we can insert values into a tree, but we can’t read them in any way. So let’s add a contains
method as well:
extension BinarySearchTree { func contains(_ element: Element) -> Bool { switch self { case .empty: // an empty tree obviously doesn't contain any elements! return false case .node(_, element, _): // the element is equal to this node's value return true case let .node(left, value, _) where element < value: // if the element is present in the tree, it must be in the left subtree return left.contains(element) case let .node(_, _, right): // if the element is present in the tree, it must be in the right subtree return right.contains(element) } } }
Let’s try it out!
var tree = BinarySearchTree.empty tree.contains(5) // => false tree.insert(5) tree.contains(5) // => true tree.insert(3) tree.contains(3) // => true tree.contains(5) // => true
Looks good. We can’t actually remove elements, though, but that is a real pain to implement and it’s out of scope for this post.
Finally, to iterate over a binary search tree, we need to implement the IteratorProtocol
protocol. We’ll use an in-order traversal algorithm:
extension BinarySearchTree: Sequence { func makeIterator() -> BinarySearchTreeIterator<Element> { return BinarySearchTreeIterator(self) } } struct BinarySearchTreeIterator<Element: Comparable>: IteratorProtocol { var node: BinarySearchTree var stack: [(Element, BinarySearchTree)] = [] init(_ node: BinarySearchTree) { self.node = node } public mutating func next() -> Element? { while case let .node(left, value, right) = node { stack.append((value, right)) node = left } guard let (element, node) = stack.popLast() else { return nil } self.node = node return element } }
Now we can do all kinds of sequence-y stuff with trees, like so:
var tree = BinarySearchTree.empty [5, 2, 4, 8, 3, 2].forEach { tree.insert($0) } for element in tree { print(element, terminator: " ") // => 2 3 4 5 8 } tree.reduce(0, +) // => 22 tree.lazy.map(String.init).joined(separator: ", ") // => 2, 3, 4, 5, 8
And the list goes on.
That’s it for now! If you enjoy this kind of stuff, make sure to check out Károly Lőrentey’s brand new book Optimizing Collections in which he in much detail goes through implementing several data structures in Swift, focusing on performance. As of writing this post, it’s 25% off.
One Comment
Thank you for this post! This leads to a question whose answer I would find very useful. How would you write a class MarkovChain or TransitionTable in idiomatic Swift that takes a transition table like
A B C
A 0.1 0.6 0.3
B 0.2 0.0 0.8
C 0.2 0.6 0.2
(except the actual transition table will be quite a lot larger (16 x 16?), like two pre-calculated transition tables of “all the intervals in a McCartney song” and “all the intervals in a Lennon song”, with the app modulating along a linear curve from McCartney-esque to Lennon-esque and back)
(probably using the sums to make it easier):
A B C
A 0.1 0.7 1.0
B 0.2 0.2 1.0
C 0.2 0.8 1.0
then, given a current row and a new random number between 0.0 and 1.0
func transition(String currentRow, double rand) -> String
returns the next row (“A”, “B”, “C”), where “A” might point to “major tenth down”, “B” to “minor tenth down”, etc. The problem is making it as fast as possible, using binary search and anything else available.
Thanks!