diff --git a/trie/iterator.go b/trie/iterator.go
index ddc674d2bceb0dbcb0b8a1ac890c7ae94051060a..42149a7d387aa1c8a920d2c7d27a079137bbe303 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -18,7 +18,7 @@ package trie
 
 import (
 	"bytes"
-
+	"container/heap"
 	"github.com/ethereum/go-ethereum/common"
 )
 
@@ -268,6 +268,26 @@ outer:
 	return nil
 }
 
+func compareNodes(a, b NodeIterator) int {
+	cmp := bytes.Compare(a.Path(), b.Path())
+	if cmp != 0 {
+		return cmp
+	}
+
+	if a.Leaf() && !b.Leaf() {
+		return -1
+	} else if b.Leaf() && !a.Leaf() {
+		return 1
+	}
+
+	cmp = bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes())
+	if cmp != 0 {
+		return cmp
+	}
+
+	return bytes.Compare(a.LeafBlob(), b.LeafBlob())
+}
+
 type differenceIterator struct {
 	a, b  NodeIterator // Nodes returned are those in b - a.
 	eof   bool         // Indicates a has run out of elements
@@ -321,8 +341,7 @@ func (it *differenceIterator) Next(bool) bool {
 	}
 
 	for {
-		apath, bpath := it.a.Path(), it.b.Path()
-		switch bytes.Compare(apath, bpath) {
+		switch compareNodes(it.a, it.b) {
 		case -1:
 			// b jumped past a; advance a
 			if !it.a.Next(true) {
@@ -334,15 +353,6 @@ func (it *differenceIterator) Next(bool) bool {
 			// b is before a
 			return true
 		case 0:
-			if it.a.Hash() != it.b.Hash() || it.a.Leaf() != it.b.Leaf() {
-				// Keys are identical, but hashes or leaf status differs
-				return true
-			}
-			if it.a.Leaf() && it.b.Leaf() && !bytes.Equal(it.a.LeafBlob(), it.b.LeafBlob()) {
-				// Both are leaf nodes, but with different values
-				return true
-			}
-
 			// a and b are identical; skip this whole subtree if the nodes have hashes
 			hasHash := it.a.Hash() == common.Hash{}
 			if !it.b.Next(hasHash) {
@@ -364,3 +374,107 @@ func (it *differenceIterator) Error() error {
 	}
 	return it.b.Error()
 }
+
+type nodeIteratorHeap []NodeIterator
+
+func (h nodeIteratorHeap) Len() int            { return len(h) }
+func (h nodeIteratorHeap) Less(i, j int) bool  { return compareNodes(h[i], h[j]) < 0 }
+func (h nodeIteratorHeap) Swap(i, j int)       { h[i], h[j] = h[j], h[i] }
+func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) }
+func (h *nodeIteratorHeap) Pop() interface{} {
+	n := len(*h)
+	x := (*h)[n-1]
+	*h = (*h)[0 : n-1]
+	return x
+}
+
+type unionIterator struct {
+	items *nodeIteratorHeap // Nodes returned are the union of the ones in these iterators
+	count int               // Number of nodes scanned across all tries
+	err   error             // The error, if one has been encountered
+}
+
+// NewUnionIterator constructs a NodeIterator that iterates over elements in the union
+// of the provided NodeIterators. Returns the iterator, and a pointer to an integer
+// recording the number of nodes visited.
+func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
+	h := make(nodeIteratorHeap, len(iters))
+	copy(h, iters)
+	heap.Init(&h)
+
+	ui := &unionIterator{
+		items: &h,
+	}
+	return ui, &ui.count
+}
+
+func (it *unionIterator) Hash() common.Hash {
+	return (*it.items)[0].Hash()
+}
+
+func (it *unionIterator) Parent() common.Hash {
+	return (*it.items)[0].Parent()
+}
+
+func (it *unionIterator) Leaf() bool {
+	return (*it.items)[0].Leaf()
+}
+
+func (it *unionIterator) LeafBlob() []byte {
+	return (*it.items)[0].LeafBlob()
+}
+
+func (it *unionIterator) Path() []byte {
+	return (*it.items)[0].Path()
+}
+
+// Next returns the next node in the union of tries being iterated over.
+//
+// It does this by maintaining a heap of iterators, sorted by the iteration
+// order of their next elements, with one entry for each source trie. Each
+// time Next() is called, it takes the least element from the heap to return,
+// advancing any other iterators that also point to that same element. These
+// iterators are called with descend=false, since we know that any nodes under
+// these nodes will also be duplicates, found in the currently selected iterator.
+// Whenever an iterator is advanced, it is pushed back into the heap if it still
+// has elements remaining.
+//
+// In the case that descend=false - eg, we're asked to ignore all subnodes of the
+// current node - we also advance any iterators in the heap that have the current
+// path as a prefix.
+func (it *unionIterator) Next(descend bool) bool {
+	if len(*it.items) == 0 {
+		return false
+	}
+
+	// Get the next key from the union
+	least := heap.Pop(it.items).(NodeIterator)
+
+	// Skip over other nodes as long as they're identical, or, if we're not descending, as
+	// long as they have the same prefix as the current node.
+	for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) {
+		skipped := heap.Pop(it.items).(NodeIterator)
+		// Skip the whole subtree if the nodes have hashes; otherwise just skip this node
+		if skipped.Next(skipped.Hash() == common.Hash{}) {
+			it.count += 1
+			// If there are more elements, push the iterator back on the heap
+			heap.Push(it.items, skipped)
+		}
+	}
+
+	if least.Next(descend) {
+		it.count += 1
+		heap.Push(it.items, least)
+	}
+
+	return len(*it.items) > 0
+}
+
+func (it *unionIterator) Error() error {
+	for i := 0; i < len(*it.items); i++ {
+		if err := (*it.items)[i].Error(); err != nil {
+			return err
+		}
+	}
+	return nil
+}
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 0ad9711eddb3c593da6a0e1c9e7de9c38e8f1327..c101bb7b0973c56ef5f96dba2d8318797139b1e2 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -117,36 +117,38 @@ func TestNodeIteratorCoverage(t *testing.T) {
 	}
 }
 
+var testdata1 = []struct{ k, v string }{
+	{"bar", "b"},
+	{"barb", "ba"},
+	{"bars", "bb"},
+	{"bard", "bc"},
+	{"fab", "z"},
+	{"foo", "a"},
+	{"food", "ab"},
+	{"foos", "aa"},
+}
+
+var testdata2 = []struct{ k, v string }{
+	{"aardvark", "c"},
+	{"bar", "b"},
+	{"barb", "bd"},
+	{"bars", "be"},
+	{"fab", "z"},
+	{"foo", "a"},
+	{"foos", "aa"},
+	{"food", "ab"},
+	{"jars", "d"},
+}
+
 func TestDifferenceIterator(t *testing.T) {
 	triea := newEmpty()
-	valsa := []struct{ k, v string }{
-		{"bar", "b"},
-		{"barb", "ba"},
-		{"bars", "bb"},
-		{"bard", "bc"},
-		{"fab", "z"},
-		{"foo", "a"},
-		{"food", "ab"},
-		{"foos", "aa"},
-	}
-	for _, val := range valsa {
+	for _, val := range testdata1 {
 		triea.Update([]byte(val.k), []byte(val.v))
 	}
 	triea.Commit()
 
 	trieb := newEmpty()
-	valsb := []struct{ k, v string }{
-		{"aardvark", "c"},
-		{"bar", "b"},
-		{"barb", "bd"},
-		{"bars", "be"},
-		{"fab", "z"},
-		{"foo", "a"},
-		{"foos", "aa"},
-		{"food", "ab"},
-		{"jars", "d"},
-	}
-	for _, val := range valsb {
+	for _, val := range testdata2 {
 		trieb.Update([]byte(val.k), []byte(val.v))
 	}
 	trieb.Commit()
@@ -166,10 +168,57 @@ func TestDifferenceIterator(t *testing.T) {
 	}
 	for _, item := range all {
 		if found[item.k] != item.v {
-			t.Errorf("iterator value mismatch for %s: got %q want %q", item.k, found[item.k], item.v)
+			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
 		}
 	}
 	if len(found) != len(all) {
 		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
 	}
 }
+
+func TestUnionIterator(t *testing.T) {
+	triea := newEmpty()
+	for _, val := range testdata1 {
+		triea.Update([]byte(val.k), []byte(val.v))
+	}
+	triea.Commit()
+
+	trieb := newEmpty()
+	for _, val := range testdata2 {
+		trieb.Update([]byte(val.k), []byte(val.v))
+	}
+	trieb.Commit()
+
+	di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)})
+	it := NewIteratorFromNodeIterator(di)
+
+	all := []struct{ k, v string }{
+		{"aardvark", "c"},
+		{"barb", "bd"},
+		{"barb", "ba"},
+		{"bard", "bc"},
+		{"bars", "bb"},
+		{"bars", "be"},
+		{"bar", "b"},
+		{"fab", "z"},
+		{"food", "ab"},
+		{"foos", "aa"},
+		{"foo", "a"},
+		{"jars", "d"},
+	}
+
+	for i, kv := range all {
+		if !it.Next() {
+			t.Errorf("Iterator ends prematurely at element %d", i)
+		}
+		if kv.k != string(it.Key) {
+			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
+		}
+		if kv.v != string(it.Value) {
+			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
+		}
+	}
+	if it.Next() {
+		t.Errorf("Iterator returned extra values.")
+	}
+}