package websocket

import (
	"bufio"
	"context"
	cryptorand "crypto/rand"
	"fmt"
	"io"
	"io/ioutil"
	"math/rand"
	"os"
	"runtime"
	"strconv"
	"sync"
	"sync/atomic"
	"time"

	"golang.org/x/xerrors"
)

// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader, Read
// and SetReadLimit.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See the docs on Reader and CloseRead.
//
// Please be sure to call Close on the connection when you
// are finished with it to release the associated resources.
//
// Every error from Read or Reader will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
type Conn struct {
	subprotocol string
	br          *bufio.Reader
	bw          *bufio.Writer
	// writeBuf is used for masking, its the buffer in bufio.Writer.
	// Only used by the client for masking the bytes in the buffer.
	writeBuf []byte
	closer   io.Closer
	client   bool

	closeOnce sync.Once
	closeErr  error
	closed    chan struct{}

	// writeMsgLock is acquired to write a data message.
	writeMsgLock chan struct{}
	// writeFrameLock is acquired to write a single frame.
	// Effectively meaning whoever holds it gets to write to bw.
	writeFrameLock chan struct{}
	writeHeaderBuf []byte
	writeHeader    *header
	// read limit for a message in bytes.
	msgReadLimit int64

	// messageWriter state.
	writeMsgOpcode opcode
	writeMsgCtx    context.Context
	readMsgLeft    int64

	// Used to ensure the previous reader is read till EOF before allowing
	// a new one.
	previousReader *messageReader
	// readFrameLock is acquired to read from bw.
	readFrameLock     chan struct{}
	readClosed        int64
	readHeaderBuf     []byte
	controlPayloadBuf []byte

	// messageReader state
	readMsgCtx    context.Context
	readMsgHeader header
	readFrameEOF  bool
	readMaskPos   int

	setReadTimeout  chan context.Context
	setWriteTimeout chan context.Context

	activePingsMu sync.Mutex
	activePings   map[string]chan<- struct{}
}

func (c *Conn) init() {
	c.closed = make(chan struct{})

	c.msgReadLimit = 32768

	c.writeMsgLock = make(chan struct{}, 1)
	c.writeFrameLock = make(chan struct{}, 1)

	c.readFrameLock = make(chan struct{}, 1)

	c.setReadTimeout = make(chan context.Context)
	c.setWriteTimeout = make(chan context.Context)

	c.activePings = make(map[string]chan<- struct{})

	c.writeHeaderBuf = makeWriteHeaderBuf()
	c.writeHeader = &header{}
	c.readHeaderBuf = makeReadHeaderBuf()
	c.controlPayloadBuf = make([]byte, maxControlFramePayload)

	runtime.SetFinalizer(c, func(c *Conn) {
		c.close(xerrors.New("connection garbage collected"))
	})

	go c.timeoutLoop()
}

// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
	return c.subprotocol
}

func (c *Conn) close(err error) {
	c.closeOnce.Do(func() {
		runtime.SetFinalizer(c, nil)

		c.closeErr = xerrors.Errorf("websocket closed: %w", err)
		close(c.closed)

		// Have to close after c.closed is closed to ensure any goroutine that wakes up
		// from the connection being closed also sees that c.closed is closed and returns
		// closeErr.
		c.closer.Close()

		// See comment in dial.go
		if c.client {
			// By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer
			// and we can safely return them.
			// Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent
			// a deadlock.
			// As of now, this is in writeFrame, readFramePayload and readHeader.
			c.readFrameLock <- struct{}{}
			returnBufioReader(c.br)

			c.writeFrameLock <- struct{}{}
			returnBufioWriter(c.bw)
		}
	})
}

func (c *Conn) timeoutLoop() {
	readCtx := context.Background()
	writeCtx := context.Background()

	for {
		select {
		case <-c.closed:
			return

		case writeCtx = <-c.setWriteTimeout:
		case readCtx = <-c.setReadTimeout:

		case <-readCtx.Done():
			c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err()))
		case <-writeCtx.Done():
			c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err()))
		}
	}
}

