// +build !js

package websocket

import (
	"bufio"
	"context"
	"errors"
	"fmt"
	"io"
	"runtime"
	"strconv"
	"sync"
	"sync/atomic"
)

// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int

// MessageType constants.
const (
	// MessageText is for UTF-8 encoded text messages like JSON.
	MessageText MessageType = iota + 1
	// MessageBinary is for binary messages like protobufs.
	MessageBinary
)

// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
type Conn struct {
	subprotocol string
	rwc         io.ReadWriteCloser
	client      bool
	copts       *compressionOptions
	br          *bufio.Reader
	bw          *bufio.Writer

	readTimeout  chan context.Context
	writeTimeout chan context.Context

	// Read state.
	readMu            *mu
	readControlBuf    [maxControlPayload]byte
	msgReader         *msgReader
	readCloseFrameErr error

	// Write state.
	msgWriter    *msgWriter
	writeFrameMu *mu
	writeBuf     []byte
	writeHeader  header

	closed     chan struct{}
	closeMu    sync.Mutex
	closeErr   error
	wroteClose bool

	pingCounter   int32
	activePingsMu sync.Mutex
	activePings   map[string]chan<- struct{}
}

type connConfig struct {
	subprotocol string
	rwc         io.ReadWriteCloser
	client      bool
	copts       *compressionOptions

	br *bufio.Reader
	bw *bufio.Writer
}

func newConn(cfg connConfig) *Conn {
	c := &Conn{
		subprotocol: cfg.subprotocol,
		rwc:         cfg.rwc,
		client:      cfg.client,
		copts:       cfg.copts,

		br: cfg.br,
		bw: cfg.bw,

		readTimeout:  make(chan context.Context),
		writeTimeout: make(chan context.Context),

		closed:      make(chan struct{}),
		activePings: make(map[string]chan<- struct{}),
	}

	c.readMu = newMu(c)
	c.writeFrameMu = newMu(c)

	c.msgReader = newMsgReader(c)

	c.msgWriter = newMsgWriter(c)
	if c.client {
		c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
	}

	runtime.SetFinalizer(c, func(c *Conn) {
		c.close(errors.New("connection garbage collected"))
	})

	go c.timeoutLoop()

	return c
}

// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
	return c.subprotocol
}

func (c *Conn) close(err error) {
	c.closeMu.Lock()
	defer c.closeMu.Unlock()

	if c.isClosed() {
		return
	}
	close(c.closed)
	runtime.SetFinalizer(c, nil)
	c.setCloseErrLocked(err)

	// Have to close after c.closed is closed to ensure any goroutine that wakes up
	// from the connection being closed also sees that c.closed is closed and returns
	// closeErr.
	c.rwc.Close()

	go func() {
		if c.client {
			c.writeFrameMu.Lock(context.Background())
			putBufioWriter(c.bw)
		}
		c.msgWriter.close()

		if c.client {
			c.readMu.Lock(context.Background())
			putBufioReader(c.br)
			c.readMu.Unlock()
		}
		c.msgReader.close()
	}()
}

func (c *Conn) timeoutLoop() {
	readCtx := context.Background()
	writeCtx := context.Background()

	for {
		select {
		case <-c.closed:
			return

		case writeCtx = <-c.writeTimeout:
		case readCtx = <-c.readTimeout:

		case <-readCtx.Done():
			c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
			go c.writeError(StatusPolicyViolation, errors.New("timed out"))
		case <-writeCtx.Done():
			c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
			return
		}
	}
}

func (c *Conn) flate() bool {
	return c.copts != nil
}

// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
	p := atomic.AddInt32(&c.pingCounter, 1)

	err := c.ping(ctx, strconv.Itoa(int(p)))
	if err != nil {
		return fmt.Errorf("failed to ping: %w", err)
	}
	return nil
}

func (c *Conn) ping(ctx context.Context, p string) error {
	pong := make(chan struct{})

	c.activePingsMu.Lock()
	c.activePings[p] = pong
	c.activePingsMu.Unlock()

	defer func() {
		c.activePingsMu.Lock()
		delete(c.activePings, p)
		c.activePingsMu.Unlock()
	}()

	err := c.writeControl(ctx, opPing, []byte(p))
	if err != nil {
		return err
	}

	select {
	case <-c.closed:
		return c.closeErr
	case <-ctx.Done():
		err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
		c.close(err)
		return err
	case <-pong:
		return nil
	}
}

type mu struct {
	c  *Conn
	ch chan struct{}
}

func newMu(c *Conn) *mu {
	return &mu{
		c:  c,
		ch: make(chan struct{}, 1),
	}
}

func (m *mu) Lock(ctx context.Context) error {
	select {
	case <-m.c.closed:
		return m.c.closeErr
	case <-ctx.Done():
		err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
		m.c.close(err)
		return err
	case m.ch <- struct{}{}:
		return nil
	}
}

func (m *mu) TryLock() bool {
	select {
	case m.ch <- struct{}{}:
		return true
	default:
		return false
	}
}

func (m *mu) Unlock() {
	select {
	case <-m.ch:
	default:
	}
}