From dcdd57df6282a6cd43a6407e8626a5cdcca60482 Mon Sep 17 00:00:00 2001
From: gary rong <garyrong0905@gmail.com>
Date: Wed, 18 Jul 2018 18:41:36 +0800
Subject: [PATCH] core, ethdb: two tiny fixes (#17183)

* ethdb: fix memory database

* core: fix bloombits checking

* core: minor polish
---
 core/bloombits/generator.go | 30 ++++++++++++++++++------------
 ethdb/database_test.go      | 22 ++++++++++++++++++++++
 ethdb/memory_database.go    | 12 ++++++++----
 3 files changed, 48 insertions(+), 16 deletions(-)

diff --git a/core/bloombits/generator.go b/core/bloombits/generator.go
index 540085450..ae07481ad 100644
--- a/core/bloombits/generator.go
+++ b/core/bloombits/generator.go
@@ -22,16 +22,22 @@ import (
 	"github.com/ethereum/go-ethereum/core/types"
 )
 
-// errSectionOutOfBounds is returned if the user tried to add more bloom filters
-// to the batch than available space, or if tries to retrieve above the capacity,
-var errSectionOutOfBounds = errors.New("section out of bounds")
+var (
+	// errSectionOutOfBounds is returned if the user tried to add more bloom filters
+	// to the batch than available space, or if tries to retrieve above the capacity.
+	errSectionOutOfBounds = errors.New("section out of bounds")
+
+	// errBloomBitOutOfBounds is returned if the user tried to retrieve specified
+	// bit bloom above the capacity.
+	errBloomBitOutOfBounds = errors.New("bloom bit out of bounds")
+)
 
 // Generator takes a number of bloom filters and generates the rotated bloom bits
 // to be used for batched filtering.
 type Generator struct {
 	blooms   [types.BloomBitLength][]byte // Rotated blooms for per-bit matching
 	sections uint                         // Number of sections to batch together
-	nextBit  uint                         // Next bit to set when adding a bloom
+	nextSec  uint                         // Next section to set when adding a bloom
 }
 
 // NewGenerator creates a rotated bloom generator that can iteratively fill a
@@ -51,15 +57,15 @@ func NewGenerator(sections uint) (*Generator, error) {
 // in memory accordingly.
 func (b *Generator) AddBloom(index uint, bloom types.Bloom) error {
 	// Make sure we're not adding more bloom filters than our capacity
-	if b.nextBit >= b.sections {
+	if b.nextSec >= b.sections {
 		return errSectionOutOfBounds
 	}
-	if b.nextBit != index {
+	if b.nextSec != index {
 		return errors.New("bloom filter with unexpected index")
 	}
 	// Rotate the bloom and insert into our collection
-	byteIndex := b.nextBit / 8
-	bitMask := byte(1) << byte(7-b.nextBit%8)
+	byteIndex := b.nextSec / 8
+	bitMask := byte(1) << byte(7-b.nextSec%8)
 
 	for i := 0; i < types.BloomBitLength; i++ {
 		bloomByteIndex := types.BloomByteLength - 1 - i/8
@@ -69,7 +75,7 @@ func (b *Generator) AddBloom(index uint, bloom types.Bloom) error {
 			b.blooms[i][byteIndex] |= bitMask
 		}
 	}
-	b.nextBit++
+	b.nextSec++
 
 	return nil
 }
@@ -77,11 +83,11 @@ func (b *Generator) AddBloom(index uint, bloom types.Bloom) error {
 // Bitset returns the bit vector belonging to the given bit index after all
 // blooms have been added.
 func (b *Generator) Bitset(idx uint) ([]byte, error) {
-	if b.nextBit != b.sections {
+	if b.nextSec != b.sections {
 		return nil, errors.New("bloom not fully generated yet")
 	}
-	if idx >= b.sections {
-		return nil, errSectionOutOfBounds
+	if idx >= types.BloomBitLength {
+		return nil, errBloomBitOutOfBounds
 	}
 	return b.blooms[idx], nil
 }
diff --git a/ethdb/database_test.go b/ethdb/database_test.go
index 2deb50988..74675cbe6 100644
--- a/ethdb/database_test.go
+++ b/ethdb/database_test.go
@@ -59,6 +59,28 @@ func TestMemoryDB_PutGet(t *testing.T) {
 func testPutGet(db ethdb.Database, t *testing.T) {
 	t.Parallel()
 
+	for _, k := range test_values {
+		err := db.Put([]byte(k), nil)
+		if err != nil {
+			t.Fatalf("put failed: %v", err)
+		}
+	}
+
+	for _, k := range test_values {
+		data, err := db.Get([]byte(k))
+		if err != nil {
+			t.Fatalf("get failed: %v", err)
+		}
+		if len(data) != 0 {
+			t.Fatalf("get returned wrong result, got %q expected nil", string(data))
+		}
+	}
+
+	_, err := db.Get([]byte("non-exist-key"))
+	if err == nil {
+		t.Fatalf("expect to return a not found error")
+	}
+
 	for _, v := range test_values {
 		err := db.Put([]byte(v), []byte(v))
 		if err != nil {
diff --git a/ethdb/memory_database.go b/ethdb/memory_database.go
index f28ff5481..727f2f7ca 100644
--- a/ethdb/memory_database.go
+++ b/ethdb/memory_database.go
@@ -96,7 +96,10 @@ func (db *MemDatabase) NewBatch() Batch {
 
 func (db *MemDatabase) Len() int { return len(db.db) }
 
-type kv struct{ k, v []byte }
+type kv struct {
+	k, v []byte
+	del  bool
+}
 
 type memBatch struct {
 	db     *MemDatabase
@@ -105,13 +108,14 @@ type memBatch struct {
 }
 
 func (b *memBatch) Put(key, value []byte) error {
-	b.writes = append(b.writes, kv{common.CopyBytes(key), common.CopyBytes(value)})
+	b.writes = append(b.writes, kv{common.CopyBytes(key), common.CopyBytes(value), false})
 	b.size += len(value)
 	return nil
 }
 
 func (b *memBatch) Delete(key []byte) error {
-	b.writes = append(b.writes, kv{common.CopyBytes(key), nil})
+	b.writes = append(b.writes, kv{common.CopyBytes(key), nil, true})
+	b.size += 1
 	return nil
 }
 
@@ -120,7 +124,7 @@ func (b *memBatch) Write() error {
 	defer b.db.lock.Unlock()
 
 	for _, kv := range b.writes {
-		if kv.v == nil {
+		if kv.del {
 			delete(b.db.db, string(kv.k))
 			continue
 		}
-- 
GitLab