func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
	select {
	case <-ctx.Done():
		var err error
		switch lock {
		case c.writeFrameLock, c.writeMsgLock:
			err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err())
		case c.readFrameLock:
			err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err())
		default:
			panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err()))
		}
		c.close(err)
		return ctx.Err()
	case <-c.closed:
		return c.closeErr
	case lock <- struct{}{}:
		return nil
	}
}

func (c *Conn) releaseLock(lock chan struct{}) {
	// Allow multiple releases.
	select {
	case <-lock:
	default:
	}
}

func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
	for {
		h, err := c.readFrameHeader(ctx)
		if err != nil {
			return header{}, err
		}

		if h.rsv1 || h.rsv2 || h.rsv3 {
			err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
			c.Close(StatusProtocolError, err.Error())
			return header{}, err
		}

		if h.opcode.controlOp() {
			err = c.handleControl(ctx, h)
			if err != nil {
				return header{}, xerrors.Errorf("failed to handle control frame: %w", err)
			}
			continue
		}

		switch h.opcode {
		case opBinary, opText, opContinuation:
			return h, nil
		default:
			err := xerrors.Errorf("received unknown opcode %v", h.opcode)
			c.Close(StatusProtocolError, err.Error())
			return header{}, err
		}
	}
}

func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
	err := c.acquireLock(context.Background(), c.readFrameLock)
	if err != nil {
		return header{}, err
	}
	defer c.releaseLock(c.readFrameLock)

	select {
	case <-c.closed:
		return header{}, c.closeErr
	case c.setReadTimeout <- ctx:
	}

	h, err := readHeader(c.readHeaderBuf, c.br)
	if err != nil {
		select {
		case <-c.closed:
			return header{}, c.closeErr
		case <-ctx.Done():
			err = ctx.Err()
		default:
		}
		err := xerrors.Errorf("failed to read header: %w", err)
		c.releaseLock(c.readFrameLock)
		c.close(err)
		return header{}, err
	}

	select {
	case <-c.closed:
		return header{}, c.closeErr
	case c.setReadTimeout <- context.Background():
	}

	return h, nil
}

func (c *Conn) handleControl(ctx context.Context, h header) error {
	if h.payloadLength > maxControlFramePayload {
		err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength)
		c.Close(StatusProtocolError, err.Error())
		return err
	}

	if !h.fin {
		err := xerrors.Errorf("received fragmented control frame")
		c.Close(StatusProtocolError, err.Error())
		return err
	}

	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
	defer cancel()

	b := c.controlPayloadBuf[:h.payloadLength]
	_, err := c.readFramePayload(ctx, b)
	if err != nil {
		return err
	}

	if h.masked {
		fastXOR(h.maskKey, 0, b)
	}

	switch h.opcode {
	case opPing:
		return c.writePong(b)
	case opPong:
		c.activePingsMu.Lock()
		pong, ok := c.activePings[string(b)]
		c.activePingsMu.Unlock()
		if ok {
			close(pong)
		}
		return nil
	case opClose:
		ce, err := parseClosePayload(b)
		if err != nil {
			c.Close(StatusProtocolError, "received invalid close payload")
			return xerrors.Errorf("received invalid close payload: %w", err)
		}
		c.writeClose(b, xerrors.Errorf("received close frame: %w", ce))
		return c.closeErr
	default:
		panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
	}
}

// Reader waits until there is a WebSocket data message to read
// from the connection.
// It returns the type of the message and a reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// All returned errors will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
//
// You must read from the connection for control frames to be handled.
// If you do not expect any data messages from the peer, call CloseRead.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and then the message
// Read, use time.AfterFunc to cancel the context passed in early.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
	if atomic.LoadInt64(&c.readClosed) == 1 {
		return 0, nil, xerrors.Errorf("websocket connection read closed")
	}

	typ, r, err := c.reader(ctx)
	if err != nil {
		return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
	}
	return typ, r, nil
}

func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
	if c.previousReader != nil && !c.readFrameEOF {
		// The only way we know for sure the previous reader is not yet complete is
		// if there is an active frame not yet fully read.
		// Otherwise, a user may have read the last byte but not the EOF if the EOF
		// is in the next frame so we check for that below.
		return 0, nil, xerrors.Errorf("previous message not read to completion")
	}

	h, err := c.readTillMsg(ctx)
	if err != nil {
		return 0, nil, err
	}

	if c.previousReader != nil && !c.previousReader.eof {
		if h.opcode != opContinuation {
			err := xerrors.Errorf("received new data message without finishing the previous message")
			c.Close(StatusProtocolError, err.Error())
			return 0, nil, err
		}

		if !h.fin || h.payloadLength > 0 {
			return 0, nil, xerrors.Errorf("previous message not read to completion")
		}

		c.previousReader.eof = true

		h, err = c.readTillMsg(ctx)
		if err != nil {
			return 0, nil, err
		}
	} else if h.opcode == opContinuation {
		err := xerrors.Errorf("received continuation frame not after data or text frame")
		c.Close(StatusProtocolError, err.Error())
		return 0, nil, err
	}

	c.readMsgCtx = ctx
	c.readMsgHeader = h
	c.readFrameEOF = false
	c.readMaskPos = 0
	c.readMsgLeft = c.msgReadLimit

	r := &messageReader{
		c: c,
	}
	c.previousReader = r
	return MessageType(h.opcode), r, nil
}

// CloseRead will start a goroutine to read from the connection until it is closed or a data message
// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
// After calling this method, you cannot read any data messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// Use this when you do not want to read data messages from the connection anymore but will
// want to write messages to it.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
	atomic.StoreInt64(&c.readClosed, 1)

	ctx, cancel := context.WithCancel(ctx)
	go func() {
		defer cancel()
		// We use the unexported reader so that we don't get the read closed error.
		c.reader(ctx)
		c.Close(StatusPolicyViolation, "unexpected data message")
	}()
	return ctx
}

// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
	c   *Conn
	eof bool
}

// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
	n, err := r.read(p)
	if err != nil {
		// Have to return io.EOF directly for now, we cannot wrap as xerrors
		// isn't used in stdlib.
		if xerrors.Is(err, io.EOF) {
			return n, io.EOF
		}
		return n, xerrors.Errorf("failed to read: %w", err)
	}
	return n, nil
}

func (r *messageReader) read(p []byte) (int, error) {
	if r.eof {
		return 0, xerrors.Errorf("cannot use EOFed reader")
	}

	if r.c.readMsgLeft <= 0 {
		err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit)
		r.c.Close(StatusMessageTooBig, err.Error())
		return 0, err
	}

	if int64(len(p)) > r.c.readMsgLeft {
		p = p[:r.c.readMsgLeft]
	}

	if r.c.readFrameEOF {
		h, err := r.c.readTillMsg(r.c.readMsgCtx)
		if err != nil {
			return 0, err
		}

		if h.opcode != opContinuation {
			err := xerrors.Errorf("received new data message without finishing the previous message")
			r.c.Close(StatusProtocolError, err.Error())
			return 0, err
		}

		r.c.readMsgHeader = h
		r.c.readFrameEOF = false
		r.c.readMaskPos = 0
	}

	h := r.c.readMsgHeader
	if int64(len(p)) > h.payloadLength {
		p = p[:h.payloadLength]
	}

	n, err := r.c.readFramePayload(r.c.readMsgCtx, p)

	h.payloadLength -= int64(n)
	r.c.readMsgLeft -= int64(n)
	if h.masked {
		r.c.readMaskPos = fastXOR(h.maskKey, r.c.readMaskPos, p)
	}
	r.c.readMsgHeader = h

	if err != nil {
		return n, err
	}

	if h.payloadLength == 0 {
		r.c.readFrameEOF = true

		if h.fin {
			r.eof = true
			return n, io.EOF
		}
	}

	return n, nil
}

func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
	err := c.acquireLock(ctx, c.readFrameLock)
	if err != nil {
		return 0, err
	}
	defer c.releaseLock(c.readFrameLock)

	select {
	case <-c.closed:
		return 0, c.closeErr
	case c.setReadTimeout <- ctx:
	}

	n, err := io.ReadFull(c.br, p)
	if err != nil {
		select {
		case <-c.closed:
			return n, c.closeErr
		case <-ctx.Done():
			err = ctx.Err()
		default:
		}
		err = xerrors.Errorf("failed to read frame payload: %w", err)
		c.releaseLock(c.readFrameLock)
		c.close(err)
		return n, err
	}

	select {
	case <-c.closed:
		return n, c.closeErr
	case c.setReadTimeout <- context.Background():
	}

	return n, err
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
	c.msgReadLimit = n
}

