From 46655c8fa8c0977736961189ed59a457c7de1407 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Thu, 12 Oct 2023 15:46:20 -0500 Subject: [PATCH] ready --- test/capturer.go | 37 ------------- test/runner.go | 135 +++++++++++++++++++++++++---------------------- 2 files changed, 71 insertions(+), 101 deletions(-) delete mode 100644 test/capturer.go diff --git a/test/capturer.go b/test/capturer.go deleted file mode 100644 index 3bdc1e3a..00000000 --- a/test/capturer.go +++ /dev/null @@ -1,37 +0,0 @@ -package test - -import ( - "bytes" - "fmt" - - "gfx.cafe/gfx/pggat/lib/fed" - "gfx.cafe/gfx/pggat/lib/gsql" -) - -type Capturer struct { - Packets []fed.Packet -} - -func (T *Capturer) WritePacket(packet fed.Packet) error { - T.Packets = append(T.Packets, bytes.Clone(packet)) - return nil -} - -func (T *Capturer) Check(other *Capturer) error { - if len(T.Packets) != len(other.Packets) { - return fmt.Errorf("wrong number of packets! got %d but expected %d", len(other.Packets), len(T.Packets)) - } - - for i := range T.Packets { - expected := T.Packets[i] - actual := other.Packets[i] - - if !bytes.Equal(expected.Bytes(), actual.Bytes()) { - return fmt.Errorf("mismatched packet! expected %v but got %v", expected.Bytes(), actual.Bytes()) - } - } - - return nil -} - -var _ gsql.ResultWriter = (*Capturer)(nil) diff --git a/test/runner.go b/test/runner.go index d23e94c6..c1f24ccc 100644 --- a/test/runner.go +++ b/test/runner.go @@ -1,12 +1,14 @@ package test import ( + "bytes" "errors" "fmt" "io" "gfx.cafe/gfx/pggat/lib/bouncer/bouncers/v2" "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/fed/middlewares/unterminate" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/gat/pool/recipe" "gfx.cafe/gfx/pggat/lib/gsql" @@ -26,67 +28,91 @@ func MakeRunner(config Config, test Test) Runner { } } -func (T *Runner) prepare(client *gsql.Client, until int) []Capturer { - results := make([]Capturer, until) - +func (T *Runner) prepare(client *fed.Conn, until int) error { for i := 0; i < until; i++ { x := T.test.Instructions[i] switch v := x.(type) { case inst.SimpleQuery: q := packets.Query(v) - client.Do(&results[i], q.IntoPacket(nil)) + if err := client.WritePacket(&q); err != nil { + return err + } case inst.Sync: - client.Do(&results[i], fed.NewPacket(packets.TypeSync)) + if err := client.WritePacket(&packets.Sync{}); err != nil { + return err + } case inst.Parse: p := packets.Parse{ Destination: v.Destination, Query: v.Query, } - client.Do(&results[i], p.IntoPacket(nil)) + if err := client.WritePacket(&p); err != nil { + return err + } case inst.Bind: p := packets.Bind{ Destination: v.Destination, Source: v.Source, } - client.Do(&results[i], p.IntoPacket(nil)) + if err := client.WritePacket(&p); err != nil { + return err + } case inst.DescribePortal: p := packets.Describe{ - Which: 'P', - Target: string(v), + Which: 'P', + Name: string(v), + } + if err := client.WritePacket(&p); err != nil { + return err } - client.Do(&results[i], p.IntoPacket(nil)) case inst.DescribePreparedStatement: p := packets.Describe{ - Which: 'S', - Target: string(v), + Which: 'S', + Name: string(v), + } + if err := client.WritePacket(&p); err != nil { + return err } - client.Do(&results[i], p.IntoPacket(nil)) case inst.Execute: p := packets.Execute{ Target: string(v), } - client.Do(&results[i], p.IntoPacket(nil)) + if err := client.WritePacket(&p); err != nil { + return err + } case inst.ClosePortal: p := packets.Close{ - Which: 'P', - Target: string(v), + Which: 'P', + Name: string(v), + } + if err := client.WritePacket(&p); err != nil { + return err } - client.Do(&results[i], p.IntoPacket(nil)) case inst.ClosePreparedStatement: p := packets.Close{ - Which: 'S', - Target: string(v), + Which: 'S', + Name: string(v), + } + if err := client.WritePacket(&p); err != nil { + return err } - client.Do(&results[i], p.IntoPacket(nil)) case inst.CopyData: p := packets.CopyData(v) - client.Do(&results[i], p.IntoPacket(nil)) + if err := client.WritePacket(&p); err != nil { + return err + } case inst.CopyDone: - client.Do(&results[i], fed.NewPacket(packets.TypeCopyDone)) + if err := client.WritePacket(&packets.CopyDone{}); err != nil { + return err + } } } - return results + if err := client.WritePacket(&packets.Terminate{}); err != nil { + return err + } + + return client.Flush() } func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error { @@ -98,9 +124,11 @@ func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error { _ = server.Close() }() + client.Middleware = append(client.Middleware, unterminate.Unterminate) + for { var p fed.Packet - p, err = client.ReadPacket(true, p) + p, err = client.ReadPacket(true) if err != nil { if errors.Is(err, io.EOF) { break @@ -108,7 +136,7 @@ func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error { return err } - _, clientErr, serverErr := bouncers.Bounce(client, server, p) + clientErr, serverErr := bouncers.Bounce(client, server, p) if clientErr != nil { return clientErr } @@ -120,29 +148,31 @@ func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error { return nil } -func (T *Runner) runModeOnce(dialer recipe.Dialer) ([]Capturer, error) { - var client gsql.Client - results := T.prepare(&client, len(T.test.Instructions)) - if err := client.Close(); err != nil { +func (T *Runner) runModeOnce(dialer recipe.Dialer) ([]byte, error) { + inward, outward := gsql.NewPair() + if err := T.prepare(inward, len(T.test.Instructions)); err != nil { + return nil, err + } + + if err := T.runModeL1(dialer, outward); err != nil { return nil, err } - if err := T.runModeL1(dialer, fed.NewConn(&client)); err != nil { + if err := inward.Close(); err != nil { return nil, err } - return results, nil + return io.ReadAll(inward.NetConn) } func (T *Runner) runModeFail(dialer recipe.Dialer) error { for i := 1; i < len(T.test.Instructions)+1; i++ { - var client gsql.Client - T.prepare(&client, i) - if err := client.Close(); err != nil { + inward, outward := gsql.NewPair() + if err := T.prepare(inward, i); err != nil { return err } - if err := T.runModeL1(dialer, fed.NewConn(&client)); err != nil && !errors.Is(err, io.EOF) { + if err := T.runModeL1(dialer, outward); err != nil && !errors.Is(err, io.EOF) { return err } } @@ -150,7 +180,7 @@ func (T *Runner) runModeFail(dialer recipe.Dialer) error { return nil } -func (T *Runner) runMode(dialer recipe.Dialer) ([]Capturer, error) { +func (T *Runner) runMode(dialer recipe.Dialer) ([]byte, error) { instances := T.config.Stress if instances < 1 || T.test.SideEffects { return T.runModeOnce(dialer) @@ -175,14 +205,8 @@ func (T *Runner) runMode(dialer recipe.Dialer) ([]Capturer, error) { if err != nil { return err } - if len(expected) != len(actual) { - return fmt.Errorf("wrong number of results! expected %d but got %d", len(expected), len(actual)) - } - for i, exp := range expected { - act := actual[i] - if err = exp.Check(&act); err != nil { - return err - } + if !bytes.Equal(expected, actual) { + return fmt.Errorf("mismatched results: expected %v but got %v", expected, actual) } return nil }) @@ -198,7 +222,7 @@ func (T *Runner) runMode(dialer recipe.Dialer) ([]Capturer, error) { func (T *Runner) Run() error { var errs []error - var expected []Capturer + var expected []byte // modes for name, mode := range T.config.Modes { @@ -216,30 +240,13 @@ func (T *Runner) Run() error { continue } - if len(expected) != len(actual) { + if !bytes.Equal(expected, actual) { errs = append(errs, ErrorIn{ Name: name, - Err: fmt.Errorf("wrong number of results! expected %d but got %d", len(expected), len(actual)), + Err: fmt.Errorf("mismatched results: expected %v but got %v", expected, actual), }) continue } - - var modeErrs []error - - for i, exp := range expected { - act := actual[i] - - if err = exp.Check(&act); err != nil { - modeErrs = append(modeErrs, fmt.Errorf("instruction %d: %s", i+1, err.Error())) - } - } - - if len(modeErrs) > 0 { - errs = append(errs, ErrorIn{ - Name: name, - Err: Errors(modeErrs), - }) - } } if len(errs) > 0 { -- GitLab