diff --git a/lib/zap/conn.go b/lib/zap/conn.go index 6ef64daf686f9bb602cf1e046a4cfb1c377a3500..da3040fdfc0d1e004b48f122e0685a656301362d 100644 --- a/lib/zap/conn.go +++ b/lib/zap/conn.go @@ -9,6 +9,8 @@ import ( type Conn struct { conn net.Conn + w io.Writer + r io.Reader buffers net.Buffers @@ -18,18 +20,24 @@ type Conn struct { func WrapNetConn(conn net.Conn) *Conn { return &Conn{ conn: conn, + w: conn, + r: conn, } } func (T *Conn) EnableSSLClient(config *tls.Config) error { sslConn := tls.Client(T.conn, config) T.conn = sslConn + T.w = sslConn + T.r = sslConn return sslConn.Handshake() } func (T *Conn) EnableSSLServer(config *tls.Config) error { sslConn := tls.Server(T.conn, config) T.conn = sslConn + T.w = sslConn + T.r = sslConn return sslConn.Handshake() } @@ -38,7 +46,7 @@ func (T *Conn) flush() error { return nil } - _, err := T.buffers.WriteTo(T.conn) + _, err := T.buffers.WriteTo(T.w) T.buffers = T.buffers[0:] return err } @@ -47,7 +55,7 @@ func (T *Conn) ReadByte() (byte, error) { if err := T.flush(); err != nil { return 0, err } - _, err := io.ReadFull(T.conn, T.byteBuf[:]) + _, err := io.ReadFull(T.r, T.byteBuf[:]) if err != nil { return 0, err } @@ -60,12 +68,12 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } packet := NewPacket(0) if typed { - _, err := io.ReadFull(T.conn, packet) + _, err := io.ReadFull(T.r, packet) if err != nil { return nil, err } } else { - _, err := io.ReadFull(T.conn, packet[1:]) + _, err := io.ReadFull(T.r, packet[1:]) if err != nil { return nil, err } @@ -73,7 +81,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { length := binary.BigEndian.Uint32(packet[1:]) packet = packet.Grow(int(length) - 4) - _, err := io.ReadFull(T.conn, packet.Payload()) + _, err := io.ReadFull(T.r, packet.Payload()) if err != nil { return nil, err }