good morning!!!!

Skip to content
Snippets Groups Projects
Commit bf8a04e3 authored by Garet Halliday's avatar Garet Halliday
Browse files

pretty much working

parent ff267ce2
No related branches found
No related tags found
No related merge requests found
...@@ -262,7 +262,7 @@ func enableSSL(ctx *acceptContext) (bool, error) { ...@@ -262,7 +262,7 @@ func enableSSL(ctx *acceptContext) (bool, error) {
return false, nil return false, nil
} }
if err = ctx.Conn.EnableSSLClient(ctx.Options.SSLConfig); err != nil { if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, true); err != nil {
return false, err return false, err
} }
......
...@@ -47,7 +47,7 @@ func startup0( ...@@ -47,7 +47,7 @@ func startup0(
if err = ctx.Conn.WriteByte('S'); err != nil { if err = ctx.Conn.WriteByte('S'); err != nil {
return return
} }
if err = ctx.Conn.EnableSSLServer(ctx.Options.SSLConfig); err != nil { if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, false); err != nil {
return return
} }
return return
......
...@@ -2,7 +2,9 @@ package fed ...@@ -2,7 +2,9 @@ package fed
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"io" "io"
"net"
"gfx.cafe/gfx/pggat/lib/util/decorator" "gfx.cafe/gfx/pggat/lib/util/decorator"
"gfx.cafe/gfx/pggat/lib/util/strutil" "gfx.cafe/gfx/pggat/lib/util/strutil"
...@@ -11,9 +13,10 @@ import ( ...@@ -11,9 +13,10 @@ import (
type Conn struct { type Conn struct {
noCopy decorator.NoCopy noCopy decorator.NoCopy
encoder Encoder
decoder Decoder
ReadWriter io.ReadWriteCloser ReadWriter io.ReadWriteCloser
Encoder Encoder
Decoder Decoder
Middleware []Middleware Middleware []Middleware
SSL bool SSL bool
...@@ -29,27 +32,27 @@ func NewConn(rw io.ReadWriteCloser) *Conn { ...@@ -29,27 +32,27 @@ func NewConn(rw io.ReadWriteCloser) *Conn {
c := &Conn{ c := &Conn{
ReadWriter: rw, ReadWriter: rw,
} }
c.Encoder.Writer.Reset(rw) c.encoder.Writer.Reset(rw)
c.Decoder.Reader.Reset(rw) c.decoder.Reader.Reset(rw)
return c return c
} }
func (T *Conn) Flush() error { func (T *Conn) flush() error {
return T.Encoder.Flush() return T.encoder.Flush()
} }
func (T *Conn) ReadPacket(typed bool) (Packet, error) { func (T *Conn) ReadPacket(typed bool) (Packet, error) {
if err := T.Flush(); err != nil { if err := T.flush(); err != nil {
return nil, err return nil, err
} }
for { for {
if err := T.Decoder.Next(typed); err != nil { if err := T.decoder.Next(typed); err != nil {
return nil, err return nil, err
} }
var packet Packet var packet Packet
packet = PendingPacket{ packet = PendingPacket{
Decoder: &T.Decoder, Decoder: &T.decoder,
} }
for _, middleware := range T.Middleware { for _, middleware := range T.Middleware {
var err error var err error
...@@ -82,38 +85,59 @@ func (T *Conn) WritePacket(packet Packet) error { ...@@ -82,38 +85,59 @@ func (T *Conn) WritePacket(packet Packet) error {
return nil return nil
} }
err := T.Encoder.Next(packet.Type(), packet.Length()) err := T.encoder.Next(packet.Type(), packet.Length())
if err != nil { if err != nil {
return err return err
} }
return packet.WriteTo(&T.Encoder) return packet.WriteTo(&T.encoder)
} }
func (T *Conn) WriteByte(b byte) error { func (T *Conn) WriteByte(b byte) error {
return T.Encoder.Uint8(b) return T.encoder.WriteByte(b)
} }
func (T *Conn) ReadByte() (byte, error) { func (T *Conn) ReadByte() (byte, error) {
if err := T.Flush(); err != nil { if err := T.flush(); err != nil {
return 0, err return 0, err
} }
return T.Decoder.Uint8() return T.decoder.ReadByte()
} }
func (T *Conn) EnableSSLClient(config *tls.Config) error { func (T *Conn) EnableSSL(config *tls.Config, isClient bool) error {
// TODO(garet) if T.SSL {
panic("TODO") return errors.New("SSL is already enabled")
} }
T.SSL = true
// flush buffers
if err := T.flush(); err != nil {
return err
}
if T.decoder.Reader.Buffered() > 0 {
return errors.New("expected empty read buffer")
}
func (T *Conn) EnableSSLServer(config *tls.Config) error { conn, ok := T.ReadWriter.(net.Conn)
// TODO(garet) if !ok {
panic("TODO") return errors.New("ssl not supported for this read writer")
}
var sslConn *tls.Conn
if isClient {
sslConn = tls.Client(conn, config)
} else {
sslConn = tls.Server(conn, config)
}
T.encoder.Writer.Reset(sslConn)
T.decoder.Reader.Reset(sslConn)
T.ReadWriter = sslConn
return sslConn.Handshake()
} }
func (T *Conn) Close() error { func (T *Conn) Close() error {
if err := T.Encoder.Flush(); err != nil { if err := T.encoder.Flush(); err != nil {
return err return err
} }
......
...@@ -28,7 +28,28 @@ func NewDecoder(r io.Reader) *Decoder { ...@@ -28,7 +28,28 @@ func NewDecoder(r io.Reader) *Decoder {
return d return d
} }
func (T *Decoder) ReadByte() (byte, error) {
if T.pos != T.len {
_, err := T.Reader.Discard(T.len - T.pos)
if err != nil {
return 0, err
}
}
T.typ = 0
T.len = 0
T.pos = 0
return T.Reader.ReadByte()
}
func (T *Decoder) Next(typed bool) error { func (T *Decoder) Next(typed bool) error {
if T.pos != T.len {
_, err := T.Reader.Discard(T.len - T.pos)
if err != nil {
return err
}
}
var err error var err error
if typed { if typed {
_, err = io.ReadFull(&T.Reader, T.buf[:5]) _, err = io.ReadFull(&T.Reader, T.buf[:5])
......
...@@ -31,7 +31,22 @@ func (T *Encoder) Flush() error { ...@@ -31,7 +31,22 @@ func (T *Encoder) Flush() error {
return T.Writer.Flush() return T.Writer.Flush()
} }
func (T *Encoder) WriteByte(b byte) error {
if T.pos != T.len {
panic("wrong number of bytes written")
}
T.typ = 0
T.len = 0
T.pos = 0
return T.Writer.WriteByte(b)
}
func (T *Encoder) Next(typ Type, length int) error { func (T *Encoder) Next(typ Type, length int) error {
if T.pos != T.len {
panic("wrong number of bytes written")
}
if typ != 0 { if typ != 0 {
if err := T.Writer.WriteByte(byte(typ)); err != nil { if err := T.Writer.WriteByte(byte(typ)); err != nil {
return err return err
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment