good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 0655e58f authored by Artem Tsebrovskiy's avatar Artem Tsebrovskiy Committed by GitHub
Browse files

#2119 - implemented trace_filter intersection mode for (#3167)

* contracts: implemented trace_filter intersection mode for Trace API (#2119)

* fixed formatting

* revisited error check during tx tracing
parent d14c2238
Branches
Tags
No related merge requests found
...@@ -7,14 +7,16 @@ import ( ...@@ -7,14 +7,16 @@ import (
jsoniter "github.com/json-iterator/go" jsoniter "github.com/json-iterator/go"
"github.com/ledgerwatch/erigon-lib/kv/kvcache" "github.com/ledgerwatch/erigon-lib/kv/kvcache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fastjson"
"github.com/ledgerwatch/erigon/cmd/rpcdaemon/cli" "github.com/ledgerwatch/erigon/cmd/rpcdaemon/cli"
"github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/common"
"github.com/ledgerwatch/erigon/common/hexutil" "github.com/ledgerwatch/erigon/common/hexutil"
"github.com/ledgerwatch/erigon/core" "github.com/ledgerwatch/erigon/core"
"github.com/ledgerwatch/erigon/turbo/snapshotsync" "github.com/ledgerwatch/erigon/turbo/snapshotsync"
"github.com/ledgerwatch/erigon/turbo/stages" "github.com/ledgerwatch/erigon/turbo/stages"
"github.com/stretchr/testify/assert"
"github.com/valyala/fastjson"
) )
func blockNumbersFromTraces(t *testing.T, b []byte) []int { func blockNumbersFromTraces(t *testing.T, b []byte) []int {
...@@ -174,3 +176,76 @@ func TestFilterNoAddresses(t *testing.T) { ...@@ -174,3 +176,76 @@ func TestFilterNoAddresses(t *testing.T) {
} }
assert.Equal(t, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, blockNumbersFromTraces(t, buf.Bytes())) assert.Equal(t, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, blockNumbersFromTraces(t, buf.Bytes()))
} }
func TestFilterAddressIntersection(t *testing.T) {
m := stages.Mock(t)
defer m.DB.Close()
api := NewTraceAPI(NewBaseApi(nil, kvcache.New(kvcache.DefaultCoherentConfig), snapshotsync.NewBlockReader(), false), m.DB, &cli.Flags{})
toAddress1, fromAddress2, other := common.Address{1}, common.Address{2}, common.Address{3}
chain, err := core.GenerateChain(m.ChainConfig, m.Genesis, m.Engine, m.DB, 15, func(i int, gen *core.BlockGen) {
if i < 5 {
gen.SetCoinbase(toAddress1)
} else if i < 10 {
gen.SetCoinbase(fromAddress2)
} else {
gen.SetCoinbase(other)
}
}, false /* intemediateHashes */)
require.NoError(t, err, "generate chain")
err = m.InsertChain(chain)
require.NoError(t, err, "inserting chain")
fromBlock, toBlock := uint64(1), uint64(15)
t.Run("second", func(t *testing.T) {
stream := jsoniter.ConfigDefault.BorrowStream(nil)
defer jsoniter.ConfigDefault.ReturnStream(stream)
traceReq1 := TraceFilterRequest{
FromBlock: (*hexutil.Uint64)(&fromBlock),
ToBlock: (*hexutil.Uint64)(&toBlock),
FromAddress: []*common.Address{&fromAddress2, &other},
ToAddress: []*common.Address{&fromAddress2, &toAddress1},
Mode: TraceFilterModeIntersection,
}
if err = api.Filter(context.Background(), traceReq1, stream); err != nil {
t.Fatalf("trace_filter failed: %v", err)
}
assert.Equal(t, []int{6, 7, 8, 9, 10}, blockNumbersFromTraces(t, stream.Buffer()))
})
t.Run("first", func(t *testing.T) {
stream := jsoniter.ConfigDefault.BorrowStream(nil)
defer jsoniter.ConfigDefault.ReturnStream(stream)
traceReq1 := TraceFilterRequest{
FromBlock: (*hexutil.Uint64)(&fromBlock),
ToBlock: (*hexutil.Uint64)(&toBlock),
FromAddress: []*common.Address{&toAddress1, &other},
ToAddress: []*common.Address{&fromAddress2, &toAddress1},
Mode: TraceFilterModeIntersection,
}
if err = api.Filter(context.Background(), traceReq1, stream); err != nil {
t.Fatalf("trace_filter failed: %v", err)
}
assert.Equal(t, []int{1, 2, 3, 4, 5}, blockNumbersFromTraces(t, stream.Buffer()))
})
t.Run("empty", func(t *testing.T) {
stream := jsoniter.ConfigDefault.BorrowStream(nil)
defer jsoniter.ConfigDefault.ReturnStream(stream)
traceReq1 := TraceFilterRequest{
FromBlock: (*hexutil.Uint64)(&fromBlock),
ToBlock: (*hexutil.Uint64)(&toBlock),
ToAddress: []*common.Address{&other},
FromAddress: []*common.Address{&fromAddress2, &toAddress1},
Mode: TraceFilterModeIntersection,
}
if err = api.Filter(context.Background(), traceReq1, stream); err != nil {
t.Fatalf("trace_filter failed: %v", err)
}
require.Empty(t, blockNumbersFromTraces(t, stream.Buffer()))
})
}
...@@ -2,16 +2,19 @@ package commands ...@@ -2,16 +2,19 @@ package commands
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/RoaringBitmap/roaring/roaring64" "github.com/RoaringBitmap/roaring/roaring64"
jsoniter "github.com/json-iterator/go" jsoniter "github.com/json-iterator/go"
"github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/common"
"github.com/ledgerwatch/erigon/common/hexutil" "github.com/ledgerwatch/erigon/common/hexutil"
"github.com/ledgerwatch/erigon/consensus/ethash" "github.com/ledgerwatch/erigon/consensus/ethash"
"github.com/ledgerwatch/erigon/core/rawdb" "github.com/ledgerwatch/erigon/core/rawdb"
"github.com/ledgerwatch/erigon/core/types" "github.com/ledgerwatch/erigon/core/types"
"github.com/ledgerwatch/erigon/ethdb"
"github.com/ledgerwatch/erigon/ethdb/bitmapdb" "github.com/ledgerwatch/erigon/ethdb/bitmapdb"
"github.com/ledgerwatch/erigon/rpc" "github.com/ledgerwatch/erigon/rpc"
) )
...@@ -230,10 +233,62 @@ func (api *TraceAPIImpl) Filter(ctx context.Context, req TraceFilterRequest, str ...@@ -230,10 +233,62 @@ func (api *TraceAPIImpl) Filter(ctx context.Context, req TraceFilterRequest, str
toAddresses := make(map[common.Address]struct{}, len(req.ToAddress)) toAddresses := make(map[common.Address]struct{}, len(req.ToAddress))
var allBlocks roaring64.Bitmap var allBlocks roaring64.Bitmap
switch req.Mode {
case TraceFilterModeIntersection:
if len(req.FromAddress) == 0 || len(req.ToAddress) == 0 {
return fmt.Errorf("invalid parameters: for intersection mode both fromAddress and toAddress should be not empty")
}
addrIntersection := make(map[common.Address]struct{})
for _, addr := range req.FromAddress {
if addr == nil {
continue
}
addrIntersection[*addr] = struct{}{}
}
for _, addr := range req.ToAddress {
if addr == nil {
continue
}
if _, exist := addrIntersection[*addr]; !exist {
continue
}
fromAddresses[*addr] = struct{}{}
toAddresses[*addr] = struct{}{}
b, err := bitmapdb.Get64(dbtx, kv.CallToIndex, addr.Bytes(), fromBlock, toBlock)
if err != nil && !errors.Is(err, ethdb.ErrKeyNotFound) {
stream.WriteNil()
return err
}
if b != nil {
allBlocks.Or(b)
}
b, err = bitmapdb.Get64(dbtx, kv.CallFromIndex, addr.Bytes(), fromBlock, toBlock)
if err != nil && !errors.Is(err, ethdb.ErrKeyNotFound) {
stream.WriteNil()
return err
}
if b != nil {
allBlocks.Or(b)
}
}
case TraceFilterModeUnion:
fallthrough
default:
for _, addr := range req.FromAddress { for _, addr := range req.FromAddress {
if addr != nil { if addr != nil {
b, err := bitmapdb.Get64(dbtx, kv.CallFromIndex, addr.Bytes(), fromBlock, toBlock) b, err := bitmapdb.Get64(dbtx, kv.CallFromIndex, addr.Bytes(), fromBlock, toBlock)
if err != nil { if err != nil {
if errors.Is(err, ethdb.ErrKeyNotFound) {
continue
}
stream.WriteNil() stream.WriteNil()
return err return err
} }
...@@ -245,6 +300,9 @@ func (api *TraceAPIImpl) Filter(ctx context.Context, req TraceFilterRequest, str ...@@ -245,6 +300,9 @@ func (api *TraceAPIImpl) Filter(ctx context.Context, req TraceFilterRequest, str
if addr != nil { if addr != nil {
b, err := bitmapdb.Get64(dbtx, kv.CallToIndex, addr.Bytes(), fromBlock, toBlock) b, err := bitmapdb.Get64(dbtx, kv.CallToIndex, addr.Bytes(), fromBlock, toBlock)
if err != nil { if err != nil {
if errors.Is(err, ethdb.ErrKeyNotFound) {
continue
}
stream.WriteNil() stream.WriteNil()
return err return err
} }
...@@ -252,6 +310,8 @@ func (api *TraceAPIImpl) Filter(ctx context.Context, req TraceFilterRequest, str ...@@ -252,6 +310,8 @@ func (api *TraceAPIImpl) Filter(ctx context.Context, req TraceFilterRequest, str
toAddresses[*addr] = struct{}{} toAddresses[*addr] = struct{}{}
} }
} }
}
// Special case - if no addresses specified, take all traces // Special case - if no addresses specified, take all traces
if len(req.FromAddress) == 0 && len(req.ToAddress) == 0 { if len(req.FromAddress) == 0 && len(req.ToAddress) == 0 {
allBlocks.AddRange(fromBlock, toBlock+1) allBlocks.AddRange(fromBlock, toBlock+1)
...@@ -473,6 +533,16 @@ type TraceFilterRequest struct { ...@@ -473,6 +533,16 @@ type TraceFilterRequest struct {
ToBlock *hexutil.Uint64 `json:"toBlock"` ToBlock *hexutil.Uint64 `json:"toBlock"`
FromAddress []*common.Address `json:"fromAddress"` FromAddress []*common.Address `json:"fromAddress"`
ToAddress []*common.Address `json:"toAddress"` ToAddress []*common.Address `json:"toAddress"`
Mode TraceFilterMode `json:"mode"`
After *uint64 `json:"after"` After *uint64 `json:"after"`
Count *uint64 `json:"count"` Count *uint64 `json:"count"`
} }
type TraceFilterMode string
const (
// Default mode for TraceFilter. Unions results referred to addresses from FromAddress or ToAddress
TraceFilterModeUnion = "union"
// IntersectionMode retrives results referred to addresses provided both in FromAddress and ToAddress
TraceFilterModeIntersection = "intersection"
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment