package test

import (
	"bytes"
	"errors"
	"fmt"
	"io"

	"gfx.cafe/gfx/pggat/lib/bouncer/bouncers/v2"
	"gfx.cafe/gfx/pggat/lib/fed"
	"gfx.cafe/gfx/pggat/lib/fed/middlewares/unterminate"
	packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
	"gfx.cafe/gfx/pggat/lib/gat/pool/recipe"
	"gfx.cafe/gfx/pggat/lib/gsql"
	"gfx.cafe/gfx/pggat/lib/util/flip"
	"gfx.cafe/gfx/pggat/test/inst"
)

type Runner struct {
	config Config
	test   Test
}

func MakeRunner(config Config, test Test) Runner {
	return Runner{
		config: config,
		test:   test,
	}
}

func (T *Runner) prepare(client *fed.Conn, until int) error {
	for i := 0; i < until; i++ {
		x := T.test.Instructions[i]
		switch v := x.(type) {
		case inst.SimpleQuery:
			q := packets.Query(v)
			if err := client.WritePacket(&q); err != nil {
				return err
			}
		case inst.Sync:
			if err := client.WritePacket(&packets.Sync{}); err != nil {
				return err
			}
		case inst.Parse:
			p := packets.Parse{
				Destination: v.Destination,
				Query:       v.Query,
			}
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.Bind:
			p := packets.Bind{
				Destination: v.Destination,
				Source:      v.Source,
			}
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.DescribePortal:
			p := packets.Describe{
				Which: 'P',
				Name:  string(v),
			}
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.DescribePreparedStatement:
			p := packets.Describe{
				Which: 'S',
				Name:  string(v),
			}
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.Execute:
			p := packets.Execute{
				Target: string(v),
			}
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.ClosePortal:
			p := packets.Close{
				Which: 'P',
				Name:  string(v),
			}
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.ClosePreparedStatement:
			p := packets.Close{
				Which: 'S',
				Name:  string(v),
			}
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.CopyData:
			p := packets.CopyData(v)
			if err := client.WritePacket(&p); err != nil {
				return err
			}
		case inst.CopyDone:
			if err := client.WritePacket(&packets.CopyDone{}); err != nil {
				return err
			}
		}
	}

	if err := client.WritePacket(&packets.Terminate{}); err != nil {
		return err
	}

	return client.Flush()
}

func (T *Runner) runModeL1(dialer recipe.Dialer, client *fed.Conn) error {
	server, err := dialer.Dial()
	if err != nil {
		return err
	}
	defer func() {
		_ = server.Close()
	}()

	client.Middleware = append(client.Middleware, unterminate.Unterminate)

	for {
		var p fed.Packet
		p, err = client.ReadPacket(true)
		if err != nil {
			if errors.Is(err, io.EOF) {
				break
			}
			return err
		}

		clientErr, serverErr := bouncers.Bounce(client, server, p)
		if clientErr != nil {
			return clientErr
		}
		if serverErr != nil {
			return serverErr
		}
	}

	return nil
}

func (T *Runner) runModeOnce(dialer recipe.Dialer) ([]byte, error) {
	inward, outward := gsql.NewPair()
	if err := T.prepare(inward, len(T.test.Instructions)); err != nil {
		return nil, err
	}

	if err := T.runModeL1(dialer, outward); err != nil {
		return nil, err
	}

	if err := inward.Close(); err != nil {
		return nil, err
	}

	return io.ReadAll(inward.NetConn)
}

func (T *Runner) runModeFail(dialer recipe.Dialer) error {
	for i := 1; i < len(T.test.Instructions)+1; i++ {
		inward, outward := gsql.NewPair()
		if err := T.prepare(inward, i); err != nil {
			return err
		}

		if err := T.runModeL1(dialer, outward); err != nil && !errors.Is(err, io.EOF) {
			return err
		}
	}

	return nil
}

func (T *Runner) runMode(dialer recipe.Dialer) ([]byte, error) {
	instances := T.config.Stress
	if instances < 1 || T.test.SideEffects {
		return T.runModeOnce(dialer)
	}

	expected, err := T.runModeOnce(dialer)
	if err != nil {
		return nil, err
	}

	// fail testing
	if err = T.runModeFail(dialer); err != nil {
		return nil, err
	}

	// stress test
	var b flip.Bank

	for i := 0; i < instances-1; i++ {
		b.Queue(func() error {
			actual, err := T.runModeOnce(dialer)
			if err != nil {
				return err
			}
			if !bytes.Equal(expected, actual) {
				return fmt.Errorf("mismatched results: expected %v but got %v", expected, actual)
			}
			return nil
		})
	}

	if err = b.Wait(); err != nil {
		return nil, err
	}

	return expected, nil
}

func (T *Runner) Run() error {
	var errs []error

	var expected []byte

	// modes
	for name, mode := range T.config.Modes {
		actual, err := T.runMode(mode)
		if err != nil {
			errs = append(errs, ErrorIn{
				Name: name,
				Err:  err,
			})
			continue
		}

		if expected == nil {
			expected = actual
			continue
		}

		if !bytes.Equal(expected, actual) {
			errs = append(errs, ErrorIn{
				Name: name,
				Err:  fmt.Errorf("mismatched results: expected %v but got %v", expected, actual),
			})
			continue
		}
	}

	if len(errs) > 0 {
		return Errors(errs)
	}
	return nil
}