From abad521c553c4fbde2cb7c4df959fea0e03447bf Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Fri, 25 Aug 2023 16:30:24 -0500
Subject: [PATCH] types

---
 lib/psql/query.go      | 43 ++++++++++++++++++++++++++++++++++++++++++
 lib/psql/query_test.go |  2 +-
 2 files changed, 44 insertions(+), 1 deletion(-)

diff --git a/lib/psql/query.go b/lib/psql/query.go
index 88cdac2d..aa532f7f 100644
--- a/lib/psql/query.go
+++ b/lib/psql/query.go
@@ -5,6 +5,7 @@ import (
 	"errors"
 	"io"
 	"reflect"
+	"strconv"
 
 	"pggat2/lib/bouncer/backends/v0"
 	"pggat2/lib/zap"
@@ -155,6 +156,12 @@ outer2:
 
 	if row == nil {
 		if result.Kind() == reflect.Pointer {
+			if result.IsNil() {
+				return nil
+			}
+			if !result.CanSet() {
+				return ErrUnexpectedType
+			}
 			result.Set(reflect.Zero(result.Type()))
 			return nil
 		} else {
@@ -179,6 +186,42 @@ outer2:
 	case reflect.String:
 		result.SetString(string(row))
 		return nil
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		x, err := strconv.ParseUint(string(row), 10, 64)
+		if err != nil {
+			return err
+		}
+		result.SetUint(x)
+		return nil
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		x, err := strconv.ParseInt(string(row), 10, 64)
+		if err != nil {
+			return err
+		}
+		result.SetInt(x)
+		return nil
+	case reflect.Float32, reflect.Float64:
+		x, err := strconv.ParseFloat(string(row), 64)
+		if err != nil {
+			return err
+		}
+		result.SetFloat(x)
+		return nil
+	case reflect.Bool:
+		if len(row) != 1 {
+			return ErrUnexpectedType
+		}
+		var x bool
+		switch row[0] {
+		case 'f':
+			x = false
+		case 't':
+			x = true
+		default:
+			return ErrUnexpectedType
+		}
+		result.SetBool(x)
+		return nil
 	default:
 		return ErrUnexpectedType
 	}
diff --git a/lib/psql/query_test.go b/lib/psql/query_test.go
index 271b203e..3e2d454d 100644
--- a/lib/psql/query_test.go
+++ b/lib/psql/query_test.go
@@ -38,7 +38,7 @@ func TestQuery(t *testing.T) {
 
 	var res Result
 
-	err = Query(server, "SELECT usename, passwd FROM pg_shadow WHERE usename='postgres'", &res)
+	err = Query(server, "SELECT 'abc' as usename, 'test' as passwd", &res)
 	if err != nil {
 		t.Error(err)
 		return
-- 
GitLab