// Read is a convenience method to read a single message from the connection.
//
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
// The docs on Reader apply to this method as well.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
	typ, r, err := c.Reader(ctx)
	if err != nil {
		return 0, nil, err
	}

	b, err := ioutil.ReadAll(r)
	return typ, b, err
}

// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
	wc, err := c.writer(ctx, typ)
	if err != nil {
		return nil, xerrors.Errorf("failed to get writer: %w", err)
	}
	return wc, nil
}

func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
	err := c.acquireLock(ctx, c.writeMsgLock)
	if err != nil {
		return nil, err
	}
	c.writeMsgCtx = ctx
	c.writeMsgOpcode = opcode(typ)
	return &messageWriter{
		c: c,
	}, nil
}

// Write is a convenience method to write a message to the connection.
//
// See the Writer method if you want to stream a message. The docs on Writer
// regarding concurrency also apply to this method.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
	_, err := c.write(ctx, typ, p)
	if err != nil {
		return xerrors.Errorf("failed to write msg: %w", err)
	}
	return nil
}

func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
	err := c.acquireLock(ctx, c.writeMsgLock)
	if err != nil {
		return 0, err
	}
	defer c.releaseLock(c.writeMsgLock)

	n, err := c.writeFrame(ctx, true, opcode(typ), p)
	return n, err
}

// messageWriter enables writing to a WebSocket connection.
type messageWriter struct {
	c      *Conn
	closed bool
}

// Write writes the given bytes to the WebSocket connection.
func (w *messageWriter) Write(p []byte) (int, error) {
	n, err := w.write(p)
	if err != nil {
		return n, xerrors.Errorf("failed to write: %w", err)
	}
	return n, nil
}

func (w *messageWriter) write(p []byte) (int, error) {
	if w.closed {
		return 0, xerrors.Errorf("cannot use closed writer")
	}
	n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p)
	if err != nil {
		return n, xerrors.Errorf("failed to write data frame: %w", err)
	}
	w.c.writeMsgOpcode = opContinuation
	return n, nil
}

// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (w *messageWriter) Close() error {
	err := w.close()
	if err != nil {
		return xerrors.Errorf("failed to close writer: %w", err)
	}
	return nil
}

func (w *messageWriter) close() error {
	if w.closed {
		return xerrors.Errorf("cannot use closed writer")
	}
	w.closed = true

	_, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil)
	if err != nil {
		return xerrors.Errorf("failed to write fin frame: %w", err)
	}

	w.c.releaseLock(w.c.writeMsgLock)
	return nil
}

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
	_, err := c.writeFrame(ctx, true, opcode, p)
	if err != nil {
		return xerrors.Errorf("failed to write control frame: %w", err)
	}
	return nil
}

// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
	err := c.acquireLock(ctx, c.writeFrameLock)
	if err != nil {
		return 0, err
	}
	defer c.releaseLock(c.writeFrameLock)

	select {
	case <-c.closed:
		return 0, c.closeErr
	case c.setWriteTimeout <- ctx:
	}

	c.writeHeader.fin = fin
	c.writeHeader.opcode = opcode
	c.writeHeader.masked = c.client
	c.writeHeader.payloadLength = int64(len(p))

	if c.client {
		_, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:])
		if err != nil {
			return 0, xerrors.Errorf("failed to generate masking key: %w", err)
		}
	}

	n, err := c.realWriteFrame(ctx, *c.writeHeader, p)
	if err != nil {
		return n, err
	}

	// We already finished writing, no need to potentially brick the connection if
	// the context expires.
	select {
	case <-c.closed:
		return n, c.closeErr
	case c.setWriteTimeout <- context.Background():
	}

	return n, nil
}

