diff --git a/codegen/protocol/main.go b/codegen/protocol/main.go index e1a72dcc82780fd5a0b7a5f69e1947f2bd7a938c..6e8b0db76f453eb33302f6ad44e0a760e978ac30 100644 --- a/codegen/protocol/main.go +++ b/codegen/protocol/main.go @@ -37,7 +37,8 @@ func main() { if err != nil { panic(err) } - all := make(map[string]any) + backend := make(map[string]any) + frontend := make(map[string]any) var out bytes.Buffer for _, e := range f { var b []byte @@ -50,9 +51,20 @@ func main() { if err != nil { panic(err) } - - for k, v := range packets { - all[k] = v + switch e.Name() { + case "backend.yaml": + for k, v := range packets { + backend[k] = v + } + case "frontend.yaml": + for k, v := range packets { + frontend[k] = v + } + default: + for k, v := range packets { + backend[k] = v + frontend[k] = v + } } err = t.Execute(&out, packets) @@ -76,7 +88,10 @@ func main() { } t = template.Must(template.New("mod.tmpl").Funcs(funcs).ParseFiles(filepath.Join(CODEGEN, "mod.tmpl"))) - err = t.Execute(&out, all) + err = t.Execute(&out, map[string]any{ + "BackEnd": backend, + "FrontEnd": frontend, + }) if err != nil { panic(err) } diff --git a/codegen/protocol/mod.tmpl b/codegen/protocol/mod.tmpl index 959d0e8d4cf9bedec53fdfded967777bdf89cdd7..27404df78443121d0faef41dbfc6700ad1a9556a 100644 --- a/codegen/protocol/mod.tmpl +++ b/codegen/protocol/mod.tmpl @@ -7,9 +7,9 @@ type Packet interface { Write(writer io.Writer) (int, error) } -// Read switches on the identifier and returns the matching packet +// ReadFrontend switches on frontend packet identifiers and returns the matching packet // DO NOT call this function if the packet in queue does not have an identifier -func Read(reader io.Reader) (packet Packet, err error) { +func ReadFrontend(reader io.Reader) (packet Packet, err error) { var identifier byte identifier, err = ReadByte(reader) if err != nil { @@ -17,7 +17,33 @@ func Read(reader io.Reader) (packet Packet, err error) { } switch identifier { - {{range $name, $packet := . -}} + {{range $name, $packet := .FrontEnd -}} + {{if $packet.Identifier -}} + case byte('{{$packet.Identifier}}'): + packet = new({{$name}}) + {{end -}} + {{end -}} + } + + err = packet.Read(reader) + if err != nil { + return + } + + return +} + +// ReadBackend switches on backend packet identifier and returns the matching packet +// DO NOT call this function if the packet in queue does not have an identifier +func ReadBackend(reader io.Reader) (packet Packet, err error) { + var identifier byte + identifier, err = ReadByte(reader) + if err != nil { + return + } + + switch identifier { + {{range $name, $packet := .BackEnd -}} {{if $packet.Identifier -}} case byte('{{$packet.Identifier}}'): packet = new({{$name}}) diff --git a/lib/gat/protocol/frontend.go b/lib/gat/protocol/frontend.go index 59b97a21ecdd51a6465207559bacf207ba51cd5f..45cea8682f8539023d9f10b74dbee03d1039a130 100644 --- a/lib/gat/protocol/frontend.go +++ b/lib/gat/protocol/frontend.go @@ -10,6 +10,75 @@ import ( var _ bytes.Buffer var _ io.Reader +type FieldsAuthenticationResponse struct { + Data []byte +} + +func (T *FieldsAuthenticationResponse) Read(payloadLength int, reader io.Reader) (err error) { + DataLength := payloadLength + T.Data = make([]byte, int(DataLength)) + for i := 0; i < int(DataLength); i++ { + T.Data[i], err = ReadByte(reader) + if err != nil { + return + } + } + return +} + +func (T *FieldsAuthenticationResponse) Write(writer io.Writer) (length int, err error) { + var temp int + for _, v := range T.Data { + temp, err = WriteByte(writer, v) + if err != nil { + return + } + length += temp + } + _ = temp + return +} + +type AuthenticationResponse struct { + Fields FieldsAuthenticationResponse +} + +// Read reads all but the packet identifier +// WARNING: This packet DOES have an identifier. Call protocol.Read or trim the identifier first! +func (T *AuthenticationResponse) Read(reader io.Reader) (err error) { + var length int32 + length, err = ReadInt32(reader) + if err != nil { + return + } + return T.Fields.Read(int(length-4), reader) +} + +func (T *AuthenticationResponse) Write(writer io.Writer) (length int, err error) { + // TODO replace with pool + var buf bytes.Buffer + length, err = T.Fields.Write(&buf) + if err != nil { + length = 0 + return + } + _, err = WriteByte(writer, byte('p')) + if err != nil { + length = 1 + return + } + _, err = WriteInt32(writer, int32(length)) + if err != nil { + length = 5 + return + } + length += 5 + _, err = writer.Write(buf.Bytes()) + return +} + +var _ Packet = (*AuthenticationResponse)(nil) + type FieldsBindParameterValues struct { Value []byte } @@ -256,7 +325,6 @@ type CancelRequest struct { } // Read reads all but the packet identifier -// WARNING: This packet DOES have an identifier. Call protocol.Read or trim the identifier first! func (T *CancelRequest) Read(reader io.Reader) (err error) { var length int32 length, err = ReadInt32(reader) @@ -274,11 +342,6 @@ func (T *CancelRequest) Write(writer io.Writer) (length int, err error) { length = 0 return } - _, err = WriteByte(writer, byte('F')) - if err != nil { - length = 1 - return - } _, err = WriteInt32(writer, int32(length)) if err != nil { length = 5 @@ -853,75 +916,6 @@ func (T *GSSENCRequest) Write(writer io.Writer) (length int, err error) { var _ Packet = (*GSSENCRequest)(nil) -type FieldsGSSResponse struct { - Data []byte -} - -func (T *FieldsGSSResponse) Read(payloadLength int, reader io.Reader) (err error) { - DataLength := payloadLength - T.Data = make([]byte, int(DataLength)) - for i := 0; i < int(DataLength); i++ { - T.Data[i], err = ReadByte(reader) - if err != nil { - return - } - } - return -} - -func (T *FieldsGSSResponse) Write(writer io.Writer) (length int, err error) { - var temp int - for _, v := range T.Data { - temp, err = WriteByte(writer, v) - if err != nil { - return - } - length += temp - } - _ = temp - return -} - -type GSSResponse struct { - Fields FieldsGSSResponse -} - -// Read reads all but the packet identifier -// WARNING: This packet DOES have an identifier. Call protocol.Read or trim the identifier first! -func (T *GSSResponse) Read(reader io.Reader) (err error) { - var length int32 - length, err = ReadInt32(reader) - if err != nil { - return - } - return T.Fields.Read(int(length-4), reader) -} - -func (T *GSSResponse) Write(writer io.Writer) (length int, err error) { - // TODO replace with pool - var buf bytes.Buffer - length, err = T.Fields.Write(&buf) - if err != nil { - length = 0 - return - } - _, err = WriteByte(writer, byte('p')) - if err != nil { - length = 1 - return - } - _, err = WriteInt32(writer, int32(length)) - if err != nil { - length = 5 - return - } - length += 5 - _, err = writer.Write(buf.Bytes()) - return -} - -var _ Packet = (*GSSResponse)(nil) - type FieldsParse struct { PreparedStatement string Query string @@ -1023,69 +1017,6 @@ func (T *Parse) Write(writer io.Writer) (length int, err error) { var _ Packet = (*Parse)(nil) -type FieldsPasswordMessage struct { - Password string -} - -func (T *FieldsPasswordMessage) Read(payloadLength int, reader io.Reader) (err error) { - T.Password, err = ReadString(reader) - if err != nil { - return - } - return -} - -func (T *FieldsPasswordMessage) Write(writer io.Writer) (length int, err error) { - var temp int - temp, err = WriteString(writer, T.Password) - if err != nil { - return - } - length += temp - _ = temp - return -} - -type PasswordMessage struct { - Fields FieldsPasswordMessage -} - -// Read reads all but the packet identifier -// WARNING: This packet DOES have an identifier. Call protocol.Read or trim the identifier first! -func (T *PasswordMessage) Read(reader io.Reader) (err error) { - var length int32 - length, err = ReadInt32(reader) - if err != nil { - return - } - return T.Fields.Read(int(length-4), reader) -} - -func (T *PasswordMessage) Write(writer io.Writer) (length int, err error) { - // TODO replace with pool - var buf bytes.Buffer - length, err = T.Fields.Write(&buf) - if err != nil { - length = 0 - return - } - _, err = WriteByte(writer, byte('p')) - if err != nil { - length = 1 - return - } - _, err = WriteInt32(writer, int32(length)) - if err != nil { - length = 5 - return - } - length += 5 - _, err = writer.Write(buf.Bytes()) - return -} - -var _ Packet = (*PasswordMessage)(nil) - type FieldsQuery struct { Query string } @@ -1149,166 +1080,6 @@ func (T *Query) Write(writer io.Writer) (length int, err error) { var _ Packet = (*Query)(nil) -type FieldsSASLInitialResponse struct { - Mechanism string - InitialResponse []byte -} - -func (T *FieldsSASLInitialResponse) Read(payloadLength int, reader io.Reader) (err error) { - T.Mechanism, err = ReadString(reader) - if err != nil { - return - } - var InitialResponseLength int32 - InitialResponseLength, err = ReadInt32(reader) - if err != nil { - return - } - if InitialResponseLength == int32(-1) { - InitialResponseLength = 0 - } - T.InitialResponse = make([]byte, int(InitialResponseLength)) - for i := 0; i < int(InitialResponseLength); i++ { - T.InitialResponse[i], err = ReadByte(reader) - if err != nil { - return - } - } - return -} - -func (T *FieldsSASLInitialResponse) Write(writer io.Writer) (length int, err error) { - var temp int - temp, err = WriteString(writer, T.Mechanism) - if err != nil { - return - } - length += temp - temp, err = WriteInt32(writer, int32(len(T.InitialResponse))) - if err != nil { - return - } - length += temp - for _, v := range T.InitialResponse { - temp, err = WriteByte(writer, v) - if err != nil { - return - } - length += temp - } - _ = temp - return -} - -type SASLInitialResponse struct { - Fields FieldsSASLInitialResponse -} - -// Read reads all but the packet identifier -// WARNING: This packet DOES have an identifier. Call protocol.Read or trim the identifier first! -func (T *SASLInitialResponse) Read(reader io.Reader) (err error) { - var length int32 - length, err = ReadInt32(reader) - if err != nil { - return - } - return T.Fields.Read(int(length-4), reader) -} - -func (T *SASLInitialResponse) Write(writer io.Writer) (length int, err error) { - // TODO replace with pool - var buf bytes.Buffer - length, err = T.Fields.Write(&buf) - if err != nil { - length = 0 - return - } - _, err = WriteByte(writer, byte('p')) - if err != nil { - length = 1 - return - } - _, err = WriteInt32(writer, int32(length)) - if err != nil { - length = 5 - return - } - length += 5 - _, err = writer.Write(buf.Bytes()) - return -} - -var _ Packet = (*SASLInitialResponse)(nil) - -type FieldsSASLResponse struct { - Data []byte -} - -func (T *FieldsSASLResponse) Read(payloadLength int, reader io.Reader) (err error) { - DataLength := payloadLength - T.Data = make([]byte, int(DataLength)) - for i := 0; i < int(DataLength); i++ { - T.Data[i], err = ReadByte(reader) - if err != nil { - return - } - } - return -} - -func (T *FieldsSASLResponse) Write(writer io.Writer) (length int, err error) { - var temp int - for _, v := range T.Data { - temp, err = WriteByte(writer, v) - if err != nil { - return - } - length += temp - } - _ = temp - return -} - -type SASLResponse struct { - Fields FieldsSASLResponse -} - -// Read reads all but the packet identifier -// WARNING: This packet DOES have an identifier. Call protocol.Read or trim the identifier first! -func (T *SASLResponse) Read(reader io.Reader) (err error) { - var length int32 - length, err = ReadInt32(reader) - if err != nil { - return - } - return T.Fields.Read(int(length-4), reader) -} - -func (T *SASLResponse) Write(writer io.Writer) (length int, err error) { - // TODO replace with pool - var buf bytes.Buffer - length, err = T.Fields.Write(&buf) - if err != nil { - length = 0 - return - } - _, err = WriteByte(writer, byte('p')) - if err != nil { - length = 1 - return - } - _, err = WriteInt32(writer, int32(length)) - if err != nil { - length = 5 - return - } - length += 5 - _, err = writer.Write(buf.Bytes()) - return -} - -var _ Packet = (*SASLResponse)(nil) - type FieldsSSLRequest struct { SSLRequestCode int32 } diff --git a/lib/gat/protocol/mod.go b/lib/gat/protocol/mod.go index a62b287d85cb1cfa3c4e64a3ebe0ba634bbad3a6..2a57b158640a5e789649f02ca730fc13f265d785 100644 --- a/lib/gat/protocol/mod.go +++ b/lib/gat/protocol/mod.go @@ -7,9 +7,57 @@ type Packet interface { Write(writer io.Writer) (int, error) } -// Read switches on the identifier and returns the matching packet +// ReadFrontend switches on frontend packet identifiers and returns the matching packet // DO NOT call this function if the packet in queue does not have an identifier -func Read(reader io.Reader) (packet Packet, err error) { +func ReadFrontend(reader io.Reader) (packet Packet, err error) { + var identifier byte + identifier, err = ReadByte(reader) + if err != nil { + return + } + + switch identifier { + case byte('p'): + packet = new(AuthenticationResponse) + case byte('B'): + packet = new(Bind) + case byte('C'): + packet = new(Close) + case byte('d'): + packet = new(CopyData) + case byte('c'): + packet = new(CopyDone) + case byte('f'): + packet = new(CopyFail) + case byte('D'): + packet = new(Describe) + case byte('E'): + packet = new(Execute) + case byte('H'): + packet = new(Flush) + case byte('F'): + packet = new(FunctionCall) + case byte('P'): + packet = new(Parse) + case byte('Q'): + packet = new(Query) + case byte('S'): + packet = new(Sync) + case byte('X'): + packet = new(Terminate) + } + + err = packet.Read(reader) + if err != nil { + return + } + + return +} + +// ReadBackend switches on backend packet identifier and returns the matching packet +// DO NOT call this function if the packet in queue does not have an identifier +func ReadBackend(reader io.Reader) (packet Packet, err error) { var identifier byte identifier, err = ReadByte(reader) if err != nil { @@ -21,14 +69,8 @@ func Read(reader io.Reader) (packet Packet, err error) { packet = new(Authentication) case byte('K'): packet = new(BackendKeyData) - case byte('B'): - packet = new(Bind) case byte('2'): packet = new(BindComplete) - case byte('F'): - packet = new(CancelRequest) - case byte('C'): - packet = new(Close) case byte('3'): packet = new(CloseComplete) case byte('C'): @@ -39,30 +81,18 @@ func Read(reader io.Reader) (packet Packet, err error) { packet = new(CopyData) case byte('c'): packet = new(CopyDone) - case byte('f'): - packet = new(CopyFail) case byte('G'): packet = new(CopyInResponse) case byte('H'): packet = new(CopyOutResponse) case byte('D'): packet = new(DataRow) - case byte('D'): - packet = new(Describe) case byte('I'): packet = new(EmptyQueryResponse) case byte('E'): packet = new(ErrorResponse) - case byte('E'): - packet = new(Execute) - case byte('H'): - packet = new(Flush) - case byte('F'): - packet = new(FunctionCall) case byte('V'): packet = new(FunctionCallResponse) - case byte('p'): - packet = new(GSSResponse) case byte('v'): packet = new(NegotiateProtocolVersion) case byte('n'): @@ -75,28 +105,14 @@ func Read(reader io.Reader) (packet Packet, err error) { packet = new(ParameterDescription) case byte('S'): packet = new(ParameterStatus) - case byte('P'): - packet = new(Parse) case byte('1'): packet = new(ParseComplete) - case byte('p'): - packet = new(PasswordMessage) case byte('s'): packet = new(PortalSuspended) - case byte('Q'): - packet = new(Query) case byte('Z'): packet = new(ReadyForQuery) case byte('T'): packet = new(RowDescription) - case byte('p'): - packet = new(SASLInitialResponse) - case byte('p'): - packet = new(SASLResponse) - case byte('S'): - packet = new(Sync) - case byte('X'): - packet = new(Terminate) } err = packet.Read(reader) diff --git a/spec/protocol/frontend.yaml b/spec/protocol/frontend.yaml index 70d2761921e5a4182e6faca598727b2bea0521d0..6ca2d85f97eb6a1eac07c784723a509e482f75fc 100644 --- a/spec/protocol/frontend.yaml +++ b/spec/protocol/frontend.yaml @@ -19,7 +19,6 @@ Bind: Type: int16 LengthPrefixed: int16 CancelRequest: - Identifier: 'F' Fields: - Name: RequestCode Type: int32 @@ -76,7 +75,7 @@ GSSENCRequest: Fields: - Name: EncryptionRequestCode Type: int32 -GSSResponse: +AuthenticationResponse: Identifier: 'p' Fields: - Name: Data @@ -92,30 +91,11 @@ Parse: - Name: ParameterDataTypes Type: int32 LengthPrefixed: int32 -PasswordMessage: - Identifier: 'p' - Fields: - - Name: Password - Type: string Query: Identifier: 'Q' Fields: - Name: Query Type: string -SASLInitialResponse: - Identifier: 'p' - Fields: - - Name: Mechanism - Type: string - - Name: InitialResponse - Type: byte - LengthPrefixed: int32 -SASLResponse: - Identifier: 'p' - Fields: - - Name: Data - Type: byte - ArrayLength: payloadLength SSLRequest: Fields: - Name: SSLRequestCode