From 50b2d29ff37cc193698963fc9e4da470b7218815 Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Fri, 13 Nov 2020 09:27:57 +0100
Subject: [PATCH] crypto/bn256: improve bn256 fuzzer (#21815)

* crypto/cloudflare: fix nil deref in random G1/G2 reading

* crypto/bn256: improve fuzzer

* crypto/bn256: fix some flaws in fuzzer
---
 crypto/bn256/bn256_fuzz.go       | 117 ++++++++++++++-----------------
 crypto/bn256/cloudflare/bn256.go |   2 +-
 2 files changed, 53 insertions(+), 66 deletions(-)

diff --git a/crypto/bn256/bn256_fuzz.go b/crypto/bn256/bn256_fuzz.go
index 733d7ce27f..29fe7aab84 100644
--- a/crypto/bn256/bn256_fuzz.go
+++ b/crypto/bn256/bn256_fuzz.go
@@ -8,42 +8,52 @@ package bn256
 
 import (
 	"bytes"
+	"fmt"
+	"io"
 	"math/big"
 
 	cloudflare "github.com/ledgerwatch/turbo-geth/crypto/bn256/cloudflare"
 	google "github.com/ledgerwatch/turbo-geth/crypto/bn256/google"
 )
 
-// FuzzAdd fuzzez bn256 addition between the Google and Cloudflare libraries.
-func FuzzAdd(data []byte) int {
-	// Ensure we have enough data in the first place
-	if len(data) != 128 {
-		return 0
+func getG1Points(input io.Reader) (*cloudflare.G1, *google.G1) {
+	_, xc, err := cloudflare.RandomG1(input)
+	if err != nil {
+		// insufficient input
+		return nil, nil
 	}
-	// Ensure both libs can parse the first curve point
-	xc := new(cloudflare.G1)
-	_, errc := xc.Unmarshal(data[:64])
-
 	xg := new(google.G1)
-	_, errg := xg.Unmarshal(data[:64])
-
-	if (errc == nil) != (errg == nil) {
-		panic("parse mismatch")
-	} else if errc != nil {
-		return 0
+	if _, err := xg.Unmarshal(xc.Marshal()); err != nil {
+		panic(fmt.Sprintf("Could not marshal cloudflare -> google:", err))
 	}
-	// Ensure both libs can parse the second curve point
-	yc := new(cloudflare.G1)
-	_, errc = yc.Unmarshal(data[64:])
+	return xc, xg
+}
 
-	yg := new(google.G1)
-	_, errg = yg.Unmarshal(data[64:])
+func getG2Points(input io.Reader) (*cloudflare.G2, *google.G2) {
+	_, xc, err := cloudflare.RandomG2(input)
+	if err != nil {
+		// insufficient input
+		return nil, nil
+	}
+	xg := new(google.G2)
+	if _, err := xg.Unmarshal(xc.Marshal()); err != nil {
+		panic(fmt.Sprintf("Could not marshal cloudflare -> google:", err))
+	}
+	return xc, xg
+}
 
-	if (errc == nil) != (errg == nil) {
-		panic("parse mismatch")
-	} else if errc != nil {
+// FuzzAdd fuzzez bn256 addition between the Google and Cloudflare libraries.
+func FuzzAdd(data []byte) int {
+	input := bytes.NewReader(data)
+	xc, xg := getG1Points(input)
+	if xc == nil {
 		return 0
 	}
+	yc, yg := getG1Points(input)
+	if yc == nil {
+		return 0
+	}
+	// Ensure both libs can parse the second curve point
 	// Add the two points and ensure they result in the same output
 	rc := new(cloudflare.G1)
 	rc.Add(xc, yc)
@@ -54,73 +64,50 @@ func FuzzAdd(data []byte) int {
 	if !bytes.Equal(rc.Marshal(), rg.Marshal()) {
 		panic("add mismatch")
 	}
-	return 0
+	return 1
 }
 
 // FuzzMul fuzzez bn256 scalar multiplication between the Google and Cloudflare
 // libraries.
 func FuzzMul(data []byte) int {
-	// Ensure we have enough data in the first place
-	if len(data) != 96 {
+	input := bytes.NewReader(data)
+	pc, pg := getG1Points(input)
+	if pc == nil {
 		return 0
 	}
-	// Ensure both libs can parse the curve point
-	pc := new(cloudflare.G1)
-	_, errc := pc.Unmarshal(data[:64])
-
-	pg := new(google.G1)
-	_, errg := pg.Unmarshal(data[:64])
-
-	if (errc == nil) != (errg == nil) {
-		panic("parse mismatch")
-	} else if errc != nil {
+	// Add the two points and ensure they result in the same output
+	remaining := input.Len()
+	if remaining == 0 {
 		return 0
 	}
-	// Add the two points and ensure they result in the same output
+	buf := make([]byte, remaining)
+	input.Read(buf)
+
 	rc := new(cloudflare.G1)
-	rc.ScalarMult(pc, new(big.Int).SetBytes(data[64:]))
+	rc.ScalarMult(pc, new(big.Int).SetBytes(buf))
 
 	rg := new(google.G1)
-	rg.ScalarMult(pg, new(big.Int).SetBytes(data[64:]))
+	rg.ScalarMult(pg, new(big.Int).SetBytes(buf))
 
 	if !bytes.Equal(rc.Marshal(), rg.Marshal()) {
 		panic("scalar mul mismatch")
 	}
-	return 0
+	return 1
 }
 
 func FuzzPair(data []byte) int {
-	// Ensure we have enough data in the first place
-	if len(data) != 192 {
+	input := bytes.NewReader(data)
+	pc, pg := getG1Points(input)
+	if pc == nil {
 		return 0
 	}
-	// Ensure both libs can parse the curve point
-	pc := new(cloudflare.G1)
-	_, errc := pc.Unmarshal(data[:64])
-
-	pg := new(google.G1)
-	_, errg := pg.Unmarshal(data[:64])
-
-	if (errc == nil) != (errg == nil) {
-		panic("parse mismatch")
-	} else if errc != nil {
-		return 0
-	}
-	// Ensure both libs can parse the twist point
-	tc := new(cloudflare.G2)
-	_, errc = tc.Unmarshal(data[64:])
-
-	tg := new(google.G2)
-	_, errg = tg.Unmarshal(data[64:])
-
-	if (errc == nil) != (errg == nil) {
-		panic("parse mismatch")
-	} else if errc != nil {
+	tc, tg := getG2Points(input)
+	if tc == nil {
 		return 0
 	}
 	// Pair the two points and ensure thet result in the same output
 	if cloudflare.PairingCheck([]*cloudflare.G1{pc}, []*cloudflare.G2{tc}) != google.PairingCheck([]*google.G1{pg}, []*google.G2{tg}) {
 		panic("pair mismatch")
 	}
-	return 0
+	return 1
 }
diff --git a/crypto/bn256/cloudflare/bn256.go b/crypto/bn256/cloudflare/bn256.go
index 38822a76bf..a6dd972ba8 100644
--- a/crypto/bn256/cloudflare/bn256.go
+++ b/crypto/bn256/cloudflare/bn256.go
@@ -23,7 +23,7 @@ import (
 func randomK(r io.Reader) (k *big.Int, err error) {
 	for {
 		k, err = rand.Int(r, Order)
-		if k.Sign() > 0 || err != nil {
+		if err != nil || k.Sign() > 0 {
 			return
 		}
 	}
-- 
GitLab