From cbd3ae6906ece36b1b3e5e7af4d7cb55e784818a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= <peterke@gmail.com>
Date: Thu, 21 May 2015 19:41:46 +0300
Subject: [PATCH] p2p/discover: fix #838, evacuate self entries from the node
 db

---
 p2p/discover/database.go      | 47 +++++++++++++++-------
 p2p/discover/database_test.go | 73 +++++++++++++++++++++++++++++++----
 p2p/discover/table.go         |  4 +-
 3 files changed, 99 insertions(+), 25 deletions(-)

diff --git a/p2p/discover/database.go b/p2p/discover/database.go
index 2b9da0423..3a3f1254b 100644
--- a/p2p/discover/database.go
+++ b/p2p/discover/database.go
@@ -33,6 +33,8 @@ type nodeDB struct {
 	lvl    *leveldb.DB       // Interface to the database itself
 	seeder iterator.Iterator // Iterator for fetching possible seed nodes
 
+	self NodeID // Own node id to prevent adding it into the database
+
 	runner sync.Once     // Ensures we can start at most one expirer
 	quit   chan struct{} // Channel to signal the expiring thread to stop
 }
@@ -50,29 +52,30 @@ var (
 // newNodeDB creates a new node database for storing and retrieving infos about
 // known peers in the network. If no path is given, an in-memory, temporary
 // database is constructed.
-func newNodeDB(path string, version int) (*nodeDB, error) {
+func newNodeDB(path string, version int, self NodeID) (*nodeDB, error) {
 	if path == "" {
-		return newMemoryNodeDB()
+		return newMemoryNodeDB(self)
 	}
-	return newPersistentNodeDB(path, version)
+	return newPersistentNodeDB(path, version, self)
 }
 
 // newMemoryNodeDB creates a new in-memory node database without a persistent
 // backend.
-func newMemoryNodeDB() (*nodeDB, error) {
+func newMemoryNodeDB(self NodeID) (*nodeDB, error) {
 	db, err := leveldb.Open(storage.NewMemStorage(), nil)
 	if err != nil {
 		return nil, err
 	}
 	return &nodeDB{
 		lvl:  db,
+		self: self,
 		quit: make(chan struct{}),
 	}, nil
 }
 
 // newPersistentNodeDB creates/opens a leveldb backed persistent node database,
 // also flushing its contents in case of a version mismatch.
-func newPersistentNodeDB(path string, version int) (*nodeDB, error) {
+func newPersistentNodeDB(path string, version int, self NodeID) (*nodeDB, error) {
 	opts := &opt.Options{OpenFilesCacheCapacity: 5}
 	db, err := leveldb.OpenFile(path, opts)
 	if _, iscorrupted := err.(*errors.ErrCorrupted); iscorrupted {
@@ -102,11 +105,12 @@ func newPersistentNodeDB(path string, version int) (*nodeDB, error) {
 			if err = os.RemoveAll(path); err != nil {
 				return nil, err
 			}
-			return newPersistentNodeDB(path, version)
+			return newPersistentNodeDB(path, version, self)
 		}
 	}
 	return &nodeDB{
 		lvl:  db,
+		self: self,
 		quit: make(chan struct{}),
 	}, nil
 }
@@ -182,6 +186,17 @@ func (db *nodeDB) updateNode(node *Node) error {
 	return db.lvl.Put(makeKey(node.ID, nodeDBDiscoverRoot), blob, nil)
 }
 
+// deleteNode deletes all information/keys associated with a node.
+func (db *nodeDB) deleteNode(id NodeID) error {
+	deleter := db.lvl.NewIterator(util.BytesPrefix(makeKey(id, "")), nil)
+	for deleter.Next() {
+		if err := db.lvl.Delete(deleter.Key(), nil); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 // ensureExpirer is a small helper method ensuring that the data expiration
 // mechanism is running. If the expiration goroutine is already running, this
 // method simply returns.
@@ -227,17 +242,14 @@ func (db *nodeDB) expireNodes() error {
 		if field != nodeDBDiscoverRoot {
 			continue
 		}
-		// Skip the node if not expired yet
-		if seen := db.lastPong(id); seen.After(threshold) {
-			continue
-		}
-		// Otherwise delete all associated information
-		deleter := db.lvl.NewIterator(util.BytesPrefix(makeKey(id, "")), nil)
-		for deleter.Next() {
-			if err := db.lvl.Delete(deleter.Key(), nil); err != nil {
-				return err
+		// Skip the node if not expired yet (and not self)
+		if bytes.Compare(id[:], db.self[:]) != 0 {
+			if seen := db.lastPong(id); seen.After(threshold) {
+				continue
 			}
 		}
+		// Otherwise delete all associated information
+		db.deleteNode(id)
 	}
 	return nil
 }
@@ -286,6 +298,11 @@ func (db *nodeDB) querySeeds(n int) []*Node {
 		if field != nodeDBDiscoverRoot {
 			continue
 		}
+		// Dump it if its a self reference
+		if bytes.Compare(id[:], db.self[:]) == 0 {
+			db.deleteNode(id)
+			continue
+		}
 		// Load it as a potential seed
 		if node := db.node(id); node != nil {
 			nodes = append(nodes, node)
diff --git a/p2p/discover/database_test.go b/p2p/discover/database_test.go
index 9c543cd5f..88f5d2155 100644
--- a/p2p/discover/database_test.go
+++ b/p2p/discover/database_test.go
@@ -63,7 +63,7 @@ var nodeDBInt64Tests = []struct {
 }
 
 func TestNodeDBInt64(t *testing.T) {
-	db, _ := newNodeDB("", Version)
+	db, _ := newNodeDB("", Version, NodeID{})
 	defer db.close()
 
 	tests := nodeDBInt64Tests
@@ -94,7 +94,7 @@ func TestNodeDBFetchStore(t *testing.T) {
 	)
 	inst := time.Now()
 
-	db, _ := newNodeDB("", Version)
+	db, _ := newNodeDB("", Version, NodeID{})
 	defer db.close()
 
 	// Check fetch/store operations on a node ping object
@@ -165,7 +165,7 @@ var nodeDBSeedQueryNodes = []struct {
 }
 
 func TestNodeDBSeedQuery(t *testing.T) {
-	db, _ := newNodeDB("", Version)
+	db, _ := newNodeDB("", Version, NodeID{})
 	defer db.close()
 
 	// Insert a batch of nodes for querying
@@ -205,7 +205,7 @@ func TestNodeDBSeedQuery(t *testing.T) {
 }
 
 func TestNodeDBSeedQueryContinuation(t *testing.T) {
-	db, _ := newNodeDB("", Version)
+	db, _ := newNodeDB("", Version, NodeID{})
 	defer db.close()
 
 	// Insert a batch of nodes for querying
@@ -230,6 +230,32 @@ func TestNodeDBSeedQueryContinuation(t *testing.T) {
 	}
 }
 
+func TestNodeDBSelfSeedQuery(t *testing.T) {
+	// Assign a node as self to verify evacuation
+	self := nodeDBSeedQueryNodes[0].node.ID
+	db, _ := newNodeDB("", Version, self)
+	defer db.close()
+
+	// Insert a batch of nodes for querying
+	for i, seed := range nodeDBSeedQueryNodes {
+		if err := db.updateNode(seed.node); err != nil {
+			t.Fatalf("node %d: failed to insert: %v", i, err)
+		}
+	}
+	// Retrieve the entire batch and check that self was evacuated
+	seeds := db.querySeeds(2 * len(nodeDBSeedQueryNodes))
+	if len(seeds) != len(nodeDBSeedQueryNodes)-1 {
+		t.Errorf("seed count mismatch: have %v, want %v", len(seeds), len(nodeDBSeedQueryNodes)-1)
+	}
+	have := make(map[NodeID]struct{})
+	for _, seed := range seeds {
+		have[seed.ID] = struct{}{}
+	}
+	if _, ok := have[self]; ok {
+		t.Errorf("self not evacuated")
+	}
+}
+
 func TestNodeDBPersistency(t *testing.T) {
 	root, err := ioutil.TempDir("", "nodedb-")
 	if err != nil {
@@ -243,7 +269,7 @@ func TestNodeDBPersistency(t *testing.T) {
 	)
 
 	// Create a persistent database and store some values
-	db, err := newNodeDB(filepath.Join(root, "database"), Version)
+	db, err := newNodeDB(filepath.Join(root, "database"), Version, NodeID{})
 	if err != nil {
 		t.Fatalf("failed to create persistent database: %v", err)
 	}
@@ -253,7 +279,7 @@ func TestNodeDBPersistency(t *testing.T) {
 	db.close()
 
 	// Reopen the database and check the value
-	db, err = newNodeDB(filepath.Join(root, "database"), Version)
+	db, err = newNodeDB(filepath.Join(root, "database"), Version, NodeID{})
 	if err != nil {
 		t.Fatalf("failed to open persistent database: %v", err)
 	}
@@ -263,7 +289,7 @@ func TestNodeDBPersistency(t *testing.T) {
 	db.close()
 
 	// Change the database version and check flush
-	db, err = newNodeDB(filepath.Join(root, "database"), Version+1)
+	db, err = newNodeDB(filepath.Join(root, "database"), Version+1, NodeID{})
 	if err != nil {
 		t.Fatalf("failed to open persistent database: %v", err)
 	}
@@ -300,7 +326,7 @@ var nodeDBExpirationNodes = []struct {
 }
 
 func TestNodeDBExpiration(t *testing.T) {
-	db, _ := newNodeDB("", Version)
+	db, _ := newNodeDB("", Version, NodeID{})
 	defer db.close()
 
 	// Add all the test nodes and set their last pong time
@@ -323,3 +349,34 @@ func TestNodeDBExpiration(t *testing.T) {
 		}
 	}
 }
+
+func TestNodeDBSelfExpiration(t *testing.T) {
+	// Find a node in the tests that shouldn't expire, and assign it as self
+	var self NodeID
+	for _, node := range nodeDBExpirationNodes {
+		if !node.exp {
+			self = node.node.ID
+			break
+		}
+	}
+	db, _ := newNodeDB("", Version, self)
+	defer db.close()
+
+	// Add all the test nodes and set their last pong time
+	for i, seed := range nodeDBExpirationNodes {
+		if err := db.updateNode(seed.node); err != nil {
+			t.Fatalf("node %d: failed to insert: %v", i, err)
+		}
+		if err := db.updateLastPong(seed.node.ID, seed.pong); err != nil {
+			t.Fatalf("node %d: failed to update pong: %v", i, err)
+		}
+	}
+	// Expire the nodes and make sure self has been evacuated too
+	if err := db.expireNodes(); err != nil {
+		t.Fatalf("failed to expire nodes: %v", err)
+	}
+	node := db.node(self)
+	if node != nil {
+		t.Errorf("self not evacuated")
+	}
+}
diff --git a/p2p/discover/table.go b/p2p/discover/table.go
index 5e6dd8d0d..91d617f05 100644
--- a/p2p/discover/table.go
+++ b/p2p/discover/table.go
@@ -68,10 +68,10 @@ type bucket struct {
 
 func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr, nodeDBPath string) *Table {
 	// If no node database was given, use an in-memory one
-	db, err := newNodeDB(nodeDBPath, Version)
+	db, err := newNodeDB(nodeDBPath, Version, ourID)
 	if err != nil {
 		glog.V(logger.Warn).Infoln("Failed to open node database:", err)
-		db, _ = newNodeDB("", Version)
+		db, _ = newNodeDB("", Version, ourID)
 	}
 	tab := &Table{
 		net:       t,
-- 
GitLab