diff --git a/core/blockchain.go b/core/blockchain.go
index 4ca618c5bf3a6adf802c94a4094f48b09c547230..9fa5b09f95dd658cfacec2acaa4ea362015780bc 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -1543,8 +1543,16 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
 	for _, tx := range types.TxDifference(deletedTxs, addedTxs) {
 		rawdb.DeleteTxLookupEntry(batch, tx.Hash())
 	}
+	// Delete any canonical number assignments above the new head
+	number := bc.CurrentBlock().NumberU64()
+	for i := number + 1; ; i++ {
+		hash := rawdb.ReadCanonicalHash(bc.db, i)
+		if hash == (common.Hash{}) {
+			break
+		}
+		rawdb.DeleteCanonicalHash(batch, i)
+	}
 	batch.Write()
-
 	// If any logs need to be fired, do it now. In theory we could avoid creating
 	// this goroutine if there are no events to fire, but realistcally that only
 	// ever happens if we're reorging empty blocks, which will only happen on idle
diff --git a/core/blockchain_test.go b/core/blockchain_test.go
index 5ee1d9f8eb3577d17024220f9c77da901ef1fde1..80a949d904bc19940f63901b09a466d73ac954e8 100644
--- a/core/blockchain_test.go
+++ b/core/blockchain_test.go
@@ -17,6 +17,7 @@
 package core
 
 import (
+	"fmt"
 	"math/big"
 	"math/rand"
 	"sync"
@@ -1810,3 +1811,123 @@ func TestPrunedImportSide(t *testing.T) {
 	testSideImport(t, 1, 10)
 	testSideImport(t, 1, -10)
 }
+
+// getLongAndShortChains returns two chains,
+// A is longer, B is heavier
+func getLongAndShortChains() (*BlockChain, []*types.Block, []*types.Block, error) {
+	// Generate a canonical chain to act as the main dataset
+	engine := ethash.NewFaker()
+	db := rawdb.NewMemoryDatabase()
+	genesis := new(Genesis).MustCommit(db)
+
+	// Generate and import the canonical chain,
+	// Offset the time, to keep the difficulty low
+	longChain, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 80, func(i int, b *BlockGen) {
+		b.SetCoinbase(common.Address{1})
+	})
+	diskdb := rawdb.NewMemoryDatabase()
+	new(Genesis).MustCommit(diskdb)
+
+	chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{}, nil)
+	if err != nil {
+		return nil, nil, nil, fmt.Errorf("failed to create tester chain: %v", err)
+	}
+
+	// Generate fork chain, make it shorter than canon, with common ancestor pretty early
+	parentIndex := 3
+	parent := longChain[parentIndex]
+	heavyChain, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 75, func(i int, b *BlockGen) {
+		b.SetCoinbase(common.Address{2})
+		b.OffsetTime(-9)
+	})
+	// Verify that the test is sane
+	var (
+		longerTd  = new(big.Int)
+		shorterTd = new(big.Int)
+	)
+	for index, b := range longChain {
+		longerTd.Add(longerTd, b.Difficulty())
+		if index <= parentIndex {
+			shorterTd.Add(shorterTd, b.Difficulty())
+		}
+	}
+	for _, b := range heavyChain {
+		shorterTd.Add(shorterTd, b.Difficulty())
+	}
+	if shorterTd.Cmp(longerTd) <= 0 {
+		return nil, nil, nil, fmt.Errorf("Test is moot, heavyChain td (%v) must be larger than canon td (%v)", shorterTd, longerTd)
+	}
+	longerNum := longChain[len(longChain)-1].NumberU64()
+	shorterNum := heavyChain[len(heavyChain)-1].NumberU64()
+	if shorterNum >= longerNum {
+		return nil, nil, nil, fmt.Errorf("Test is moot, heavyChain num (%v) must be lower than canon num (%v)", shorterNum, longerNum)
+	}
+	return chain, longChain, heavyChain, nil
+}
+
+// TestReorgToShorterRemovesCanonMapping tests that if we
+// 1. Have a chain [0 ... N .. X]
+// 2. Reorg to shorter but heavier chain [0 ... N ... Y]
+// 3. Then there should be no canon mapping for the block at height X
+func TestReorgToShorterRemovesCanonMapping(t *testing.T) {
+	chain, canonblocks, sideblocks, err := getLongAndShortChains()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if n, err := chain.InsertChain(canonblocks); err != nil {
+		t.Fatalf("block %d: failed to insert into chain: %v", n, err)
+	}
+	canonNum := chain.CurrentBlock().NumberU64()
+	_, err = chain.InsertChain(sideblocks)
+	if err != nil {
+		t.Errorf("Got error, %v", err)
+	}
+	head := chain.CurrentBlock()
+	if got := sideblocks[len(sideblocks)-1].Hash(); got != head.Hash() {
+		t.Fatalf("head wrong, expected %x got %x", head.Hash(), got)
+	}
+	// We have now inserted a sidechain.
+	if blockByNum := chain.GetBlockByNumber(canonNum); blockByNum != nil {
+		t.Errorf("expected block to be gone: %v", blockByNum.NumberU64())
+	}
+	if headerByNum := chain.GetHeaderByNumber(canonNum); headerByNum != nil {
+		t.Errorf("expected header to be gone: %v", headerByNum.Number.Uint64())
+	}
+}
+
+// TestReorgToShorterRemovesCanonMappingHeaderChain is the same scenario
+// as TestReorgToShorterRemovesCanonMapping, but applied on headerchain
+// imports -- that is, for fast sync
+func TestReorgToShorterRemovesCanonMappingHeaderChain(t *testing.T) {
+	chain, canonblocks, sideblocks, err := getLongAndShortChains()
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Convert into headers
+	canonHeaders := make([]*types.Header, len(canonblocks))
+	for i, block := range canonblocks {
+		canonHeaders[i] = block.Header()
+	}
+	if n, err := chain.InsertHeaderChain(canonHeaders, 0); err != nil {
+		t.Fatalf("header %d: failed to insert into chain: %v", n, err)
+	}
+	canonNum := chain.CurrentHeader().Number.Uint64()
+	sideHeaders := make([]*types.Header, len(sideblocks))
+	for i, block := range sideblocks {
+		sideHeaders[i] = block.Header()
+	}
+	if n, err := chain.InsertHeaderChain(sideHeaders, 0); err != nil {
+		t.Fatalf("header %d: failed to insert into chain: %v", n, err)
+	}
+	head := chain.CurrentHeader()
+	if got := sideblocks[len(sideblocks)-1].Hash(); got != head.Hash() {
+		t.Fatalf("head wrong, expected %x got %x", head.Hash(), got)
+	}
+	// We have now inserted a sidechain.
+	if blockByNum := chain.GetBlockByNumber(canonNum); blockByNum != nil {
+		t.Errorf("expected block to be gone: %v", blockByNum.NumberU64())
+	}
+	if headerByNum := chain.GetHeaderByNumber(canonNum); headerByNum != nil {
+		t.Errorf("expected header to be gone: %v", headerByNum.Number.Uint64())
+	}
+}