func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) {
	defer func() {
		if err != nil {
			select {
			case <-c.closed:
				err = c.closeErr
			case <-ctx.Done():
				err = ctx.Err()
			default:
			}

			err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err)
			// We need to release the lock first before closing the connection to ensure
			// the lock can be acquired inside close to ensure no one can access c.bw.
			c.releaseLock(c.writeFrameLock)
			c.close(err)
		}
	}()

	headerBytes := writeHeader(c.writeHeaderBuf, h)
	_, err = c.bw.Write(headerBytes)
	if err != nil {
		return 0, err
	}

	if c.client {
		var keypos int
		for len(p) > 0 {
			if c.bw.Available() == 0 {
				err = c.bw.Flush()
				if err != nil {
					return n, err
				}
			}

			// Start of next write in the buffer.
			i := c.bw.Buffered()

			p2 := p
			if len(p) > c.bw.Available() {
				p2 = p[:c.bw.Available()]
			}

			n2, err := c.bw.Write(p2)
			if err != nil {
				return n, err
			}

			keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])

			p = p[n2:]
			n += n2
		}
	} else {
		n, err = c.bw.Write(p)
		if err != nil {
			return n, err
		}
	}

	if h.fin {
		err = c.bw.Flush()
		if err != nil {
			return n, err
		}
	}

	return n, nil
}

func (c *Conn) writePong(p []byte) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	err := c.writeControl(ctx, opPong, p)
	return err
}

// Close closes the WebSocket connection with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5 seconds.
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes otherwise an internal
// error will be sent to the peer. For this reason, you should avoid
// sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection.
func (c *Conn) Close(code StatusCode, reason string) error {
	err := c.exportedClose(code, reason)
	if err != nil {
		return xerrors.Errorf("failed to close connection: %w", err)
	}
	return nil
}

func (c *Conn) exportedClose(code StatusCode, reason string) error {
	ce := CloseError{
		Code:   code,
		Reason: reason,
	}

	// This function also will not wait for a close frame from the peer like the RFC
	// wants because that makes no sense and I don't think anyone actually follows that.
	// Definitely worth seeing what popular browsers do later.
	p, err := ce.bytes()
	if err != nil {
		fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err)
		ce = CloseError{
			Code: StatusInternalError,
		}
		p, _ = ce.bytes()
	}

	err = c.writeClose(p, xerrors.Errorf("sent close frame: %w", ce))
	if err != nil {
		return err
	}

	if !xerrors.Is(c.closeErr, ce) {
		return c.closeErr
	}

	return nil
}

func (c *Conn) writeClose(p []byte, cerr error) error {
	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
	defer cancel()

	// If this fails, the connection had to have died.
	err := c.writeControl(ctx, opClose, p)
	if err != nil {
		return err
	}

	c.close(cerr)

	return nil
}

func init() {
	rand.Seed(time.Now().UnixNano())
}

// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as otherwise it does
// not read from the connection and relies on Reader to unblock
// when the pong arrives.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
	err := c.ping(ctx)
	if err != nil {
		return xerrors.Errorf("failed to ping: %w", err)
	}
	return nil
}

func (c *Conn) ping(ctx context.Context) error {
	id := rand.Uint64()
	p := strconv.FormatUint(id, 10)

	pong := make(chan struct{})

	c.activePingsMu.Lock()
	c.activePings[p] = pong
	c.activePingsMu.Unlock()

	defer func() {
		c.activePingsMu.Lock()
		delete(c.activePings, p)
		c.activePingsMu.Unlock()
	}()

	err := c.writeControl(ctx, opPing, []byte(p))
	if err != nil {
		return err
	}

	select {
	case <-c.closed:
		return c.closeErr
	case <-ctx.Done():
		err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err())
		c.close(err)
		return err
	case <-pong:
		return nil
	}
}

type writerFunc func(p []byte) (int, error)

func (f writerFunc) Write(p []byte) (int, error) {
	return f(p)
}

// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and stores it in c.writeBuf.
func (c *Conn) extractBufioWriterBuf(w io.Writer) {
	c.bw.Reset(writerFunc(func(p2 []byte) (int, error) {
		c.writeBuf = p2[:cap(p2)]
		return len(p2), nil
	}))

	c.bw.WriteByte(0)
	c.bw.Flush()

	c.bw.Reset(w)
}