diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 0880d0fa8b8c97f1b80a00198c28c66292b0c9b5..3a0d542cf99a8e5f7361ee3f1c0a4e1611269aec 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -2,8 +2,9 @@ package main import ( "io" + "net" - "pggat2/lib/frontend/frontends/v0" + "pggat2/lib/backend/backends/v0" "pggat2/lib/pnet" "pggat2/lib/pnet/packet" ) @@ -65,56 +66,56 @@ func (T *LogWriter) Write() packet.Out { var _ pnet.Writer = (*LogWriter)(nil) func main() { - frontend, err := frontends.NewListener() - if err != nil { - panic(err) - } - err = frontend.Listen() - if err != nil { - panic(err) - } /* - conn, err := net.Dial("tcp", "localhost:5432") + frontend, err := frontends.NewListener() if err != nil { panic(err) } - server, err := backends.NewServer(conn) + err = frontend.Listen() if err != nil { panic(err) } - readWriter := pnet.JoinedReadWriter{ - Reader: &TestReader{ - packets: []testPacket{ - { - typ: packet.Query, - bytes: []byte("select 1\x00"), - }, - { - typ: packet.Query, - bytes: []byte("set TimeZone = \"America/Denver\"\x00"), - }, - { - typ: packet.Query, - bytes: []byte("reset all\x00"), - }, + */ + conn, err := net.Dial("tcp", "localhost:5432") + if err != nil { + panic(err) + } + server := backends.NewServer(conn) + if server == nil { + panic("failed to connect to server") + } + readWriter := pnet.JoinedReadWriter{ + Reader: &TestReader{ + packets: []testPacket{ + { + typ: packet.Query, + bytes: []byte("select 1\x00"), + }, + { + typ: packet.Query, + bytes: []byte("set TimeZone = \"America/Denver\"\x00"), + }, + { + typ: packet.Query, + bytes: []byte("reset all\x00"), }, }, - Writer: &LogWriter{}, - } - err = server.Transaction(readWriter) - if err != nil { - panic(err) - } - err = server.Transaction(readWriter) - if err != nil { - panic(err) - } - err = server.Transaction(readWriter) - if err != nil { - panic(err) - } - // log.Println(server) - _ = server - _ = conn.Close() - */ + }, + Writer: &LogWriter{}, + } + perr := server.Transaction(readWriter) + if perr != nil { + panic(perr) + } + perr = server.Transaction(readWriter) + if perr != nil { + panic(perr) + } + perr = server.Transaction(readWriter) + if perr != nil { + panic(perr) + } + // log.Println(server) + _ = server + _ = conn.Close() } diff --git a/lib/backend/backends/v0/server.go b/lib/backend/backends/v0/server.go index b25a4f0f424eed97567c8659640d79b3ca889245..4b55aaa1f834fc9ee1d64c602bafefa72c576a2f 100644 --- a/lib/backend/backends/v0/server.go +++ b/lib/backend/backends/v0/server.go @@ -1,20 +1,19 @@ package backends import ( - "errors" + "fmt" "net" "pggat2/lib/auth/md5" "pggat2/lib/auth/sasl" "pggat2/lib/backend" + "pggat2/lib/perror" "pggat2/lib/pnet" "pggat2/lib/pnet/packet" + "pggat2/lib/pnet/packet/packets/v3.0" "pggat2/lib/util/decorator" ) -var ErrBadPacketFormat = errors.New("bad packet format") -var ErrProtocolError = errors.New("server sent unexpected packet") - type Server struct { noCopy decorator.NoCopy @@ -27,7 +26,7 @@ type Server struct { parameters map[string]string } -func NewServer(conn net.Conn) (*Server, error) { +func NewServer(conn net.Conn) *Server { server := &Server{ conn: conn, IOReader: pnet.MakeIOReader(conn), @@ -36,24 +35,25 @@ func NewServer(conn net.Conn) (*Server, error) { } err := server.accept() if err != nil { - return nil, err + panic(fmt.Sprint("failed to connect to server: ", err)) + return nil } - return server, nil + return server } -func (T *Server) authenticationSASLChallenge(mechanism sasl.Client) (bool, error) { +func (T *Server) authenticationSASLChallenge(mechanism sasl.Client) (bool, perror.Error) { in, err := T.Read() if err != nil { - return false, err + return false, perror.Wrap(err) } if in.Type() != packet.Authentication { - return false, ErrProtocolError + return false, pnet.ErrProtocolError } method, ok := in.Int32() if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } switch method { @@ -61,47 +61,39 @@ func (T *Server) authenticationSASLChallenge(mechanism sasl.Client) (bool, error // challenge response, err := mechanism.Continue(in.Remaining()) if err != nil { - return false, err + return false, perror.Wrap(err) } out := T.Write() - out.Type(packet.AuthenticationResponse) - out.Bytes(response) + packets.WriteAuthenticationResponse(out, response) err = out.Send() - return false, err + return false, perror.Wrap(err) case 12: // finish err = mechanism.Final(in.Remaining()) if err != nil { - return false, err + return false, perror.Wrap(err) } return true, nil default: - return false, ErrProtocolError + return false, pnet.ErrProtocolError } } -func (T *Server) authenticationSASL(mechanisms []string, username, password string) error { +func (T *Server) authenticationSASL(mechanisms []string, username, password string) perror.Error { mechanism, err := sasl.NewClient(mechanisms, username, password) if err != nil { - return err + return perror.Wrap(err) } initialResponse := mechanism.InitialResponse() out := T.Write() - out.Type(packet.AuthenticationResponse) - out.String(mechanism.Name()) - if initialResponse == nil { - out.Int32(-1) - } else { - out.Int32(int32(len(initialResponse))) - out.Bytes(initialResponse) - } + packets.WriteSASLInitialResponse(out, mechanism.Name(), initialResponse) err = out.Send() if err != nil { - return err + return perror.Wrap(err) } // challenge loop @@ -118,33 +110,35 @@ func (T *Server) authenticationSASL(mechanisms []string, username, password stri return nil } -func (T *Server) authenticationMD5(salt [4]byte, username, password string) error { +func (T *Server) authenticationMD5(salt [4]byte, username, password string) perror.Error { out := T.Write() - out.Type(packet.AuthenticationResponse) - out.String(md5.Encode(username, password, salt)) - return out.Send() + packets.WritePasswordMessage(out, md5.Encode(username, password, salt)) + return perror.Wrap(out.Send()) } -func (T *Server) authenticationCleartext(password string) error { +func (T *Server) authenticationCleartext(password string) perror.Error { out := T.Write() - out.Type(packet.AuthenticationResponse) - out.String(password) - return out.Send() + packets.WritePasswordMessage(out, password) + return perror.Wrap(out.Send()) } -func (T *Server) startup0(username, password string) (bool, error) { +func (T *Server) startup0(username, password string) (bool, perror.Error) { in, err := T.Read() if err != nil { - return false, err + return false, perror.Wrap(err) } switch in.Type() { case packet.ErrorResponse: - return false, errors.New("received error response") + perr, ok := packets.ReadErrorResponse(in) + if !ok { + return false, pnet.ErrBadPacketFormat + } + return false, perr case packet.Authentication: method, ok := in.Int32() if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } // they have more authentication methods than there are pokemon switch method { @@ -152,90 +146,106 @@ func (T *Server) startup0(username, password string) (bool, error) { // we're good to go, that was easy return true, nil case 2: - return false, errors.New("kerberos v5 is not supported") + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "kerberos v5 is not supported", + ) case 3: return false, T.authenticationCleartext(password) case 5: - var salt [4]byte - ok = in.Bytes(salt[:]) + salt, ok := packets.ReadAuthenticationMD5(in) if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } return false, T.authenticationMD5(salt, username, password) case 6: - return false, errors.New("scm credential is not supported") + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "scm credential is not supported", + ) case 7: - return false, errors.New("gss is not supported") + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "gss is not supported", + ) case 9: - return false, errors.New("sspi is not supported") + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "sspi is not supported", + ) case 10: // read list of mechanisms - var mechanisms []string - for { - mechanism, ok := in.String() - if !ok { - return false, ErrBadPacketFormat - } - if mechanism == "" { - break - } - mechanisms = append(mechanisms, mechanism) + mechanisms, ok := packets.ReadAuthenticationSASL(in) + if !ok { + return false, pnet.ErrBadPacketFormat } return false, T.authenticationSASL(mechanisms, username, password) default: - return false, errors.New("unknown authentication method") + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "unknown authentication method", + ) } case packet.NegotiateProtocolVersion: // we only support protocol 3.0 for now - return false, errors.New("server wanted to negotiate protocol version") + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "server wanted to negotiate protocol version", + ) default: - return false, ErrProtocolError + return false, pnet.ErrProtocolError } } -func (T *Server) parameterStatus(in packet.In) error { - parameter, ok := in.String() +func (T *Server) parameterStatus(in packet.In) perror.Error { + key, value, ok := packets.ReadParameterStatus(in) if !ok { - return ErrBadPacketFormat + return pnet.ErrBadPacketFormat } - value, ok := in.String() - if !ok { - return ErrBadPacketFormat - } - T.parameters[parameter] = value + T.parameters[key] = value return nil } -func (T *Server) startup1() (bool, error) { +func (T *Server) startup1() (bool, perror.Error) { in, err := T.Read() if err != nil { - return false, err + return false, perror.Wrap(err) } switch in.Type() { case packet.BackendKeyData: ok := in.Bytes(T.cancellationKey[:]) if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } return false, nil case packet.ParameterStatus: - err = T.parameterStatus(in) + err := T.parameterStatus(in) return false, err case packet.ReadyForQuery: return true, nil case packet.ErrorResponse: - return false, errors.New("received error response") + err, ok := packets.ReadErrorResponse(in) + if !ok { + return false, pnet.ErrBadPacketFormat + } + return false, err case packet.NoticeResponse: // TODO(garet) do something with notice return false, nil default: - return false, ErrProtocolError + return false, pnet.ErrProtocolError } } -func (T *Server) accept() error { +func (T *Server) accept() perror.Error { // we can re-use the memory for this pkt most of the way down because we don't pass this anywhere out := T.Write() out.Int16(3) @@ -247,7 +257,7 @@ func (T *Server) accept() error { err := out.Send() if err != nil { - return err + return perror.Wrap(err) } for { @@ -276,24 +286,10 @@ func (T *Server) accept() error { return nil } -func (T *Server) proxyIn(in packet.In) error { - out := T.Write() - out.Type(in.Type()) - out.Bytes(in.Full()) - return out.Send() -} - -func (T *Server) proxyOut(peer pnet.Writer, in packet.In) error { - out := peer.Write() - out.Type(in.Type()) - out.Bytes(in.Full()) - return out.Send() -} - -func (T *Server) query0(peer pnet.ReadWriter) (bool, error) { +func (T *Server) query0(peer pnet.ReadWriter) (bool, perror.Error) { in, err := T.Read() if err != nil { - return false, err + return false, perror.Wrap(err) } switch in.Type() { case packet.CommandComplete, @@ -302,29 +298,48 @@ func (T *Server) query0(peer pnet.ReadWriter) (bool, error) { packet.EmptyQueryResponse, packet.ErrorResponse, packet.NoticeResponse: - return false, T.proxyOut(peer, in) + out := peer.Write() + packet.Proxy(out, in) + err := out.Send() + return false, perror.Wrap(err) case packet.CopyInResponse: - return false, errors.New("not implemented") // TODO(garet) + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "not implemented", + ) // TODO(garet) case packet.CopyOutResponse: - return false, errors.New("not implemented") // TODO(garet) + return false, perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "not implemented", + ) // TODO(garet) case packet.ReadyForQuery: - return true, T.proxyOut(peer, in) + out := peer.Write() + packet.Proxy(out, in) + err := out.Send() + return true, perror.Wrap(err) case packet.ParameterStatus: - err = T.parameterStatus(in) + err := T.parameterStatus(in) if err != nil { return false, err } - return false, T.proxyOut(peer, in) + out := peer.Write() + packet.Proxy(out, in) + err = perror.Wrap(out.Send()) + return false, err default: - return false, ErrProtocolError + return false, pnet.ErrProtocolError } } -func (T *Server) query(peer pnet.ReadWriter, in packet.In) error { +func (T *Server) query(peer pnet.ReadWriter, in packet.In) perror.Error { // send in (initial query) to server - err := T.proxyIn(in) + out := T.Write() + packet.Proxy(out, in) + err := out.Send() if err != nil { - return err + return perror.Wrap(err) } for { @@ -340,16 +355,20 @@ func (T *Server) query(peer pnet.ReadWriter, in packet.In) error { } // Transaction handles a transaction from peer, returning when the transaction is complete -func (T *Server) Transaction(peer pnet.ReadWriter) error { +func (T *Server) Transaction(peer pnet.ReadWriter) perror.Error { in, err := peer.Read() if err != nil { - return err + return perror.Wrap(err) } switch in.Type() { case packet.Query: return T.query(peer, in) default: - return errors.New("unsupported operation") + return perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "unsupported operation", + ) } } diff --git a/lib/frontend/frontends/v0/client.go b/lib/frontend/frontends/v0/client.go index d58b378821c144e4966a95f1242ee50fba89945d..f3bacf91cf98680f42f84fdf477f912e1670d0f7 100644 --- a/lib/frontend/frontends/v0/client.go +++ b/lib/frontend/frontends/v0/client.go @@ -11,21 +11,10 @@ import ( "pggat2/lib/perror" "pggat2/lib/pnet" "pggat2/lib/pnet/packet" + packets "pggat2/lib/pnet/packet/packets/v3.0" "pggat2/lib/util/decorator" ) -var ErrBadPacketFormat = perror.New( - perror.FATAL, - perror.ProtocolViolation, - "Bad packet format", -) - -var ErrProtocolError = perror.New( - perror.FATAL, - perror.ProtocolViolation, - "Expected a different packet", -) - type Client struct { noCopy decorator.NoCopy @@ -56,28 +45,19 @@ func NewClient(conn net.Conn) *Client { return client } -func negotiateProtocolVersionPacket(pkt packet.Out, unsupportedOptions []string) { - pkt.Type(packet.NegotiateProtocolVersion) - pkt.Int32(0) - pkt.Int32(int32(len(unsupportedOptions))) - for _, v := range unsupportedOptions { - pkt.String(v) - } -} - func (T *Client) startup0() (bool, perror.Error) { pkt, err := T.ReadUntyped() if err != nil { - return false, perror.WrapError(err) + return false, perror.Wrap(err) } majorVersion, ok := pkt.Uint16() if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } minorVersion, ok := pkt.Uint16() if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } if majorVersion == 1234 { @@ -93,11 +73,11 @@ func (T *Client) startup0() (bool, perror.Error) { case 5679: // SSL is not supported yet err = T.WriteByte('N') - return false, perror.WrapError(err) + return false, perror.Wrap(err) case 5680: // GSSAPI is not supported yet err = T.WriteByte('N') - return false, perror.WrapError(err) + return false, perror.Wrap(err) default: return false, perror.New( perror.FATAL, @@ -120,7 +100,7 @@ func (T *Client) startup0() (bool, perror.Error) { for { key, ok := pkt.String() if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } if key == "" { break @@ -128,7 +108,7 @@ func (T *Client) startup0() (bool, perror.Error) { value, ok := pkt.String() if !ok { - return false, ErrBadPacketFormat + return false, pnet.ErrBadPacketFormat } switch key { @@ -161,11 +141,11 @@ func (T *Client) startup0() (bool, perror.Error) { if minorVersion != 0 || len(unsupportedOptions) > 0 { // negotiate protocol out := T.Write() - negotiateProtocolVersionPacket(out, unsupportedOptions) + packets.WriteNegotiateProtocolVersion(out, 0, unsupportedOptions) err = out.Send() if err != nil { - return false, perror.WrapError(err) + return false, perror.Wrap(err) } } @@ -183,53 +163,25 @@ func (T *Client) startup0() (bool, perror.Error) { return true, nil } -func authenticationSASLPacket(pkt packet.Out) { - pkt.Type(packet.Authentication) - pkt.Int32(10) - for _, mechanism := range sasl.Mechanisms { - pkt.String(mechanism) - } - pkt.String("") -} - -func authenticationSASLContinuePacket(pkt packet.Out, resp []byte) { - pkt.Type(packet.Authentication) - pkt.Int32(11) - pkt.Bytes(resp) -} - -func authenticationSASLFinalPacket(pkt packet.Out, resp []byte) { - pkt.Type(packet.Authentication) - pkt.Int32(12) - pkt.Bytes(resp) -} - func (T *Client) authenticationSASLInitial(username, password string) (sasl.Server, []byte, bool, perror.Error) { // check which authentication method the client wants in, err := T.Read() if err != nil { - return nil, nil, false, perror.WrapError(err) - } - if in.Type() != packet.AuthenticationResponse { - return nil, nil, false, ErrBadPacketFormat + return nil, nil, false, perror.Wrap(err) } - - mechanism, ok := in.String() + mechanism, initialResponse, ok := packets.ReadSASLInitialResponse(in) if !ok { - return nil, nil, false, ErrBadPacketFormat + return nil, nil, false, pnet.ErrBadPacketFormat } + tool, err := sasl.NewServer(mechanism, username, password) if err != nil { - return nil, nil, false, perror.WrapError(err) - } - _, ok = in.Int32() - if !ok { - return nil, nil, false, ErrBadPacketFormat + return nil, nil, false, perror.Wrap(err) } - resp, done, err := tool.InitialResponse(in.Remaining()) + resp, done, err := tool.InitialResponse(initialResponse) if err != nil { - return nil, nil, false, perror.WrapError(err) + return nil, nil, false, perror.Wrap(err) } return tool, resp, done, nil } @@ -237,25 +189,26 @@ func (T *Client) authenticationSASLInitial(username, password string) (sasl.Serv func (T *Client) authenticationSASLContinue(tool sasl.Server) ([]byte, bool, perror.Error) { in, err := T.Read() if err != nil { - return nil, false, perror.WrapError(err) + return nil, false, perror.Wrap(err) } - if in.Type() != packet.AuthenticationResponse { - return nil, false, ErrProtocolError + clientResp, ok := packets.ReadAuthenticationResponse(in) + if !ok { + return nil, false, pnet.ErrProtocolError } - resp, done, err := tool.Continue(in.Full()) + resp, done, err := tool.Continue(clientResp) if err != nil { - return nil, false, perror.WrapError(err) + return nil, false, perror.Wrap(err) } return resp, done, nil } func (T *Client) authenticationSASL(username, password string) perror.Error { out := T.Write() - authenticationSASLPacket(out) + packets.WriteAuthenticationSASL(out, sasl.Mechanisms) err := out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } tool, resp, done, perr := T.authenticationSASLInitial(username, password) @@ -266,18 +219,18 @@ func (T *Client) authenticationSASL(username, password string) perror.Error { } if done { out = T.Write() - authenticationSASLFinalPacket(out, resp) + packets.WriteAuthenticationSASLFinal(out, resp) err = out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } break } else { out = T.Write() - authenticationSASLContinuePacket(out, resp) + packets.WriteAuthenticationSASLContinue(out, resp) err = out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } } @@ -287,32 +240,26 @@ func (T *Client) authenticationSASL(username, password string) perror.Error { return nil } -func authenticationMD5Packet(pkt packet.Out, salt [4]byte) { - pkt.Type(packet.Authentication) - pkt.Uint32(5) - pkt.Bytes(salt[:]) -} - func (T *Client) authenticationMD5(username, password string) perror.Error { var salt [4]byte _, err := rand.Read(salt[:]) if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } // password time out := T.Write() - authenticationMD5Packet(out, salt) + packets.WriteAuthenticationMD5(out, salt) err = out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } // read password in, err := T.Read() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } if in.Type() != packet.AuthenticationResponse { @@ -325,7 +272,7 @@ func (T *Client) authenticationMD5(username, password string) perror.Error { pw, ok := in.String() if !ok { - return ErrBadPacketFormat + return pnet.ErrBadPacketFormat } if !md5.Check(username, password, salt, pw) { @@ -339,24 +286,19 @@ func (T *Client) authenticationMD5(username, password string) perror.Error { return nil } -func authenticationCleartextPacket(pkt packet.Out) { - pkt.Type(packet.Authentication) - pkt.Uint32(3) -} - func (T *Client) authenticationCleartext(password string) perror.Error { out := T.Write() - authenticationCleartextPacket(out) + packets.WriteAuthenticationCleartext(out) err := out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } // read password in, err := T.Read() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } if in.Type() != packet.AuthenticationResponse { @@ -369,7 +311,7 @@ func (T *Client) authenticationCleartext(password string) perror.Error { pw, ok := in.String() if !ok { - return ErrBadPacketFormat + return pnet.ErrBadPacketFormat } if pw != password { @@ -383,21 +325,6 @@ func (T *Client) authenticationCleartext(password string) perror.Error { return nil } -func authenticationOkPacket(pkt packet.Out) { - pkt.Type(packet.Authentication) - pkt.Uint32(0) -} - -func backendKeyDataPacket(pkt packet.Out, cancellationKey [8]byte) { - pkt.Type(packet.BackendKeyData) - pkt.Bytes(cancellationKey[:]) -} - -func readyForQueryPacket(pkt packet.Out, state byte) { - pkt.Type(packet.ReadyForQuery) - pkt.Uint8(state) -} - func (T *Client) accept() perror.Error { for { done, err := T.startup0() @@ -417,33 +344,33 @@ func (T *Client) accept() perror.Error { // send auth ok out := T.Write() - authenticationOkPacket(out) + packets.WriteAuthenticationOk(out) err := out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } // send backend key data _, err = rand.Read(T.cancellationKey[:]) if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } out = T.Write() - backendKeyDataPacket(out, T.cancellationKey) + packets.WriteBackendKeyData(out, T.cancellationKey) err = out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } // send ready for query out = T.Write() - readyForQueryPacket(out, 'I') + packets.WriteReadyForQuery(out, 'I') err = out.Send() if err != nil { - return perror.WrapError(err) + return perror.Wrap(err) } return nil @@ -452,24 +379,7 @@ func (T *Client) accept() perror.Error { func (T *Client) Close(err perror.Error) { if err != nil { out := T.Write() - out.Type(packet.ErrorResponse) - - out.Uint8('S') - out.String(string(err.Severity())) - - out.Uint8('C') - out.String(string(err.Code())) - - out.Uint8('M') - out.String(err.Message()) - - for _, field := range err.Extra() { - out.Uint8(uint8(field.Type)) - out.String(field.Value) - } - - out.Uint8(0) - + packets.WriteErrorResponse(out, err) _ = out.Send() } _ = T.conn.Close() diff --git a/lib/perror/wrap.go b/lib/perror/wrap.go index 53479d92e6e9ce6af032a2100cdba1c7ee087b9f..7a652d7b50c569f42fe02571ef8f4215a58c48ad 100644 --- a/lib/perror/wrap.go +++ b/lib/perror/wrap.go @@ -1,6 +1,6 @@ package perror -func WrapError(err error) Error { +func Wrap(err error) Error { if err == nil { return nil } diff --git a/lib/pnet/errors.go b/lib/pnet/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..d094bd7ffaf36afa9648908e0f7203184e6fa12b --- /dev/null +++ b/lib/pnet/errors.go @@ -0,0 +1,15 @@ +package pnet + +import "pggat2/lib/perror" + +var ErrBadPacketFormat = perror.New( + perror.FATAL, + perror.ProtocolViolation, + "Bad packet format", +) + +var ErrProtocolError = perror.New( + perror.FATAL, + perror.ProtocolViolation, + "Unexpected packet", +) diff --git a/lib/pnet/packet/in.go b/lib/pnet/packet/in.go index 491933ba6cee2697427802eb4779215af888c0cd..bc98448c31b2a95a51f23bd7ecc601f3828efc84 100644 --- a/lib/pnet/packet/in.go +++ b/lib/pnet/packet/in.go @@ -66,6 +66,13 @@ func (T In) Remaining() []byte { return full[T.buf.pos:] } +func (T In) Reset() { + if T.done() { + panic("Read after Send") + } + T.buf.pos = 0 +} + func (T In) Int8() (int8, bool) { v, ok := T.Uint8() return int8(v), ok diff --git a/lib/pnet/packet/out.go b/lib/pnet/packet/out.go index baff707be28c4e3bee27e343f97f7ca69aaf5e2d..b014ad117d4546ecdf282840eb9e3649307ba4ae 100644 --- a/lib/pnet/packet/out.go +++ b/lib/pnet/packet/out.go @@ -54,6 +54,13 @@ func (T Out) Type(typ Type) { T.buf.typ = typ } +func (T Out) Reset() { + if T.done() { + panic("Write after Send") + } + T.buf.buf = T.buf.buf[:0] +} + func (T Out) Int8(v int8) { T.Uint8(uint8(v)) } diff --git a/lib/pnet/packet/packets/v3.0/authenticationcleartext.go b/lib/pnet/packet/packets/v3.0/authenticationcleartext.go new file mode 100644 index 0000000000000000000000000000000000000000..2464c5c60872f0309be025fb34e117de1f685c87 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/authenticationcleartext.go @@ -0,0 +1,24 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadAuthenticationCleartext(in packet.In) bool { + in.Reset() + if in.Type() != packet.Authentication { + return false + } + method, ok := in.Int32() + if !ok { + return false + } + if method != 3 { + return false + } + return true +} + +func WriteAuthenticationCleartext(out packet.Out) { + out.Reset() + out.Type(packet.Authentication) + out.Int32(3) +} diff --git a/lib/pnet/packet/packets/v3.0/authenticationmd5.go b/lib/pnet/packet/packets/v3.0/authenticationmd5.go new file mode 100644 index 0000000000000000000000000000000000000000..bde2d1968d22ed9ce9aa108bb04dc641801f8a89 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/authenticationmd5.go @@ -0,0 +1,30 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadAuthenticationMD5(in packet.In) ([4]byte, bool) { + in.Reset() + if in.Type() != packet.Authentication { + return [4]byte{}, false + } + method, ok := in.Int32() + if !ok { + return [4]byte{}, false + } + if method != 5 { + return [4]byte{}, false + } + var salt [4]byte + ok = in.Bytes(salt[:]) + if !ok { + return salt, false + } + return salt, true +} + +func WriteAuthenticationMD5(out packet.Out, salt [4]byte) { + out.Reset() + out.Type(packet.Authentication) + out.Uint32(5) + out.Bytes(salt[:]) +} diff --git a/lib/pnet/packet/packets/v3.0/authenticationok.go b/lib/pnet/packet/packets/v3.0/authenticationok.go new file mode 100644 index 0000000000000000000000000000000000000000..26d6f8261c6196310a855c544fdb1fbe71d216fd --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/authenticationok.go @@ -0,0 +1,24 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadAuthenticationOk(in packet.In) bool { + in.Reset() + if in.Type() != packet.Authentication { + return false + } + method, ok := in.Int32() + if !ok { + return false + } + if method != 0 { + return false + } + return true +} + +func WriteAuthenticationOk(out packet.Out) { + out.Reset() + out.Type(packet.Authentication) + out.Int32(0) +} diff --git a/lib/pnet/packet/packets/v3.0/authenticationresponse.go b/lib/pnet/packet/packets/v3.0/authenticationresponse.go new file mode 100644 index 0000000000000000000000000000000000000000..221cf19b7496ed8bf083a70e1f751e3dd14758a2 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/authenticationresponse.go @@ -0,0 +1,17 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadAuthenticationResponse(in packet.In) ([]byte, bool) { + in.Reset() + if in.Type() != packet.AuthenticationResponse { + return nil, false + } + return in.Full(), true +} + +func WriteAuthenticationResponse(out packet.Out, resp []byte) { + out.Reset() + out.Type(packet.AuthenticationResponse) + out.Bytes(resp) +} diff --git a/lib/pnet/packet/packets/v3.0/authenticationsasl.go b/lib/pnet/packet/packets/v3.0/authenticationsasl.go new file mode 100644 index 0000000000000000000000000000000000000000..7253194f7a5af9b4b75a4e4090a274d194ebb297 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/authenticationsasl.go @@ -0,0 +1,43 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadAuthenticationSASL(in packet.In) ([]string, bool) { + in.Reset() + if in.Type() != packet.Authentication { + return nil, false + } + + method, ok := in.Int32() + if !ok { + return nil, false + } + + if method != 10 { + return nil, false + } + + var mechanisms []string + for { + mechanism, ok := in.String() + if !ok { + return nil, false + } + if mechanism == "" { + break + } + mechanisms = append(mechanisms, mechanism) + } + + return mechanisms, true +} + +func WriteAuthenticationSASL(out packet.Out, mechanisms []string) { + out.Reset() + out.Type(packet.Authentication) + out.Int32(10) + for _, mechanism := range mechanisms { + out.String(mechanism) + } + out.Uint8(0) +} diff --git a/lib/pnet/packet/packets/v3.0/authenticationsaslcontinue.go b/lib/pnet/packet/packets/v3.0/authenticationsaslcontinue.go new file mode 100644 index 0000000000000000000000000000000000000000..69691e9c1a938188f5a0738ab1474eedd7201642 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/authenticationsaslcontinue.go @@ -0,0 +1,25 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadAuthenticationSASLContinue(in packet.In) ([]byte, bool) { + in.Reset() + if in.Type() != packet.Authentication { + return nil, false + } + method, ok := in.Int32() + if !ok { + return nil, false + } + if method != 11 { + return nil, false + } + return in.Full(), true +} + +func WriteAuthenticationSASLContinue(out packet.Out, resp []byte) { + out.Reset() + out.Type(packet.Authentication) + out.Int32(11) + out.Bytes(resp) +} diff --git a/lib/pnet/packet/packets/v3.0/authenticationsaslfinal.go b/lib/pnet/packet/packets/v3.0/authenticationsaslfinal.go new file mode 100644 index 0000000000000000000000000000000000000000..2a7d3594fe3e85dab9cd974090b48f71e47991f1 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/authenticationsaslfinal.go @@ -0,0 +1,25 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadAuthenticationSASLFinal(in packet.In) ([]byte, bool) { + in.Reset() + if in.Type() != packet.Authentication { + return nil, false + } + method, ok := in.Int32() + if !ok { + return nil, false + } + if method != 12 { + return nil, false + } + return in.Full(), true +} + +func WriteAuthenticationSASLFinal(out packet.Out, resp []byte) { + out.Reset() + out.Type(packet.Authentication) + out.Int32(12) + out.Bytes(resp) +} diff --git a/lib/pnet/packet/packets/v3.0/backendkeydata.go b/lib/pnet/packet/packets/v3.0/backendkeydata.go new file mode 100644 index 0000000000000000000000000000000000000000..c41ebd139031b08875840bf4e7efc65fca044f3a --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/backendkeydata.go @@ -0,0 +1,22 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadBackendKeyData(in packet.In) ([8]byte, bool) { + in.Reset() + if in.Type() != packet.BackendKeyData { + return [8]byte{}, false + } + var cancellationKey [8]byte + ok := in.Bytes(cancellationKey[:]) + if !ok { + return cancellationKey, false + } + return cancellationKey, true +} + +func WriteBackendKeyData(out packet.Out, cancellationKey [8]byte) { + out.Reset() + out.Type(packet.BackendKeyData) + out.Bytes(cancellationKey[:]) +} diff --git a/lib/pnet/packet/packets/v3.0/errorresponse.go b/lib/pnet/packet/packets/v3.0/errorresponse.go new file mode 100644 index 0000000000000000000000000000000000000000..60ee0ca58507ddf24c975981b842848a074bf3d9 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/errorresponse.go @@ -0,0 +1,76 @@ +package packets + +import ( + "pggat2/lib/perror" + "pggat2/lib/pnet/packet" +) + +func ReadErrorResponse(in packet.In) (perror.Error, bool) { + in.Reset() + if in.Type() != packet.ErrorResponse { + return nil, false + } + + var severity perror.Severity + var code perror.Code + var message string + var extra []perror.ExtraField + + for { + typ, ok := in.Uint8() + if !ok { + return nil, false + } + + if typ == 0 { + break + } + + value, ok := in.String() + if !ok { + return nil, false + } + + switch typ { + case 'S': + severity = perror.Severity(value) + case 'C': + code = perror.Code(value) + case 'M': + message = value + default: + extra = append(extra, perror.ExtraField{ + Type: perror.Extra(typ), + Value: value, + }) + } + } + + return perror.New( + severity, + code, + message, + extra..., + ), true +} + +func WriteErrorResponse(out packet.Out, err perror.Error) { + out.Reset() + out.Type(packet.ErrorResponse) + + out.Uint8('S') + out.String(string(err.Severity())) + + out.Uint8('C') + out.String(string(err.Code())) + + out.Uint8('M') + out.String(err.Message()) + + for _, field := range err.Extra() { + out.Uint8(uint8(field.Type)) + out.String(field.Value) + } + + out.Uint8(0) +} diff --git a/lib/pnet/packet/packets/v3.0/negotiateprotocolversion.go b/lib/pnet/packet/packets/v3.0/negotiateprotocolversion.go new file mode 100644 index 0000000000000000000000000000000000000000..4ba56adc24013796b2f44f2346d50f971913c2b8 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/negotiateprotocolversion.go @@ -0,0 +1,40 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadNegotiateProtocolVersion(in packet.In) (minorProtocolVersion int32, unrecognizedOptions []string, ok bool) { + in.Reset() + if in.Type() != packet.NegotiateProtocolVersion { + return + } + minorProtocolVersion, ok = in.Int32() + if !ok { + return + } + var numUnrecognizedOptions int32 + numUnrecognizedOptions, ok = in.Int32() + if !ok { + return + } + unrecognizedOptions = make([]string, 0, numUnrecognizedOptions) + for i := 0; i < int(numUnrecognizedOptions); i++ { + var unrecognizedOption string + unrecognizedOption, ok = in.String() + if !ok { + return + } + unrecognizedOptions = append(unrecognizedOptions, unrecognizedOption) + } + ok = true + return +} + +func WriteNegotiateProtocolVersion(out packet.Out, minorProtocolVersion int32, unrecognizedOptions []string) { + out.Reset() + out.Type(packet.NegotiateProtocolVersion) + out.Int32(minorProtocolVersion) + out.Int32(int32(len(unrecognizedOptions))) + for _, option := range unrecognizedOptions { + out.String(option) + } +} diff --git a/lib/pnet/packet/packets/v3.0/parameterstatus.go b/lib/pnet/packet/packets/v3.0/parameterstatus.go new file mode 100644 index 0000000000000000000000000000000000000000..bac3411b64acd2664128ce47a409aa6067827d0b --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/parameterstatus.go @@ -0,0 +1,26 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadParameterStatus(in packet.In) (key, value string, ok bool) { + in.Reset() + if in.Type() != packet.ParameterStatus { + return + } + key, ok = in.String() + if !ok { + return + } + value, ok = in.String() + if !ok { + return + } + return +} + +func WriteParameterStatus(out packet.Out, key, value string) { + out.Reset() + out.Type(packet.ParameterStatus) + out.String(key) + out.String(value) +} diff --git a/lib/pnet/packet/packets/v3.0/passwordmessage.go b/lib/pnet/packet/packets/v3.0/passwordmessage.go new file mode 100644 index 0000000000000000000000000000000000000000..ac01b9ad5a439eb6d48081363d8da57c67e83816 --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/passwordmessage.go @@ -0,0 +1,21 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadPasswordMessage(in packet.In) (string, bool) { + in.Reset() + if in.Type() != packet.AuthenticationResponse { + return "", false + } + password, ok := in.String() + if !ok { + return "", false + } + return password, true +} + +func WritePasswordMessage(out packet.Out, password string) { + out.Reset() + out.Type(packet.AuthenticationResponse) + out.String(password) +} diff --git a/lib/pnet/packet/packets/v3.0/readyforquery.go b/lib/pnet/packet/packets/v3.0/readyforquery.go new file mode 100644 index 0000000000000000000000000000000000000000..4035428bc7b13259041ef73bcfdaff3e3fbff3ba --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/readyforquery.go @@ -0,0 +1,21 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadReadyForQuery(in packet.In) (byte, bool) { + in.Reset() + if in.Type() != packet.ReadyForQuery { + return 0, false + } + state, ok := in.Uint8() + if !ok { + return 0, false + } + return state, true +} + +func WriteReadyForQuery(out packet.Out, state uint8) { + out.Reset() + out.Type(packet.ReadyForQuery) + out.Uint8(state) +} diff --git a/lib/pnet/packet/packets/v3.0/saslinitialresponse.go b/lib/pnet/packet/packets/v3.0/saslinitialresponse.go new file mode 100644 index 0000000000000000000000000000000000000000..68815d52de4aeb563c1ee733d267eb6ce5616b3e --- /dev/null +++ b/lib/pnet/packet/packets/v3.0/saslinitialresponse.go @@ -0,0 +1,43 @@ +package packets + +import "pggat2/lib/pnet/packet" + +func ReadSASLInitialResponse(in packet.In) (mechanism string, initialResponse []byte, ok bool) { + in.Reset() + if in.Type() != packet.AuthenticationResponse { + return + } + + mechanism, ok = in.String() + if !ok { + return + } + + var initialResponseSize int32 + initialResponseSize, ok = in.Int32() + if !ok { + return + } + if initialResponseSize == -1 { + return + } + + initialResponse = make([]byte, int(initialResponseSize)) + ok = in.Bytes(initialResponse[:]) + if !ok { + return + } + return +} + +func WriteSASLInitialResponse(out packet.Out, mechanism string, initialResponse []byte) { + out.Reset() + out.Type(packet.AuthenticationResponse) + out.String(mechanism) + if initialResponse == nil { + out.Int32(-1) + } else { + out.Int32(int32(len(initialResponse))) + out.Bytes(initialResponse) + } +} diff --git a/lib/pnet/packet/proxy.go b/lib/pnet/packet/proxy.go new file mode 100644 index 0000000000000000000000000000000000000000..706b9163a5726ba942c2b2e0f2bd51d9ee058cdf --- /dev/null +++ b/lib/pnet/packet/proxy.go @@ -0,0 +1,6 @@ +package packet + +func Proxy(out Out, in In) { + out.Type(in.Type()) + out.Bytes(in.Full()) +}