diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index b5f166da30d2ed90e1c26fef50d97a7a23f181c3..29ba4628a5d55eeb1caaed4db3f8f0f0dd01e6a5 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -5,7 +5,7 @@ import ( "net/http" _ "net/http/pprof" - "pggat2/lib/backend/backends/v0" + "pggat2/lib/bouncer/backends/v0" "pggat2/lib/bouncer/bouncers/v0" "pggat2/lib/bouncer/frontends/v0" "pggat2/lib/middleware/middlewares/unread" @@ -25,15 +25,12 @@ func testServer(r rob.Scheduler) { if err != nil { panic(err) } - server := backends.NewServer(conn) - if server == nil { - panic("failed to connect to server") - } - + server := pnet.MakeIOReadWriter(conn) + backends.Accept(&server) sink := r.NewSink(0) for { j := sink.Read().(job) - bouncers.Bounce(j.client, server) + bouncers.Bounce(j.client, &server) select { case j.done <- struct{}{}: default: diff --git a/lib/backend/backends/v0/server.go b/lib/backend/backends/v0/server.go deleted file mode 100644 index dc979402b053b1e8be86922269ce0e8325b7664a..0000000000000000000000000000000000000000 --- a/lib/backend/backends/v0/server.go +++ /dev/null @@ -1,300 +0,0 @@ -package backends - -import ( - "fmt" - "net" - - "pggat2/lib/auth/md5" - "pggat2/lib/auth/sasl" - "pggat2/lib/backend" - "pggat2/lib/perror" - "pggat2/lib/pnet" - "pggat2/lib/pnet/packet" - "pggat2/lib/pnet/packet/packets/v3.0" -) - -type Server struct { - conn net.Conn - - reader pnet.IOReader - writer pnet.IOWriter - - cancellationKey [8]byte -} - -func NewServer(conn net.Conn) *Server { - server := &Server{ - conn: conn, - reader: pnet.MakeIOReader(conn), - writer: pnet.MakeIOWriter(conn), - } - err := server.accept() - if err != nil { - panic(fmt.Sprint("failed to connect to server: ", err)) - return nil - } - return server -} - -func (T *Server) authenticationSASLChallenge(mechanism sasl.Client) (bool, perror.Error) { - in, err := T.Read() - if err != nil { - return false, perror.Wrap(err) - } - - if in.Type() != packet.Authentication { - return false, pnet.ErrProtocolError - } - - method, ok := in.Int32() - if !ok { - return false, pnet.ErrBadPacketFormat - } - - switch method { - case 11: - // challenge - response, err := mechanism.Continue(in.Remaining()) - if err != nil { - return false, perror.Wrap(err) - } - - out := T.Write() - packets.WriteAuthenticationResponse(out, response) - - err = out.Send() - return false, perror.Wrap(err) - case 12: - // finish - err = mechanism.Final(in.Remaining()) - if err != nil { - return false, perror.Wrap(err) - } - - return true, nil - default: - return false, pnet.ErrProtocolError - } -} - -func (T *Server) authenticationSASL(mechanisms []string, username, password string) perror.Error { - mechanism, err := sasl.NewClient(mechanisms, username, password) - if err != nil { - return perror.Wrap(err) - } - initialResponse := mechanism.InitialResponse() - - out := T.Write() - packets.WriteSASLInitialResponse(out, mechanism.Name(), initialResponse) - err = out.Send() - if err != nil { - return perror.Wrap(err) - } - - // challenge loop - for { - done, err := T.authenticationSASLChallenge(mechanism) - if err != nil { - return err - } - if done { - break - } - } - - return nil -} - -func (T *Server) authenticationMD5(salt [4]byte, username, password string) perror.Error { - out := T.Write() - packets.WritePasswordMessage(out, md5.Encode(username, password, salt)) - return perror.Wrap(out.Send()) -} - -func (T *Server) authenticationCleartext(password string) perror.Error { - out := T.Write() - packets.WritePasswordMessage(out, password) - return perror.Wrap(out.Send()) -} - -func (T *Server) startup0(username, password string) (bool, perror.Error) { - in, err := T.Read() - if err != nil { - return false, perror.Wrap(err) - } - - switch in.Type() { - case packet.ErrorResponse: - perr, ok := packets.ReadErrorResponse(in) - if !ok { - return false, pnet.ErrBadPacketFormat - } - return false, perr - case packet.Authentication: - method, ok := in.Int32() - if !ok { - return false, pnet.ErrBadPacketFormat - } - // they have more authentication methods than there are pokemon - switch method { - case 0: - // we're good to go, that was easy - return true, nil - case 2: - return false, perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "kerberos v5 is not supported", - ) - case 3: - return false, T.authenticationCleartext(password) - case 5: - salt, ok := packets.ReadAuthenticationMD5(in) - if !ok { - return false, pnet.ErrBadPacketFormat - } - return false, T.authenticationMD5(salt, username, password) - case 6: - return false, perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "scm credential is not supported", - ) - case 7: - return false, perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "gss is not supported", - ) - case 9: - return false, perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "sspi is not supported", - ) - case 10: - // read list of mechanisms - mechanisms, ok := packets.ReadAuthenticationSASL(in) - if !ok { - return false, pnet.ErrBadPacketFormat - } - - return false, T.authenticationSASL(mechanisms, username, password) - default: - return false, perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "unknown authentication method", - ) - } - case packet.NegotiateProtocolVersion: - // we only support protocol 3.0 for now - return false, perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "server wanted to negotiate protocol version", - ) - default: - return false, pnet.ErrProtocolError - } -} - -func (T *Server) parameterStatus(in packet.In) perror.Error { - // TODO(garet) do something with parameters - return nil -} - -func (T *Server) startup1() (bool, perror.Error) { - in, err := T.Read() - if err != nil { - return false, perror.Wrap(err) - } - - switch in.Type() { - case packet.BackendKeyData: - ok := in.Bytes(T.cancellationKey[:]) - if !ok { - return false, pnet.ErrBadPacketFormat - } - return false, nil - case packet.ParameterStatus: - err := T.parameterStatus(in) - return false, err - case packet.ReadyForQuery: - return true, nil - case packet.ErrorResponse: - err, ok := packets.ReadErrorResponse(in) - if !ok { - return false, pnet.ErrBadPacketFormat - } - return false, err - case packet.NoticeResponse: - // TODO(garet) do something with notice - return false, nil - default: - return false, pnet.ErrProtocolError - } -} - -func (T *Server) accept() perror.Error { - // we can re-use the memory for this pkt most of the way down because we don't pass this anywhere - out := T.Write() - out.Int16(3) - out.Int16(0) - // TODO(garet) don't hardcode username and password - out.String("user") - out.String("postgres") - out.String("") - - err := out.Send() - if err != nil { - return perror.Wrap(err) - } - - for { - // TODO(garet) don't hardcode username and password - done, err := T.startup0("postgres", "password") - if err != nil { - return err - } - if done { - break - } - } - - for { - done, err := T.startup1() - if err != nil { - return err - } - if done { - break - } - } - - // startup complete, connection is ready for queries - - return nil -} - -func (T *Server) Write() packet.Out { - return T.writer.Write() -} - -func (T *Server) WriteByte(b byte) error { - return T.writer.WriteByte(b) -} - -func (T *Server) Send(typ packet.Type, payload []byte) error { - return T.writer.Send(typ, payload) -} - -func (T *Server) Read() (packet.In, error) { - return T.reader.Read() -} - -func (T *Server) ReadUntyped() (packet.In, error) { - return T.reader.ReadUntyped() -} - -var _ backend.Server = (*Server)(nil) diff --git a/lib/backend/server.go b/lib/backend/server.go deleted file mode 100644 index 5656030dfbc8f51bed082f89098af0c882aba0cc..0000000000000000000000000000000000000000 --- a/lib/backend/server.go +++ /dev/null @@ -1,9 +0,0 @@ -package backend - -import ( - "pggat2/lib/pnet" -) - -type Server interface { - pnet.ReadWriteSender -} diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go new file mode 100644 index 0000000000000000000000000000000000000000..1f5dbf7b3d6d785ab14ff9d00cd17e2d1ac2cae7 --- /dev/null +++ b/lib/bouncer/backends/v0/accept.go @@ -0,0 +1,284 @@ +package backends + +import ( + "errors" + + "pggat2/lib/auth/md5" + "pggat2/lib/auth/sasl" + "pggat2/lib/perror" + "pggat2/lib/pnet" + "pggat2/lib/pnet/packet" + packets "pggat2/lib/pnet/packet/packets/v3.0" +) + +type Status int + +const ( + Fail Status = iota + Ok +) + +var ( + ErrProtocolError = errors.New("protocol error") + ErrBadPacket = errors.New("bad packet") +) + +func fail(server pnet.ReadWriteSender, err error) { + panic(err) +} + +func failpg(server pnet.ReadWriteSender, err perror.Error) { + panic(err) +} + +func authenticationSASLChallenge(server pnet.ReadWriteSender, mechanism sasl.Client) (done bool, status Status) { + in, err := server.Read() + if err != nil { + fail(server, err) + return false, Fail + } + + if in.Type() != packet.Authentication { + fail(server, ErrProtocolError) + return false, Fail + } + + method, ok := in.Int32() + if !ok { + fail(server, ErrBadPacket) + return false, Fail + } + + switch method { + case 11: + // challenge + response, err := mechanism.Continue(in.Remaining()) + if err != nil { + fail(server, err) + return false, Fail + } + + out := server.Write() + packets.WriteAuthenticationResponse(out, response) + + err = out.Send() + if err != nil { + fail(server, err) + return false, Fail + } + return false, Ok + case 12: + // finish + err = mechanism.Final(in.Remaining()) + if err != nil { + fail(server, err) + return false, Fail + } + + return true, Ok + default: + fail(server, ErrProtocolError) + return false, Fail + } +} + +func authenticationSASL(server pnet.ReadWriteSender, mechanisms []string, username, password string) Status { + mechanism, err := sasl.NewClient(mechanisms, username, password) + if err != nil { + fail(server, err) + return Fail + } + initialResponse := mechanism.InitialResponse() + + out := server.Write() + packets.WriteSASLInitialResponse(out, mechanism.Name(), initialResponse) + err = out.Send() + if err != nil { + fail(server, err) + return Fail + } + + // challenge loop + for { + done, status := authenticationSASLChallenge(server, mechanism) + if status != Ok { + return status + } + if done { + break + } + } + + return Ok +} + +func authenticationMD5(server pnet.ReadWriteSender, salt [4]byte, username, password string) Status { + out := server.Write() + packets.WritePasswordMessage(out, md5.Encode(username, password, salt)) + err := out.Send() + if err != nil { + fail(server, err) + return Fail + } + return Ok +} + +func authenticationCleartext(server pnet.ReadWriteSender, password string) Status { + out := server.Write() + packets.WritePasswordMessage(out, password) + err := out.Send() + if err != nil { + fail(server, err) + return Fail + } + return Ok +} + +func startup0(server pnet.ReadWriteSender, username, password string) (done bool, status Status) { + in, err := server.Read() + if err != nil { + fail(server, err) + return false, Fail + } + + switch in.Type() { + case packet.ErrorResponse: + perr, ok := packets.ReadErrorResponse(in) + if !ok { + fail(server, ErrBadPacket) + return false, Fail + } + failpg(server, perr) + return false, Fail + case packet.Authentication: + method, ok := in.Int32() + if !ok { + fail(server, ErrBadPacket) + return false, Fail + } + // they have more authentication methods than there are pokemon + switch method { + case 0: + // we're good to go, that was easy + return true, Ok + case 2: + fail(server, errors.New("kerberos v5 is not supported")) + return false, Fail + case 3: + return false, authenticationCleartext(server, password) + case 5: + salt, ok := packets.ReadAuthenticationMD5(in) + if !ok { + fail(server, ErrBadPacket) + return false, Fail + } + return false, authenticationMD5(server, salt, username, password) + case 6: + fail(server, errors.New("scm credential is not supported")) + return false, Fail + case 7: + fail(server, errors.New("gss is not supported")) + return false, Fail + case 9: + fail(server, errors.New("sspi is not supported")) + return false, Fail + case 10: + // read list of mechanisms + mechanisms, ok := packets.ReadAuthenticationSASL(in) + if !ok { + fail(server, ErrBadPacket) + return false, Fail + } + + return false, authenticationSASL(server, mechanisms, username, password) + default: + fail(server, errors.New("unknown authentication method")) + return false, Fail + } + case packet.NegotiateProtocolVersion: + // we only support protocol 3.0 for now + fail(server, errors.New("server wanted to negotiate protocol version")) + return false, Fail + default: + fail(server, ErrProtocolError) + return false, Fail + } +} + +func startup1(server pnet.ReadWriteSender) (done bool, status Status) { + in, err := server.Read() + if err != nil { + fail(server, err) + return false, Fail + } + + switch in.Type() { + case packet.BackendKeyData: + var cancellationKey [8]byte + ok := in.Bytes(cancellationKey[:]) + if !ok { + fail(server, ErrBadPacket) + return false, Fail + } + // TODO(garet) put cancellation key somewhere + return false, Ok + case packet.ParameterStatus: + return false, Ok + case packet.ReadyForQuery: + return true, Ok + case packet.ErrorResponse: + perr, ok := packets.ReadErrorResponse(in) + if !ok { + fail(server, ErrBadPacket) + return false, Fail + } + failpg(server, perr) + return false, Fail + case packet.NoticeResponse: + // TODO(garet) do something with notice + return false, Ok + default: + fail(server, ErrProtocolError) + return false, Fail + } +} + +func Accept(server pnet.ReadWriteSender) { + // we can re-use the memory for this pkt most of the way down because we don't pass this anywhere + out := server.Write() + out.Int16(3) + out.Int16(0) + // TODO(garet) don't hardcode username and password + out.String("user") + out.String("postgres") + out.String("") + + err := out.Send() + if err != nil { + fail(server, err) + return + } + + for { + // TODO(garet) don't hardcode username and password + done, status := startup0(server, "postgres", "password") + if status != Ok { + return + } + if done { + break + } + } + + for { + done, status := startup1(server) + if status != Ok { + return + } + if done { + break + } + } + + // startup complete, connection is ready for queries +} diff --git a/lib/middleware/middlewares/unread/unread.go b/lib/middleware/middlewares/unread/unread.go index 5624b3fa148956b937027fafcd7982ee0992cf66..0b306998c2b30f22b169dbac8e5ad40a317aa797 100644 --- a/lib/middleware/middlewares/unread/unread.go +++ b/lib/middleware/middlewares/unread/unread.go @@ -6,9 +6,9 @@ import ( ) type Unread struct { - in packet.In - read bool - inner pnet.ReadWriteSender + in packet.In + read bool + pnet.ReadWriteSender } func NewUnread(inner pnet.ReadWriteSender) (*Unread, error) { @@ -17,8 +17,8 @@ func NewUnread(inner pnet.ReadWriteSender) (*Unread, error) { return nil, err } return &Unread{ - in: in, - inner: inner, + in: in, + ReadWriteSender: inner, }, nil } @@ -28,8 +28,8 @@ func NewUnreadUntyped(inner pnet.ReadWriteSender) (*Unread, error) { return nil, err } return &Unread{ - in: in, - inner: inner, + in: in, + ReadWriteSender: inner, }, nil } @@ -38,7 +38,7 @@ func (T *Unread) Read() (packet.In, error) { T.read = true return T.in, nil } - return T.inner.Read() + return T.ReadWriteSender.Read() } func (T *Unread) ReadUntyped() (packet.In, error) { @@ -46,19 +46,7 @@ func (T *Unread) ReadUntyped() (packet.In, error) { T.read = true return T.in, nil } - return T.inner.ReadUntyped() -} - -func (T *Unread) Write() packet.Out { - return T.inner.Write() -} - -func (T *Unread) WriteByte(b byte) error { - return T.inner.WriteByte(b) -} - -func (T *Unread) Send(typ packet.Type, bytes []byte) error { - return T.inner.Send(typ, bytes) + return T.ReadWriteSender.ReadUntyped() } var _ pnet.ReadWriteSender = (*Unread)(nil)