diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 29ba4628a5d55eeb1caaed4db3f8f0f0dd01e6a5..6a36d596002387f037a496c34bd1bbdf7822f7ff 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -16,7 +16,7 @@ import ( ) type job struct { - client pnet.ReadWriteSender + client pnet.ReadWriter done chan<- struct{} } diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 1f5dbf7b3d6d785ab14ff9d00cd17e2d1ac2cae7..7687d1da1de99e9da20a379cdce3394b46462bf8 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -23,15 +23,15 @@ var ( ErrBadPacket = errors.New("bad packet") ) -func fail(server pnet.ReadWriteSender, err error) { +func fail(server pnet.ReadWriter, err error) { panic(err) } -func failpg(server pnet.ReadWriteSender, err perror.Error) { +func failpg(server pnet.ReadWriter, err perror.Error) { panic(err) } -func authenticationSASLChallenge(server pnet.ReadWriteSender, mechanism sasl.Client) (done bool, status Status) { +func authenticationSASLChallenge(server pnet.ReadWriter, mechanism sasl.Client) (done bool, status Status) { in, err := server.Read() if err != nil { fail(server, err) @@ -61,7 +61,7 @@ func authenticationSASLChallenge(server pnet.ReadWriteSender, mechanism sasl.Cli out := server.Write() packets.WriteAuthenticationResponse(out, response) - err = out.Send() + err = server.Send(out.Finish()) if err != nil { fail(server, err) return false, Fail @@ -82,7 +82,7 @@ func authenticationSASLChallenge(server pnet.ReadWriteSender, mechanism sasl.Cli } } -func authenticationSASL(server pnet.ReadWriteSender, mechanisms []string, username, password string) Status { +func authenticationSASL(server pnet.ReadWriter, mechanisms []string, username, password string) Status { mechanism, err := sasl.NewClient(mechanisms, username, password) if err != nil { fail(server, err) @@ -92,7 +92,7 @@ func authenticationSASL(server pnet.ReadWriteSender, mechanisms []string, userna out := server.Write() packets.WriteSASLInitialResponse(out, mechanism.Name(), initialResponse) - err = out.Send() + err = server.Send(out.Finish()) if err != nil { fail(server, err) return Fail @@ -112,10 +112,10 @@ func authenticationSASL(server pnet.ReadWriteSender, mechanisms []string, userna return Ok } -func authenticationMD5(server pnet.ReadWriteSender, salt [4]byte, username, password string) Status { +func authenticationMD5(server pnet.ReadWriter, salt [4]byte, username, password string) Status { out := server.Write() packets.WritePasswordMessage(out, md5.Encode(username, password, salt)) - err := out.Send() + err := server.Send(out.Finish()) if err != nil { fail(server, err) return Fail @@ -123,10 +123,10 @@ func authenticationMD5(server pnet.ReadWriteSender, salt [4]byte, username, pass return Ok } -func authenticationCleartext(server pnet.ReadWriteSender, password string) Status { +func authenticationCleartext(server pnet.ReadWriter, password string) Status { out := server.Write() packets.WritePasswordMessage(out, password) - err := out.Send() + err := server.Send(out.Finish()) if err != nil { fail(server, err) return Fail @@ -134,7 +134,7 @@ func authenticationCleartext(server pnet.ReadWriteSender, password string) Statu return Ok } -func startup0(server pnet.ReadWriteSender, username, password string) (done bool, status Status) { +func startup0(server pnet.ReadWriter, username, password string) (done bool, status Status) { in, err := server.Read() if err != nil { fail(server, err) @@ -205,7 +205,7 @@ func startup0(server pnet.ReadWriteSender, username, password string) (done bool } } -func startup1(server pnet.ReadWriteSender) (done bool, status Status) { +func startup1(server pnet.ReadWriter) (done bool, status Status) { in, err := server.Read() if err != nil { fail(server, err) @@ -243,7 +243,7 @@ func startup1(server pnet.ReadWriteSender) (done bool, status Status) { } } -func Accept(server pnet.ReadWriteSender) { +func Accept(server pnet.ReadWriter) { // we can re-use the memory for this pkt most of the way down because we don't pass this anywhere out := server.Write() out.Int16(3) @@ -253,7 +253,7 @@ func Accept(server pnet.ReadWriteSender) { out.String("postgres") out.String("") - err := out.Send() + err := server.Send(out.Finish()) if err != nil { fail(server, err) return diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 3f21c9bb9bb38e72bfce7b5ae35bfa1e881873f0..a1e051e5f9fb8e5a020f7bdb2cad3982c7e50323 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -7,7 +7,7 @@ import ( "pggat2/lib/auth/sasl" "pggat2/lib/perror" "pggat2/lib/pnet" - packets "pggat2/lib/pnet/packet/packets/v3.0" + "pggat2/lib/pnet/packet/packets/v3.0" ) type Status int @@ -17,14 +17,14 @@ const ( Ok ) -func fail(client pnet.ReadWriteSender, err perror.Error) { +func fail(client pnet.ReadWriter, err perror.Error) { out := client.Write() packets.WriteErrorResponse(out, err) - _ = out.Send() + _ = client.Send(out.Finish()) panic(err) } -func startup0(client pnet.ReadWriteSender) (done bool, status Status) { +func startup0(client pnet.ReadWriter) (done bool, status Status) { in, err := client.ReadUntyped() if err != nil { fail(client, perror.Wrap(err)) @@ -142,7 +142,7 @@ func startup0(client pnet.ReadWriteSender) (done bool, status Status) { out := client.Write() packets.WriteNegotiateProtocolVersion(out, 0, unsupportedOptions) - err = out.Send() + err = client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return false, Fail @@ -164,7 +164,7 @@ func startup0(client pnet.ReadWriteSender) (done bool, status Status) { return true, Ok } -func authenticationSASLInitial(client pnet.ReadWriteSender, username, password string) (server sasl.Server, resp []byte, done bool, status Status) { +func authenticationSASLInitial(client pnet.ReadWriter, username, password string) (server sasl.Server, resp []byte, done bool, status Status) { // check which authentication method the client wants in, err := client.Read() if err != nil { @@ -191,7 +191,7 @@ func authenticationSASLInitial(client pnet.ReadWriteSender, username, password s return tool, resp, done, Ok } -func authenticationSASLContinue(client pnet.ReadWriteSender, tool sasl.Server) (resp []byte, done bool, status Status) { +func authenticationSASLContinue(client pnet.ReadWriter, tool sasl.Server) (resp []byte, done bool, status Status) { in, err := client.Read() if err != nil { fail(client, perror.Wrap(err)) @@ -211,10 +211,10 @@ func authenticationSASLContinue(client pnet.ReadWriteSender, tool sasl.Server) ( return resp, done, Ok } -func authenticationSASL(client pnet.ReadWriteSender, username, password string) Status { +func authenticationSASL(client pnet.ReadWriter, username, password string) Status { out := client.Write() packets.WriteAuthenticationSASL(out, sasl.Mechanisms) - err := out.Send() + err := client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return Fail @@ -229,7 +229,7 @@ func authenticationSASL(client pnet.ReadWriteSender, username, password string) if done { out = client.Write() packets.WriteAuthenticationSASLFinal(out, resp) - err = out.Send() + err = client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return Fail @@ -238,7 +238,7 @@ func authenticationSASL(client pnet.ReadWriteSender, username, password string) } else { out = client.Write() packets.WriteAuthenticationSASLContinue(out, resp) - err = out.Send() + err = client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return Fail @@ -251,10 +251,10 @@ func authenticationSASL(client pnet.ReadWriteSender, username, password string) return Ok } -func updateParameter(client pnet.ReadWriteSender, name, value string) Status { +func updateParameter(client pnet.ReadWriter, name, value string) Status { out := client.Write() packets.WriteParameterStatus(out, name, value) - err := out.Send() + err := client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return Fail @@ -262,7 +262,7 @@ func updateParameter(client pnet.ReadWriteSender, name, value string) Status { return Ok } -func Accept(client pnet.ReadWriteSender) { +func Accept(client pnet.ReadWriter) { for { done, status := startup0(client) if status != Ok { @@ -281,7 +281,7 @@ func Accept(client pnet.ReadWriteSender) { // send auth Ok out := client.Write() packets.WriteAuthenticationOk(out) - err := out.Send() + err := client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return @@ -349,7 +349,7 @@ func Accept(client pnet.ReadWriteSender) { } out = client.Write() packets.WriteBackendKeyData(out, cancellationKey) - err = out.Send() + err = client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return @@ -358,7 +358,7 @@ func Accept(client pnet.ReadWriteSender) { // send ready for query out = client.Write() packets.WriteReadyForQuery(out, 'I') - err = out.Send() + err = client.Send(out.Finish()) if err != nil { fail(client, perror.Wrap(err)) return diff --git a/lib/middleware/middlewares/eqp/consumer.go b/lib/middleware/middlewares/eqp/consumer.go index 31bff0f47e7833367692f50ec372f20e84cc4b66..b8db773e8cdd65a407793f84e1b4b91a344fb386 100644 --- a/lib/middleware/middlewares/eqp/consumer.go +++ b/lib/middleware/middlewares/eqp/consumer.go @@ -9,10 +9,10 @@ import ( type Consumer struct { preparedStatements map[string]PreparedStatement portals map[string]Portal - inner pnet.ReadWriteSender + inner pnet.ReadWriter } -func MakeConsumer(inner pnet.ReadWriteSender) Consumer { +func MakeConsumer(inner pnet.ReadWriter) Consumer { return Consumer{ preparedStatements: make(map[string]PreparedStatement), portals: make(map[string]Portal), @@ -29,7 +29,7 @@ func (T Consumer) ReadUntyped() (packet.In, error) { } func (T Consumer) Write() packet.Out { - return T.inner.Write().WithSender(T) + return T.inner.Write() } func (T Consumer) WriteByte(b byte) error { @@ -91,4 +91,4 @@ func (T Consumer) Send(typ packet.Type, bytes []byte) error { return T.inner.Send(typ, bytes) } -var _ pnet.ReadWriteSender = Consumer{} +var _ pnet.ReadWriter = Consumer{} diff --git a/lib/middleware/middlewares/eqp/creator.go b/lib/middleware/middlewares/eqp/creator.go index 71a6d34439597ce6561afe0347235f789ad4d4a6..fa92e44603ab63d530f1f364ca4aab622f110d4c 100644 --- a/lib/middleware/middlewares/eqp/creator.go +++ b/lib/middleware/middlewares/eqp/creator.go @@ -9,10 +9,10 @@ import ( type Creator struct { preparedStatements map[string]PreparedStatement portals map[string]Portal - inner pnet.ReadWriteSender + inner pnet.ReadWriter } -func MakeCreator(inner pnet.ReadWriteSender) Creator { +func MakeCreator(inner pnet.ReadWriter) Creator { return Creator{ preparedStatements: make(map[string]PreparedStatement), portals: make(map[string]Portal), @@ -103,4 +103,4 @@ func (T Creator) Send(typ packet.Type, payload []byte) error { return T.inner.Send(typ, payload) } -var _ pnet.ReadWriteSender = Creator{} +var _ pnet.ReadWriter = Creator{} diff --git a/lib/middleware/middlewares/eqp/stealer.go b/lib/middleware/middlewares/eqp/stealer.go index 15521826893664408d35053cab6b04256780094d..ee25753042a9b58a79662e68b71b2500c58b1106 100644 --- a/lib/middleware/middlewares/eqp/stealer.go +++ b/lib/middleware/middlewares/eqp/stealer.go @@ -33,7 +33,7 @@ func (T *Stealer) ReadUntyped() (packet.In, error) { } func (T *Stealer) Write() packet.Out { - return T.consumer.Write().WithSender(T) + return T.consumer.Write() } func (T *Stealer) WriteByte(b byte) error { @@ -42,16 +42,16 @@ func (T *Stealer) WriteByte(b byte) error { func (T *Stealer) bindPreparedStatement(target string, preparedStatement PreparedStatement) error { T.buf.Reset() - out := packet.MakeOut(&T.buf, T.consumer) + out := packet.MakeOut(&T.buf) packets.WriteParse(out, target, preparedStatement.Query, preparedStatement.ParameterDataTypes) - return out.Send() + return T.consumer.Send(out.Finish()) } func (T *Stealer) bindPortal(target string, portal Portal) error { T.buf.Reset() - out := packet.MakeOut(&T.buf, T.consumer) + out := packet.MakeOut(&T.buf) packets.WriteBind(out, target, portal.Source, portal.ParameterFormatCodes, portal.ParameterValues, portal.ResultFormatCodes) - return out.Send() + return T.consumer.Send(out.Finish()) } func (T *Stealer) syncPreparedStatement(target string) error { @@ -112,4 +112,4 @@ func (T *Stealer) Send(typ packet.Type, bytes []byte) error { return T.consumer.Send(typ, bytes) } -var _ pnet.ReadWriteSender = (*Stealer)(nil) +var _ pnet.ReadWriter = (*Stealer)(nil) diff --git a/lib/middleware/middlewares/unread/unread.go b/lib/middleware/middlewares/unread/unread.go index 0b306998c2b30f22b169dbac8e5ad40a317aa797..4b41bb2f82074b8c0fb753f2dbab7c8322031d09 100644 --- a/lib/middleware/middlewares/unread/unread.go +++ b/lib/middleware/middlewares/unread/unread.go @@ -8,28 +8,28 @@ import ( type Unread struct { in packet.In read bool - pnet.ReadWriteSender + pnet.ReadWriter } -func NewUnread(inner pnet.ReadWriteSender) (*Unread, error) { +func NewUnread(inner pnet.ReadWriter) (*Unread, error) { in, err := inner.Read() if err != nil { return nil, err } return &Unread{ - in: in, - ReadWriteSender: inner, + in: in, + ReadWriter: inner, }, nil } -func NewUnreadUntyped(inner pnet.ReadWriteSender) (*Unread, error) { +func NewUnreadUntyped(inner pnet.ReadWriter) (*Unread, error) { in, err := inner.ReadUntyped() if err != nil { return nil, err } return &Unread{ - in: in, - ReadWriteSender: inner, + in: in, + ReadWriter: inner, }, nil } @@ -38,7 +38,7 @@ func (T *Unread) Read() (packet.In, error) { T.read = true return T.in, nil } - return T.ReadWriteSender.Read() + return T.ReadWriter.Read() } func (T *Unread) ReadUntyped() (packet.In, error) { @@ -46,7 +46,7 @@ func (T *Unread) ReadUntyped() (packet.In, error) { T.read = true return T.in, nil } - return T.ReadWriteSender.ReadUntyped() + return T.ReadWriter.ReadUntyped() } -var _ pnet.ReadWriteSender = (*Unread)(nil) +var _ pnet.ReadWriter = (*Unread)(nil) diff --git a/lib/middleware/middlewares/unterminate/unterminate.go b/lib/middleware/middlewares/unterminate/unterminate.go index dcf6b8419882d0b67b88c559209d25695172c628..b08aa27468fae8019ca68a5d5c3376c3aa168f1e 100644 --- a/lib/middleware/middlewares/unterminate/unterminate.go +++ b/lib/middleware/middlewares/unterminate/unterminate.go @@ -8,17 +8,17 @@ import ( ) type Unterminate struct { - pnet.ReadWriteSender + pnet.ReadWriter } -func MakeUnterminate(inner pnet.ReadWriteSender) Unterminate { +func MakeUnterminate(inner pnet.ReadWriter) Unterminate { return Unterminate{ - ReadWriteSender: inner, + ReadWriter: inner, } } func (T Unterminate) Read() (packet.In, error) { - in, err := T.ReadWriteSender.Read() + in, err := T.ReadWriter.Read() if err != nil { return packet.In{}, err } @@ -28,4 +28,4 @@ func (T Unterminate) Read() (packet.In, error) { return in, nil } -var _ pnet.ReadWriteSender = Unterminate{} +var _ pnet.ReadWriter = Unterminate{} diff --git a/lib/pnet/iowriter.go b/lib/pnet/iowriter.go index f262754c33ff41f15c8caba1b34037343e86e040..d8e64de2a07ac01a20492a75964835e9118c7bac 100644 --- a/lib/pnet/iowriter.go +++ b/lib/pnet/iowriter.go @@ -36,7 +36,6 @@ func (T *IOWriter) Write() packet.Out { return packet.MakeOut( &T.buf, - T, ) } diff --git a/lib/pnet/packet/out.go b/lib/pnet/packet/out.go index 088e03d7dc017fd7d27bdcd48dfd590e5dc589b2..84b7bda2ec500a11c6582aeb840b45fec70fc421 100644 --- a/lib/pnet/packet/out.go +++ b/lib/pnet/packet/out.go @@ -21,35 +21,16 @@ func (T *OutBuf) Reset() { } type Out struct { - buf *OutBuf - rev int - sender Sender + buf *OutBuf + rev int } func MakeOut( buf *OutBuf, - sender Sender, ) Out { - if sender == nil { - panic("sender cannot be nil") - } return Out{ - buf: buf, - rev: buf.rev, - sender: sender, - } -} - -func (T Out) WithSender( - sender Sender, -) Out { - if sender == nil { - panic("sender cannot be nil") - } - return Out{ - buf: T.buf, - rev: T.rev, - sender: sender, + buf: buf, + rev: buf.rev, } } @@ -57,6 +38,13 @@ func (T Out) done() bool { return T.rev != T.buf.rev } +func (T Out) Finish() (Type, []byte) { + if T.done() { + panic("Write after Send") + } + return T.buf.typ, T.buf.buf +} + func (T Out) Type(typ Type) { if T.done() { panic("Write after Send") @@ -137,11 +125,3 @@ func (T Out) Bytes(v []byte) { } T.buf.buf = append(T.buf.buf, v...) } - -func (T Out) Send() error { - if T.done() { - panic("Send called twice") - } - T.buf.rev++ - return T.sender.Send(T.buf.typ, T.buf.buf) -} diff --git a/lib/pnet/packet/sender.go b/lib/pnet/packet/sender.go deleted file mode 100644 index 45ec8b90d16756ba434f4c4346883de86bcefee6..0000000000000000000000000000000000000000 --- a/lib/pnet/packet/sender.go +++ /dev/null @@ -1,5 +0,0 @@ -package packet - -type Sender interface { - Send(Type, []byte) error -} diff --git a/lib/pnet/proxy.go b/lib/pnet/proxy.go index c26a46eb0533fcb37bd98f28724ab58509268062..bebee64ba0ae906199c3778b6895aecbc4735e27 100644 --- a/lib/pnet/proxy.go +++ b/lib/pnet/proxy.go @@ -5,7 +5,7 @@ import "pggat2/lib/pnet/packet" func ProxyPacket(writer Writer, in packet.In) error { out := writer.Write() packet.Proxy(out, in) - return out.Send() + return writer.Send(out.Finish()) } func Proxy(writer Writer, reader Reader) error { diff --git a/lib/pnet/readwriter.go b/lib/pnet/readwriter.go index 38f92d65a364dc53d82b97ab21940c346cb0e43e..81d125660a867ab52c11808e17355925925ed053 100644 --- a/lib/pnet/readwriter.go +++ b/lib/pnet/readwriter.go @@ -1,13 +1,6 @@ package pnet -import "pggat2/lib/pnet/packet" - type ReadWriter interface { Reader Writer } - -type ReadWriteSender interface { - ReadWriter - packet.Sender -} diff --git a/lib/pnet/writer.go b/lib/pnet/writer.go index 93a48fa11bee40214ec3cd4b73fd588ab876e72a..bd92c1f566931c54f7e6ae642acace146d3052c0 100644 --- a/lib/pnet/writer.go +++ b/lib/pnet/writer.go @@ -10,4 +10,5 @@ type Writer interface { io.ByteWriter Write() packet.Out + Send(packet.Type, []byte) error }