From 4f4b395aa46b486332ce942a28345f93bb88cb13 Mon Sep 17 00:00:00 2001
From: Igor Mandrigin <mandrigin@users.noreply.github.com>
Date: Tue, 24 Mar 2020 01:10:36 +0300
Subject: [PATCH] Introduce code node & get rid of code map (#398)

* introduce code node

* replace codeMap with code touches

* fix a comment

* fixups to tests

* fix compile error

* fix getnodedata tests

* add tests and test stubs

* add more test stubs

* add test method bodies

* add and fix more tests on trie for new codenode

* add test change code between blocks

* fix crash in stateless

* remove unneded files

* remove comment

* fix deleted account code

* fix resolve set builder for code nodes
---
 cmd/pics/state.go                |  70 ++-----
 cmd/state/stateless/stateless.go |   5 +-
 core/state/database.go           | 143 ++++++++++----
 core/state/database_test.go      |  48 +++++
 core/state/stateless.go          |  45 +++--
 eth/handler.go                   |   3 +-
 trie/debug.go                    |   7 +
 trie/hashbuilder.go              |  25 ++-
 trie/node.go                     |   5 +
 trie/resolve_set.go              |  27 ++-
 trie/resolve_set_builder.go      |  42 ++--
 trie/resolver.go                 |  31 ++-
 trie/resolver_stateful.go        |  14 ++
 trie/resolver_stateless.go       |   2 +-
 trie/resolver_stateless_test.go  |   6 +-
 trie/trie.go                     |  96 ++++++++-
 trie/trie_from_witness.go        |  28 ++-
 trie/trie_test.go                | 324 +++++++++++++++++++++++++++++++
 trie/trie_transform.go           |   7 +-
 trie/trie_witness.go             |   8 +-
 trie/visual.go                   |  21 +-
 trie/witness_builder.go          |  16 +-
 trie/witness_builder_test.go     |   8 +-
 23 files changed, 774 insertions(+), 207 deletions(-)

diff --git a/cmd/pics/state.go b/cmd/pics/state.go
index e3ddaab0d3..799e7dc868 100644
--- a/cmd/pics/state.go
+++ b/cmd/pics/state.go
@@ -39,7 +39,7 @@ func constructCodeMap(tds *state.TrieDbState) (map[common.Hash][]byte, error) {
 	return codeMap, nil
 }
 
-func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number int, keyCompression int, codeCompressed bool, valCompressed bool,
+func statePicture(t *trie.Trie, number int, keyCompression int, codeCompressed bool, valCompressed bool,
 	quadTrie bool, quadColors bool, highlights [][]byte) (*trie.Trie, error) {
 	filename := fmt.Sprintf("state_%d.dot", number)
 	f, err := os.Create(filename)
@@ -62,7 +62,6 @@ func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number int, keyC
 		FontColors:     fontColors,
 		Values:         true,
 		CutTerminals:   keyCompression,
-		CodeMap:        codeMap,
 		CodeCompressed: codeCompressed,
 		ValCompressed:  valCompressed,
 		ValHex:         true,
@@ -332,14 +331,10 @@ func initialState1() error {
 		return err
 	}
 	t := tds.Trie()
-	var codeMap map[common.Hash][]byte
-	if codeMap, err = constructCodeMap(tds); err != nil {
+	if _, err = statePicture(t, 0, 0, false, false, false, false, nil); err != nil {
 		return err
 	}
-	if _, err = statePicture(t, codeMap, 0, 0, false, false, false, false, nil); err != nil {
-		return err
-	}
-	if _, err = statePicture(t, codeMap, 1, 48, false, false, false, false, nil); err != nil {
+	if _, err = statePicture(t, 1, 48, false, false, false, false, nil); err != nil {
 		return err
 	}
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 0); err != nil {
@@ -460,13 +455,10 @@ func initialState1() error {
 	if _, err = blockchain.InsertChain(context.Background(), types.Blocks{blocks[0]}); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 1); err != nil {
 		return err
 	}
-	if _, err = statePicture(t, codeMap, 2, 48, false, false, false, false, nil); err != nil {
+	if _, err = statePicture(t, 2, 48, false, false, false, false, nil); err != nil {
 		return err
 	}
 
@@ -476,13 +468,10 @@ func initialState1() error {
 	if _, err = blockchain.InsertChain(context.Background(), types.Blocks{blocks[1]}); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 2); err != nil {
 		return err
 	}
-	if _, err = statePicture(t, codeMap, 3, 48, false, false, false, false, nil); err != nil {
+	if _, err = statePicture(t, 3, 48, false, false, false, false, nil); err != nil {
 		return err
 	}
 
@@ -492,19 +481,16 @@ func initialState1() error {
 	if _, err = blockchain.InsertChain(context.Background(), types.Blocks{blocks[2]}); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 3); err != nil {
 		return err
 	}
-	if _, err = statePicture(t, codeMap, 4, 48, false, false, false, false, nil); err != nil {
+	if _, err = statePicture(t, 4, 48, false, false, false, false, nil); err != nil {
 		return err
 	}
-	if _, err = statePicture(t, codeMap, 5, 48, true, false, false, false, nil); err != nil {
+	if _, err = statePicture(t, 5, 48, true, false, false, false, nil); err != nil {
 		return err
 	}
-	if _, err = statePicture(t, codeMap, 6, 48, true, true, false, false, nil); err != nil {
+	if _, err = statePicture(t, 6, 48, true, true, false, false, nil); err != nil {
 		return err
 	}
 
@@ -517,10 +503,7 @@ func initialState1() error {
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 4); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
-	if _, err = statePicture(t, codeMap, 7, 48, true, true, false, false, nil); err != nil {
+	if _, err = statePicture(t, 7, 48, true, true, false, false, nil); err != nil {
 		return err
 	}
 
@@ -533,10 +516,7 @@ func initialState1() error {
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 5); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
-	if _, err = statePicture(t, codeMap, 8, 54, true, true, false, false, nil); err != nil {
+	if _, err = statePicture(t, 8, 54, true, true, false, false, nil); err != nil {
 		return err
 	}
 
@@ -549,13 +529,10 @@ func initialState1() error {
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 5); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
-	if _, err = statePicture(t, codeMap, 9, 54, true, true, false, false, nil); err != nil {
+	if _, err = statePicture(t, 9, 54, true, true, false, false, nil); err != nil {
 		return err
 	}
-	if _, err = statePicture(t, codeMap, 10, 110, true, true, true, true, nil); err != nil {
+	if _, err = statePicture(t, 10, 110, true, true, true, true, nil); err != nil {
 		return err
 	}
 
@@ -568,10 +545,7 @@ func initialState1() error {
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 7); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
-	quadTrie, err := statePicture(t, codeMap, 11, 110, true, true, true, true, nil)
+	quadTrie, err := statePicture(t, 11, 110, true, true, true, true, nil)
 	if err != nil {
 		return err
 	}
@@ -586,10 +560,7 @@ func initialState1() error {
 	if err = stateDatabaseComparison(snapshotDb, dbBolt, 8); err != nil {
 		return err
 	}
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
-	if _, err = statePicture(t, codeMap, 12, 110, true, true, true, true, nil); err != nil {
+	if _, err = statePicture(t, 12, 110, true, true, true, true, nil); err != nil {
 		return err
 	}
 
@@ -607,23 +578,18 @@ func initialState1() error {
 		touchQuads = append(touchQuads, touchQuad)
 	}
 
-	if codeMap, err = constructCodeMap(tds); err != nil {
-		return err
-	}
-
 	var witness *trie.Witness
 
-	if witness, err = quadTrie.ExtractWitness(0, false, rs, codeMap); err != nil {
+	if witness, err = quadTrie.ExtractWitness(0, false, rs); err != nil {
 		return err
 	}
 
 	var witnessTrie *trie.Trie
-	var witnessCodeMap map[common.Hash][]byte
 
-	if witnessTrie, witnessCodeMap, err = trie.BuildTrieFromWitness(witness, false, false); err != nil {
+	if witnessTrie, err = trie.BuildTrieFromWitness(witness, false, false); err != nil {
 		return err
 	}
-	if _, err = statePicture(witnessTrie, witnessCodeMap, 13, 110, true, true, false /*already quad*/, true, touchQuads); err != nil {
+	if _, err = statePicture(witnessTrie, 13, 110, true, true, false /*already quad*/, true, touchQuads); err != nil {
 		return err
 	}
 
@@ -632,7 +598,7 @@ func initialState1() error {
 	}
 
 	// Repeat the block witness illustration, but without any highlighted keys
-	if _, err = statePicture(witnessTrie, witnessCodeMap, 15, 110, true, true, false /*already quad*/, true, nil); err != nil {
+	if _, err = statePicture(witnessTrie, 15, 110, true, true, false /*already quad*/, true, nil); err != nil {
 		return err
 	}
 
diff --git a/cmd/state/stateless/stateless.go b/cmd/state/stateless/stateless.go
index 1bdaeda326..6e3c0de7f2 100644
--- a/cmd/state/stateless/stateless.go
+++ b/cmd/state/stateless/stateless.go
@@ -70,7 +70,7 @@ func runBlock(ibs *state.IntraBlockState, txnWriter state.StateWriter, blockWrit
 	return nil
 }
 
-func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number uint64) error {
+func statePicture(t *trie.Trie, number uint64) error {
 	filename := fmt.Sprintf("state_%d.dot", number)
 	f, err := os.Create(filename)
 	if err != nil {
@@ -85,7 +85,6 @@ func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number uint64) e
 		FontColors:     fontColors,
 		Values:         true,
 		CutTerminals:   0,
-		CodeMap:        codeMap,
 		CodeCompressed: false,
 		ValCompressed:  false,
 		ValHex:         true,
@@ -375,7 +374,7 @@ func Stateless(
 				return
 			}
 			if _, ok := starkBlocks[blockNum-1]; ok {
-				err = statePicture(s.GetTrie(), s.GetCodeMap(), blockNum-1)
+				err = statePicture(s.GetTrie(), blockNum-1)
 				check(err)
 			}
 			ibs := state.New(s)
diff --git a/core/state/database.go b/core/state/database.go
index 60f55deb76..824c0592ed 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -28,7 +28,6 @@ import (
 	"sync"
 	"sync/atomic"
 
-	lru "github.com/hashicorp/golang-lru"
 	"github.com/ledgerwatch/bolt"
 	"github.com/ledgerwatch/turbo-geth/common"
 	"github.com/ledgerwatch/turbo-geth/common/dbutils"
@@ -95,6 +94,8 @@ func (nw *NoopWriter) CreateContract(address common.Address) error {
 // Structure holding updates, deletes, and reads registered within one change period
 // A change period can be transaction within a block, or a block within group of blocks
 type Buffer struct {
+	codeReads      map[common.Hash]common.Hash
+	codeUpdates    map[common.Hash][]byte
 	storageUpdates map[common.Hash]map[common.Hash][]byte
 	storageReads   map[common.Hash]map[common.Hash]struct{}
 	accountUpdates map[common.Hash]*accounts.Account
@@ -105,6 +106,8 @@ type Buffer struct {
 
 // Prepares buffer for work or clears previous data
 func (b *Buffer) initialise() {
+	b.codeReads = make(map[common.Hash]common.Hash)
+	b.codeUpdates = make(map[common.Hash][]byte)
 	b.storageUpdates = make(map[common.Hash]map[common.Hash][]byte)
 	b.storageReads = make(map[common.Hash]map[common.Hash]struct{})
 	b.accountUpdates = make(map[common.Hash]*accounts.Account)
@@ -124,6 +127,14 @@ func (b *Buffer) detachAccounts() {
 
 // Merges the content of another buffer into this one
 func (b *Buffer) merge(other *Buffer) {
+	for addrHash, codeHash := range other.codeReads {
+		b.codeReads[addrHash] = codeHash
+	}
+
+	for addrHash, code := range other.codeUpdates {
+		b.codeUpdates[addrHash] = code
+	}
+
 	for addrHash, om := range other.storageUpdates {
 		m, ok := b.storageUpdates[addrHash]
 		if !ok {
@@ -167,8 +178,6 @@ type TrieDbState struct {
 	buffers                []*Buffer
 	aggregateBuffer        *Buffer // Merge of all buffers
 	currentBuffer          *Buffer
-	codeCache              *lru.Cache
-	codeSizeCache          *lru.Cache
 	historical             bool
 	noHistory              bool
 	resolveReads           bool
@@ -183,14 +192,6 @@ type TrieDbState struct {
 }
 
 func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieDbState, error) {
-	csc, err := lru.New(100000)
-	if err != nil {
-		return nil, err
-	}
-	cc, err := lru.New(10000)
-	if err != nil {
-		return nil, err
-	}
 	t := trie.New(root)
 	tp := trie.NewTriePruning(blockNr)
 
@@ -199,8 +200,6 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieD
 		tMu:               new(sync.Mutex),
 		db:                db,
 		blockNr:           blockNr,
-		codeCache:         cc,
-		codeSizeCache:     csc,
 		resolveSetBuilder: trie.NewResolveSetBuilder(),
 		tp:                tp,
 		savePreimages:     true,
@@ -573,8 +572,6 @@ func (tds *TrieDbState) WithNewBuffer() *TrieDbState {
 		buffers:           buffers,
 		aggregateBuffer:   aggregateBuffer,
 		currentBuffer:     currentBuffer,
-		codeCache:         tds.codeCache,
-		codeSizeCache:     tds.codeSizeCache,
 		historical:        tds.historical,
 		noHistory:         tds.noHistory,
 		resolveReads:      tds.resolveReads,
@@ -704,6 +701,10 @@ func (tds *TrieDbState) populateStorageBlockProof(storageTouches common.StorageK
 	return nil
 }
 
+func (tds *TrieDbState) buildCodeTouches(withReads bool) map[common.Hash]common.Hash {
+	return tds.aggregateBuffer.codeReads
+}
+
 // Builds a sorted list of all address hashes that were touched within the
 // period for which we are aggregating updates
 func (tds *TrieDbState) buildAccountTouches(withReads bool, withValues bool) (common.Hashes, []*accounts.Account) {
@@ -748,9 +749,30 @@ func (tds *TrieDbState) buildAccountTouches(withReads bool, withValues bool) (co
 	return accountTouches, aValues
 }
 
+func (tds *TrieDbState) resolveCodeTouches(codeTouches map[common.Hash]common.Hash, resolveFunc trie.ResolveFunc) error {
+	firstRequest := true
+	for address, codeHash := range codeTouches {
+		if need, req := tds.t.NeedResolutonForCode(address, codeHash); need {
+			if tds.resolver == nil {
+				tds.resolver = trie.NewResolver(0, true, tds.blockNr)
+				tds.resolver.SetHistorical(tds.historical)
+			} else if firstRequest {
+				tds.resolver.Reset(0, true, tds.blockNr)
+			}
+			firstRequest = false
+			tds.resolver.AddCodeRequest(req)
+		}
+	}
+
+	if !firstRequest {
+		return resolveFunc(tds.resolver)
+	}
+	return nil
+}
+
 // Expands the accounts trie (by loading data from the database) if it is required
 // for accessing accounts whose addresses are contained in the accountTouches
-func (tds *TrieDbState) resolveAccountTouches(accountTouches common.Hashes, resolveFunc func(*trie.Resolver) error) error {
+func (tds *TrieDbState) resolveAccountTouches(accountTouches common.Hashes, resolveFunc trie.ResolveFunc) error {
 	var firstRequest = true
 	for _, addrHash := range accountTouches {
 		if need, req := tds.t.NeedResolution(nil, addrHash[:]); need {
@@ -784,7 +806,7 @@ func (tds *TrieDbState) ExtractTouches() (accountTouches [][]byte, storageTouche
 	return tds.resolveSetBuilder.ExtractTouches()
 }
 
-func (tds *TrieDbState) resolveStateTrieWithFunc(resolveFunc func(*trie.Resolver) error) error {
+func (tds *TrieDbState) resolveStateTrieWithFunc(resolveFunc trie.ResolveFunc) error {
 	// Aggregating the current buffer, if any
 	if tds.currentBuffer != nil {
 		if tds.aggregateBuffer == nil {
@@ -805,12 +827,20 @@ func (tds *TrieDbState) resolveStateTrieWithFunc(resolveFunc func(*trie.Resolver
 
 	// Prepare (resolve) accounts trie so that actual modifications can proceed without database access
 	accountTouches, _ := tds.buildAccountTouches(tds.resolveReads, false)
+
+	// Prepare (resolve) contract code reads so that actual modifications can proceed without database access
+	codeTouches := tds.buildCodeTouches(tds.resolveReads)
+
 	var err error
 
 	if err = tds.resolveAccountTouches(accountTouches, resolveFunc); err != nil {
 		return err
 	}
 
+	if err = tds.resolveCodeTouches(codeTouches, resolveFunc); err != nil {
+		return err
+	}
+
 	if tds.resolveReads {
 		tds.populateAccountBlockProof(accountTouches)
 	}
@@ -943,6 +973,13 @@ func (tds *TrieDbState) updateTrieRoots(forward bool) ([]common.Hash, error) {
 				tds.t.UpdateAccount(addrHash[:], account)
 			} else {
 				tds.t.Delete(addrHash[:])
+				delete(b.codeUpdates, addrHash)
+			}
+		}
+
+		for addrHash, newCode := range b.codeUpdates {
+			if err := tds.t.UpdateAccountCode(addrHash[:], newCode); err != nil {
+				return nil, err
 			}
 		}
 		for addrHash, m := range b.storageUpdates {
@@ -1331,18 +1368,35 @@ func (tds *TrieDbState) ReadAccountStorage(address common.Address, incarnation u
 	return enc, nil
 }
 
+func (tds *TrieDbState) ReadCodeByHash(codeHash common.Hash) (code []byte, err error) {
+	if bytes.Equal(codeHash[:], emptyCodeHash) {
+		return nil, nil
+	}
+
+	code, err = tds.db.Get(dbutils.CodeBucket, codeHash[:])
+	if tds.resolveReads {
+		// we have to be careful, because the code might change
+		// during the block executuion, so we are always
+		// storing the latest code hash
+		tds.resolveSetBuilder.ReadCode(codeHash)
+	}
+	return code, err
+}
+
 func (tds *TrieDbState) ReadAccountCode(address common.Address, codeHash common.Hash) (code []byte, err error) {
 	if bytes.Equal(codeHash[:], emptyCodeHash) {
 		return nil, nil
 	}
-	if cached, ok := tds.codeCache.Get(codeHash); ok {
-		code, err = cached.([]byte), nil
+
+	addrHash, err := tds.HashAddress(address, false /*save*/)
+	if err != nil {
+		return nil, err
+	}
+
+	if cached, ok := tds.t.GetAccountCode(addrHash[:]); ok {
+		code, err = cached, nil
 	} else {
 		code, err = tds.db.Get(dbutils.CodeBucket, codeHash[:])
-		if err == nil {
-			tds.codeSizeCache.Add(codeHash, len(code))
-			tds.codeCache.Add(codeHash, code)
-		}
 	}
 	if tds.resolveReads {
 		addrHash, err1 := common.HashData(address[:])
@@ -1352,25 +1406,23 @@ func (tds *TrieDbState) ReadAccountCode(address common.Address, codeHash common.
 		if _, ok := tds.currentBuffer.accountUpdates[addrHash]; !ok {
 			tds.currentBuffer.accountReads[addrHash] = struct{}{}
 		}
-		tds.resolveSetBuilder.ReadCode(codeHash, code)
+		// we have to be careful, because the code might change
+		// during the block executuion, so we are always
+		// storing the latest code hash
+		tds.currentBuffer.codeReads[addrHash] = codeHash
+		tds.resolveSetBuilder.ReadCode(codeHash)
 	}
 	return code, err
 }
 
 func (tds *TrieDbState) ReadAccountCodeSize(address common.Address, codeHash common.Hash) (codeSize int, err error) {
-	var code []byte
-	if cached, ok := tds.codeSizeCache.Get(codeHash); ok {
-		codeSize, err = cached.(int), nil
-		if tds.resolveReads {
-			if cachedCode, ok := tds.codeCache.Get(codeHash); ok {
-				code, err = cachedCode.([]byte), nil
-			} else {
-				code, err = tds.ReadAccountCode(address, codeHash)
-				if err != nil {
-					return 0, err
-				}
-			}
-		}
+	addrHash, err := tds.HashAddress(address, false /*save*/)
+	if err != nil {
+		return 0, err
+	}
+
+	if code, ok := tds.t.GetAccountCode(addrHash[:]); ok {
+		codeSize, err = len(code), nil
 	} else {
 		code, err = tds.ReadAccountCode(address, codeHash)
 		if err != nil {
@@ -1386,7 +1438,11 @@ func (tds *TrieDbState) ReadAccountCodeSize(address common.Address, codeHash com
 		if _, ok := tds.currentBuffer.accountUpdates[addrHash]; !ok {
 			tds.currentBuffer.accountReads[addrHash] = struct{}{}
 		}
-		tds.resolveSetBuilder.ReadCode(codeHash, code)
+		// we have to be careful, because the code might change
+		// during the block executuion, so we are always
+		// storing the latest code hash
+		tds.currentBuffer.codeReads[addrHash] = codeHash
+		tds.resolveSetBuilder.ReadCode(codeHash)
 	}
 	return codeSize, nil
 }
@@ -1519,8 +1575,9 @@ func (tsw *TrieStateWriter) DeleteAccount(_ context.Context, address common.Addr
 
 func (tsw *TrieStateWriter) UpdateAccountCode(addrHash common.Hash, incarnation uint64, codeHash common.Hash, code []byte) error {
 	if tsw.tds.resolveReads {
-		tsw.tds.resolveSetBuilder.CreateCode(codeHash, code)
+		tsw.tds.resolveSetBuilder.CreateCode(codeHash)
 	}
+	tsw.tds.currentBuffer.codeUpdates[addrHash] = code
 	return nil
 }
 
@@ -1551,12 +1608,12 @@ func (tsw *TrieStateWriter) WriteAccountStorage(_ context.Context, address commo
 
 // ExtractWitness produces block witness for the block just been processed, in a serialised form
 func (tds *TrieDbState) ExtractWitness(trace bool, isBinary bool) (*trie.Witness, error) {
-	rs, codeMap := tds.resolveSetBuilder.Build(isBinary)
+	rs := tds.resolveSetBuilder.Build(isBinary)
 
-	return tds.makeBlockWitness(trace, rs, codeMap, isBinary)
+	return tds.makeBlockWitness(trace, rs, isBinary)
 }
 
-func (tds *TrieDbState) makeBlockWitness(trace bool, rs *trie.ResolveSet, codeMap map[common.Hash][]byte, isBinary bool) (*trie.Witness, error) {
+func (tds *TrieDbState) makeBlockWitness(trace bool, rs *trie.ResolveSet, isBinary bool) (*trie.Witness, error) {
 	tds.tMu.Lock()
 	defer tds.tMu.Unlock()
 
@@ -1565,7 +1622,7 @@ func (tds *TrieDbState) makeBlockWitness(trace bool, rs *trie.ResolveSet, codeMa
 		t = trie.HexToBin(tds.t).Trie()
 	}
 
-	return t.ExtractWitness(tds.blockNr, trace, rs, codeMap)
+	return t.ExtractWitness(tds.blockNr, trace, rs)
 }
 
 func (tsw *TrieStateWriter) CreateContract(address common.Address) error {
diff --git a/core/state/database_test.go b/core/state/database_test.go
index 3c3d960e35..d23189d370 100644
--- a/core/state/database_test.go
+++ b/core/state/database_test.go
@@ -1370,3 +1370,51 @@ func TestClearTombstonesForReCreatedAccount(t *testing.T) {
 		assert.Equal(expect, ok, k)
 	}
 }
+
+func TestChangeAccountCodeBetweenBlocks(t *testing.T) {
+	contract := common.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624")
+
+	db := ethdb.NewMemDatabase()
+	tds, err := state.NewTrieDbState(common.Hash{}, db, 0)
+	if err != nil {
+		t.Errorf("could not create TrieDbState: %v", err)
+	}
+	tsw := tds.TrieStateWriter()
+	intraBlockState := state.New(tds)
+	ctx := context.Background()
+	// Start the 1st transaction
+	tds.StartNewBuffer()
+	intraBlockState.CreateAccount(contract, true)
+
+	oldCode := []byte{0x01, 0x02, 0x03, 0x04}
+
+	intraBlockState.SetCode(contract, oldCode)
+	intraBlockState.AddBalance(contract, big.NewInt(1000000000))
+	if err = intraBlockState.FinalizeTx(ctx, tsw); err != nil {
+		t.Errorf("error finalising 1st tx: %v", err)
+	}
+
+	tds.ComputeTrieRoots()
+
+	oldCodeHash := common.BytesToHash(crypto.Keccak256(oldCode))
+
+	trieCode, err := tds.ReadAccountCode(contract, oldCodeHash)
+	assert.NoError(t, err, "you can receive the new code")
+	assert.Equal(t, oldCode, trieCode, "new code should be received")
+
+	tds.StartNewBuffer()
+
+	newCode := []byte{0x04, 0x04, 0x04, 0x04}
+	intraBlockState.SetCode(contract, newCode)
+
+	if err = intraBlockState.FinalizeTx(ctx, tsw); err != nil {
+		t.Errorf("error finalising 1st tx: %v", err)
+	}
+
+	tds.ComputeTrieRoots()
+
+	newCodeHash := common.BytesToHash(crypto.Keccak256(newCode))
+	trieCode, err = tds.ReadAccountCode(contract, newCodeHash)
+	assert.NoError(t, err, "you can receive the new code")
+	assert.Equal(t, newCode, trieCode, "new code should be received")
+}
diff --git a/core/state/stateless.go b/core/state/stateless.go
index 6b5a8cfd21..720df5785e 100644
--- a/core/state/stateless.go
+++ b/core/state/stateless.go
@@ -34,7 +34,7 @@ import (
 // during the execution of block(s)
 type Stateless struct {
 	t              *trie.Trie             // State trie
-	codeMap        map[common.Hash][]byte // Lookup index from code hashes to corresponding bytecode
+	codeUpdates    map[common.Hash][]byte // Lookup index from code hashes to corresponding bytecode
 	blockNr        uint64                 // Current block number
 	storageUpdates map[common.Hash]map[common.Hash][]byte
 	accountUpdates map[common.Hash]*accounts.Account
@@ -47,10 +47,11 @@ type Stateless struct {
 // It deserialises the block witness and creates the state trie out of it, checking that the root of the constructed
 // state trie matches the value of `stateRoot` parameter
 func NewStateless(stateRoot common.Hash, blockWitness *trie.Witness, blockNr uint64, trace bool, isBinary bool) (*Stateless, error) {
-	t, codeMap, err := trie.BuildTrieFromWitness(blockWitness, isBinary, trace)
+	t, err := trie.BuildTrieFromWitness(blockWitness, isBinary, trace)
 	if err != nil {
 		return nil, err
 	}
+
 	if !isBinary {
 		if t.Hash() != stateRoot {
 			filename := fmt.Sprintf("root_%d.txt", blockNr)
@@ -64,7 +65,7 @@ func NewStateless(stateRoot common.Hash, blockWitness *trie.Witness, blockNr uin
 	}
 	return &Stateless{
 		t:              t,
-		codeMap:        codeMap,
+		codeUpdates:    make(map[common.Hash][]byte),
 		storageUpdates: make(map[common.Hash]map[common.Hash][]byte),
 		accountUpdates: make(map[common.Hash]*accounts.Account),
 		deleted:        make(map[common.Hash]struct{}),
@@ -113,7 +114,6 @@ func (s *Stateless) ReadAccountStorage(address common.Address, incarnation uint6
 }
 
 // ReadAccountCode is a part of the StateReader interface
-// This implementation looks the code up in the codeMap, failing if the code is not found.
 func (s *Stateless) ReadAccountCode(address common.Address, codeHash common.Hash) (code []byte, err error) {
 	if bytes.Equal(codeHash[:], emptyCodeHash) {
 		return nil, nil
@@ -121,10 +121,20 @@ func (s *Stateless) ReadAccountCode(address common.Address, codeHash common.Hash
 	if s.trace {
 		fmt.Printf("Getting code for %x\n", codeHash)
 	}
-	if code, ok := s.codeMap[codeHash]; ok {
+
+	addrHash, err := common.HashData(address[:])
+	if err != nil {
+		return nil, err
+	}
+
+	if code, ok := s.codeUpdates[addrHash]; ok {
+		return code, nil
+	}
+
+	if code, ok := s.t.GetAccountCode(addrHash[:]); ok {
 		return code, nil
 	}
-	return nil, fmt.Errorf("could not find bytecode for hash %x", codeHash)
+	return nil, fmt.Errorf("could not find bytecode for acc: %x hash %x", address, codeHash)
 }
 
 // ReadAccountCodeSize is a part of the StateReader interface
@@ -134,9 +144,20 @@ func (s *Stateless) ReadAccountCodeSize(address common.Address, codeHash common.
 	if bytes.Equal(codeHash[:], emptyCodeHash) {
 		return 0, nil
 	}
-	if code, ok := s.codeMap[codeHash]; ok {
+
+	addrHash, err := common.HashData(address[:])
+	if err != nil {
+		return 0, err
+	}
+
+	if code, ok := s.codeUpdates[addrHash]; ok {
+		return len(code), nil
+	}
+
+	if code, ok := s.t.GetAccountCode(addrHash[:]); ok {
 		return len(code), nil
 	}
+
 	return 0, fmt.Errorf("could not find bytecode for hash %x", codeHash)
 }
 
@@ -172,9 +193,9 @@ func (s *Stateless) DeleteAccount(_ context.Context, address common.Address, ori
 // UpdateAccountCode is a part of the StateWriter interface
 // This implementation adds the code to the codeMap to make it available for further accesses
 func (s *Stateless) UpdateAccountCode(addrHash common.Hash, incarnation uint64, codeHash common.Hash, code []byte) error {
-	if _, ok := s.codeMap[codeHash]; !ok {
-		s.codeMap[codeHash] = code
-	}
+
+	s.codeUpdates[codeHash] = code
+
 	if s.trace {
 		fmt.Printf("Stateless: UpdateAccountCode %x codeHash %x\n", addrHash, codeHash)
 	}
@@ -311,7 +332,3 @@ func (s *Stateless) CheckRoot(expected common.Hash) error {
 func (s *Stateless) GetTrie() *trie.Trie {
 	return s.t
 }
-
-func (s *Stateless) GetCodeMap() map[common.Hash][]byte {
-	return s.codeMap
-}
diff --git a/eth/handler.go b/eth/handler.go
index b7c4e7038a..0b2c038fea 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -712,8 +712,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			}
 
 			// Now attempt to get the byte code
-			var zeroAddress common.Address
-			code, err := tds.ReadAccountCode(zeroAddress, hash)
+			code, err := tds.ReadCodeByHash(hash)
 			if err == nil {
 				data = append(data, code)
 				bytes += len(code)
diff --git a/trie/debug.go b/trie/debug.go
index b9fb404cf7..d3ae43e2d4 100644
--- a/trie/debug.go
+++ b/trie/debug.go
@@ -252,6 +252,13 @@ func (n valueNode) print(w io.Writer) {
 	fmt.Fprintf(w, "v(%x)", []byte(n))
 }
 
+func (n codeNode) fstring(ind string) string {
+	return fmt.Sprintf("code: %x ", []byte(n))
+}
+func (n codeNode) print(w io.Writer) {
+	fmt.Fprintf(w, "code(%x)", []byte(n))
+}
+
 func (an accountNode) fstring(ind string) string {
 	encodedAccount := pool.GetBuffer(an.EncodingLengthForHashing())
 	an.EncodeForHashing(encodedAccount.B)
diff --git a/trie/hashbuilder.go b/trie/hashbuilder.go
index 55311b49e6..4929e4396a 100644
--- a/trie/hashbuilder.go
+++ b/trie/hashbuilder.go
@@ -1,6 +1,7 @@
 package trie
 
 import (
+	"bytes"
 	"fmt"
 	"io"
 	"math/big"
@@ -210,14 +211,26 @@ func (hb *HashBuilder) accountLeaf(length int, keyHex []byte, storageSize uint64
 		}
 		popped++
 	}
+	var accountCode codeNode
 	if fieldSet&uint32(8) != 0 {
 		copy(hb.acc.CodeHash[:], hb.hashStack[len(hb.hashStack)-popped*hashStackStride-common.HashLength:len(hb.hashStack)-popped*hashStackStride])
+		ok := false
+		if !bytes.Equal(hb.acc.CodeHash[:], EmptyCodeHash[:]) {
+			stackTop := hb.nodeStack[len(hb.nodeStack)-popped-1]
+			if stackTop != nil { // if we don't have any stack top it might be okay because we didn't resolve the code yet (stateful resolver)
+				// but if we have something on top of the stack that isn't `nil`, it has to be a codeNode
+				accountCode, ok = stackTop.(codeNode)
+				if !ok {
+					return fmt.Errorf("unexpected node type on the node stack, wanted codeNode, got %t:%s", stackTop, stackTop)
+				}
+			}
+		}
 		popped++
 	}
 	var accCopy accounts.Account
 	accCopy.Copy(&hb.acc)
 
-	s := &shortNode{Key: common.CopyBytes(key), Val: &accountNode{accCopy, root, true}}
+	s := &shortNode{Key: common.CopyBytes(key), Val: &accountNode{accCopy, root, true, accountCode}}
 	// this invocation will take care of the popping given number of items from both hash stack and node stack,
 	// pushing resulting hash to the hash stack, and nil to the node stack
 	if err = hb.accountLeafHashWithKey(key, popped); err != nil {
@@ -520,23 +533,23 @@ func (hb *HashBuilder) hash(hash []byte) error {
 	return nil
 }
 
-func (hb *HashBuilder) code(code []byte) (common.Hash, error) {
+func (hb *HashBuilder) code(code []byte) error {
 	if hb.trace {
 		fmt.Printf("CODE\n")
 	}
 	codeCopy := common.CopyBytes(code)
-	hb.nodeStack = append(hb.nodeStack, nil)
+	hb.nodeStack = append(hb.nodeStack, codeNode(codeCopy))
 	hb.sha.Reset()
 	if _, err := hb.sha.Write(codeCopy); err != nil {
-		return common.Hash{}, err
+		return err
 	}
 	var hash [hashStackStride]byte // RLP representation of hash (or un-hashes value)
 	hash[0] = 0x80 + common.HashLength
 	if _, err := hb.sha.Read(hash[1:]); err != nil {
-		return common.Hash{}, err
+		return err
 	}
 	hb.hashStack = append(hb.hashStack, hash[:]...)
-	return common.BytesToHash(hash[1:]), nil
+	return nil
 }
 
 func (hb *HashBuilder) emptyRoot() {
diff --git a/trie/node.go b/trie/node.go
index 660863d24e..314d264411 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -62,7 +62,10 @@ type (
 		accounts.Account
 		storage     node
 		rootCorrect bool
+		code        codeNode
 	}
+
+	codeNode []byte
 )
 
 // nilValueNode is used when collapsing internal trie nodes for hashing, since
@@ -211,6 +214,7 @@ type nodeRef struct {
 
 func (n hashNode) reference() []byte      { return n }
 func (n valueNode) reference() []byte     { return nil }
+func (n codeNode) reference() []byte      { return nil }
 func (n *fullNode) reference() []byte     { return n.ref.data[0:n.ref.len] }
 func (n *duoNode) reference() []byte      { return n.ref.data[0:n.ref.len] }
 func (n *shortNode) reference() []byte    { return n.ref.data[0:n.ref.len] }
@@ -222,4 +226,5 @@ func (n duoNode) String() string      { return n.fstring("") }
 func (n shortNode) String() string    { return n.fstring("") }
 func (n hashNode) String() string     { return n.fstring("") }
 func (n valueNode) String() string    { return n.fstring("") }
+func (n codeNode) String() string     { return n.fstring("") }
 func (an accountNode) String() string { return an.fstring("") }
diff --git a/trie/resolve_set.go b/trie/resolve_set.go
index 655bf2d6db..2da87c149c 100644
--- a/trie/resolve_set.go
+++ b/trie/resolve_set.go
@@ -19,6 +19,8 @@ package trie
 import (
 	"bytes"
 	"sort"
+
+	"github.com/ledgerwatch/turbo-geth/common"
 )
 
 // ResolveSet encapsulates the set of keys that are required to be fully available, or resolved
@@ -26,20 +28,21 @@ import (
 // pairs
 // DESCRIBED: docs/programmers_guide/guide.md#converting-sequence-of-keys-and-value-into-a-multiproof
 type ResolveSet struct {
-	inited    bool // Whether keys are sorted and "LTE" and "GT" indices set
-	binary    bool // if true, use binary encoding instead of Hex
-	minLength int  // Mininum length of prefixes for which `HashOnly` function can return `true`
-	lteIndex  int  // Index of the "LTE" key in the keys slice. Next one is "GT"
-	hexes     sortable
+	inited      bool // Whether keys are sorted and "LTE" and "GT" indices set
+	binary      bool // if true, use binary encoding instead of Hex
+	minLength   int  // Mininum length of prefixes for which `HashOnly` function can return `true`
+	lteIndex    int  // Index of the "LTE" key in the keys slice. Next one is "GT"
+	hexes       sortable
+	codeTouches map[common.Hash]struct{}
 }
 
 // NewResolveSet creates new ResolveSet
 func NewResolveSet(minLength int) *ResolveSet {
-	return &ResolveSet{minLength: minLength}
+	return &ResolveSet{minLength: minLength, codeTouches: make(map[common.Hash]struct{})}
 }
 
 func NewBinaryResolveSet(minLength int) *ResolveSet {
-	return &ResolveSet{minLength: minLength, binary: true}
+	return &ResolveSet{minLength: minLength, codeTouches: make(map[common.Hash]struct{}), binary: true}
 }
 
 // AddKey adds a new key (in KEY encoding) to the set
@@ -57,6 +60,16 @@ func (rs *ResolveSet) AddHex(hex []byte) {
 	}
 }
 
+// AddCodeTouch adds a new code touch into the resolve set
+func (rs *ResolveSet) AddCodeTouch(codeHash common.Hash) {
+	rs.codeTouches[codeHash] = struct{}{}
+}
+
+func (rs *ResolveSet) IsCodeTouched(codeHash common.Hash) bool {
+	_, ok := rs.codeTouches[codeHash]
+	return ok
+}
+
 func (rs *ResolveSet) ensureInited() {
 	if rs.inited {
 		return
diff --git a/trie/resolve_set_builder.go b/trie/resolve_set_builder.go
index 21f5e8a838..6e30530e89 100644
--- a/trie/resolve_set_builder.go
+++ b/trie/resolve_set_builder.go
@@ -6,17 +6,17 @@ import "github.com/ledgerwatch/turbo-geth/common"
 // the execution of a block. It also tracks the contract codes that were created and used during the execution
 // of a block
 type ResolveSetBuilder struct {
-	touches        [][]byte               // Read/change set of account keys (account hashes)
-	storageTouches [][]byte               // Read/change set of storage keys (account hashes concatenated with storage key hashes)
-	proofCodes     map[common.Hash][]byte // Contract codes that have been accessed
-	createdCodes   map[common.Hash][]byte // Contract codes that were created (deployed)
+	touches        [][]byte                 // Read/change set of account keys (account hashes)
+	storageTouches [][]byte                 // Read/change set of storage keys (account hashes concatenated with storage key hashes)
+	proofCodes     map[common.Hash]struct{} // Contract codes that have been accessed (codeHash)
+	createdCodes   map[common.Hash]struct{} // Contract codes that were created (deployed) (codeHash)
 }
 
 // NewResolveSetBuilder creates new ProofGenerator and initialised its maps
 func NewResolveSetBuilder() *ResolveSetBuilder {
 	return &ResolveSetBuilder{
-		proofCodes:   make(map[common.Hash][]byte),
-		createdCodes: make(map[common.Hash][]byte),
+		proofCodes:   make(map[common.Hash]struct{}),
+		createdCodes: make(map[common.Hash]struct{}),
 	}
 }
 
@@ -39,30 +39,30 @@ func (pg *ResolveSetBuilder) ExtractTouches() ([][]byte, [][]byte) {
 	return touches, storageTouches
 }
 
-// extractCodeMap returns the map of all contract codes that were required during the block's execution
-// but were not created during that same block. It also clears the maps for the next block's execution
-func (pg *ResolveSetBuilder) extractCodeMap() map[common.Hash][]byte {
+// extractCodeTouches returns the set of all contract codes that were required during the block's execution
+// but were not created during that same block. It also clears the set for the next block's execution
+func (pg *ResolveSetBuilder) extractCodeTouches() map[common.Hash]struct{} {
 	proofCodes := pg.proofCodes
-	pg.proofCodes = make(map[common.Hash][]byte)
-	pg.createdCodes = make(map[common.Hash][]byte)
+	pg.proofCodes = make(map[common.Hash]struct{})
+	pg.createdCodes = make(map[common.Hash]struct{})
 	return proofCodes
 }
 
 // ReadCode registers that given contract code has been accessed during current block's execution
-func (pg *ResolveSetBuilder) ReadCode(codeHash common.Hash, code []byte) {
-	if _, ok := pg.createdCodes[codeHash]; !ok {
-		pg.proofCodes[codeHash] = code
+func (pg *ResolveSetBuilder) ReadCode(codeHash common.Hash) {
+	if _, ok := pg.proofCodes[codeHash]; !ok {
+		pg.proofCodes[codeHash] = struct{}{}
 	}
 }
 
 // CreateCode registers that given contract code has been created (deployed) during current block's execution
-func (pg *ResolveSetBuilder) CreateCode(codeHash common.Hash, code []byte) {
+func (pg *ResolveSetBuilder) CreateCode(codeHash common.Hash) {
 	if _, ok := pg.proofCodes[codeHash]; !ok {
-		pg.createdCodes[codeHash] = code
+		pg.createdCodes[codeHash] = struct{}{}
 	}
 }
 
-func (pg *ResolveSetBuilder) Build(isBinary bool) (*ResolveSet, CodeMap) {
+func (pg *ResolveSetBuilder) Build(isBinary bool) *ResolveSet {
 	var rs *ResolveSet
 	if isBinary {
 		rs = NewBinaryResolveSet(0)
@@ -71,6 +71,7 @@ func (pg *ResolveSetBuilder) Build(isBinary bool) (*ResolveSet, CodeMap) {
 	}
 
 	touches, storageTouches := pg.ExtractTouches()
+	codeTouches := pg.extractCodeTouches()
 
 	for _, touch := range touches {
 		rs.AddKey(touch)
@@ -78,6 +79,9 @@ func (pg *ResolveSetBuilder) Build(isBinary bool) (*ResolveSet, CodeMap) {
 	for _, touch := range storageTouches {
 		rs.AddKey(touch)
 	}
-	codeMap := pg.extractCodeMap()
-	return rs, codeMap
+	for codeHash, _ := range codeTouches {
+		rs.AddCodeTouch(codeHash)
+	}
+
+	return rs
 }
diff --git a/trie/resolver.go b/trie/resolver.go
index 1ca198f571..92715e45ef 100644
--- a/trie/resolver.go
+++ b/trie/resolver.go
@@ -13,6 +13,8 @@ import (
 
 var emptyHash [32]byte
 
+type ResolveFunc func(*Resolver) error
+
 func (t *Trie) Rebuild(db ethdb.Database, blockNr uint64) error {
 	if t.root == nil {
 		return nil
@@ -38,15 +40,17 @@ type Resolver struct {
 	blockNr          uint64
 	topLevels        int // How many top levels of the trie to keep (not roll into hashes)
 	requests         []*ResolveRequest
+	codeRequests     []*ResolveRequestForCode
 	witnesses        []*Witness // list of witnesses for resolved subtries, nil if `collectWitnesses` is false
 }
 
 func NewResolver(topLevels int, forAccounts bool, blockNr uint64) *Resolver {
 	tr := Resolver{
-		accounts:  forAccounts,
-		requests:  []*ResolveRequest{},
-		blockNr:   blockNr,
-		topLevels: topLevels,
+		accounts:     forAccounts,
+		requests:     []*ResolveRequest{},
+		codeRequests: []*ResolveRequestForCode{},
+		blockNr:      blockNr,
+		topLevels:    topLevels,
 	}
 	return &tr
 }
@@ -56,6 +60,7 @@ func (tr *Resolver) Reset(topLevels int, forAccounts bool, blockNr uint64) {
 	tr.accounts = forAccounts
 	tr.blockNr = blockNr
 	tr.requests = tr.requests[:0]
+	tr.codeRequests = tr.codeRequests[:0]
 	tr.witnesses = nil
 	tr.collectWitnesses = false
 	tr.historical = false
@@ -76,6 +81,11 @@ func (tr *Resolver) SetHistorical(h bool) {
 	tr.historical = h
 }
 
+// AddCodeRequest add a request for code resolution
+func (tr *Resolver) AddCodeRequest(req *ResolveRequestForCode) {
+	tr.codeRequests = append(tr.codeRequests, req)
+}
+
 // Resolver implements sort.Interface
 // and sorts by resolve requests
 // (more general requests come first)
@@ -152,7 +162,10 @@ func (tr *Resolver) ResolveStateful(db ethdb.Database, blockNr uint64) error {
 	sort.Stable(tr)
 
 	resolver := NewResolverStateful(tr.topLevels, tr.requests, hf)
-	return resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical)
+	if err := resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical); err != nil {
+		return err
+	}
+	return resolver.AttachRequestedCode(db, tr.codeRequests)
 }
 
 func (tr *Resolver) ResolveStatefulCached(db ethdb.Database, blockNr uint64) error {
@@ -166,7 +179,10 @@ func (tr *Resolver) ResolveStatefulCached(db ethdb.Database, blockNr uint64) err
 	sort.Stable(tr)
 
 	resolver := NewResolverStatefulCached(tr.topLevels, tr.requests, hf)
-	return resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical)
+	if err := resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical); err != nil {
+		return err
+	}
+	return resolver.AttachRequestedCode(db, tr.codeRequests)
 }
 
 // ResolveStateless resolves and hooks subtries using a witnesses database instead of
@@ -174,6 +190,7 @@ func (tr *Resolver) ResolveStatefulCached(db ethdb.Database, blockNr uint64) err
 func (tr *Resolver) ResolveStateless(db WitnessStorage, blockNr uint64, trieLimit uint32, startPos int64) (int64, error) {
 	sort.Stable(tr)
 	resolver := NewResolverStateless(tr.requests, hookSubtrie)
+	// we expect CodeNodes to be already attached to the trie in stateless resolution
 	return resolver.RebuildTrie(db, blockNr, trieLimit, startPos)
 }
 
@@ -212,7 +229,7 @@ func (tr *Resolver) extractWitnessAndHookSubtrie(currentReq *ResolveRequest, hbR
 		tr.witnesses = make([]*Witness, 0)
 	}
 
-	witness, err := extractWitnessFromRootNode(hbRoot, tr.blockNr, false /*tr.hb.trace*/, nil, nil)
+	witness, err := extractWitnessFromRootNode(hbRoot, tr.blockNr, false /*tr.hb.trace*/, nil)
 	if err != nil {
 		return fmt.Errorf("error while extracting witness for resolver: %w", err)
 	}
diff --git a/trie/resolver_stateful.go b/trie/resolver_stateful.go
index d282c88d56..f5bd455f18 100644
--- a/trie/resolver_stateful.go
+++ b/trie/resolver_stateful.go
@@ -186,6 +186,20 @@ func (tr *ResolverStateful) RebuildTrie(
 	return nil
 }
 
+func (tr *ResolverStateful) AttachRequestedCode(db ethdb.Getter, requests []*ResolveRequestForCode) error {
+	for _, req := range requests {
+		codeHash := req.codeHash
+		code, err := db.Get(dbutils.CodeBucket, codeHash[:])
+		if err != nil {
+			return err
+		}
+		if err := req.t.UpdateAccountCode(req.addrHash[:], codeNode(code)); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (tr *ResolverStateful) WalkerAccounts(keyIdx int, k []byte, v []byte) error {
 	return tr.Walker(true, keyIdx, k, v)
 }
diff --git a/trie/resolver_stateless.go b/trie/resolver_stateless.go
index 6f43bc8659..5fa5c18031 100644
--- a/trie/resolver_stateless.go
+++ b/trie/resolver_stateless.go
@@ -42,7 +42,7 @@ func (r *ResolverStateless) RebuildTrie(db WitnessStorage, blockNr uint64, trieL
 				return 0, err
 			}
 
-			trie, _, err := BuildTrieFromWitness(witness, false /*is-binary*/, false /*trace*/)
+			trie, err := BuildTrieFromWitness(witness, false /*is-binary*/, false /*trace*/)
 			if err != nil {
 				return 0, err
 			}
diff --git a/trie/resolver_stateless_test.go b/trie/resolver_stateless_test.go
index 0e5e81392c..c04882f43e 100644
--- a/trie/resolver_stateless_test.go
+++ b/trie/resolver_stateless_test.go
@@ -35,17 +35,17 @@ func TestRebuildTrie(t *testing.T) {
 	trie2 := buildTestTrie(10)
 	trie3 := buildTestTrie(100)
 
-	w1, err := extractWitnessFromRootNode(trie1.root, 1, false, nil, nil)
+	w1, err := extractWitnessFromRootNode(trie1.root, 1, false, nil)
 	if err != nil {
 		t.Error(err)
 	}
 
-	w2, err := extractWitnessFromRootNode(trie2.root, 1, false, nil, nil)
+	w2, err := extractWitnessFromRootNode(trie2.root, 1, false, nil)
 	if err != nil {
 		t.Error(err)
 	}
 
-	w3, err := extractWitnessFromRootNode(trie3.root, 1, false, nil, nil)
+	w3, err := extractWitnessFromRootNode(trie3.root, 1, false, nil)
 	if err != nil {
 		t.Error(err)
 	}
diff --git a/trie/trie.go b/trie/trie.go
index 7a1f2c87af..95275e9b44 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -26,7 +26,9 @@ import (
 	"github.com/ledgerwatch/turbo-geth/common/debug"
 	"github.com/ledgerwatch/turbo-geth/core/types/accounts"
 	"github.com/ledgerwatch/turbo-geth/crypto"
+	"github.com/ledgerwatch/turbo-geth/ethdb"
 	"github.com/ledgerwatch/turbo-geth/log"
+	"github.com/pkg/errors"
 )
 
 var (
@@ -131,6 +133,31 @@ func (t *Trie) GetAccount(key []byte) (value *accounts.Account, gotValue bool) {
 	return nil, gotValue
 }
 
+func (t *Trie) GetAccountCode(key []byte) (value []byte, gotValue bool) {
+	if t.root == nil {
+		return nil, true
+	}
+
+	hex := keybytesToHex(key)
+	if t.binary {
+		hex = keyHexToBin(hex)
+	}
+
+	accNode, gotValue := t.getAccount(t.root, hex, 0)
+	if accNode != nil {
+		if bytes.Equal(accNode.Account.CodeHash[:], EmptyCodeHash[:]) {
+			return nil, gotValue
+		}
+
+		if accNode.code == nil {
+			return nil, false
+		}
+
+		return accNode.code, gotValue
+	}
+	return nil, gotValue
+}
+
 func (t *Trie) getAccount(origNode node, key []byte, pos int) (value *accountNode, gotValue bool) {
 	switch n := (origNode).(type) {
 	case nil:
@@ -247,20 +274,59 @@ func (t *Trie) UpdateAccount(key []byte, acc *accounts.Account) {
 	if t.root == nil {
 		var newnode node
 		if value.Root == EmptyRoot || value.Root == (common.Hash{}) {
-			newnode = &shortNode{Key: hex, Val: &accountNode{*value, nil, true}}
+			newnode = &shortNode{Key: hex, Val: &accountNode{*value, nil, true, nil}}
 		} else {
-			newnode = &shortNode{Key: hex, Val: &accountNode{*value, hashNode(value.Root[:]), true}}
+			newnode = &shortNode{Key: hex, Val: &accountNode{*value, hashNode(value.Root[:]), true, nil}}
 		}
 		t.root = newnode
 	} else {
 		if value.Root == EmptyRoot || value.Root == (common.Hash{}) {
-			_, t.root = t.insert(t.root, hex, 0, &accountNode{*value, nil, true})
+			_, t.root = t.insert(t.root, hex, 0, &accountNode{*value, nil, true, nil})
 		} else {
-			_, t.root = t.insert(t.root, hex, 0, &accountNode{*value, hashNode(value.Root[:]), true})
+			_, t.root = t.insert(t.root, hex, 0, &accountNode{*value, hashNode(value.Root[:]), true, nil})
 		}
 	}
 }
 
+// UpdateAccountCode attaches the code node to an account at specified key
+func (t *Trie) UpdateAccountCode(key []byte, code codeNode) error {
+	if t.root == nil {
+		return nil
+	}
+
+	hex := keybytesToHex(key)
+	if t.binary {
+		hex = keyHexToBin(hex)
+	}
+
+	accNode, gotValue := t.getAccount(t.root, hex, 0)
+	if accNode == nil || !gotValue {
+		return errors.Wrapf(ethdb.ErrKeyNotFound, "account not found with key: %x", key)
+	}
+
+	actualCodeHash := crypto.Keccak256(code)
+	if !bytes.Equal(accNode.CodeHash[:], actualCodeHash) {
+		return fmt.Errorf("inserted code mismatch account hash (acc.CodeHash=%x codeHash=%x)", accNode.CodeHash[:], actualCodeHash)
+	}
+
+	accNode.code = code
+
+	_, t.root = t.insert(t.root, hex, 0, accNode)
+	return nil
+}
+
+// ResolveRequestFor Code expresses the need to fetch code from the DB (by its hash) and attach
+// to a specific account leaf in the trie.
+type ResolveRequestForCode struct {
+	t        *Trie
+	addrHash common.Hash // contract address hash
+	codeHash common.Hash
+}
+
+func (rr *ResolveRequestForCode) String() string {
+	return fmt.Sprintf("rr_code{addrHash:%x,codeHash:%x}", rr.addrHash, rr.codeHash)
+}
+
 // ResolveRequest expresses the need to fetch a subtrie from the database. The location of this
 // subtrie is specified by the resolveHex[:resolvePos]. The remaining part of resolveHex (if present)
 // is useful to ensure that specific leaves of the trie are fully expanded (and not rolled into
@@ -283,6 +349,8 @@ type ResolveRequest struct {
 	NodeRLP       []byte   // [OUT] RLP of the resolved node
 }
 
+/* add code hash there, if nil -- don't do anything with the code */
+
 // NewResolveRequest creates a new ResolveRequest.
 // contract must be either address hash + incarnation (32+8 bytes) or nil
 func (t *Trie) NewResolveRequest(contract []byte, hex []byte, pos int, resolveHash []byte) *ResolveRequest {
@@ -293,6 +361,23 @@ func (rr *ResolveRequest) String() string {
 	return fmt.Sprintf("rr{t:%x,resolveHex:%x,resolvePos:%d,resolveHash:%s}", rr.contract, rr.resolveHex, rr.resolvePos, rr.resolveHash)
 }
 
+func (t *Trie) NewResolveRequestForCode(addrHash common.Hash, codeHash common.Hash) *ResolveRequestForCode {
+	return &ResolveRequestForCode{t, addrHash, codeHash}
+}
+
+func (t *Trie) NeedResolutonForCode(addrHash common.Hash, codeHash common.Hash) (bool, *ResolveRequestForCode) {
+	if bytes.Equal(codeHash[:], EmptyCodeHash[:]) {
+		return false, nil
+	}
+
+	_, ok := t.GetAccountCode(addrHash[:])
+	if !ok {
+		return true, t.NewResolveRequestForCode(addrHash, codeHash)
+	}
+
+	return false, nil
+}
+
 // NeedResolution determines whether the trie needs to be extended (resolved) by fetching data
 // from the database, if one were to access the key specified
 // In the case of "Yes", also returns a corresponding ResolveRequest
@@ -380,6 +465,9 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b
 		if origNok && vnok {
 			updated = !origAccN.Equals(&vAccN.Account)
 			if updated {
+				if !bytes.Equal(origAccN.CodeHash[:], vAccN.CodeHash[:]) {
+					origAccN.code = nil
+				}
 				origAccN.Account.Copy(&vAccN.Account)
 				origAccN.rootCorrect = false
 			}
diff --git a/trie/trie_from_witness.go b/trie/trie_from_witness.go
index 184ceec652..47d1f2093c 100644
--- a/trie/trie_from_witness.go
+++ b/trie/trie_from_witness.go
@@ -4,12 +4,10 @@ import (
 	"fmt"
 	"math/big"
 
-	"github.com/ledgerwatch/turbo-geth/common"
 	"github.com/ledgerwatch/turbo-geth/trie/rlphacks"
 )
 
-func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, CodeMap, error) {
-	codeMap := make(map[common.Hash][]byte)
+func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, error) {
 	hb := NewHashBuilder(false)
 	for _, operator := range witness.Operators {
 		switch op := operator.(type) {
@@ -20,38 +18,36 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C
 			keyHex := op.Key
 			val := op.Value
 			if err := hb.leaf(len(op.Key), keyHex, rlphacks.RlpSerializableBytes(val)); err != nil {
-				return nil, nil, err
+				return nil, err
 			}
 		case *OperatorExtension:
 			if trace {
 				fmt.Printf("EXTENSION ")
 			}
 			if err := hb.extension(op.Key); err != nil {
-				return nil, nil, err
+				return nil, err
 			}
 		case *OperatorBranch:
 			if trace {
 				fmt.Printf("BRANCH ")
 			}
 			if err := hb.branch(uint16(op.Mask)); err != nil {
-				return nil, nil, err
+				return nil, err
 			}
 		case *OperatorHash:
 			if trace {
 				fmt.Printf("HASH ")
 			}
 			if err := hb.hash(op.Hash[:]); err != nil {
-				return nil, nil, err
+				return nil, err
 			}
 		case *OperatorCode:
 			if trace {
 				fmt.Printf("CODE ")
 			}
 
-			if codeHash, err := hb.code(op.Code); err == nil {
-				codeMap[codeHash] = op.Code
-			} else {
-				return nil, nil, err
+			if err := hb.code(op.Code); err != nil {
+				return nil, err
 			}
 
 		case *OperatorLeafAccount:
@@ -74,7 +70,7 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C
 			incarnaton := uint64(0)
 
 			if err := hb.accountLeaf(len(op.Key), op.Key, 0, balance, nonce, incarnaton, fieldSet); err != nil {
-				return nil, nil, err
+				return nil, err
 			}
 		case *OperatorEmptyRoot:
 			if trace {
@@ -82,7 +78,7 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C
 			}
 			hb.emptyRoot()
 		default:
-			return nil, nil, fmt.Errorf("unknown operand type: %T", operator)
+			return nil, fmt.Errorf("unknown operand type: %T", operator)
 		}
 	}
 	if trace {
@@ -90,9 +86,9 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C
 	}
 	if !hb.hasRoot() {
 		if isBinary {
-			return NewBinary(EmptyRoot), nil, nil
+			return NewBinary(EmptyRoot), nil
 		}
-		return New(EmptyRoot), nil, nil
+		return New(EmptyRoot), nil
 	}
 	r := hb.root()
 	var tr *Trie
@@ -102,5 +98,5 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C
 		tr = New(hb.rootHash())
 	}
 	tr.root = r
-	return tr, codeMap, nil
+	return tr, nil
 }
diff --git a/trie/trie_test.go b/trie/trie_test.go
index c7f44f01b7..5066b925a0 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -566,3 +566,327 @@ func TestHashMapLeak(t *testing.T) {
 	assert.GreaterOrEqual(t, nHashes, nExpected*7/8)
 	assert.LessOrEqual(t, nHashes, nExpected*9/8)
 }
+
+func genRandomByteArrayOfLen(length uint) []byte {
+	array := make([]byte, length)
+	for i := uint(0); i < length; i++ {
+		array[i] = byte(rand.Intn(256))
+	}
+	return array
+}
+
+func getAddressForIndex(index int) [20]byte {
+	var address [20]byte
+	binary.BigEndian.PutUint32(address[:], uint32(index))
+	return address
+}
+
+func TestCodeNodeValid(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	numberOfAccounts := 20
+
+	addresses := make([][20]byte, numberOfAccounts)
+	for i := 0; i < len(addresses); i++ {
+		addresses[i] = getAddressForIndex(i)
+	}
+	codeValues := make([][]byte, len(addresses))
+	for i := 0; i < len(addresses); i++ {
+		codeValues[i] = genRandomByteArrayOfLen(128)
+		codeHash := common.BytesToHash(crypto.Keccak256(codeValues[i]))
+		balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+		acc := accounts.NewAccount()
+		acc.Nonce = uint64(random.Int63())
+		acc.Balance = *balance
+		acc.Root = EmptyRoot
+		acc.CodeHash = codeHash
+
+		trie.UpdateAccount(crypto.Keccak256(addresses[i][:]), &acc)
+		err := trie.UpdateAccountCode(crypto.Keccak256(addresses[i][:]), codeValues[i])
+		assert.Nil(t, err, "should successfully insert code")
+	}
+
+	for i := 0; i < len(addresses); i++ {
+		value, gotValue := trie.GetAccountCode(crypto.Keccak256(addresses[i][:]))
+		assert.True(t, gotValue, "should receive code value")
+		assert.True(t, bytes.Equal(value, codeValues[i]), "should receive the right code")
+	}
+}
+
+func TestCodeNodeUpdateNotExisting(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+	codeValue := genRandomByteArrayOfLen(128)
+
+	codeHash := common.BytesToHash(crypto.Keccak256(codeValue))
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue)
+	assert.Nil(t, err, "should successfully insert code")
+
+	nonExistingAddress := getAddressForIndex(9999)
+	codeValue2 := genRandomByteArrayOfLen(128)
+
+	err = trie.UpdateAccountCode(crypto.Keccak256(nonExistingAddress[:]), codeValue2)
+	assert.Error(t, err, "should return an error for non existing acc")
+}
+
+func TestCodeNodeGetNotExistingAccount(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+	codeValue := genRandomByteArrayOfLen(128)
+
+	codeHash := common.BytesToHash(crypto.Keccak256(codeValue))
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue)
+	assert.Nil(t, err, "should successfully insert code")
+
+	nonExistingAddress := getAddressForIndex(9999)
+
+	value, gotValue := trie.GetAccountCode(crypto.Keccak256(nonExistingAddress[:]))
+	assert.True(t, gotValue, "should indicate that account doesn't exist at all (not just hashed)")
+	assert.Nil(t, value, "the value should be nil")
+}
+
+func TestCodeNodeGetHashedAccount(t *testing.T) {
+	trie := newEmpty()
+
+	address := getAddressForIndex(0)
+
+	fakeAccount := genRandomByteArrayOfLen(50)
+	fakeAccountHash := common.BytesToHash(crypto.Keccak256(fakeAccount))
+
+	hex := keybytesToHex(crypto.Keccak256(address[:]))
+
+	_, trie.root = trie.insert(trie.root, hex, 0, hashNode(fakeAccountHash[:]))
+
+	value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:]))
+	assert.False(t, gotValue, "should indicate that account exists but hashed")
+	assert.Nil(t, value, "the value should be nil")
+}
+
+func TestCodeNodeGetExistingAccountNoCodeNotEmpty(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+	codeValue := genRandomByteArrayOfLen(128)
+
+	codeHash := common.BytesToHash(crypto.Keccak256(codeValue))
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+
+	value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:]))
+	assert.False(t, gotValue, "should indicate that account exists with code but the code isn't in cache")
+	assert.Nil(t, value, "the value should be nil")
+}
+
+func TestCodeNodeGetExistingAccountEmptyCode(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+
+	codeHash := EmptyCodeHash
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+
+	value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:]))
+	assert.True(t, gotValue, "should indicate that account exists with empty code")
+	assert.Nil(t, value, "the value should be nil")
+}
+
+func TestCodeNodeWrongHash(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+
+	codeValue1 := genRandomByteArrayOfLen(128)
+	codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1))
+
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash1
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+
+	codeValue2 := genRandomByteArrayOfLen(128)
+	err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue2)
+	assert.Error(t, err, "should NOT be able to insert code with wrong hash")
+}
+
+func TestCodeNodeUpdateAccountAndCodeValidHash(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+
+	codeValue1 := genRandomByteArrayOfLen(128)
+	codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1))
+
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash1
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1)
+	assert.Nil(t, err, "should successfully insert code")
+
+	codeValue2 := genRandomByteArrayOfLen(128)
+	codeHash2 := common.BytesToHash(crypto.Keccak256(codeValue2))
+
+	acc.CodeHash = codeHash2
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err = trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue2)
+	assert.Nil(t, err, "should successfully insert code")
+}
+
+func TestCodeNodeUpdateAccountAndCodeInvalidHash(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+
+	codeValue1 := genRandomByteArrayOfLen(128)
+	codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1))
+
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash1
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1)
+	assert.Nil(t, err, "should successfully insert code")
+
+	codeValue2 := genRandomByteArrayOfLen(128)
+	codeHash2 := common.BytesToHash(crypto.Keccak256(codeValue2))
+
+	codeValue3 := genRandomByteArrayOfLen(128)
+
+	acc.CodeHash = codeHash2
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err = trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue3)
+	assert.Error(t, err, "should NOT be able to insert code with wrong hash")
+}
+
+func TestCodeNodeUpdateAccountChangeCodeHash(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+
+	codeValue1 := genRandomByteArrayOfLen(128)
+	codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1))
+
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash1
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1)
+	assert.Nil(t, err, "should successfully insert code")
+
+	codeValue2 := genRandomByteArrayOfLen(128)
+	codeHash2 := common.BytesToHash(crypto.Keccak256(codeValue2))
+
+	acc.CodeHash = codeHash2
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:]))
+	assert.Nil(t, value, "the value should reset after the code change happen")
+	assert.False(t, gotValue, "should indicate that the code isn't in the cache")
+}
+
+func TestCodeNodeUpdateAccountNoChangeCodeHash(t *testing.T) {
+	trie := newEmpty()
+
+	random := rand.New(rand.NewSource(0))
+
+	address := getAddressForIndex(0)
+
+	codeValue1 := genRandomByteArrayOfLen(128)
+	codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1))
+
+	balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+
+	acc := accounts.NewAccount()
+	acc.Nonce = uint64(random.Int63())
+	acc.Balance = *balance
+	acc.Root = EmptyRoot
+	acc.CodeHash = codeHash1
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1)
+	assert.Nil(t, err, "should successfully insert code")
+
+	acc.Nonce = uint64(random.Int63())
+	balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
+	acc.Balance = *balance
+
+	trie.UpdateAccount(crypto.Keccak256(address[:]), &acc)
+	value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:]))
+	assert.Equal(t, value, codeValue1, "the value should NOT reset after account's non codehash had changed")
+	assert.True(t, gotValue, "should indicate that the code is still in the cache")
+}
diff --git a/trie/trie_transform.go b/trie/trie_transform.go
index e8d71be077..d89929ea9f 100644
--- a/trie/trie_transform.go
+++ b/trie/trie_transform.go
@@ -18,7 +18,12 @@ func transformSubTrie(nd node, hex []byte, newTrie *Trie, transformFunc keyTrans
 	case *accountNode:
 		accountCopy := accounts.NewAccount()
 		accountCopy.Copy(&n.Account)
-		_, newTrie.root = newTrie.insert(newTrie.root, transformFunc(hex), 0, &accountNode{accountCopy, nil, true})
+		var code []byte = nil
+		if n.code != nil {
+			code = make([]byte, len(n.code))
+			copy(code, n.code)
+		}
+		_, newTrie.root = newTrie.insert(newTrie.root, transformFunc(hex), 0, &accountNode{accountCopy, nil, true, codeNode(code)})
 		aHex := hex
 		if aHex[len(aHex)-1] == 16 {
 			aHex = aHex[:len(aHex)-1]
diff --git a/trie/trie_witness.go b/trie/trie_witness.go
index 43eb2d6fea..520b107f30 100644
--- a/trie/trie_witness.go
+++ b/trie/trie_witness.go
@@ -1,14 +1,14 @@
 package trie
 
-func (t *Trie) ExtractWitness(blockNr uint64, trace bool, rs *ResolveSet, codeMap CodeMap) (*Witness, error) {
-	return extractWitnessFromRootNode(t.root, blockNr, trace, rs, codeMap)
+func (t *Trie) ExtractWitness(blockNr uint64, trace bool, rs *ResolveSet) (*Witness, error) {
+	return extractWitnessFromRootNode(t.root, blockNr, trace, rs)
 }
 
 // extractWitnessFromRootNode extracts a witness for a subtrie starting from the specified root
 // if hashOnly param is nil it will make a witness for the full subtrie,
 // if hashOnly param is set to a ResolveSet instance, it will make a witness for only the accounts/storages that were actually touched; other paths will be hashed.
-func extractWitnessFromRootNode(root node, blockNr uint64, trace bool, hashOnly HashOnly, codeMap CodeMap) (*Witness, error) {
-	builder := NewWitnessBuilder(root, blockNr, trace, codeMap)
+func extractWitnessFromRootNode(root node, blockNr uint64, trace bool, hashOnly HashOnly) (*Witness, error) {
+	builder := NewWitnessBuilder(root, blockNr, trace)
 	var limiter *MerklePathLimiter
 	if hashOnly != nil {
 		hr := newHasher(false)
diff --git a/trie/visual.go b/trie/visual.go
index e08a30a350..862ab2212b 100644
--- a/trie/visual.go
+++ b/trie/visual.go
@@ -30,16 +30,15 @@ import (
 // VisualOpts contains various configuration options fo the Visual function
 // It has been introduced as a replacement for too many arguments with options
 type VisualOpts struct {
-	Highlights     [][]byte               // Collection of keys, in the HEX encoding, that need to be highlighted with digits
-	IndexColors    []string               // Array of colors for representing digits as colored boxes
-	FontColors     []string               // Array of colors, the same length as indexColors, for the textual digits inside the coloured boxes
-	CutTerminals   int                    // Specifies how many digits to cut from the terminal short node keys for a more convinient display
-	Values         bool                   // Whether to display value nodes (as box with rounded corners)
-	CodeCompressed bool                   // Whether to turn the code from a large rectange to a small square for a more convinient display
-	ValCompressed  bool                   // Whether long values (over 10 characters) are shortened using ... in the middle
-	ValHex         bool                   // Whether values should be displayed as hex numbers (otherwise they are displayed as just strings)
-	SameLevel      bool                   // Whether the leaves (and hashes) need to be on the same horizontal level
-	CodeMap        map[common.Hash][]byte // Map that allows looking up bytecode of contracts by the bytecode's hash
+	Highlights     [][]byte // Collection of keys, in the HEX encoding, that need to be highlighted with digits
+	IndexColors    []string // Array of colors for representing digits as colored boxes
+	FontColors     []string // Array of colors, the same length as indexColors, for the textual digits inside the coloured boxes
+	CutTerminals   int      // Specifies how many digits to cut from the terminal short node keys for a more convinient display
+	Values         bool     // Whether to display value nodes (as box with rounded corners)
+	CodeCompressed bool     // Whether to turn the code from a large rectange to a small square for a more convinient display
+	ValCompressed  bool     // Whether long values (over 10 characters) are shortened using ... in the middle
+	ValHex         bool     // Whether values should be displayed as hex numbers (otherwise they are displayed as just strings)
+	SameLevel      bool     // Whether the leaves (and hashes) need to be on the same horizontal level
 }
 
 // Visual creates visualisation of trie with highlighting
@@ -108,7 +107,7 @@ func visualNode(nd node, hex []byte, w io.Writer, highlights [][]byte, opts *Vis
 				`n_%x -> e_%x;
 `, hex, accountHex)
 			if !a.IsEmptyCodeHash() {
-				if code, ok := opts.CodeMap[a.CodeHash]; ok {
+				if code := a.code; code != nil {
 					codeHex := keybytesToHex(code)
 					codeHex = codeHex[:len(codeHex)-1]
 					visual.HexBox(w, fmt.Sprintf("c_%x", accountHex), codeHex, 32, opts.CodeCompressed, false)
diff --git a/trie/witness_builder.go b/trie/witness_builder.go
index af7251b691..5706076733 100644
--- a/trie/witness_builder.go
+++ b/trie/witness_builder.go
@@ -10,6 +10,7 @@ import (
 type HashNodeFunc func(node, bool, []byte) (int, error)
 type HashOnly interface {
 	HashOnly([]byte) bool
+	IsCodeTouched(common.Hash) bool
 	Current() []byte
 }
 
@@ -18,22 +19,18 @@ type MerklePathLimiter struct {
 	HashFunc HashNodeFunc
 }
 
-type CodeMap map[common.Hash][]byte
-
 type WitnessBuilder struct {
 	root     node
 	blockNr  uint64
 	trace    bool
-	codeMap  CodeMap
 	operands []WitnessOperator
 }
 
-func NewWitnessBuilder(root node, blockNr uint64, trace bool, codeMap CodeMap) *WitnessBuilder {
+func NewWitnessBuilder(root node, blockNr uint64, trace bool) *WitnessBuilder {
 	return &WitnessBuilder{
 		root:     root,
 		blockNr:  blockNr,
 		trace:    trace,
-		codeMap:  codeMap,
 		operands: make([]WitnessOperator, 0),
 	}
 }
@@ -158,17 +155,16 @@ func (b *WitnessBuilder) addEmptyRoot() error {
 	return nil
 }
 
-func (b *WitnessBuilder) processAccountCode(n *accountNode) error {
+func (b *WitnessBuilder) processAccountCode(n *accountNode, hashOnly HashOnly) error {
 	if n.IsEmptyRoot() && n.IsEmptyCodeHash() {
 		return nil
 	}
 
-	code, ok := b.codeMap[n.CodeHash]
-	if !ok {
+	if n.code == nil || !hashOnly.IsCodeTouched(n.CodeHash) {
 		return b.addHashOp(hashNode(n.CodeHash[:]))
 	}
 
-	return b.addCodeOp(code)
+	return b.addCodeOp(n.code)
 }
 
 func (b *WitnessBuilder) processAccountStorage(n *accountNode, hex []byte, limiter *MerklePathLimiter) error {
@@ -188,7 +184,7 @@ func (b *WitnessBuilder) makeBlockWitness(
 	nd node, hex []byte, limiter *MerklePathLimiter, force bool) error {
 
 	processAccountNode := func(key []byte, storageKey []byte, n *accountNode) error {
-		if err := b.processAccountCode(n); err != nil {
+		if err := b.processAccountCode(n, limiter.HashOnly); err != nil {
 			return err
 		}
 		if err := b.processAccountStorage(n, storageKey, limiter); err != nil {
diff --git a/trie/witness_builder_test.go b/trie/witness_builder_test.go
index 9aa492cc86..d319b49b70 100644
--- a/trie/witness_builder_test.go
+++ b/trie/witness_builder_test.go
@@ -18,7 +18,7 @@ func TestBlockWitnessBinary(t *testing.T) {
 	rs := NewBinaryResolveSet(2)
 	rs.AddKey([]byte("ABCD0001"))
 
-	bwb := NewWitnessBuilder(trBin.Trie().root, 1, false, nil)
+	bwb := NewWitnessBuilder(trBin.Trie().root, 1, false)
 
 	hr := newHasher(false)
 	defer returnHasherToPool(hr)
@@ -29,7 +29,7 @@ func TestBlockWitnessBinary(t *testing.T) {
 		t.Errorf("Could not make block witness: %v", err)
 	}
 
-	trBin1, _, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/)
+	trBin1, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/)
 	if err != nil {
 		t.Errorf("Could not restore trie from the block witness: %v", err)
 	}
@@ -57,7 +57,7 @@ func TestBlockWitnessBinaryAccount(t *testing.T) {
 	rs := NewBinaryResolveSet(2)
 	rs.AddKey([]byte("ABCD0001"))
 
-	bwb := NewWitnessBuilder(trBin.Trie().root, 1, false, nil)
+	bwb := NewWitnessBuilder(trBin.Trie().root, 1, false)
 
 	hr := newHasher(false)
 	defer returnHasherToPool(hr)
@@ -68,7 +68,7 @@ func TestBlockWitnessBinaryAccount(t *testing.T) {
 		t.Errorf("Could not make block witness: %v", err)
 	}
 
-	trBin1, _, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/)
+	trBin1, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/)
 	if err != nil {
 		t.Errorf("Could not restore trie from the block witness: %v", err)
 	}
-- 
GitLab