From b797dd07d2f2cf0868d4fe79e120d5cf0b8fdc0b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jano=C5=A1=20Gulja=C5=A1?= <janos@users.noreply.github.com>
Date: Sat, 2 Mar 2019 08:44:22 +0100
Subject: [PATCH] swarm/shed, swarm/storage/localstore: add
 LastPullSubscriptionChunk (#19190)

* swarm/shed, swarm/storage/localstore: add LastPullSubscriptionChunk

* swarm/shed: fix comments

* swarm/shed: fix TestIncByteSlice test

* swarm/storage/localstore: fix TestDB_LastPullSubscriptionChunk
---
 swarm/shed/db_test.go                         |   8 +-
 swarm/shed/index.go                           |  99 ++++++++++--
 swarm/shed/index_test.go                      | 146 ++++++++++++++++++
 swarm/storage/localstore/subscription_pull.go |  18 +++
 .../localstore/subscription_pull_test.go      |  72 +++++++++
 5 files changed, 328 insertions(+), 15 deletions(-)

diff --git a/swarm/shed/db_test.go b/swarm/shed/db_test.go
index 65fdac4a6..4e8276f74 100644
--- a/swarm/shed/db_test.go
+++ b/swarm/shed/db_test.go
@@ -100,11 +100,13 @@ func newTestDB(t *testing.T) (db *DB, cleanupFunc func()) {
 	if err != nil {
 		t.Fatal(err)
 	}
-	cleanupFunc = func() { os.RemoveAll(dir) }
 	db, err = NewDB(dir, "")
 	if err != nil {
-		cleanupFunc()
+		os.RemoveAll(dir)
 		t.Fatal(err)
 	}
-	return db, cleanupFunc
+	return db, func() {
+		db.Close()
+		os.RemoveAll(dir)
+	}
 }
diff --git a/swarm/shed/index.go b/swarm/shed/index.go
index df88b1b62..d02bf1a00 100644
--- a/swarm/shed/index.go
+++ b/swarm/shed/index.go
@@ -20,6 +20,7 @@ import (
 	"bytes"
 
 	"github.com/syndtr/goleveldb/leveldb"
+	"github.com/syndtr/goleveldb/leveldb/iterator"
 )
 
 // Item holds fields relevant to Swarm Chunk data and metadata.
@@ -245,21 +246,14 @@ func (f Index) Iterate(fn IndexIterFunc, options *IterateOptions) (err error) {
 		ok = it.Next()
 	}
 	for ; ok; ok = it.Next() {
-		key := it.Key()
-		if !bytes.HasPrefix(key, prefix) {
-			break
-		}
-		// create a copy of key byte slice not to share leveldb underlaying slice array
-		keyItem, err := f.decodeKeyFunc(append([]byte(nil), key...))
-		if err != nil {
-			return err
-		}
-		// create a copy of value byte slice not to share leveldb underlaying slice array
-		valueItem, err := f.decodeValueFunc(keyItem, append([]byte(nil), it.Value()...))
+		item, err := f.itemFromIterator(it, prefix)
 		if err != nil {
+			if err == leveldb.ErrNotFound {
+				break
+			}
 			return err
 		}
-		stop, err := fn(keyItem.Merge(valueItem))
+		stop, err := fn(item)
 		if err != nil {
 			return err
 		}
@@ -270,6 +264,87 @@ func (f Index) Iterate(fn IndexIterFunc, options *IterateOptions) (err error) {
 	return it.Error()
 }
 
+// First returns the first item in the Index which encoded key starts with a prefix.
+// If the prefix is nil, the first element of the whole index is returned.
+// If Index has no elements, a leveldb.ErrNotFound error is returned.
+func (f Index) First(prefix []byte) (i Item, err error) {
+	it := f.db.NewIterator()
+	defer it.Release()
+
+	totalPrefix := append(f.prefix, prefix...)
+	it.Seek(totalPrefix)
+
+	return f.itemFromIterator(it, totalPrefix)
+}
+
+// itemFromIterator returns the Item from the current iterator position.
+// If the complete encoded key does not start with totalPrefix,
+// leveldb.ErrNotFound is returned. Value for totalPrefix must start with
+// Index prefix.
+func (f Index) itemFromIterator(it iterator.Iterator, totalPrefix []byte) (i Item, err error) {
+	key := it.Key()
+	if !bytes.HasPrefix(key, totalPrefix) {
+		return i, leveldb.ErrNotFound
+	}
+	// create a copy of key byte slice not to share leveldb underlaying slice array
+	keyItem, err := f.decodeKeyFunc(append([]byte(nil), key...))
+	if err != nil {
+		return i, err
+	}
+	// create a copy of value byte slice not to share leveldb underlaying slice array
+	valueItem, err := f.decodeValueFunc(keyItem, append([]byte(nil), it.Value()...))
+	if err != nil {
+		return i, err
+	}
+	return keyItem.Merge(valueItem), it.Error()
+}
+
+// Last returns the last item in the Index which encoded key starts with a prefix.
+// If the prefix is nil, the last element of the whole index is returned.
+// If Index has no elements, a leveldb.ErrNotFound error is returned.
+func (f Index) Last(prefix []byte) (i Item, err error) {
+	it := f.db.NewIterator()
+	defer it.Release()
+
+	// get the next prefix in line
+	// since leveldb iterator Seek seeks to the
+	// next key if the key that it seeks to is not found
+	// and by getting the previous key, the last one for the
+	// actual prefix is found
+	nextPrefix := incByteSlice(prefix)
+	l := len(prefix)
+
+	if l > 0 && nextPrefix != nil {
+		it.Seek(append(f.prefix, nextPrefix...))
+		it.Prev()
+	} else {
+		it.Last()
+	}
+
+	totalPrefix := append(f.prefix, prefix...)
+	return f.itemFromIterator(it, totalPrefix)
+}
+
+// incByteSlice returns the byte slice of the same size
+// of the provided one that is by one incremented in its
+// total value. If all bytes in provided slice are equal
+// to 255 a nil slice would be returned indicating that
+// increment can not happen for the same length.
+func incByteSlice(b []byte) (next []byte) {
+	l := len(b)
+	next = make([]byte, l)
+	copy(next, b)
+	for i := l - 1; i >= 0; i-- {
+		if b[i] == 255 {
+			next[i] = 0
+		} else {
+			next[i] = b[i] + 1
+			return next
+		}
+	}
+	return nil
+}
+
 // Count returns the number of items in index.
 func (f Index) Count() (count int, err error) {
 	it := f.db.NewIterator()
diff --git a/swarm/shed/index_test.go b/swarm/shed/index_test.go
index 97d7c91f4..489f001bd 100644
--- a/swarm/shed/index_test.go
+++ b/swarm/shed/index_test.go
@@ -779,3 +779,149 @@ func checkItem(t *testing.T, got, want Item) {
 		t.Errorf("got access timestamp %v, expected %v", got.AccessTimestamp, want.AccessTimestamp)
 	}
 }
+
+// TestIndex_firstAndLast validates that index First and Last methods
+// are returning expected results based on the provided prefix.
+func TestIndex_firstAndLast(t *testing.T) {
+	db, cleanupFunc := newTestDB(t)
+	defer cleanupFunc()
+
+	index, err := db.NewIndex("retrieval", retrievalIndexFuncs)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	addrs := [][]byte{
+		{0, 0, 0, 0, 0},
+		{0, 1},
+		{0, 1, 0, 0, 0},
+		{0, 1, 0, 0, 1},
+		{0, 1, 0, 0, 2},
+		{0, 2, 0, 0, 1},
+		{0, 4, 0, 0, 0},
+		{0, 10, 0, 0, 10},
+		{0, 10, 0, 0, 11},
+		{0, 10, 0, 0, 20},
+		{1, 32, 255, 0, 1},
+		{1, 32, 255, 0, 2},
+		{1, 32, 255, 0, 3},
+		{255, 255, 255, 255, 32},
+		{255, 255, 255, 255, 64},
+		{255, 255, 255, 255, 255},
+	}
+
+	// ensure that the addresses are sorted for
+	// validation of nil prefix
+	sort.Slice(addrs, func(i, j int) (less bool) {
+		return bytes.Compare(addrs[i], addrs[j]) == -1
+	})
+
+	batch := new(leveldb.Batch)
+	for _, addr := range addrs {
+		index.PutInBatch(batch, Item{
+			Address: addr,
+		})
+	}
+	err = db.WriteBatch(batch)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for _, tc := range []struct {
+		prefix []byte
+		first  []byte
+		last   []byte
+		err    error
+	}{
+		{
+			prefix: nil,
+			first:  addrs[0],
+			last:   addrs[len(addrs)-1],
+		},
+		{
+			prefix: []byte{0, 0},
+			first:  []byte{0, 0, 0, 0, 0},
+			last:   []byte{0, 0, 0, 0, 0},
+		},
+		{
+			prefix: []byte{0},
+			first:  []byte{0, 0, 0, 0, 0},
+			last:   []byte{0, 10, 0, 0, 20},
+		},
+		{
+			prefix: []byte{0, 1},
+			first:  []byte{0, 1},
+			last:   []byte{0, 1, 0, 0, 2},
+		},
+		{
+			prefix: []byte{0, 10},
+			first:  []byte{0, 10, 0, 0, 10},
+			last:   []byte{0, 10, 0, 0, 20},
+		},
+		{
+			prefix: []byte{1, 32, 255},
+			first:  []byte{1, 32, 255, 0, 1},
+			last:   []byte{1, 32, 255, 0, 3},
+		},
+		{
+			prefix: []byte{255},
+			first:  []byte{255, 255, 255, 255, 32},
+			last:   []byte{255, 255, 255, 255, 255},
+		},
+		{
+			prefix: []byte{255, 255, 255, 255, 255},
+			first:  []byte{255, 255, 255, 255, 255},
+			last:   []byte{255, 255, 255, 255, 255},
+		},
+		{
+			prefix: []byte{0, 3},
+			err:    leveldb.ErrNotFound,
+		},
+		{
+			prefix: []byte{222},
+			err:    leveldb.ErrNotFound,
+		},
+	} {
+		got, err := index.Last(tc.prefix)
+		if tc.err != err {
+			t.Errorf("got error %v for Last with prefix %v, want %v", err, tc.prefix, tc.err)
+		} else {
+			if !bytes.Equal(got.Address, tc.last) {
+				t.Errorf("got %v for Last with prefix %v, want %v", got.Address, tc.prefix, tc.last)
+			}
+		}
+
+		got, err = index.First(tc.prefix)
+		if tc.err != err {
+			t.Errorf("got error %v for First with prefix %v, want %v", err, tc.prefix, tc.err)
+		} else {
+			if !bytes.Equal(got.Address, tc.first) {
+				t.Errorf("got %v for First with prefix %v, want %v", got.Address, tc.prefix, tc.first)
+			}
+		}
+	}
+}
+
+// TestIncByteSlice validates returned values of incByteSlice function.
+func TestIncByteSlice(t *testing.T) {
+	for _, tc := range []struct {
+		b    []byte
+		want []byte
+	}{
+		{b: nil, want: nil},
+		{b: []byte{}, want: nil},
+		{b: []byte{0}, want: []byte{1}},
+		{b: []byte{42}, want: []byte{43}},
+		{b: []byte{255}, want: nil},
+		{b: []byte{0, 0}, want: []byte{0, 1}},
+		{b: []byte{1, 0}, want: []byte{1, 1}},
+		{b: []byte{1, 255}, want: []byte{2, 0}},
+		{b: []byte{255, 255}, want: nil},
+		{b: []byte{32, 0, 255}, want: []byte{32, 1, 0}},
+	} {
+		got := incByteSlice(tc.b)
+		if !bytes.Equal(got, tc.want) {
+			t.Errorf("got %v, want %v", got, tc.want)
+		}
+	}
+}
diff --git a/swarm/storage/localstore/subscription_pull.go b/swarm/storage/localstore/subscription_pull.go
index 0830eee70..0b96102e3 100644
--- a/swarm/storage/localstore/subscription_pull.go
+++ b/swarm/storage/localstore/subscription_pull.go
@@ -26,6 +26,7 @@ import (
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/swarm/chunk"
 	"github.com/ethereum/go-ethereum/swarm/shed"
+	"github.com/syndtr/goleveldb/leveldb"
 )
 
 // SubscribePull returns a channel that provides chunk addresses and stored times from pull syncing index.
@@ -158,6 +159,23 @@ func (db *DB) SubscribePull(ctx context.Context, bin uint8, since, until *ChunkD
 	return chunkDescriptors, stop
 }
 
+// LastPullSubscriptionChunk returns ChunkDescriptor of the latest Chunk
+// in pull syncing index for a provided bin. If there are no chunks in
+// that bin, chunk.ErrChunkNotFound is returned.
+func (db *DB) LastPullSubscriptionChunk(bin uint8) (c *ChunkDescriptor, err error) {
+	item, err := db.pullIndex.Last([]byte{bin})
+	if err != nil {
+		if err == leveldb.ErrNotFound {
+			return nil, chunk.ErrChunkNotFound
+		}
+		return nil, err
+	}
+	return &ChunkDescriptor{
+		Address:        item.Address,
+		StoreTimestamp: item.StoreTimestamp,
+	}, nil
+}
+
 // ChunkDescriptor holds information required for Pull syncing. This struct
 // is provided by subscribing to pull index.
 type ChunkDescriptor struct {
diff --git a/swarm/storage/localstore/subscription_pull_test.go b/swarm/storage/localstore/subscription_pull_test.go
index 130f0c9fe..d5ddae02b 100644
--- a/swarm/storage/localstore/subscription_pull_test.go
+++ b/swarm/storage/localstore/subscription_pull_test.go
@@ -485,3 +485,75 @@ func checkErrChan(ctx context.Context, t *testing.T, errChan chan error, wantedC
 		}
 	}
 }
+
+// TestDB_LastPullSubscriptionChunk validates that LastPullSubscriptionChunk
+// is returning the last chunk descriptor for proximity order bins by
+// doing a few rounds of chunk uploads.
+func TestDB_LastPullSubscriptionChunk(t *testing.T) {
+	db, cleanupFunc := newTestDB(t, nil)
+	defer cleanupFunc()
+
+	uploader := db.NewPutter(ModePutUpload)
+
+	addrs := make(map[uint8][]chunk.Address)
+
+	lastTimestamp := time.Now().UTC().UnixNano()
+	var lastTimestampMu sync.RWMutex
+	defer setNow(func() (t int64) {
+		lastTimestampMu.Lock()
+		defer lastTimestampMu.Unlock()
+		lastTimestamp++
+		return lastTimestamp
+	})()
+
+	last := make(map[uint8]ChunkDescriptor)
+
+	// do a few rounds of uploads and check if
+	// last pull subscription chunk is correct
+	for _, count := range []int{1, 3, 10, 11, 100, 120} {
+
+		// upload
+		for i := 0; i < count; i++ {
+			ch := generateTestRandomChunk()
+
+			err := uploader.Put(ch)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			bin := db.po(ch.Address())
+
+			if _, ok := addrs[bin]; !ok {
+				addrs[bin] = make([]chunk.Address, 0)
+			}
+			addrs[bin] = append(addrs[bin], ch.Address())
+
+			lastTimestampMu.RLock()
+			storeTimestamp := lastTimestamp
+			lastTimestampMu.RUnlock()
+
+			last[bin] = ChunkDescriptor{
+				Address:        ch.Address(),
+				StoreTimestamp: storeTimestamp,
+			}
+		}
+
+		// check
+		for bin := uint8(0); bin <= uint8(chunk.MaxPO); bin++ {
+			want, ok := last[bin]
+			got, err := db.LastPullSubscriptionChunk(bin)
+			if ok {
+				if err != nil {
+					t.Errorf("got unexpected error value %v", err)
+				}
+				if !bytes.Equal(got.Address, want.Address) {
+					t.Errorf("got last address %s, want %s", got.Address.Hex(), want.Address.Hex())
+				}
+			} else {
+				if err != chunk.ErrChunkNotFound {
+					t.Errorf("got unexpected error value %v, want %v", err, chunk.ErrChunkNotFound)
+				}
+			}
+		}
+	}
+}
-- 
GitLab