diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 71a97445ea3d728e45f699befb433df3412f56b0..f88e9e273ba123b1774d9418dc96af2b380a54ab 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -1,8 +1,6 @@ package backends import ( - "errors" - "pggat2/lib/auth/md5" "pggat2/lib/auth/sasl" "pggat2/lib/perror" @@ -10,84 +8,57 @@ import ( packets "pggat2/lib/zap/packets/v3.0" ) -type Status int - -const ( - Fail Status = iota - Ok -) - -var ( - ErrProtocolError = errors.New("protocol error") - ErrBadPacket = errors.New("bad packet") -) - -func fail(server zap.ReadWriter, err error) { - panic(err) -} - -func failpg(server zap.ReadWriter, err perror.Error) { - panic(err) -} - -func authenticationSASLChallenge(server zap.ReadWriter, mechanism sasl.Client) (done bool, status Status) { +func authenticationSASLChallenge(server zap.ReadWriter, mechanism sasl.Client) (done bool, err perror.Error) { packet := zap.NewPacket() defer packet.Done() - err := server.Read(packet) + err = perror.Wrap(server.Read(packet)) if err != nil { - fail(server, err) - return false, Fail + return } read := packet.Read() if read.ReadType() != packets.Authentication { - fail(server, ErrProtocolError) - return false, Fail + err = packets.ErrUnexpectedPacket + return } method, ok := read.ReadInt32() if !ok { - fail(server, ErrBadPacket) - return false, Fail + err = packets.ErrBadFormat + return } switch method { case 11: // challenge - response, err := mechanism.Continue(read.ReadUnsafeRemaining()) - if err != nil { - fail(server, err) - return false, Fail + response, err2 := mechanism.Continue(read.ReadUnsafeRemaining()) + if err2 != nil { + err = perror.Wrap(err2) + return } packets.WriteAuthenticationResponse(packet, response) - err = server.Write(packet) - if err != nil { - fail(server, err) - return false, Fail - } - return false, Ok + err = perror.Wrap(server.Write(packet)) + return case 12: // finish - err = mechanism.Final(read.ReadUnsafeRemaining()) + err = perror.Wrap(mechanism.Final(read.ReadUnsafeRemaining())) if err != nil { - fail(server, err) - return false, Fail + return } - return true, Ok + return true, nil default: - fail(server, ErrProtocolError) - return false, Fail + err = packets.ErrUnexpectedPacket + return } } -func authenticationSASL(server zap.ReadWriter, mechanisms []string, username, password string) Status { +func authenticationSASL(server zap.ReadWriter, mechanisms []string, username, password string) perror.Error { mechanism, err := sasl.NewClient(mechanisms, username, password) if err != nil { - fail(server, err) - return Fail + return perror.Wrap(err) } initialResponse := mechanism.InitialResponse() @@ -96,130 +67,148 @@ func authenticationSASL(server zap.ReadWriter, mechanisms []string, username, pa packets.WriteSASLInitialResponse(packet, mechanism.Name(), initialResponse) err = server.Write(packet) if err != nil { - fail(server, err) - return Fail + return perror.Wrap(err) } // challenge loop for { - done, status := authenticationSASLChallenge(server, mechanism) - if status != Ok { - return status + done, err := authenticationSASLChallenge(server, mechanism) + if err != nil { + return err } if done { break } } - return Ok + return nil } -func authenticationMD5(server zap.ReadWriter, salt [4]byte, username, password string) Status { +func authenticationMD5(server zap.ReadWriter, salt [4]byte, username, password string) perror.Error { packet := zap.NewPacket() defer packet.Done() packets.WritePasswordMessage(packet, md5.Encode(username, password, salt)) err := server.Write(packet) if err != nil { - fail(server, err) - return Fail + return perror.Wrap(err) } - return Ok + return nil } -func authenticationCleartext(server zap.ReadWriter, password string) Status { +func authenticationCleartext(server zap.ReadWriter, password string) perror.Error { packet := zap.NewPacket() defer packet.Done() packets.WritePasswordMessage(packet, password) err := server.Write(packet) if err != nil { - fail(server, err) - return Fail + return perror.Wrap(err) } - return Ok + return nil } -func startup0(server zap.ReadWriter, username, password string) (done bool, status Status) { +func startup0(server zap.ReadWriter, username, password string) (done bool, err perror.Error) { packet := zap.NewPacket() defer packet.Done() - err := server.Read(packet) + err = perror.Wrap(server.Read(packet)) if err != nil { - fail(server, err) - return false, Fail + return } read := packet.Read() switch read.ReadType() { case packets.ErrorResponse: - perr, ok := packets.ReadErrorResponse(&read) + var ok bool + err, ok = packets.ReadErrorResponse(&read) if !ok { - fail(server, ErrBadPacket) - return false, Fail + err = packets.ErrBadFormat } - failpg(server, perr) - return false, Fail + return case packets.Authentication: read2 := read method, ok := read2.ReadInt32() if !ok { - fail(server, ErrBadPacket) - return false, Fail + err = packets.ErrBadFormat + return } // they have more authentication methods than there are pokemon switch method { case 0: // we're good to go, that was easy - return true, Ok + return true, nil case 2: - fail(server, errors.New("kerberos v5 is not supported")) - return false, Fail + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "kerberos v5 is not supported", + ) + return case 3: return false, authenticationCleartext(server, password) case 5: salt, ok := packets.ReadAuthenticationMD5(&read) if !ok { - fail(server, ErrBadPacket) - return false, Fail + err = packets.ErrBadFormat + return } return false, authenticationMD5(server, salt, username, password) case 6: - fail(server, errors.New("scm credential is not supported")) - return false, Fail + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "scm credential is not supported", + ) + return case 7: - fail(server, errors.New("gss is not supported")) - return false, Fail + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "gss is not supported", + ) + return case 9: - fail(server, errors.New("sspi is not supported")) - return false, Fail + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "sspi is not supported", + ) + return case 10: // read list of mechanisms mechanisms, ok := packets.ReadAuthenticationSASL(&read) if !ok { - fail(server, ErrBadPacket) - return false, Fail + err = packets.ErrBadFormat + return } return false, authenticationSASL(server, mechanisms, username, password) default: - fail(server, errors.New("unknown authentication method")) - return false, Fail + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "unknown authentication method", + ) + return } case packets.NegotiateProtocolVersion: // we only support protocol 3.0 for now - fail(server, errors.New("server wanted to negotiate protocol version")) - return false, Fail + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "server wanted to negotiate protocol version", + ) + return default: - fail(server, ErrProtocolError) - return false, Fail + err = packets.ErrUnexpectedPacket + return } } -func startup1(server zap.ReadWriter) (done bool, status Status) { +func startup1(server zap.ReadWriter) (done bool, err perror.Error) { packet := zap.NewPacket() defer packet.Done() - err := server.Read(packet) + err = perror.Wrap(server.Read(packet)) if err != nil { - fail(server, err) - return false, Fail + return } read := packet.Read() @@ -228,33 +217,32 @@ func startup1(server zap.ReadWriter) (done bool, status Status) { var cancellationKey [8]byte ok := read.ReadBytes(cancellationKey[:]) if !ok { - fail(server, ErrBadPacket) - return false, Fail + err = packets.ErrBadFormat + return } // TODO(garet) put cancellation key somewhere - return false, Ok + return false, nil case packets.ParameterStatus: - return false, Ok + return false, nil case packets.ReadyForQuery: - return true, Ok + return true, nil case packets.ErrorResponse: - perr, ok := packets.ReadErrorResponse(&read) + var ok bool + err, ok = packets.ReadErrorResponse(&read) if !ok { - fail(server, ErrBadPacket) - return false, Fail + err = packets.ErrBadFormat } - failpg(server, perr) - return false, Fail + return case packets.NoticeResponse: // TODO(garet) do something with notice - return false, Ok + return false, nil default: - fail(server, ErrProtocolError) - return false, Fail + err = packets.ErrUnexpectedPacket + return false, err } } -func Accept(server zap.ReadWriter, username, password, database string) { +func Accept(server zap.ReadWriter, username, password, database string) perror.Error { if database == "" { database = username } @@ -269,16 +257,15 @@ func Accept(server zap.ReadWriter, username, password, database string) { packet.WriteString(database) packet.WriteString("") - err := server.WriteUntyped(packet) + err := perror.Wrap(server.WriteUntyped(packet)) if err != nil { - fail(server, err) - return + return err } for { - done, status := startup0(server, username, password) - if status != Ok { - return + done, err := startup0(server, username, password) + if err != nil { + return err } if done { break @@ -286,9 +273,9 @@ func Accept(server zap.ReadWriter, username, password, database string) { } for { - done, status := startup1(server) - if status != Ok { - return + done, err := startup1(server) + if err != nil { + return err } if done { break @@ -296,4 +283,5 @@ func Accept(server zap.ReadWriter, username, password, database string) { } // startup complete, connection is ready for queries + return nil } diff --git a/lib/bouncer/bouncers/v2/bouncer.go b/lib/bouncer/bouncers/v2/bouncer.go index 86137ae8083ecc88b08153cb01906ac1d7823113..121cd64fd53f0491b211d7a432806d5e587753d9 100644 --- a/lib/bouncer/bouncers/v2/bouncer.go +++ b/lib/bouncer/bouncers/v2/bouncer.go @@ -308,36 +308,29 @@ func transaction(ctx *bctx.Context) berr.Error { } } -func clientError(ctx *bctx.Context, err error) { +func clientFail(ctx *bctx.Context, err perror.Error) { // send fatal error to client packet := zap.NewPacket() - packets.WriteErrorResponse(packet, perror.New( - perror.FATAL, - perror.ProtocolViolation, - err.Error(), - )) + packets.WriteErrorResponse(packet, err) _ = ctx.ClientWrite(packet) } -func serverError(ctx *bctx.Context, err error) { - panic("server error: " + err.Error()) -} - -func Bounce(client, server zap.ReadWriter) { +func Bounce(client, server zap.ReadWriter) (clientError error, serverError error) { ctx := bctx.MakeContext(client, server) err := transaction(&ctx) if err != nil { switch e := err.(type) { case berr.Client: - clientError(&ctx, e) - if err2 := rserver.Recover(&ctx); err2 != nil { - serverError(&ctx, err2) - } + clientError = e + serverError = rserver.Recover(&ctx) + clientFail(&ctx, perror.Wrap(clientError)) case berr.Server: - serverError(&ctx, e) - clientError(&ctx, e) + serverError = e + clientFail(&ctx, perror.Wrap(serverError)) default: panic("unreachable") } } + + return } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 340295ae1f9a82878de51bb4f8f18cff43b1f1b8..79d5d091d9965163841014c0b440381eeadd2833 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -10,38 +10,23 @@ import ( "pggat2/lib/zap/packets/v3.0" ) -type Status int - -const ( - Fail Status = iota - Ok -) - -func fail(client zap.ReadWriter, err perror.Error) { - packet := zap.NewPacket() - defer packet.Done() - packets.WriteErrorResponse(packet, err) - _ = client.Write(packet) -} - -func startup0(client zap.ReadWriter) (user, database string, done bool, status Status) { +func startup0(client zap.ReadWriter) (user, database string, done bool, err perror.Error) { packet := zap.NewUntypedPacket() defer packet.Done() - err := client.ReadUntyped(packet) + err = perror.Wrap(client.ReadUntyped(packet)) if err != nil { - fail(client, perror.Wrap(err)) return } read := packet.Read() majorVersion, ok := read.ReadUint16() if !ok { - fail(client, packets.ErrBadFormat) + err = packets.ErrBadFormat return } minorVersion, ok := read.ReadUint16() if !ok { - fail(client, packets.ErrBadFormat) + err = packets.ErrBadFormat return } @@ -50,46 +35,37 @@ func startup0(client zap.ReadWriter) (user, database string, done bool, status S switch minorVersion { case 5678: // Cancel - fail(client, perror.New( + err = perror.New( perror.FATAL, perror.FeatureNotSupported, "Cancel is not supported yet", - )) + ) return case 5679: // SSL is not supported yet - err = client.WriteByte('N') - if err != nil { - fail(client, perror.Wrap(err)) - return - } - status = Ok + err = perror.Wrap(client.WriteByte('N')) return case 5680: // GSSAPI is not supported yet - err = client.WriteByte('N') - if err != nil { - fail(client, perror.Wrap(err)) - return - } - status = Ok + err = perror.Wrap(client.WriteByte('N')) return default: - fail(client, perror.New( + err = perror.New( perror.FATAL, perror.ProtocolViolation, "Unknown request code", - )) + ) return } } if majorVersion != 3 { - fail(client, perror.New( + err = perror.New( perror.FATAL, perror.ProtocolViolation, "Unsupported protocol version", - )) + ) + return } var unsupportedOptions []string @@ -97,7 +73,7 @@ func startup0(client zap.ReadWriter) (user, database string, done bool, status S for { key, ok := read.ReadString() if !ok { - fail(client, packets.ErrBadFormat) + err = packets.ErrBadFormat return } if key == "" { @@ -106,7 +82,7 @@ func startup0(client zap.ReadWriter) (user, database string, done bool, status S value, ok := read.ReadString() if !ok { - fail(client, packets.ErrBadFormat) + err = packets.ErrBadFormat return } @@ -116,18 +92,18 @@ func startup0(client zap.ReadWriter) (user, database string, done bool, status S case "database": database = value case "options": - fail(client, perror.New( + err = perror.New( perror.FATAL, perror.FeatureNotSupported, "Startup options are not supported yet", - )) + ) return case "replication": - fail(client, perror.New( + err = perror.New( perror.FATAL, perror.FeatureNotSupported, "Replication mode is not supported yet", - )) + ) return default: if strings.HasPrefix(key, "_pq_.") { @@ -145,136 +121,132 @@ func startup0(client zap.ReadWriter) (user, database string, done bool, status S defer packet.Done() packets.WriteNegotiateProtocolVersion(packet, 0, unsupportedOptions) - err = client.Write(packet) + err = perror.Wrap(client.Write(packet)) if err != nil { - fail(client, perror.Wrap(err)) return } } if user == "" { - fail(client, perror.New( + err = perror.New( perror.FATAL, perror.InvalidAuthorizationSpecification, "User is required", - )) + ) return } if database == "" { database = user } - status = Ok done = true return } -func authenticationSASLInitial(client zap.ReadWriter, username, password string) (server sasl.Server, resp []byte, done bool, status Status) { +func authenticationSASLInitial(client zap.ReadWriter, username, password string) (tool sasl.Server, resp []byte, done bool, err perror.Error) { // check which authentication method the client wants packet := zap.NewPacket() defer packet.Done() - err := client.Read(packet) + err = perror.Wrap(client.Read(packet)) if err != nil { - fail(client, perror.Wrap(err)) - return nil, nil, false, Fail + return } read := packet.Read() mechanism, initialResponse, ok := packets.ReadSASLInitialResponse(&read) if !ok { - fail(client, packets.ErrBadFormat) - return nil, nil, false, Fail + err = packets.ErrBadFormat + return } - tool, err := sasl.NewServer(mechanism, username, password) - if err != nil { - fail(client, perror.Wrap(err)) - return nil, nil, false, Fail + var err2 error + tool, err2 = sasl.NewServer(mechanism, username, password) + if err2 != nil { + err = perror.Wrap(err2) + return } - resp, done, err = tool.InitialResponse(initialResponse) - if err != nil { - fail(client, perror.Wrap(err)) - return nil, nil, false, Fail + resp, done, err2 = tool.InitialResponse(initialResponse) + if err2 != nil { + err = perror.Wrap(err2) + return } - return tool, resp, done, Ok + return } -func authenticationSASLContinue(client zap.ReadWriter, tool sasl.Server) (resp []byte, done bool, status Status) { +func authenticationSASLContinue(client zap.ReadWriter, tool sasl.Server) (resp []byte, done bool, err perror.Error) { packet := zap.NewPacket() defer packet.Done() - err := client.Read(packet) + err = perror.Wrap(client.Read(packet)) if err != nil { - fail(client, perror.Wrap(err)) - return nil, false, Fail + return } read := packet.Read() clientResp, ok := packets.ReadAuthenticationResponse(&read) if !ok { - fail(client, packets.ErrBadFormat) - return nil, false, Fail + err = packets.ErrBadFormat + return } - resp, done, err = tool.Continue(clientResp) - if err != nil { - fail(client, perror.Wrap(err)) - return nil, false, Fail + var err2 error + resp, done, err2 = tool.Continue(clientResp) + if err2 != nil { + err = perror.Wrap(err2) + return } - return resp, done, Ok + return } -func authenticationSASL(client zap.ReadWriter, username, password string) Status { +func authenticationSASL(client zap.ReadWriter, username, password string) perror.Error { packet := zap.NewPacket() defer packet.Done() packets.WriteAuthenticationSASL(packet, sasl.Mechanisms) - err := client.Write(packet) + err := perror.Wrap(client.Write(packet)) if err != nil { - fail(client, perror.Wrap(err)) - return Fail + return err } - tool, resp, done, status := authenticationSASLInitial(client, username, password) + tool, resp, done, err := authenticationSASLInitial(client, username, password) + if err != nil { + return err + } for { - if status != Ok { - return status - } if done { packets.WriteAuthenticationSASLFinal(packet, resp) - err = client.Write(packet) + err = perror.Wrap(client.Write(packet)) if err != nil { - fail(client, perror.Wrap(err)) - return Fail + return err } break } else { packets.WriteAuthenticationSASLContinue(packet, resp) - err = client.Write(packet) + err = perror.Wrap(client.Write(packet)) if err != nil { - fail(client, perror.Wrap(err)) - return Fail + return err } } - resp, done, status = authenticationSASLContinue(client, tool) + resp, done, err = authenticationSASLContinue(client, tool) + if err != nil { + return err + } } - return Ok + return nil } -func updateParameter(pkts *zap.Packets, name, value string) Status { +func updateParameter(pkts *zap.Packets, name, value string) { packet := zap.NewPacket() defer packet.Done() packets.WriteParameterStatus(packet, name, value) pkts.Append(packet) - return Ok } -func Accept(client zap.ReadWriter, getPassword func(user string) string, initialParameterStatus map[string]string) (user string, database string, ok bool) { +func accept(client zap.ReadWriter, getPassword func(user string) (string, bool), initialParameterStatus map[string]string) (user string, database string, err perror.Error) { for { var done bool - var status Status - user, database, done, status = startup0(client) - if status != Ok { + user, database, done, err = startup0(client) + if err != nil { return } if done { @@ -282,8 +254,18 @@ func Accept(client zap.ReadWriter, getPassword func(user string) string, initial } } - status := authenticationSASL(client, user, getPassword(user)) - if status != Ok { + password, ok := getPassword(user) + if !ok { + err = perror.New( + perror.FATAL, + perror.InvalidPassword, + "User not found", + ) + return + } + + err = authenticationSASL(client, user, password) + if err != nil { return } @@ -296,17 +278,14 @@ func Accept(client zap.ReadWriter, getPassword func(user string) string, initial pkts.Append(packet) for name, value := range initialParameterStatus { - status = updateParameter(pkts, name, value) - if status != Ok { - return - } + updateParameter(pkts, name, value) } // send backend key data var cancellationKey [8]byte - _, err := rand.Read(cancellationKey[:]) - if err != nil { - fail(client, perror.Wrap(err)) + _, err2 := rand.Read(cancellationKey[:]) + if err2 != nil { + err = perror.Wrap(err2) return } @@ -319,12 +298,25 @@ func Accept(client zap.ReadWriter, getPassword func(user string) string, initial packets.WriteReadyForQuery(packet, 'I') pkts.Append(packet) - err = client.WriteV(pkts) + err = perror.Wrap(client.WriteV(pkts)) if err != nil { - fail(client, perror.Wrap(err)) return } - ok = true + return +} + +func fail(client zap.ReadWriter, err perror.Error) { + packet := zap.NewPacket() + defer packet.Done() + packets.WriteErrorResponse(packet, err) + _ = client.Write(packet) +} + +func Accept(client zap.ReadWriter, getPassword func(user string) (string, bool), initialParameterStatus map[string]string) (user, database string, err perror.Error) { + user, database, err = accept(client, getPassword, initialParameterStatus) + if err != nil { + fail(client, err) + } return } diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go index 6cc1d0d8db392d645a37f36759b041195e6b5b44..561f70deccf19c0b476bde2cb1a60d35d496e1ad 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pooler.go @@ -54,14 +54,14 @@ func (T *Pooler) Serve(client zap.ReadWriter) { unterminate.Unterminate, ) - username, database, ok := frontends.Accept(client, func(username string) string { + username, database, err := frontends.Accept(client, func(username string) (string, bool) { user := T.GetUser(username) if user == nil { - return "" + return "", false } - return user.GetPassword() + return user.GetPassword(), true }, DefaultParameterStatus) - if !ok { + if err != nil { return } @@ -89,9 +89,6 @@ func (T *Pooler) ListenAndServe(address string) error { if err != nil { return err } - go T.Serve(zap.CombinedReadWriter{ - Reader: zap.IOReader{Reader: conn}, - Writer: zap.IOWriter{Writer: conn}, - }) + go T.Serve(zap.WrapIOReadWriter(conn)) } } diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index fa7d3a48720fee53153996a2667bad347194d3f9..811229994aee93e0ff1e7e7ee3150d7230525592 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -51,10 +51,14 @@ func (T *Pool) release(server zap.ReadWriter) { func (T *Pool) Serve(client zap.ReadWriter) { server := T.acquire() - defer T.release(server) for { - // TODO(garet) test if client has disconnected - bouncers.Bounce(client, server) + clientErr, serverErr := bouncers.Bounce(client, server) + if clientErr != nil || serverErr != nil { + if serverErr == nil { + T.release(server) + } + break + } } } @@ -65,11 +69,12 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) { // TODO(garet) do something here continue } - rw := zap.CombinedReadWriter{ - Reader: zap.IOReader{Reader: conn}, - Writer: zap.IOWriter{Writer: conn}, + rw := zap.WrapIOReadWriter(conn) + err2 := backends.Accept(rw, recipe.User, recipe.Password, recipe.Database) + if err2 != nil { + // TODO(garet) do something here + continue } - backends.Accept(rw, recipe.User, recipe.Password, recipe.Database) T.release(rw) } } diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go index 33b8ca1627e4382d2bcc032dd71052fe9f4bde4a..43c4bd83c0955502ae26febeb68bb02eb24098c9 100644 --- a/lib/gat/pools/transaction/conn.go +++ b/lib/gat/pools/transaction/conn.go @@ -18,7 +18,10 @@ func (T Conn) Do(_ rob.Constraints, work any) { job := work.(Work) job.ps.SetServer(T.ps) T.eqp.SetClient(job.eqp) - bouncers.Bounce(job.rw, T.rw) + _, backendError := bouncers.Bounce(job.rw, T.rw) + if backendError != nil { + // TODO(garet) remove from pool + } return } diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go index f97a27f6b5adc8c69dcb2a06edb3d74233c4a132..eec49aeac1c1354fc3847bb1996798ee570264b8 100644 --- a/lib/gat/pools/transaction/pool.go +++ b/lib/gat/pools/transaction/pool.go @@ -32,10 +32,7 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) { // TODO(garet) do something here continue } - rw := zap.CombinedReadWriter{ - Reader: zap.IOReader{Reader: conn}, - Writer: zap.IOWriter{Writer: conn}, - } + rw := zap.WrapIOReadWriter(conn) eqps := eqp.NewServer() pss := ps.NewServer() mw := interceptor.NewInterceptor( @@ -43,7 +40,11 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) { eqps, pss, ) - backends.Accept(mw, recipe.User, recipe.Password, recipe.Database) + err2 := backends.Accept(mw, recipe.User, recipe.Password, recipe.Database) + if err2 != nil { + // TODO(garet) do something here + continue + } T.s.AddSink(0, Conn{ rw: mw, eqp: eqps, diff --git a/lib/zap/reader.go b/lib/zap/reader.go index 61bc8de86c31ebec23263472c27205ed09bdc0fd..2ba2d2b52f8fb6bcc52ba39a57369f8e6143c355 100644 --- a/lib/zap/reader.go +++ b/lib/zap/reader.go @@ -8,27 +8,35 @@ type Reader interface { ReadUntyped(*UntypedPacket) error } -type IOReader struct { - Reader io.Reader +func WrapIOReader(readCloser io.ReadCloser) Reader { + return ioReader{ + reader: readCloser, + closer: readCloser, + } +} + +type ioReader struct { + reader io.Reader + closer io.Closer } -func (T IOReader) ReadByte() (byte, error) { +func (T ioReader) ReadByte() (byte, error) { var res = []byte{0} - _, err := io.ReadFull(T.Reader, res) + _, err := io.ReadFull(T.reader, res) if err != nil { return 0, err } return res[0], err } -func (T IOReader) Read(packet *Packet) error { - _, err := packet.ReadFrom(T.Reader) +func (T ioReader) Read(packet *Packet) error { + _, err := packet.ReadFrom(T.reader) return err } -func (T IOReader) ReadUntyped(packet *UntypedPacket) error { - _, err := packet.ReadFrom(T.Reader) +func (T ioReader) ReadUntyped(packet *UntypedPacket) error { + _, err := packet.ReadFrom(T.reader) return err } -var _ Reader = IOReader{} +var _ Reader = ioReader{} diff --git a/lib/zap/readwriter.go b/lib/zap/readwriter.go index f98d5197cee6aeafd6769c74025f517c0c09a6ae..0bf3e27ca22af822c775e43b71cfd01dde637c1c 100644 --- a/lib/zap/readwriter.go +++ b/lib/zap/readwriter.go @@ -1,5 +1,7 @@ package zap +import "io" + type ReadWriter interface { Reader Writer @@ -9,3 +11,10 @@ type CombinedReadWriter struct { Reader Writer } + +func WrapIOReadWriter(readWriteCloser io.ReadWriteCloser) ReadWriter { + return CombinedReadWriter{ + Reader: WrapIOReader(readWriteCloser), + Writer: WrapIOWriter(readWriteCloser), + } +} diff --git a/lib/zap/writer.go b/lib/zap/writer.go index 26f21b705b8d2ae5ed2ecf338ac6251e76eeef34..24e571bc1b98fb31d3b339faec4428a6b464d8df 100644 --- a/lib/zap/writer.go +++ b/lib/zap/writer.go @@ -11,28 +11,36 @@ type Writer interface { WriteV(*Packets) error } -type IOWriter struct { - Writer io.Writer +func WrapIOWriter(writeCloser io.WriteCloser) Writer { + return ioWriter{ + writer: writeCloser, + closer: writeCloser, + } } -func (T IOWriter) WriteByte(b byte) error { - _, err := T.Writer.Write([]byte{b}) +type ioWriter struct { + writer io.Writer + closer io.Closer +} + +func (T ioWriter) WriteByte(b byte) error { + _, err := T.writer.Write([]byte{b}) return err } -func (T IOWriter) Write(packet *Packet) error { - _, err := packet.WriteTo(T.Writer) +func (T ioWriter) Write(packet *Packet) error { + _, err := packet.WriteTo(T.writer) return err } -func (T IOWriter) WriteUntyped(packet *UntypedPacket) error { - _, err := packet.WriteTo(T.Writer) +func (T ioWriter) WriteUntyped(packet *UntypedPacket) error { + _, err := packet.WriteTo(T.writer) return err } -func (T IOWriter) WriteV(packets *Packets) error { - _, err := packets.WriteTo(T.Writer) +func (T ioWriter) WriteV(packets *Packets) error { + _, err := packets.WriteTo(T.writer) return err } -var _ Writer = IOWriter{} +var _ Writer = ioWriter{}