From 363c19e6efc53e445bc4c4a27c717c6526b04f64 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 10 Oct 2023 18:14:44 -0500
Subject: [PATCH] just need to fix Authentication packet

---
 hack/packetgen/main.go                |   6 ++
 hack/packetgen/protocol.yaml          |  65 +++++++++-----
 hack/packetgen/templates/decode.tmpl  | 118 +++++++++++++++++++++++++-
 hack/packetgen/templates/encode.tmpl  |  94 +++++++++++++++++++-
 hack/packetgen/templates/packets.tmpl |  41 ++++++++-
 hack/packetgen/templates/preType.tmpl |  20 ++---
 6 files changed, 307 insertions(+), 37 deletions(-)

diff --git a/hack/packetgen/main.go b/hack/packetgen/main.go
index f536d8ef..c6d4f14e 100644
--- a/hack/packetgen/main.go
+++ b/hack/packetgen/main.go
@@ -2,15 +2,21 @@ package main
 
 import (
 	_ "embed"
+	"fmt"
 
 	"gfx.cafe/util/temple"
 	"gfx.cafe/util/temple/lib/prayer"
 )
 
 func main() {
+	var idx int
 	var obj any
 	temple.RegisterTemplateDir("templates")
 	temple.ReadObjectFile(&obj, "protocol.yaml")
+	temple.RegisterFunc("temp", func() string {
+		idx++
+		return fmt.Sprintf("temp%d", idx)
+	})
 	temple.Prepare(&prayer.Go{
 		Input:   "packets",
 		Obj:     obj,
diff --git a/hack/packetgen/protocol.yaml b/hack/packetgen/protocol.yaml
index ebb29859..63d13816 100644
--- a/hack/packetgen/protocol.yaml
+++ b/hack/packetgen/protocol.yaml
@@ -1,4 +1,12 @@
 Packets:
+  StartupMessage:
+    Struct:
+      Name: Payload
+      Fields:
+        - Name: MajorVersion
+          Basic: uint16
+        - Name: MinorVersion
+          Basic: uint16
   Authentication:
     Type: 'R'
     Struct:
@@ -7,7 +15,8 @@ Packets:
         - Name: Mode
           Map:
             Name: Mode
-            Prefix: int32
+            Prefix:
+              Basic: int32
             Items:
               Ok:
                 Type: 0
@@ -56,7 +65,8 @@ Packets:
           Basic: string
         - Name: InitialClientResponse
           NullableLengthPrefixedSlice:
-            Prefix: int32
+            Prefix:
+              Basic: int32
             Basic: uint8
   SASLResponse:
     Type: 'p'
@@ -82,17 +92,21 @@ Packets:
           Basic: string
         - Name: FormatCodes
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             Basic: int16
         - Name: Parameters
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             NullableLengthPrefixedSlice:
-              Prefix: int32
+              Prefix:
+                Basic: int32
               Basic: uint8
         - Name: ResultFormatCodes
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             Basic: int16
   BindComplete:
     Type: '2'
@@ -128,7 +142,8 @@ Packets:
           Basic: int8
         - Name: ColumnFormatCodes
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             Basic: int16
   CopyOutResponse:
     Type: 'H'
@@ -139,7 +154,8 @@ Packets:
           Basic: int8
         - Name: ColumnFormatCodes
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             Basic: int16
   CopyBothResponse:
     Type: 'W'
@@ -150,14 +166,17 @@ Packets:
           Basic: int8
         - Name: ColumnFormatCodes
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             Basic: int16
   DataRow:
     Type: 'D'
     LengthPrefixedSlice:
-      Prefix: uint16
+      Prefix:
+        Basic: uint16
       NullableLengthPrefixedSlice:
-        Prefix: int32
+        Prefix:
+          Basic: int32
         Basic: uint8
   Describe:
     Type: 'D'
@@ -198,20 +217,24 @@ Packets:
           Basic: int32
         - Name: ArgumentFormatCodes
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             Basic: int16
         - Name: Arguments
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             NullableLengthPrefixedSlice:
-              Prefix: int32
+              Prefix:
+                Basic: int32
               Basic: uint8
         - Name: ResultFormatCode
           Basic: int16
   FunctionCallResponse:
     Type: 'V'
     NullableLengthPrefixedSlice:
-      Prefix: int32
+      Prefix:
+        Basic: int32
       Basic: uint8
   NegotiateProtocolVersion:
     Type: 'v'
@@ -222,7 +245,8 @@ Packets:
           Basic: int32
         - Name: UnrecognizedProtocolOptions
           LengthPrefixedSlice:
-            Prefix: uint32
+            Prefix:
+              Basic: uint32
             Basic: string
   NoData:
     Type: 'n'
@@ -248,7 +272,8 @@ Packets:
   ParameterDescription:
     Type: 't'
     LengthPrefixedSlice:
-      Prefix: uint16
+      Prefix:
+        Basic: uint16
       Basic: int32
   ParameterStatus:
     Type: 'S'
@@ -270,7 +295,8 @@ Packets:
           Basic: string
         - Name: ParameterDataTypes
           LengthPrefixedSlice:
-            Prefix: uint16
+            Prefix:
+              Basic: uint16
             Basic: int32
   ParseComplete:
     Type: '1'
@@ -285,7 +311,8 @@ Packets:
   RowDescription:
     Type: 'T'
     LengthPrefixedSlice:
-      Prefix: uint16
+      Prefix:
+        Basic: uint16
       Struct:
         Name: Row
         Fields:
diff --git a/hack/packetgen/templates/decode.tmpl b/hack/packetgen/templates/decode.tmpl
index 96e3f78d..b06b0f55 100644
--- a/hack/packetgen/templates/decode.tmpl
+++ b/hack/packetgen/templates/decode.tmpl
@@ -1,6 +1,118 @@
 {{$name := index . 0 -}}
 {{$value := index . 1 -}}
 
-func (T *{{$name}}) ReadFrom(decoder *fed.Decoder) error {
-	panic("TODO") // TODO
-}
+{{if some $value.Map -}}
+	err = {{$name}}.ReadFrom(decoder)
+	if err != nil {
+		return
+	}
+{{else if some $value.Remaining -}}
+	{{$name}} = {{$name}}[:0]
+
+	for {
+		if decoder.Position() >= decoder.Length() {
+			break
+		}
+
+		{{$name}} = slices.Resize({{$name}}, len({{$name}})+1)
+
+		{{$targetName := printf "%s[len(%s)-1]" $name $name -}}
+
+		{{template "decode" (list $targetName $value.Remaining)}}
+	}
+{{else if some $value.Basic -}}
+	*(*{{$value.Basic}})(&({{$name}})), err = decoder.{{upperCamel $value.Basic}}()
+	if err != nil {
+		return
+	}
+{{else if some $value.Array -}}
+	{{$indexName := temp -}}
+
+	for {{$indexName}} := 0; {{$indexName}} < {{$value.Array.Length}}; {{$indexName}}++ {
+		{{$targetName := printf "%s[%s]" $name $indexName -}}
+
+		{{template "decode" (list $targetName $value.Array)}}
+	}
+{{else if some $value.Struct -}}
+	{{range $field := $value.Struct.Fields -}}
+        {{$fieldName := printf "%s.%s" $name $field.Name -}}
+
+		{{template "decode" (list $fieldName $field) -}}
+	{{end -}}
+{{else if some $value.LengthPrefixedSlice -}}
+	{{$lengthName := temp -}}
+
+	var {{$lengthName}} {{template "type" (list "" $value.LengthPrefixedSlice.Prefix)}}
+	{{template "decode" (list $lengthName $value.LengthPrefixedSlice.Prefix)}}
+
+	{{$name}} = slices.Resize({{$name}}, int({{$lengthName}}))
+
+	{{$indexName := temp -}}
+
+	for {{$indexName}} := 0; {{$indexName}} < int({{$lengthName}}); {{$indexName}}++ {
+		{{$targetName := printf "%s[%s]" $name $indexName -}}
+
+		{{template "decode" (list $targetName $value.LengthPrefixedSlice)}}
+	}
+
+{{else if some $value.NullableLengthPrefixedSlice -}}
+    {{$lengthName := temp -}}
+
+	var {{$lengthName}} {{template "type" (list "" $value.NullableLengthPrefixedSlice.Prefix)}}
+    {{template "decode" (list $lengthName $value.NullableLengthPrefixedSlice.Prefix)}}
+
+	if {{$lengthName}} == -1 {
+		{{$name}} = nil
+	} else {
+		{{$name}} = slices.Resize({{$name}}, int({{$lengthName}}))
+
+		{{$indexName := temp -}}
+
+		for {{$indexName}} := 0; {{$indexName}} < int({{$lengthName}}); {{$indexName}}++ {
+			{{$targetName := printf "%s[%s]" $name $indexName -}}
+
+			{{template "decode" (list $targetName $value.NullableLengthPrefixedSlice)}}
+		}
+	}
+
+{{else if some $value.ZeroByteTerminatedSlice -}}
+	{{$name}} = {{$name}}[:0]
+
+	for {
+	    {{$name}} = slices.Resize({{$name}}, len({{$name}})+1)
+
+		{{$targetName := printf "%s[len(%s)-1]" $name $name -}}
+
+		{{$keyName := printf "%s.%s" $targetName $value.ZeroByteTerminatedSlice.KeyName -}}
+
+		{{$keyName}}, err = decoder.Uint8()
+		if err != nil {
+			return
+		}
+		if {{$keyName}} == 0 {
+			{{$name}} = {{$name}}[:len({{$name}})-1]
+			break
+		}
+
+		{{range $field := $value.ZeroByteTerminatedSlice.Fields -}}
+			{{$fieldName := printf "%s.%s" $targetName $field.Name -}}
+
+			{{template "decode" (list $fieldName $field) -}}
+		{{end -}}
+	}
+{{else if some $value.ZeroTerminatedSlice -}}
+	{{$name}} = {{$name}}[:0]
+
+	for {
+		{{$name}} = slices.Resize({{$name}}, len({{$name}})+1)
+
+		{{$targetName := printf "%s[len(%s)-1]" $name $name -}}
+
+		{{template "decode" (list $targetName $value.ZeroTerminatedSlice)}}
+
+		if {{$targetName}} == *new({{template "type" (list "" $value.ZeroTerminatedSlice)}}) {
+    		{{$name}} = {{$name}}[:len({{$name}})-1]
+			break
+		}
+	}
+{{end -}}
diff --git a/hack/packetgen/templates/encode.tmpl b/hack/packetgen/templates/encode.tmpl
index e3a0189a..50006c36 100644
--- a/hack/packetgen/templates/encode.tmpl
+++ b/hack/packetgen/templates/encode.tmpl
@@ -1,6 +1,94 @@
 {{$name := index . 0 -}}
 {{$value := index . 1 -}}
 
-func (T *{{$name}}) WriteTo(encoder *fed.Encoder) error {
-	panic("TODO") // TODO
-}
+{{if some $value.Map -}}
+	err = {{$name}}.WriteTo(encoder)
+	if err != nil {
+		return
+	}
+{{else if some $value.Remaining -}}
+	{{$itemName := temp -}}
+
+	for _, {{$itemName}} := range {{$name}} {
+		{{template "encode" (list $itemName $value.Remaining)}}
+	}
+{{else if some $value.Basic -}}
+	err = encoder.{{upperCamel $value.Basic}}({{$value.Basic}}({{$name}}))
+	if err != nil {
+		return
+	}
+{{else if some $value.Array -}}
+	{{$itemName := temp -}}
+
+	for _, {{$itemName}} := range {{$name}} {
+		{{template "encode" (list $itemName $value.Array)}}
+	}
+{{else if some $value.Struct -}}
+	{{range $field := $value.Struct.Fields -}}
+		{{$fieldName := printf "%s.%s" $name $field.Name -}}
+
+		{{template "encode" (list $fieldName $field)}}
+	{{end -}}
+{{else if some $value.LengthPrefixedSlice -}}
+	{{$lengthName := temp -}}
+
+	{{$lengthName}} := {{template "type" (list "" $value.LengthPrefixedSlice.Prefix)}}(len({{$name}}))
+
+	{{template "encode" (list $lengthName $value.LengthPrefixedSlice.Prefix)}}
+
+	{{$itemName := temp -}}
+
+	for _, {{$itemName}} := range {{$name}} {
+		{{template "encode" (list $itemName $value.LengthPrefixedSlice)}}
+	}
+{{else if some $value.NullableLengthPrefixedSlice -}}
+    {{$lengthName := temp -}}
+
+    {{$lengthName}} := {{template "type" (list "" $value.NullableLengthPrefixedSlice.Prefix)}}(len({{$name}}))
+
+	if {{$name}} == nil {
+		{{$lengthName}} = -1
+	}
+
+    {{template "encode" (list $lengthName $value.NullableLengthPrefixedSlice.Prefix)}}
+
+    {{$itemName := temp -}}
+
+	for _, {{$itemName}} := range {{$name}} {
+    	{{template "encode" (list $itemName $value.NullableLengthPrefixedSlice)}}
+	}
+{{else if some $value.ZeroByteTerminatedSlice -}}
+	{{$itemName := temp -}}
+
+	for _, {{$itemName}} := range {{$name}} {
+		{{$keyName := printf "%s.%s" $itemName $value.ZeroByteTerminatedSlice.KeyName -}}
+
+		err = encoder.Uint8({{$keyName}})
+		if err != nil {
+			return
+		}
+
+		{{range $field := $value.ZeroByteTerminatedSlice.Fields -}}
+			{{$fieldName := printf "%s.%s" $itemName $field.Name -}}
+
+			{{template "encode" (list $fieldName $field)}}
+		{{end -}}
+	}
+
+	err = encoder.Uint8(0)
+	if err != nil {
+		return
+	}
+{{else if some $value.ZeroTerminatedSlice -}}
+	{{$itemName := temp -}}
+
+	for _, {{$itemName}} := range {{$name}} {
+    	{{template "encode" (list $itemName $value.ZeroTerminatedSlice)}}
+	}
+
+	{{$doneName := temp -}}
+
+	var {{$doneName}} {{template "type" (list "" $value.ZeroTerminatedSlice)}}
+
+	{{template "encode" (list $doneName $value.ZeroTerminatedSlice)}}
+{{end -}}
diff --git a/hack/packetgen/templates/packets.tmpl b/hack/packetgen/templates/packets.tmpl
index 59d455a9..9d46347b 100644
--- a/hack/packetgen/templates/packets.tmpl
+++ b/hack/packetgen/templates/packets.tmpl
@@ -1,12 +1,49 @@
 // automatically generated. do not edit
 
+import (
+	"gfx.cafe/gfx/pggat/lib/fed"
+	"gfx.cafe/gfx/pggat/lib/util/slices"
+
+	"errors"
+)
+
+var (
+	ErrUnexpectedPacket = errors.New("unexpected packet")
+)
+
+const (
+	{{range $name, $packet := .Packets -}}
+		{{if some $packet.Type -}}
+			{{$name}}Type = '{{$packet.Type}}'
+		{{end -}}
+	{{end -}}
+)
+
 {{range $name, $packet := .Packets -}}
 	{{template "preType" (list $name $packet)}}
 
 	type {{$name}} {{template "type" (list $name $packet)}}
 
-    {{template "decode" (list $name $packet)}}
+	func (T *{{$name}}) ReadFrom(decoder *fed.Decoder) (err error) {
+		{{if some $packet.Type -}}
+			if decoder.Type() != {{$name}}Type {
+				return ErrUnexpectedPacket
+			}
+		{{else -}}
+			if decoder.Type() != 0 {
+				return ErrUnexpectedPacket
+			}
+		{{end -}}
+
+		{{template "decode" (list "(*T)" $packet)}}
+
+		return
+	}
+
+	func (T *{{$name}}) WriteTo(encoder *fed.Encoder) (err error) {
+		{{template "encode" (list "(*T)" $packet)}}
 
-	{{template "encode" (list $name $packet)}}
+		return
+	}
 
 {{end}}
diff --git a/hack/packetgen/templates/preType.tmpl b/hack/packetgen/templates/preType.tmpl
index b932ecff..a98bff85 100644
--- a/hack/packetgen/templates/preType.tmpl
+++ b/hack/packetgen/templates/preType.tmpl
@@ -13,9 +13,17 @@
 
 		func (*{{$itemName}}) {{$ifaceName}}() {}
 
-        {{template "decode" (list $itemName $item)}}
+		func (T *{{$itemName}}) ReadFrom(decoder *fed.Decoder) (err error) {
+			{{template "decode" (list "(*T)" $item)}}
 
-        {{template "encode" (list $itemName $item)}}
+			return
+		}
+
+		func (T *{{$itemName}}) WriteTo(encoder *fed.Encoder) (err error) {
+			{{template "encode" (list "(*T)" $item)}}
+
+			return
+		}
 
 	{{end -}}
 
@@ -38,10 +46,6 @@
 		{{end -}}
 	}
 
-    {{template "decode" (list $structName $value)}}
-
-    {{template "encode" (list $structName $value)}}
-
 {{else if some $value.LengthPrefixedSlice -}}
 	{{template "preType" (list $name $value.LengthPrefixedSlice) -}}
 {{else if some $value.NullableLengthPrefixedSlice -}}
@@ -61,8 +65,4 @@
 		{{end -}}
 	}
 
-    {{template "decode" (list $structName $value)}}
-
-    {{template "encode" (list $structName $value)}}
-
 {{end -}}
\ No newline at end of file
-- 
GitLab