From 4871e25f5fe8d58344f5267ef197662dde018d21 Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Thu, 8 Mar 2018 13:48:19 +0100
Subject: [PATCH] core/vm: optimize eq, slt, sgt and iszero + tests (#16047)

* vm: optimize eq, slt, sgt and iszero + tests

* core/vm: fix error in slt/sgt, found by vmtests. Added testcase

* core/vm: make slt/sgt cleaner
---
 core/vm/instructions.go      | 71 ++++++++++++++++++++++++------------
 core/vm/instructions_test.go | 15 +++++++-
 2 files changed, 61 insertions(+), 25 deletions(-)

diff --git a/core/vm/instructions.go b/core/vm/instructions.go
index 6daf4e10d..66e804fb7 100644
--- a/core/vm/instructions.go
+++ b/core/vm/instructions.go
@@ -30,6 +30,8 @@ import (
 
 var (
 	bigZero                  = new(big.Int)
+	tt255                    = math.BigPow(2, 255)
+	tt256                    = math.BigPow(2, 256)
 	errWriteProtection       = errors.New("evm: write protection")
 	errReturnDataOutOfBounds = errors.New("evm: return data out of bounds")
 	errExecutionReverted     = errors.New("evm: execution reverted")
@@ -191,50 +193,71 @@ func opGt(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack
 }
 
 func opSlt(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) {
-	x, y := math.S256(stack.pop()), math.S256(stack.pop())
-	if x.Cmp(math.S256(y)) < 0 {
-		stack.push(evm.interpreter.intPool.get().SetUint64(1))
-	} else {
-		stack.push(new(big.Int))
-	}
+	x, y := stack.pop(), stack.peek()
 
-	evm.interpreter.intPool.put(x, y)
+	xSign := x.Cmp(tt255)
+	ySign := y.Cmp(tt255)
+
+	switch {
+	case xSign >= 0 && ySign < 0:
+		y.SetUint64(1)
+
+	case xSign < 0 && ySign >= 0:
+		y.SetUint64(0)
+
+	default:
+		if x.Cmp(y) < 0 {
+			y.SetUint64(1)
+		} else {
+			y.SetUint64(0)
+		}
+	}
+	evm.interpreter.intPool.put(x)
 	return nil, nil
 }
 
 func opSgt(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) {
-	x, y := math.S256(stack.pop()), math.S256(stack.pop())
-	if x.Cmp(y) > 0 {
-		stack.push(evm.interpreter.intPool.get().SetUint64(1))
-	} else {
-		stack.push(new(big.Int))
-	}
+	x, y := stack.pop(), stack.peek()
 
-	evm.interpreter.intPool.put(x, y)
+	xSign := x.Cmp(tt255)
+	ySign := y.Cmp(tt255)
+
+	switch {
+	case xSign >= 0 && ySign < 0:
+		y.SetUint64(0)
+
+	case xSign < 0 && ySign >= 0:
+		y.SetUint64(1)
+
+	default:
+		if x.Cmp(y) > 0 {
+			y.SetUint64(1)
+		} else {
+			y.SetUint64(0)
+		}
+	}
+	evm.interpreter.intPool.put(x)
 	return nil, nil
 }
 
 func opEq(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) {
-	x, y := stack.pop(), stack.pop()
+	x, y := stack.pop(), stack.peek()
 	if x.Cmp(y) == 0 {
-		stack.push(evm.interpreter.intPool.get().SetUint64(1))
+		y.SetUint64(1)
 	} else {
-		stack.push(new(big.Int))
+		y.SetUint64(0)
 	}
-
-	evm.interpreter.intPool.put(x, y)
+	evm.interpreter.intPool.put(x)
 	return nil, nil
 }
 
 func opIszero(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *Stack) ([]byte, error) {
-	x := stack.pop()
+	x := stack.peek()
 	if x.Sign() > 0 {
-		stack.push(new(big.Int))
+		x.SetUint64(0)
 	} else {
-		stack.push(evm.interpreter.intPool.get().SetUint64(1))
+		x.SetUint64(1)
 	}
-
-	evm.interpreter.intPool.put(x)
 	return nil, nil
 }
 
diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go
index eef4328bd..134363bb7 100644
--- a/core/vm/instructions_test.go
+++ b/core/vm/instructions_test.go
@@ -161,6 +161,7 @@ func TestSAR(t *testing.T) {
 
 func TestSGT(t *testing.T) {
 	tests := []twoOperandTest{
+
 		{"0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"},
 		{"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"},
 		{"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"},
@@ -171,6 +172,8 @@ func TestSGT(t *testing.T) {
 		{"8000000000000000000000000000000000000000000000000000000000000001", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"},
 		{"8000000000000000000000000000000000000000000000000000000000000001", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000001"},
 		{"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"},
+		{"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffb", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd", "0000000000000000000000000000000000000000000000000000000000000001"},
+		{"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffb", "0000000000000000000000000000000000000000000000000000000000000000"},
 	}
 	testTwoOperandOp(t, tests, opSgt)
 }
@@ -187,6 +190,8 @@ func TestSLT(t *testing.T) {
 		{"8000000000000000000000000000000000000000000000000000000000000001", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000000"},
 		{"8000000000000000000000000000000000000000000000000000000000000001", "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "0000000000000000000000000000000000000000000000000000000000000000"},
 		{"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", "8000000000000000000000000000000000000000000000000000000000000001", "0000000000000000000000000000000000000000000000000000000000000001"},
+		{"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffb", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd", "0000000000000000000000000000000000000000000000000000000000000000"},
+		{"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffd", "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffb", "0000000000000000000000000000000000000000000000000000000000000001"},
 	}
 	testTwoOperandOp(t, tests, opSlt)
 }
@@ -349,7 +354,11 @@ func BenchmarkOpEq(b *testing.B) {
 
 	opBenchmark(b, opEq, x, y)
 }
-
+func BenchmarkOpEq2(b *testing.B) {
+	x := "FBCDEF090807060504030201ffffffffFBCDEF090807060504030201ffffffff"
+	y := "FBCDEF090807060504030201ffffffffFBCDEF090807060504030201fffffffe"
+	opBenchmark(b, opEq, x, y)
+}
 func BenchmarkOpAnd(b *testing.B) {
 	x := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
 	y := "ABCDEF090807060504030201ffffffffffffffffffffffffffffffffffffffff"
@@ -412,3 +421,7 @@ func BenchmarkOpSAR(b *testing.B) {
 
 	opBenchmark(b, opSAR, x, y)
 }
+func BenchmarkOpIsZero(b *testing.B) {
+	x := "FBCDEF090807060504030201ffffffffFBCDEF090807060504030201ffffffff"
+	opBenchmark(b, opIszero, x)
+}
-- 
GitLab