good morning!!!!

Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
header.go 2.93 KiB
package websocket

import (
	"encoding/binary"
	"fmt"
	"io"
	"math"

	"golang.org/x/xerrors"
)

// First byte contains fin, rsv1, rsv2, rsv3.
// Second byte contains mask flag and payload len.
// Next 8 bytes are the maximum extended payload length.
// Last 4 bytes are the mask key.
// https://tools.ietf.org/html/rfc6455#section-5.2
const maxHeaderSize = 1 + 1 + 8 + 4

// header represents a WebSocket frame header.
// See https://tools.ietf.org/html/rfc6455#section-5.2
type header struct {
	fin    bool
	rsv1   bool
	rsv2   bool
	rsv3   bool
	opcode opcode

	payloadLength int64

	masked  bool
	maskKey [4]byte
}

func makeWriteHeaderBuf() []byte {
	return make([]byte, maxHeaderSize)
}

// bytes returns the bytes of the header.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func writeHeader(b []byte, h header) []byte {
	if b == nil {
		b = makeWriteHeaderBuf()
	}

	b = b[:2]
	b[0] = 0

	if h.fin {
		b[0] |= 1 << 7
	}
	if h.rsv1 {
		b[0] |= 1 << 6
	}
	if h.rsv2 {
		b[0] |= 1 << 5
	}
	if h.rsv3 {
		b[0] |= 1 << 4
	}

	b[0] |= byte(h.opcode)

	switch {
	case h.payloadLength < 0:
		panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength))
	case h.payloadLength <= 125:
		b[1] = byte(h.payloadLength)
	case h.payloadLength <= math.MaxUint16:
		b[1] = 126
		b = b[:len(b)+2]
		binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength))
	default:
		b[1] = 127
		b = b[:len(b)+8]
		binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength))
	}

	if h.masked {
		b[1] |= 1 << 7
		b = b[:len(b)+4]
		copy(b[len(b)-4:], h.maskKey[:])
	}

	return b
}

func makeReadHeaderBuf() []byte {
	return make([]byte, maxHeaderSize-2)
}

// readHeader reads a header from the reader.
// See https://tools.ietf.org/html/rfc6455#section-5.2
func readHeader(b []byte, r io.Reader) (header, error) {
	if b == nil {
		b = makeReadHeaderBuf()
	}

	// We read the first two bytes first so that we know
	// exactly how long the header is.
	b = b[:2]
	_, err := io.ReadFull(r, b)
	if err != nil {
		return header{}, err
	}

	var h header
	h.fin = b[0]&(1<<7) != 0
	h.rsv1 = b[0]&(1<<6) != 0
	h.rsv2 = b[0]&(1<<5) != 0
	h.rsv3 = b[0]&(1<<4) != 0

	h.opcode = opcode(b[0] & 0xf)

	var extra int

	h.masked = b[1]&(1<<7) != 0
	if h.masked {
		extra += 4
	}

	payloadLength := b[1] &^ (1 << 7)
	switch {
	case payloadLength < 126:
		h.payloadLength = int64(payloadLength)
	case payloadLength == 126:
		extra += 2
	case payloadLength == 127:
		extra += 8
	}

	if extra == 0 {
		return h, nil
	}

	b = b[:extra]
	_, err = io.ReadFull(r, b)
	if err != nil {
		return header{}, err
	}
	switch {
	case payloadLength == 126:
		h.payloadLength = int64(binary.BigEndian.Uint16(b))
		b = b[2:]
	case payloadLength == 127:
		h.payloadLength = int64(binary.BigEndian.Uint64(b))
		if h.payloadLength < 0 {
			return header{}, xerrors.Errorf("header with negative payload length: %v", h.payloadLength)
		}
		b = b[8:]
	}

	if h.masked {
		copy(h.maskKey[:], b)
	}

	return h, nil
}