diff --git a/codegen/protocol/packets.tmpl b/codegen/protocol/packets.tmpl index 83537d321e43495bc2ae07e82563a83498f238e1..63c411e8852a3db38d4a0ca5c5c4110e6adf0dab 100644 --- a/codegen/protocol/packets.tmpl +++ b/codegen/protocol/packets.tmpl @@ -10,13 +10,10 @@ import ( var _ bytes.Buffer var _ io.Reader -{{define "fieldType" -}} +{{define "fieldTypeNoArr" -}} {{$parent := (index . 0) -}} {{$field := (index . 1) -}} - {{if or $field.LengthPrefixed $field.ArrayLength -}} - [] - {{- end -}} {{if $field.Struct -}} Fields{{$parent}}{{$field.Name}} {{- else if $field.Type -}} @@ -24,6 +21,16 @@ var _ io.Reader {{- end -}} {{end -}} +{{define "fieldType" -}} + {{$parent := (index . 0) -}} + {{$field := (index . 1) -}} + + {{if or $field.LengthPrefixed $field.ArrayLength $field.While -}} + [] + {{- end -}} + {{template "fieldTypeNoArr" (list $parent $field) -}} +{{end -}} + {{define "declareField" -}} {{$parent := (index . 0) -}} {{$field := (index . 1) -}} @@ -75,6 +82,14 @@ var _ io.Reader for i := 0; i < int({{$field.Name}}Length); i++ { {{template "readFieldL1" (list $parent $field (printf "T.%s[i]" $field.Name))}} } + {{else if $field.While -}} + var P {{template "fieldTypeNoArr" (list $parent $field)}} + for ok := true; ok; ok = {{$field.While}} { + {{template "readFieldL1" (list $parent $field "P")}} + T.{{$field.Name}} = append(T.{{$field.Name}}, P) + var newp {{template "fieldTypeNoArr" (list $parent $field)}} + P = newp + } {{else -}} {{template "readFieldL1" (list $parent $field (printf "T.%s" $field.Name))}} {{end -}} @@ -120,7 +135,7 @@ var _ io.Reader length += temp {{end -}} - {{if or $field.LengthPrefixed $field.ArrayLength -}} + {{if or $field.LengthPrefixed $field.ArrayLength $field.While -}} for _, v := range T.{{$field.Name}} { {{template "writeFieldL1" (list $parent $field "v")}} } @@ -170,7 +185,7 @@ var _ io.Reader {{template "declareFields" (list $name $packet) -}} type {{$name}} struct { - fields Fields{{$name}} + Fields Fields{{$name}} } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -180,13 +195,13 @@ var _ io.Reader if err != nil { return } - return T.fields.Read(int(length - 4), reader) + return T.Fields.Read(int(length - 4), reader) } func (T *{{$name}}) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return diff --git a/lib/gat/client.go b/lib/gat/client.go index 288d3c9e3a41db10b4e695eda889e24e8f987ed2..95b0b7176298e120919677fc6d471c15933a6bca 100644 --- a/lib/gat/client.go +++ b/lib/gat/client.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "errors" "fmt" + "gfx.cafe/gfx/pggat/lib/gat/protocol" "io" "math/big" "net" @@ -101,11 +102,17 @@ func NewClient( } func (c *Client) Accept(ctx context.Context) error { - params, err := ReadStartup(c.r) + startup := new(protocol.StartupMessage) + err := startup.Read(c.r) if err != nil { return err } + params := make(map[string]string) + for _, v := range startup.Fields.Parameters { + params[v.Name] = v.Value + } + var ok bool c.pool_name, ok = params["database"] if !ok { diff --git a/lib/gat/protocol/backend.go b/lib/gat/protocol/backend.go index 618b93b1c22e8fe5dbf2a8bc6fcf5de62013cf91..49911d700cd6a542ee5ec5976e7556bfe34aa5d4 100644 --- a/lib/gat/protocol/backend.go +++ b/lib/gat/protocol/backend.go @@ -40,7 +40,7 @@ func (T *FieldsAuthentication) Write(writer io.Writer) (length int, err error) { } type Authentication struct { - fields FieldsAuthentication + Fields FieldsAuthentication } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -50,13 +50,13 @@ func (T *Authentication) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Authentication) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -110,7 +110,7 @@ func (T *FieldsBackendKeyData) Write(writer io.Writer) (length int, err error) { } type BackendKeyData struct { - fields FieldsBackendKeyData + Fields FieldsBackendKeyData } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -120,13 +120,13 @@ func (T *BackendKeyData) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *BackendKeyData) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -160,7 +160,7 @@ func (T *FieldsBindComplete) Write(writer io.Writer) (length int, err error) { } type BindComplete struct { - fields FieldsBindComplete + Fields FieldsBindComplete } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -170,13 +170,13 @@ func (T *BindComplete) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *BindComplete) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -210,7 +210,7 @@ func (T *FieldsCloseComplete) Write(writer io.Writer) (length int, err error) { } type CloseComplete struct { - fields FieldsCloseComplete + Fields FieldsCloseComplete } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -220,13 +220,13 @@ func (T *CloseComplete) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CloseComplete) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -270,7 +270,7 @@ func (T *FieldsCommandComplete) Write(writer io.Writer) (length int, err error) } type CommandComplete struct { - fields FieldsCommandComplete + Fields FieldsCommandComplete } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -280,13 +280,13 @@ func (T *CommandComplete) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CommandComplete) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -358,7 +358,7 @@ func (T *FieldsCopyBothResponse) Write(writer io.Writer) (length int, err error) } type CopyBothResponse struct { - fields FieldsCopyBothResponse + Fields FieldsCopyBothResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -368,13 +368,13 @@ func (T *CopyBothResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CopyBothResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -446,7 +446,7 @@ func (T *FieldsCopyInResponse) Write(writer io.Writer) (length int, err error) { } type CopyInResponse struct { - fields FieldsCopyInResponse + Fields FieldsCopyInResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -456,13 +456,13 @@ func (T *CopyInResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CopyInResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -534,7 +534,7 @@ func (T *FieldsCopyOutResponse) Write(writer io.Writer) (length int, err error) } type CopyOutResponse struct { - fields FieldsCopyOutResponse + Fields FieldsCopyOutResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -544,13 +544,13 @@ func (T *CopyOutResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CopyOutResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -653,7 +653,7 @@ func (T *FieldsDataRow) Write(writer io.Writer) (length int, err error) { } type DataRow struct { - fields FieldsDataRow + Fields FieldsDataRow } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -663,13 +663,13 @@ func (T *DataRow) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *DataRow) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -703,7 +703,7 @@ func (T *FieldsEmptyQueryResponse) Write(writer io.Writer) (length int, err erro } type EmptyQueryResponse struct { - fields FieldsEmptyQueryResponse + Fields FieldsEmptyQueryResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -713,13 +713,13 @@ func (T *EmptyQueryResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *EmptyQueryResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -739,41 +739,76 @@ func (T *EmptyQueryResponse) Write(writer io.Writer) (length int, err error) { return } -type FieldsErrorResponse struct { +type FieldsErrorResponseResponses struct { Code byte Value string } -func (T *FieldsErrorResponse) Read(payloadLength int, reader io.Reader) (err error) { +func (T *FieldsErrorResponseResponses) Read(payloadLength int, reader io.Reader) (err error) { T.Code, err = ReadByte(reader) if err != nil { return } - T.Value, err = ReadString(reader) - if err != nil { - return + if T.Code != 0 { + T.Value, err = ReadString(reader) + if err != nil { + return + } } return } -func (T *FieldsErrorResponse) Write(writer io.Writer) (length int, err error) { +func (T *FieldsErrorResponseResponses) Write(writer io.Writer) (length int, err error) { var temp int temp, err = WriteByte(writer, T.Code) if err != nil { return } length += temp - temp, err = WriteString(writer, T.Value) - if err != nil { - return + if T.Code != 0 { + temp, err = WriteString(writer, T.Value) + if err != nil { + return + } + length += temp + } + _ = temp + return +} + +type FieldsErrorResponse struct { + Responses []FieldsErrorResponseResponses +} + +func (T *FieldsErrorResponse) Read(payloadLength int, reader io.Reader) (err error) { + var P FieldsErrorResponseResponses + for ok := true; ok; ok = P.Code != 0 { + err = P.Read(payloadLength, reader) + if err != nil { + return + } + T.Responses = append(T.Responses, P) + var newp FieldsErrorResponseResponses + P = newp + } + return +} + +func (T *FieldsErrorResponse) Write(writer io.Writer) (length int, err error) { + var temp int + for _, v := range T.Responses { + temp, err = v.Write(writer) + if err != nil { + return + } + length += temp } - length += temp _ = temp return } type ErrorResponse struct { - fields FieldsErrorResponse + Fields FieldsErrorResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -783,13 +818,13 @@ func (T *ErrorResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *ErrorResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -851,7 +886,7 @@ func (T *FieldsFunctionCallResponse) Write(writer io.Writer) (length int, err er } type FunctionCallResponse struct { - fields FieldsFunctionCallResponse + Fields FieldsFunctionCallResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -861,13 +896,13 @@ func (T *FunctionCallResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *FunctionCallResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -939,7 +974,7 @@ func (T *FieldsNegotiateProtocolVersion) Write(writer io.Writer) (length int, er } type NegotiateProtocolVersion struct { - fields FieldsNegotiateProtocolVersion + Fields FieldsNegotiateProtocolVersion } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -949,13 +984,13 @@ func (T *NegotiateProtocolVersion) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *NegotiateProtocolVersion) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -989,7 +1024,7 @@ func (T *FieldsNoData) Write(writer io.Writer) (length int, err error) { } type NoData struct { - fields FieldsNoData + Fields FieldsNoData } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -999,13 +1034,13 @@ func (T *NoData) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *NoData) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1025,17 +1060,17 @@ func (T *NoData) Write(writer io.Writer) (length int, err error) { return } -type FieldsNoticeResponse struct { - Type byte +type FieldsNoticeResponseResponses struct { + Code byte Value string } -func (T *FieldsNoticeResponse) Read(payloadLength int, reader io.Reader) (err error) { - T.Type, err = ReadByte(reader) +func (T *FieldsNoticeResponseResponses) Read(payloadLength int, reader io.Reader) (err error) { + T.Code, err = ReadByte(reader) if err != nil { return } - if T.Type != 0 { + if T.Code != 0 { T.Value, err = ReadString(reader) if err != nil { return @@ -1044,14 +1079,14 @@ func (T *FieldsNoticeResponse) Read(payloadLength int, reader io.Reader) (err er return } -func (T *FieldsNoticeResponse) Write(writer io.Writer) (length int, err error) { +func (T *FieldsNoticeResponseResponses) Write(writer io.Writer) (length int, err error) { var temp int - temp, err = WriteByte(writer, T.Type) + temp, err = WriteByte(writer, T.Code) if err != nil { return } length += temp - if T.Type != 0 { + if T.Code != 0 { temp, err = WriteString(writer, T.Value) if err != nil { return @@ -1062,8 +1097,39 @@ func (T *FieldsNoticeResponse) Write(writer io.Writer) (length int, err error) { return } +type FieldsNoticeResponse struct { + Responses []FieldsNoticeResponseResponses +} + +func (T *FieldsNoticeResponse) Read(payloadLength int, reader io.Reader) (err error) { + var P FieldsNoticeResponseResponses + for ok := true; ok; ok = P.Code != 0 { + err = P.Read(payloadLength, reader) + if err != nil { + return + } + T.Responses = append(T.Responses, P) + var newp FieldsNoticeResponseResponses + P = newp + } + return +} + +func (T *FieldsNoticeResponse) Write(writer io.Writer) (length int, err error) { + var temp int + for _, v := range T.Responses { + temp, err = v.Write(writer) + if err != nil { + return + } + length += temp + } + _ = temp + return +} + type NoticeResponse struct { - fields FieldsNoticeResponse + Fields FieldsNoticeResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1073,13 +1139,13 @@ func (T *NoticeResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *NoticeResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1143,7 +1209,7 @@ func (T *FieldsNotificationResponse) Write(writer io.Writer) (length int, err er } type NotificationResponse struct { - fields FieldsNotificationResponse + Fields FieldsNotificationResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1153,13 +1219,13 @@ func (T *NotificationResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *NotificationResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1221,7 +1287,7 @@ func (T *FieldsParameterDescription) Write(writer io.Writer) (length int, err er } type ParameterDescription struct { - fields FieldsParameterDescription + Fields FieldsParameterDescription } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1231,13 +1297,13 @@ func (T *ParameterDescription) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *ParameterDescription) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1291,7 +1357,7 @@ func (T *FieldsParameterStatus) Write(writer io.Writer) (length int, err error) } type ParameterStatus struct { - fields FieldsParameterStatus + Fields FieldsParameterStatus } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1301,13 +1367,13 @@ func (T *ParameterStatus) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *ParameterStatus) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1341,7 +1407,7 @@ func (T *FieldsParseComplete) Write(writer io.Writer) (length int, err error) { } type ParseComplete struct { - fields FieldsParseComplete + Fields FieldsParseComplete } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1351,13 +1417,13 @@ func (T *ParseComplete) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *ParseComplete) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1391,7 +1457,7 @@ func (T *FieldsPortalSuspended) Write(writer io.Writer) (length int, err error) } type PortalSuspended struct { - fields FieldsPortalSuspended + Fields FieldsPortalSuspended } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1401,13 +1467,13 @@ func (T *PortalSuspended) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *PortalSuspended) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1441,7 +1507,7 @@ func (T *FieldsReadForQuery) Write(writer io.Writer) (length int, err error) { } type ReadForQuery struct { - fields FieldsReadForQuery + Fields FieldsReadForQuery } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1451,13 +1517,13 @@ func (T *ReadForQuery) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *ReadForQuery) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1602,7 +1668,7 @@ func (T *FieldsRowDescription) Write(writer io.Writer) (length int, err error) { } type RowDescription struct { - fields FieldsRowDescription + Fields FieldsRowDescription } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1612,13 +1678,13 @@ func (T *RowDescription) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *RowDescription) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return diff --git a/lib/gat/protocol/frontend.go b/lib/gat/protocol/frontend.go index dad1d9c0bcb0ccc6e43cce27a8467a0d65028e7b..118dcf0d40287e800b6ff7ca3df8586993996606 100644 --- a/lib/gat/protocol/frontend.go +++ b/lib/gat/protocol/frontend.go @@ -169,7 +169,7 @@ func (T *FieldsBind) Write(writer io.Writer) (length int, err error) { } type Bind struct { - fields FieldsBind + Fields FieldsBind } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -179,13 +179,13 @@ func (T *Bind) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Bind) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -249,7 +249,7 @@ func (T *FieldsCancelRequest) Write(writer io.Writer) (length int, err error) { } type CancelRequest struct { - fields FieldsCancelRequest + Fields FieldsCancelRequest } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -259,13 +259,13 @@ func (T *CancelRequest) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CancelRequest) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -319,7 +319,7 @@ func (T *FieldsClose) Write(writer io.Writer) (length int, err error) { } type Close struct { - fields FieldsClose + Fields FieldsClose } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -329,13 +329,13 @@ func (T *Close) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Close) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -379,7 +379,7 @@ func (T *FieldsCopyFail) Write(writer io.Writer) (length int, err error) { } type CopyFail struct { - fields FieldsCopyFail + Fields FieldsCopyFail } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -389,13 +389,13 @@ func (T *CopyFail) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CopyFail) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -449,7 +449,7 @@ func (T *FieldsDescribe) Write(writer io.Writer) (length int, err error) { } type Describe struct { - fields FieldsDescribe + Fields FieldsDescribe } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -459,13 +459,13 @@ func (T *Describe) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Describe) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -519,7 +519,7 @@ func (T *FieldsExecute) Write(writer io.Writer) (length int, err error) { } type Execute struct { - fields FieldsExecute + Fields FieldsExecute } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -529,13 +529,13 @@ func (T *Execute) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Execute) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -569,7 +569,7 @@ func (T *FieldsFlush) Write(writer io.Writer) (length int, err error) { } type Flush struct { - fields FieldsFlush + Fields FieldsFlush } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -579,13 +579,13 @@ func (T *Flush) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Flush) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -736,7 +736,7 @@ func (T *FieldsFunctionCall) Write(writer io.Writer) (length int, err error) { } type FunctionCall struct { - fields FieldsFunctionCall + Fields FieldsFunctionCall } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -746,13 +746,13 @@ func (T *FunctionCall) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *FunctionCall) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -796,7 +796,7 @@ func (T *FieldsGSSENCRequest) Write(writer io.Writer) (length int, err error) { } type GSSENCRequest struct { - fields FieldsGSSENCRequest + Fields FieldsGSSENCRequest } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -806,13 +806,13 @@ func (T *GSSENCRequest) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *GSSENCRequest) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -857,7 +857,7 @@ func (T *FieldsGSSResponse) Write(writer io.Writer) (length int, err error) { } type GSSResponse struct { - fields FieldsGSSResponse + Fields FieldsGSSResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -867,13 +867,13 @@ func (T *GSSResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *GSSResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -955,7 +955,7 @@ func (T *FieldsParse) Write(writer io.Writer) (length int, err error) { } type Parse struct { - fields FieldsParse + Fields FieldsParse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -965,13 +965,13 @@ func (T *Parse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Parse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1015,7 +1015,7 @@ func (T *FieldsPasswordMessage) Write(writer io.Writer) (length int, err error) } type PasswordMessage struct { - fields FieldsPasswordMessage + Fields FieldsPasswordMessage } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1025,13 +1025,13 @@ func (T *PasswordMessage) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *PasswordMessage) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1075,7 +1075,7 @@ func (T *FieldsQuery) Write(writer io.Writer) (length int, err error) { } type Query struct { - fields FieldsQuery + Fields FieldsQuery } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1085,13 +1085,13 @@ func (T *Query) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Query) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1163,7 +1163,7 @@ func (T *FieldsSASLInitialResponse) Write(writer io.Writer) (length int, err err } type SASLInitialResponse struct { - fields FieldsSASLInitialResponse + Fields FieldsSASLInitialResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1173,13 +1173,13 @@ func (T *SASLInitialResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *SASLInitialResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1229,7 +1229,7 @@ func (T *FieldsSASLResponse) Write(writer io.Writer) (length int, err error) { } type SASLResponse struct { - fields FieldsSASLResponse + Fields FieldsSASLResponse } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1239,13 +1239,13 @@ func (T *SASLResponse) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *SASLResponse) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1289,7 +1289,7 @@ func (T *FieldsSSLRequest) Write(writer io.Writer) (length int, err error) { } type SSLRequest struct { - fields FieldsSSLRequest + Fields FieldsSSLRequest } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1299,13 +1299,13 @@ func (T *SSLRequest) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *SSLRequest) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1320,25 +1320,63 @@ func (T *SSLRequest) Write(writer io.Writer) (length int, err error) { return } -type FieldsStartupMessage struct { - ProtocolVersionNumber int32 - ParameterName string - ParameterValue string +type FieldsStartupMessageParameters struct { + Name string + Value string } -func (T *FieldsStartupMessage) Read(payloadLength int, reader io.Reader) (err error) { - T.ProtocolVersionNumber, err = ReadInt32(reader) +func (T *FieldsStartupMessageParameters) Read(payloadLength int, reader io.Reader) (err error) { + T.Name, err = ReadString(reader) if err != nil { return } - T.ParameterName, err = ReadString(reader) + if T.Name != "" { + T.Value, err = ReadString(reader) + if err != nil { + return + } + } + return +} + +func (T *FieldsStartupMessageParameters) Write(writer io.Writer) (length int, err error) { + var temp int + temp, err = WriteString(writer, T.Name) if err != nil { return } - T.ParameterValue, err = ReadString(reader) + length += temp + if T.Name != "" { + temp, err = WriteString(writer, T.Value) + if err != nil { + return + } + length += temp + } + _ = temp + return +} + +type FieldsStartupMessage struct { + ProtocolVersionNumber int32 + Parameters []FieldsStartupMessageParameters +} + +func (T *FieldsStartupMessage) Read(payloadLength int, reader io.Reader) (err error) { + T.ProtocolVersionNumber, err = ReadInt32(reader) if err != nil { return } + var P FieldsStartupMessageParameters + for ok := true; ok; ok = P.Name != "" { + err = P.Read(payloadLength, reader) + if err != nil { + return + } + T.Parameters = append(T.Parameters, P) + var newp FieldsStartupMessageParameters + P = newp + } return } @@ -1349,22 +1387,19 @@ func (T *FieldsStartupMessage) Write(writer io.Writer) (length int, err error) { return } length += temp - temp, err = WriteString(writer, T.ParameterName) - if err != nil { - return - } - length += temp - temp, err = WriteString(writer, T.ParameterValue) - if err != nil { - return + for _, v := range T.Parameters { + temp, err = v.Write(writer) + if err != nil { + return + } + length += temp } - length += temp _ = temp return } type StartupMessage struct { - fields FieldsStartupMessage + Fields FieldsStartupMessage } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1374,13 +1409,13 @@ func (T *StartupMessage) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *StartupMessage) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1409,7 +1444,7 @@ func (T *FieldsSync) Write(writer io.Writer) (length int, err error) { } type Sync struct { - fields FieldsSync + Fields FieldsSync } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1419,13 +1454,13 @@ func (T *Sync) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Sync) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -1459,7 +1494,7 @@ func (T *FieldsTerminate) Write(writer io.Writer) (length int, err error) { } type Terminate struct { - fields FieldsTerminate + Fields FieldsTerminate } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -1469,13 +1504,13 @@ func (T *Terminate) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *Terminate) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return diff --git a/lib/gat/protocol/io.go b/lib/gat/protocol/io.go index 8254750217adf1751f4b254061ad949231036b5d..46d0fd76f316961f38f370f7574c576b3e241e16 100644 --- a/lib/gat/protocol/io.go +++ b/lib/gat/protocol/io.go @@ -3,6 +3,7 @@ package protocol import ( "encoding/binary" "io" + "strings" ) func ReadByte(reader io.Reader) (byte, error) { @@ -50,14 +51,17 @@ func ReadInt64(reader io.Reader) (int64, error) { } func ReadString(reader io.Reader) (string, error) { - // TODO i actually have no idea how they format strings, but i'm guessing it's UTF-8 with an int32 length prefix - length, err := ReadInt32(reader) - if err != nil { - return "", err + var builder strings.Builder + for { + b, err := ReadByte(reader) + if err != nil { + return "", err + } + if b == 0 { + return builder.String(), nil + } + builder.WriteByte(b) } - b := make([]byte, length) - _, err = reader.Read(b[:]) - return string(b), err } func WriteByte(writer io.Writer, value byte) (int, error) { @@ -101,14 +105,13 @@ func WriteInt64(writer io.Writer, value int64) (int, error) { } func WriteString(writer io.Writer, value string) (int, error) { - // TODO i actually have no idea how they format strings, but i'm guessing it's UTF-8 with an int32 length prefix - length, err := WriteInt32(writer, int32(len(value))) + _, err := writer.Write([]byte(value)) if err != nil { return 0, err } - _, err = writer.Write([]byte(value)) + _, err = WriteByte(writer, 0) if err != nil { return 0, err } - return length + len(value), nil + return len(value) + 1, nil } diff --git a/lib/gat/protocol/shared.go b/lib/gat/protocol/shared.go index 1062db83ca0efdc82e14efc3331b886a0aefae6a..a04184e548d247720887a10e3293224f19a7d98d 100644 --- a/lib/gat/protocol/shared.go +++ b/lib/gat/protocol/shared.go @@ -40,7 +40,7 @@ func (T *FieldsCopyData) Write(writer io.Writer) (length int, err error) { } type CopyData struct { - fields FieldsCopyData + Fields FieldsCopyData } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -50,13 +50,13 @@ func (T *CopyData) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CopyData) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return @@ -90,7 +90,7 @@ func (T *FieldsCopyDone) Write(writer io.Writer) (length int, err error) { } type CopyDone struct { - fields FieldsCopyDone + Fields FieldsCopyDone } // Read reads all but the packet identifier. Be sure to read that beforehand (if it exists) @@ -100,13 +100,13 @@ func (T *CopyDone) Read(reader io.Reader) (err error) { if err != nil { return } - return T.fields.Read(int(length-4), reader) + return T.Fields.Read(int(length-4), reader) } func (T *CopyDone) Write(writer io.Writer) (length int, err error) { // TODO replace with pool var buf bytes.Buffer - length, err = T.fields.Write(&buf) + length, err = T.Fields.Write(&buf) if err != nil { length = 0 return diff --git a/spec/protocol/backend.yaml b/spec/protocol/backend.yaml index 8ae0f06aff33cbeabf856a3c34dd02881e095528..467c2c3526fa984bcbbddbdb0bde77be9e5eeb3d 100644 --- a/spec/protocol/backend.yaml +++ b/spec/protocol/backend.yaml @@ -59,10 +59,15 @@ EmptyQueryResponse: ErrorResponse: Identifier: 'E' Fields: - - Name: Code - Type: byte - - Name: Value - Type: string + - Name: Responses + Struct: + Fields: + - Name: Code + Type: byte + - Name: Value + Type: string + If: 'T.Code != 0' + While: 'P.Code != 0' FunctionCallResponse: Identifier: 'V' Fields: @@ -82,11 +87,15 @@ NoData: NoticeResponse: Identifier: 'N' Fields: - - Name: Type - Type: byte - - Name: Value - Type: string - If: 'T.Type != 0' + - Name: Responses + Struct: + Fields: + - Name: Code + Type: byte + - Name: Value + Type: string + If: 'T.Code != 0' + While: 'P.Code != 0' NotificationResponse: Identifier: 'A' Fields: diff --git a/spec/protocol/frontend.yaml b/spec/protocol/frontend.yaml index b89615fd5ea33dd2e3ac93042506419a63f79fe4..70d2761921e5a4182e6faca598727b2bea0521d0 100644 --- a/spec/protocol/frontend.yaml +++ b/spec/protocol/frontend.yaml @@ -124,10 +124,15 @@ StartupMessage: Fields: - Name: ProtocolVersionNumber Type: int32 - - Name: ParameterName - Type: string - - Name: ParameterValue - Type: string + - Name: Parameters + Struct: + Fields: + - Name: Name + Type: string + - Name: Value + Type: string + If: 'T.Name != ""' + While: 'P.Name != ""' Sync: Identifier: 'S' Terminate: