From 777882ff52e38200f0754ea43191ab0cbf5b35db Mon Sep 17 00:00:00 2001
From: b00ris <b00ris@mail.ru>
Date: Tue, 18 May 2021 21:32:04 +0300
Subject: [PATCH] Fix corner case for KV_Snapshot.Next (#1957)

* fix next

* comment test
---
 ethdb/kv_snapshot.go | 67 +++++++++++++++++++++++++++-----------------
 1 file changed, 41 insertions(+), 26 deletions(-)

diff --git a/ethdb/kv_snapshot.go b/ethdb/kv_snapshot.go
index 4fa7f12c4f..ec55f58afc 100644
--- a/ethdb/kv_snapshot.go
+++ b/ethdb/kv_snapshot.go
@@ -515,6 +515,7 @@ func (s *snCursor) SeekExact(key []byte) ([]byte, []byte, error) {
 
 func (s *snCursor) iteration(dbNextElement func() ([]byte, []byte, error), sndbNextElement func() ([]byte, []byte, error), cmpFunc func(kdb, ksndb []byte) (int, bool)) ([]byte, []byte, error) {
 	var err error
+	var noDBNext, noSnDBNext bool
 	//current returns error on empty bucket
 	lastDBKey, lastDBVal, err := s.dbCursor.Current()
 	if err != nil {
@@ -523,11 +524,17 @@ func (s *snCursor) iteration(dbNextElement func() ([]byte, []byte, error), sndbN
 		if innerErr != nil {
 			return nil, nil, fmt.Errorf("get current from db %w inner %v", err, innerErr)
 		}
+		noDBNext = true
 	}
 
 	lastSNDBKey, lastSNDBVal, err := s.snCursor.Current()
 	if err != nil {
-		return nil, nil, err
+		var innerErr error
+		lastSNDBKey, lastSNDBVal, innerErr = sndbNextElement()
+		if innerErr != nil {
+			return nil, nil, fmt.Errorf("get current from snapshot %w inner %v", err, innerErr)
+		}
+		noSnDBNext = true
 	}
 
 	cmp, br := cmpFunc(lastDBKey, lastSNDBKey)
@@ -537,40 +544,48 @@ func (s *snCursor) iteration(dbNextElement func() ([]byte, []byte, error), sndbN
 
 	//todo Seek fastpath
 	if cmp > 0 {
-		lastSNDBKey, lastSNDBVal, err = sndbNextElement()
-		if err != nil {
-			return nil, nil, err
-		}
-		//todo
-		if currentKeyCmp, _ := common.KeyCmp(s.currentKey, lastDBKey); len(lastSNDBKey) == 0 && currentKeyCmp >= 0 && len(s.currentKey) > 0 {
-			lastDBKey, lastDBVal, err = dbNextElement()
-		}
-		if err != nil {
-			return nil, nil, err
+		if !noSnDBNext {
+			lastSNDBKey, lastSNDBVal, err = sndbNextElement()
+			if err != nil {
+				return nil, nil, err
+			}
+
+			if currentKeyCmp, _ := common.KeyCmp(s.currentKey, lastDBKey); len(lastSNDBKey) == 0 && currentKeyCmp >= 0 && len(s.currentKey) > 0 {
+				lastDBKey, lastDBVal, err = dbNextElement()
+			}
+			if err != nil {
+				return nil, nil, err
+			}
 		}
 	}
 
 	//current receives last acceptable key. If it is empty
 	if cmp < 0 {
-		lastDBKey, lastDBVal, err = dbNextElement()
-		if err != nil {
-			return nil, nil, err
-		}
-		if currentKeyCmp, _ := common.KeyCmp(s.currentKey, lastSNDBKey); len(lastDBKey) == 0 && currentKeyCmp >= 0 && len(s.currentKey) > 0 {
-			lastSNDBKey, lastSNDBVal, err = sndbNextElement()
-		}
-		if err != nil {
-			return nil, nil, err
+		if !noDBNext {
+			lastDBKey, lastDBVal, err = dbNextElement()
+			if err != nil {
+				return nil, nil, err
+			}
+			if currentKeyCmp, _ := common.KeyCmp(s.currentKey, lastSNDBKey); len(lastDBKey) == 0 && currentKeyCmp >= 0 && len(s.currentKey) > 0 {
+				lastSNDBKey, lastSNDBVal, err = sndbNextElement()
+			}
+			if err != nil {
+				return nil, nil, err
+			}
 		}
 	}
 	if cmp == 0 {
-		lastDBKey, lastDBVal, err = dbNextElement()
-		if err != nil {
-			return nil, nil, err
+		if !noDBNext {
+			lastDBKey, lastDBVal, err = dbNextElement()
+			if err != nil {
+				return nil, nil, err
+			}
 		}
-		lastSNDBKey, lastSNDBVal, err = sndbNextElement()
-		if err != nil {
-			return nil, nil, err
+		if !noSnDBNext {
+			lastSNDBKey, lastSNDBVal, err = sndbNextElement()
+			if err != nil {
+				return nil, nil, err
+			}
 		}
 	}
 
-- 
GitLab