diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index d6f690978bcd9a5e2bc74be6ce9dfa307cd55933..66f5aaa05427958dac98ec6305ae2f1e867b40c4 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -56,8 +56,10 @@ type Client struct { pool_name string username string - gatling gat.Gat - conf *config.Global + gatling gat.Gat + statements map[string]*protocol.Parse + portals map[string]*protocol.Bind + conf *config.Global log zlog.Logger } @@ -69,14 +71,16 @@ func NewClient( admin_only bool, ) *Client { c := &Client{ - conn: conn, - r: bufio.NewReader(conn), - wr: conn, - bufwr: bufio.NewWriter(conn), - recv: make(chan protocol.Packet), - addr: conn.RemoteAddr(), - gatling: gatling, - conf: conf, + conn: conn, + r: bufio.NewReader(conn), + wr: conn, + bufwr: bufio.NewWriter(conn), + recv: make(chan protocol.Packet), + addr: conn.RemoteAddr(), + gatling: gatling, + statements: make(map[string]*protocol.Parse), + portals: make(map[string]*protocol.Bind), + conf: conf, } c.log = log.With(). Stringer("clientaddr", c.addr).Logger() @@ -314,6 +318,14 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return false, ctx.Err() } switch cast := rsp.(type) { + case *protocol.Parse: + return true, c.parse(ctx, cast) + case *protocol.Bind: + return true, c.bind(ctx, cast) + case *protocol.Describe: + return true, c.handle_describe(ctx, cast) + case *protocol.Execute: + return true, c.handle_execute(ctx, cast) case *protocol.Query: return true, c.handle_query(ctx, cast) case *protocol.FunctionCall: @@ -325,6 +337,24 @@ func (c *Client) tick(ctx context.Context) (bool, error) { return true, nil } +func (c *Client) parse(ctx context.Context, q *protocol.Parse) error { + c.statements[q.Fields.PreparedStatement] = q + return c.Send(new(protocol.ParseComplete)) +} + +func (c *Client) bind(ctx context.Context, b *protocol.Bind) error { + c.portals[b.Fields.Destination] = b + return c.Send(new(protocol.BindComplete)) +} + +func (c *Client) handle_describe(ctx context.Context, d *protocol.Describe) error { + return c.server.Describe(ctx, c, d) +} + +func (c *Client) handle_execute(ctx context.Context, e *protocol.Execute) error { + return c.server.Execute(ctx, c, e) +} + func (c *Client) handle_query(ctx context.Context, q *protocol.Query) error { parsed, err := parse.Parse(q.Fields.Query) if err != nil { diff --git a/lib/gat/gatling/conn_pool/conn_pool.go b/lib/gat/gatling/conn_pool/conn_pool.go index 7b13241f10dfaf5ced633760c0b8ebcdbc6d9216..6f8fe0552476b6f7143a3deb847327e5462a7b15 100644 --- a/lib/gat/gatling/conn_pool/conn_pool.go +++ b/lib/gat/gatling/conn_pool/conn_pool.go @@ -165,6 +165,14 @@ func (c *ConnectionPool) GetServerInfo() []*protocol.ParameterStatus { return srv.primary.GetServerInfo() } +func (c *ConnectionPool) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { + return (<-c.workerPool).HandleDescribe(ctx, client, d) +} + +func (c *ConnectionPool) Execute(ctx context.Context, client gat.Client, e *protocol.Execute) error { + return (<-c.workerPool).HandleExecute(ctx, client, e) +} + func (c *ConnectionPool) SimpleQuery(ctx context.Context, client gat.Client, q string) error { return (<-c.workerPool).HandleSimpleQuery(ctx, client, q) } diff --git a/lib/gat/gatling/conn_pool/worker.go b/lib/gat/gatling/conn_pool/worker.go index 9965f26560a0f98000c4f32007dad8ac7128338a..e7d97b440de233c3d39c359e24e4b3cfc80f78c1 100644 --- a/lib/gat/gatling/conn_pool/worker.go +++ b/lib/gat/gatling/conn_pool/worker.go @@ -18,35 +18,54 @@ type worker struct { w *ConnectionPool } +// ret urn worker to pool +func (w *worker) ret() { + w.w.workerPool <- w +} + +func (w *worker) HandleDescribe(ctx context.Context, c gat.Client, d *protocol.Describe) error { + defer w.ret() + + errch := make(chan error) + go func() { + defer close(errch) + errch <- w.z_actually_do_describe(ctx, c, d) + }() + + return <-errch +} + +func (w *worker) HandleExecute(ctx context.Context, c gat.Client, e *protocol.Execute) error { + defer w.ret() + + errch := make(chan error) + go func() { + defer close(errch) + errch <- w.z_actually_do_execute(ctx, c, e) + }() + + return <-errch +} + func (w *worker) HandleFunction(ctx context.Context, c gat.Client, fn *protocol.FunctionCall) error { log.Println("worker selected for fn") - defer func() { - // return self to the connection pool after - log.Println("worker returned for fn") - w.w.workerPool <- w - }() + defer w.ret() errch := make(chan error) go func() { - err := w.z_actually_do_fn(ctx, c, fn) - if err != nil { - ctx.Done() - } - errch <- err - close(errch) + defer close(errch) + errch <- w.z_actually_do_fn(ctx, c, fn) }() return <-errch } func (w *worker) HandleSimpleQuery(ctx context.Context, c gat.Client, query string) error { - defer func() { - // return self to the connection pool after - w.w.workerPool <- w - }() + defer w.ret() + errch := make(chan error) go func() { - err := w.z_actually_do_simple_query(ctx, c, query) - errch <- err + defer close(errch) + errch <- w.z_actually_do_simple_query(ctx, c, query) }() // wait until query or close @@ -59,16 +78,14 @@ func (w *worker) HandleSimpleQuery(ctx context.Context, c gat.Client, query stri } func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query string) error { - defer func() { - // return self to the connection pool after - w.w.workerPool <- w - }() + defer w.ret() + errch := make(chan error) go func() { + defer close(errch) //log.Println("performing transaction...") - err := w.z_actually_do_transaction(ctx, c, query) + errch <- w.z_actually_do_transaction(ctx, c, query) //log.Println("done", err) - errch <- err }() // wait until query or close @@ -80,6 +97,12 @@ func (w *worker) HandleTransaction(ctx context.Context, c gat.Client, query stri } } +func (w *worker) z_actually_do_describe(ctx context.Context, client gat.Client, payload *protocol.Describe) error { + return nil +} +func (w *worker) z_actually_do_execute(ctx context.Context, client gat.Client, payload *protocol.Execute) error { + return nil +} func (w *worker) z_actually_do_fn(ctx context.Context, client gat.Client, payload *protocol.FunctionCall) error { c := w.w srv := c.chooseConnections() diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index dddf479ee3dbdbd9217f9b5f2905490d6cd0cf3c..8f2e2249018e5d715db8db57fbb3affe3f5608f9 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -15,6 +15,12 @@ type Client interface { type ConnectionPool interface { GetUser() *config.User GetServerInfo() []*protocol.ParameterStatus + + // extended queries + Describe(ctx context.Context, client Client, describe *protocol.Describe) error + Execute(ctx context.Context, client Client, execute *protocol.Execute) error + + // simple queries SimpleQuery(ctx context.Context, client Client, query string) error Transaction(ctx context.Context, client Client, query string) error CallFunction(ctx context.Context, client Client, payload *protocol.FunctionCall) error diff --git a/lib/gat/protocol/frontend.go b/lib/gat/protocol/frontend.go index e30cb452438b4fa23932c38361cdeead2f532bd6..b22a6f155f55fae8137c1acb6395473d9ae2f5f5 100644 --- a/lib/gat/protocol/frontend.go +++ b/lib/gat/protocol/frontend.go @@ -128,11 +128,11 @@ func (T *FieldsBindParameterValues) Write(writer io.Writer) (length int, err err } type FieldsBind struct { - Destination string - PreparedStatement string - FormatCodes []int16 - ParameterValues []FieldsBindParameterValues - ResultColumnFormatCodes []int16 + Destination string + PreparedStatement string + ParameterFormatCodes []int16 + ParameterValues []FieldsBindParameterValues + ResultFormatCodes []int16 } func (T *FieldsBind) Read(payloadLength int, reader io.Reader) (err error) { @@ -144,17 +144,17 @@ func (T *FieldsBind) Read(payloadLength int, reader io.Reader) (err error) { if err != nil { return } - var FormatCodesLength int16 - FormatCodesLength, err = ReadInt16(reader) + var ParameterFormatCodesLength int16 + ParameterFormatCodesLength, err = ReadInt16(reader) if err != nil { return } - if FormatCodesLength == int16(-1) { - T.FormatCodes = nil + if ParameterFormatCodesLength == int16(-1) { + T.ParameterFormatCodes = nil } else { - T.FormatCodes = make([]int16, int(FormatCodesLength)) - for i := 0; i < int(FormatCodesLength); i++ { - T.FormatCodes[i], err = ReadInt16(reader) + T.ParameterFormatCodes = make([]int16, int(ParameterFormatCodesLength)) + for i := 0; i < int(ParameterFormatCodesLength); i++ { + T.ParameterFormatCodes[i], err = ReadInt16(reader) if err != nil { return } @@ -176,17 +176,17 @@ func (T *FieldsBind) Read(payloadLength int, reader io.Reader) (err error) { } } } - var ResultColumnFormatCodesLength int16 - ResultColumnFormatCodesLength, err = ReadInt16(reader) + var ResultFormatCodesLength int16 + ResultFormatCodesLength, err = ReadInt16(reader) if err != nil { return } - if ResultColumnFormatCodesLength == int16(-1) { - T.ResultColumnFormatCodes = nil + if ResultFormatCodesLength == int16(-1) { + T.ResultFormatCodes = nil } else { - T.ResultColumnFormatCodes = make([]int16, int(ResultColumnFormatCodesLength)) - for i := 0; i < int(ResultColumnFormatCodesLength); i++ { - T.ResultColumnFormatCodes[i], err = ReadInt16(reader) + T.ResultFormatCodes = make([]int16, int(ResultFormatCodesLength)) + for i := 0; i < int(ResultFormatCodesLength); i++ { + T.ResultFormatCodes[i], err = ReadInt16(reader) if err != nil { return } @@ -207,16 +207,16 @@ func (T *FieldsBind) Write(writer io.Writer) (length int, err error) { return } length += temp - if T.FormatCodes == nil { + if T.ParameterFormatCodes == nil { temp, err = WriteInt16(writer, int16(-1)) } else { - temp, err = WriteInt16(writer, int16(len(T.FormatCodes))) + temp, err = WriteInt16(writer, int16(len(T.ParameterFormatCodes))) } if err != nil { return } length += temp - for _, v := range T.FormatCodes { + for _, v := range T.ParameterFormatCodes { temp, err = WriteInt16(writer, v) if err != nil { return @@ -239,16 +239,16 @@ func (T *FieldsBind) Write(writer io.Writer) (length int, err error) { } length += temp } - if T.ResultColumnFormatCodes == nil { + if T.ResultFormatCodes == nil { temp, err = WriteInt16(writer, int16(-1)) } else { - temp, err = WriteInt16(writer, int16(len(T.ResultColumnFormatCodes))) + temp, err = WriteInt16(writer, int16(len(T.ResultFormatCodes))) } if err != nil { return } length += temp - for _, v := range T.ResultColumnFormatCodes { + for _, v := range T.ResultFormatCodes { temp, err = WriteInt16(writer, v) if err != nil { return diff --git a/lib/parse/parse.go b/lib/parse/parse.go index 149af33008300dd6212d17f90448424d886c5886..5becb253c843da76846a9e991557389499ec5f33 100644 --- a/lib/parse/parse.go +++ b/lib/parse/parse.go @@ -3,7 +3,6 @@ package parse import ( "errors" "fmt" - "strings" "unicode" "unicode/utf8" ) @@ -72,39 +71,43 @@ func (r *reader) nextMultiLineComment() error { } func (r *reader) nextIdentifier() (string, error) { - var stack strings.Builder + start := r.p + for { + pre := r.p + c, ok := r.nextRune() if !ok { break } switch { case c == ';': - return stack.String(), EndOfStatement + return r.v[start:pre], EndOfStatement case unicode.IsSpace(c): - if stack.Len() == 0 { + if pre == start { + start = r.p continue } // this identifier is done - return stack.String(), nil + return r.v[start:pre], nil case unicode.IsDigit(c): - if stack.Len() == 0 { + if pre == start { return "", newUnexpectedCharacter(c) } fallthrough case unicode.IsLetter(c), c == '_', c == '$': - stack.WriteRune(c) - case c == '-' && stack.Len() == 0: + case c == '-' && pre == start: if r.nextComment() != nil { return "", newUnexpectedCharacter(c) } + start = r.p default: return "", newUnexpectedCharacter(c) } } - return stack.String(), EndOfSQL + return r.v[start:r.p], EndOfSQL } func (r *reader) nextString(delim string) error { @@ -149,9 +152,6 @@ func (r *reader) nextDollarIdentifier() error { switch { case c == ';': return EndOfStatement - case unicode.IsSpace(c): - // this identifier is done - return NotThisToken case unicode.IsDigit(c): if start == pre { return NotThisToken @@ -187,8 +187,13 @@ func (r *reader) nextArgument() (string, error) { return r.v[start:pre], nil case c == ';': return r.v[start:pre], EndOfStatement - case c == '\'', c == '"': - err := r.nextString(string(c)) + case c == '\'': + err := r.nextString("'") + if err != nil { + return r.v[start:r.p], err + } + case c == '"': + err := r.nextString("\"") if err != nil { return r.v[start:r.p], err } @@ -250,7 +255,7 @@ func (r *reader) nextCommand() (cmd Command, err error) { } if err != nil { - if errors.Is(err, EndOfStatement) { + if err == EndOfStatement { err = nil } return @@ -274,7 +279,7 @@ func Parse(sql string) (cmds []Command, err error) { } if err != nil { - if errors.Is(err, EndOfSQL) { + if err == EndOfSQL { err = nil } return diff --git a/spec/protocol/frontend.yaml b/spec/protocol/frontend.yaml index fcb694e0846a2251c0a427e2d411a408a5fce21f..4188734c8cb0c601cee34abd0609d8966e76a2f1 100644 --- a/spec/protocol/frontend.yaml +++ b/spec/protocol/frontend.yaml @@ -5,7 +5,7 @@ Bind: Type: string - Name: PreparedStatement Type: string - - Name: FormatCodes + - Name: ParameterFormatCodes Type: int16 LengthPrefixed: int16 - Name: ParameterValues @@ -15,7 +15,7 @@ Bind: Type: byte LengthPrefixed: int32 LengthPrefixed: int16 - - Name: ResultColumnFormatCodes + - Name: ResultFormatCodes Type: int16 LengthPrefixed: int16 Close: