diff --git a/test/capturer.go b/test/capturer.go new file mode 100644 index 0000000000000000000000000000000000000000..5b6dfeb829ecefc59b4a10db0d88f078b70ca69c --- /dev/null +++ b/test/capturer.go @@ -0,0 +1,37 @@ +package test + +import ( + "bytes" + "fmt" + + "pggat/lib/fed" + "pggat/lib/gsql" +) + +type Capturer struct { + Packets []fed.Packet +} + +func (T *Capturer) WritePacket(packet fed.Packet) error { + T.Packets = append(T.Packets, packet) + return nil +} + +func (T *Capturer) Check(other *Capturer) error { + if len(T.Packets) != len(other.Packets) { + return fmt.Errorf("not enough packets! got %d but expected %d", len(other.Packets), len(T.Packets)) + } + + for i := range T.Packets { + expected := T.Packets[i] + actual := other.Packets[i] + + if !bytes.Equal(expected.Bytes(), actual.Bytes()) { + return fmt.Errorf("mismatched packet! expected %v but got %v", expected.Bytes(), actual.Bytes()) + } + } + + return nil +} + +var _ gsql.ResultWriter = (*Capturer)(nil) diff --git a/test/runner.go b/test/runner.go index f61cb293c74cb6bcee2e7d68981e848b89e86d55..4cd235c567649b01386d5bcb5cdf1d600a44d9c7 100644 --- a/test/runner.go +++ b/test/runner.go @@ -4,8 +4,6 @@ import ( "errors" "io" - "tuxpa.in/a/zlog/log" - "pggat/lib/bouncer/bouncers/v2" "pggat/lib/fed" packets "pggat/lib/fed/packets/v3.0" @@ -53,20 +51,13 @@ func (T *Runner) setup() error { return nil } -type logWriter struct{} - -func (logWriter) WritePacket(pkt fed.Packet) error { - log.Print("got packet ", pkt) - return nil -} - func (T *Runner) run(pkts ...fed.Packet) error { // expected - { - log.Print("expected packets") + var expected Capturer + { var client gsql.Client - client.Do(logWriter{}, pkts...) + client.Do(&expected, pkts...) if err := client.Close(); err != nil { return err } @@ -97,11 +88,10 @@ func (T *Runner) run(pkts ...fed.Packet) error { // actual for name, p := range T.pools { - log.Print() - log.Print("pool ", name) + var result Capturer var client gsql.Client - client.Do(logWriter{}, pkts...) + client.Do(&result, pkts...) if err := client.Close(); err != nil { return err } @@ -110,6 +100,9 @@ func (T *Runner) run(pkts ...fed.Packet) error { return err } + if err := expected.Check(&result); err != nil { + return err + } _ = name } diff --git a/test/tester_test.go b/test/tester_test.go index 9f74ebbffe9ec1c1ef3ec1898655edcde53e93ff..809190fcf2d16b4250dfbd89c75815b858bcaefc 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -33,7 +33,9 @@ func TestTester(t *testing.T) { }, }, }) - if err := tester.Run(tests.SimpleQuery); err != nil { + if err := tester.Run( + tests.SimpleQuery, + ); err != nil { t.Error(err) } }