diff --git a/cmd/rpcdaemon/commands/call_traces_test.go b/cmd/rpcdaemon/commands/call_traces_test.go index a7a63cf7f83e2efdf3631608d14ad2858d9f4608..a658c26339bf9d9f2852cd1a8adde6f967d63b70 100644 --- a/cmd/rpcdaemon/commands/call_traces_test.go +++ b/cmd/rpcdaemon/commands/call_traces_test.go @@ -7,14 +7,16 @@ import ( jsoniter "github.com/json-iterator/go" "github.com/ledgerwatch/erigon-lib/kv/kvcache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" + "github.com/ledgerwatch/erigon/cmd/rpcdaemon/cli" "github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/common/hexutil" "github.com/ledgerwatch/erigon/core" "github.com/ledgerwatch/erigon/turbo/snapshotsync" "github.com/ledgerwatch/erigon/turbo/stages" - "github.com/stretchr/testify/assert" - "github.com/valyala/fastjson" ) func blockNumbersFromTraces(t *testing.T, b []byte) []int { @@ -174,3 +176,76 @@ func TestFilterNoAddresses(t *testing.T) { } assert.Equal(t, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, blockNumbersFromTraces(t, buf.Bytes())) } + +func TestFilterAddressIntersection(t *testing.T) { + m := stages.Mock(t) + defer m.DB.Close() + + api := NewTraceAPI(NewBaseApi(nil, kvcache.New(kvcache.DefaultCoherentConfig), snapshotsync.NewBlockReader(), false), m.DB, &cli.Flags{}) + + toAddress1, fromAddress2, other := common.Address{1}, common.Address{2}, common.Address{3} + chain, err := core.GenerateChain(m.ChainConfig, m.Genesis, m.Engine, m.DB, 15, func(i int, gen *core.BlockGen) { + if i < 5 { + gen.SetCoinbase(toAddress1) + } else if i < 10 { + gen.SetCoinbase(fromAddress2) + } else { + gen.SetCoinbase(other) + } + + }, false /* intemediateHashes */) + require.NoError(t, err, "generate chain") + + err = m.InsertChain(chain) + require.NoError(t, err, "inserting chain") + + fromBlock, toBlock := uint64(1), uint64(15) + t.Run("second", func(t *testing.T) { + stream := jsoniter.ConfigDefault.BorrowStream(nil) + defer jsoniter.ConfigDefault.ReturnStream(stream) + + traceReq1 := TraceFilterRequest{ + FromBlock: (*hexutil.Uint64)(&fromBlock), + ToBlock: (*hexutil.Uint64)(&toBlock), + FromAddress: []*common.Address{&fromAddress2, &other}, + ToAddress: []*common.Address{&fromAddress2, &toAddress1}, + Mode: TraceFilterModeIntersection, + } + if err = api.Filter(context.Background(), traceReq1, stream); err != nil { + t.Fatalf("trace_filter failed: %v", err) + } + assert.Equal(t, []int{6, 7, 8, 9, 10}, blockNumbersFromTraces(t, stream.Buffer())) + }) + t.Run("first", func(t *testing.T) { + stream := jsoniter.ConfigDefault.BorrowStream(nil) + defer jsoniter.ConfigDefault.ReturnStream(stream) + + traceReq1 := TraceFilterRequest{ + FromBlock: (*hexutil.Uint64)(&fromBlock), + ToBlock: (*hexutil.Uint64)(&toBlock), + FromAddress: []*common.Address{&toAddress1, &other}, + ToAddress: []*common.Address{&fromAddress2, &toAddress1}, + Mode: TraceFilterModeIntersection, + } + if err = api.Filter(context.Background(), traceReq1, stream); err != nil { + t.Fatalf("trace_filter failed: %v", err) + } + assert.Equal(t, []int{1, 2, 3, 4, 5}, blockNumbersFromTraces(t, stream.Buffer())) + }) + t.Run("empty", func(t *testing.T) { + stream := jsoniter.ConfigDefault.BorrowStream(nil) + defer jsoniter.ConfigDefault.ReturnStream(stream) + + traceReq1 := TraceFilterRequest{ + FromBlock: (*hexutil.Uint64)(&fromBlock), + ToBlock: (*hexutil.Uint64)(&toBlock), + ToAddress: []*common.Address{&other}, + FromAddress: []*common.Address{&fromAddress2, &toAddress1}, + Mode: TraceFilterModeIntersection, + } + if err = api.Filter(context.Background(), traceReq1, stream); err != nil { + t.Fatalf("trace_filter failed: %v", err) + } + require.Empty(t, blockNumbersFromTraces(t, stream.Buffer())) + }) +} diff --git a/cmd/rpcdaemon/commands/trace_filtering.go b/cmd/rpcdaemon/commands/trace_filtering.go index 4825e74daa3334a359dd62c2d04ef314a56f91ee..2fc63b07decc22f61cf7e43d506240fc1b76a64a 100644 --- a/cmd/rpcdaemon/commands/trace_filtering.go +++ b/cmd/rpcdaemon/commands/trace_filtering.go @@ -2,16 +2,19 @@ package commands import ( "context" + "errors" "fmt" "github.com/RoaringBitmap/roaring/roaring64" jsoniter "github.com/json-iterator/go" "github.com/ledgerwatch/erigon-lib/kv" + "github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/common/hexutil" "github.com/ledgerwatch/erigon/consensus/ethash" "github.com/ledgerwatch/erigon/core/rawdb" "github.com/ledgerwatch/erigon/core/types" + "github.com/ledgerwatch/erigon/ethdb" "github.com/ledgerwatch/erigon/ethdb/bitmapdb" "github.com/ledgerwatch/erigon/rpc" ) @@ -230,28 +233,85 @@ func (api *TraceAPIImpl) Filter(ctx context.Context, req TraceFilterRequest, str toAddresses := make(map[common.Address]struct{}, len(req.ToAddress)) var allBlocks roaring64.Bitmap - for _, addr := range req.FromAddress { - if addr != nil { - b, err := bitmapdb.Get64(dbtx, kv.CallFromIndex, addr.Bytes(), fromBlock, toBlock) - if err != nil { - stream.WriteNil() - return err + switch req.Mode { + case TraceFilterModeIntersection: + if len(req.FromAddress) == 0 || len(req.ToAddress) == 0 { + return fmt.Errorf("invalid parameters: for intersection mode both fromAddress and toAddress should be not empty") + } + + addrIntersection := make(map[common.Address]struct{}) + for _, addr := range req.FromAddress { + if addr == nil { + continue } - allBlocks.Or(b) - fromAddresses[*addr] = struct{}{} + addrIntersection[*addr] = struct{}{} } - } - for _, addr := range req.ToAddress { - if addr != nil { + + for _, addr := range req.ToAddress { + if addr == nil { + continue + } + + if _, exist := addrIntersection[*addr]; !exist { + continue + } + + fromAddresses[*addr] = struct{}{} + toAddresses[*addr] = struct{}{} + b, err := bitmapdb.Get64(dbtx, kv.CallToIndex, addr.Bytes(), fromBlock, toBlock) - if err != nil { + if err != nil && !errors.Is(err, ethdb.ErrKeyNotFound) { stream.WriteNil() return err } - allBlocks.Or(b) - toAddresses[*addr] = struct{}{} + + if b != nil { + allBlocks.Or(b) + } + + b, err = bitmapdb.Get64(dbtx, kv.CallFromIndex, addr.Bytes(), fromBlock, toBlock) + if err != nil && !errors.Is(err, ethdb.ErrKeyNotFound) { + stream.WriteNil() + return err + } + + if b != nil { + allBlocks.Or(b) + } + } + case TraceFilterModeUnion: + fallthrough + default: + for _, addr := range req.FromAddress { + if addr != nil { + b, err := bitmapdb.Get64(dbtx, kv.CallFromIndex, addr.Bytes(), fromBlock, toBlock) + if err != nil { + if errors.Is(err, ethdb.ErrKeyNotFound) { + continue + } + stream.WriteNil() + return err + } + allBlocks.Or(b) + fromAddresses[*addr] = struct{}{} + } + } + for _, addr := range req.ToAddress { + if addr != nil { + b, err := bitmapdb.Get64(dbtx, kv.CallToIndex, addr.Bytes(), fromBlock, toBlock) + if err != nil { + if errors.Is(err, ethdb.ErrKeyNotFound) { + continue + } + stream.WriteNil() + return err + } + allBlocks.Or(b) + toAddresses[*addr] = struct{}{} + } } } + // Special case - if no addresses specified, take all traces if len(req.FromAddress) == 0 && len(req.ToAddress) == 0 { allBlocks.AddRange(fromBlock, toBlock+1) @@ -473,6 +533,16 @@ type TraceFilterRequest struct { ToBlock *hexutil.Uint64 `json:"toBlock"` FromAddress []*common.Address `json:"fromAddress"` ToAddress []*common.Address `json:"toAddress"` + Mode TraceFilterMode `json:"mode"` After *uint64 `json:"after"` Count *uint64 `json:"count"` } + +type TraceFilterMode string + +const ( + // Default mode for TraceFilter. Unions results referred to addresses from FromAddress or ToAddress + TraceFilterModeUnion = "union" + // IntersectionMode retrives results referred to addresses provided both in FromAddress and ToAddress + TraceFilterModeIntersection = "intersection" +)