From 12224c7f5924720767d73f06ed4571dc3ce2f092 Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Tue, 27 Jan 2015 14:33:26 +0100
Subject: [PATCH] p2p/discover: new package implementing the Node Discovery
 Protocol

---
 p2p/discover/table.go      | 447 +++++++++++++++++++++++++++++++++++++
 p2p/discover/table_test.go | 403 +++++++++++++++++++++++++++++++++
 p2p/discover/udp.go        | 422 ++++++++++++++++++++++++++++++++++
 p2p/discover/udp_test.go   | 156 +++++++++++++
 4 files changed, 1428 insertions(+)
 create mode 100644 p2p/discover/table.go
 create mode 100644 p2p/discover/table_test.go
 create mode 100644 p2p/discover/udp.go
 create mode 100644 p2p/discover/udp_test.go

diff --git a/p2p/discover/table.go b/p2p/discover/table.go
new file mode 100644
index 000000000..26526330b
--- /dev/null
+++ b/p2p/discover/table.go
@@ -0,0 +1,447 @@
+// Package discover implements the Node Discovery Protocol.
+//
+// The Node Discovery protocol provides a way to find RLPx nodes that
+// can be connected to. It uses a Kademlia-like protocol to maintain a
+// distributed database of the IDs and endpoints of all listening
+// nodes.
+package discover
+
+import (
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"encoding/hex"
+	"fmt"
+	"io"
+	"math/rand"
+	"net"
+	"sort"
+	"strings"
+	"sync"
+	"time"
+
+	"github.com/ethereum/go-ethereum/crypto/secp256k1"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+const (
+	alpha      = 3                   // Kademlia concurrency factor
+	bucketSize = 16                  // Kademlia bucket size
+	nBuckets   = len(NodeID{})*8 + 1 // Number of buckets
+)
+
+type Table struct {
+	mutex   sync.Mutex        // protects buckets, their content, and nursery
+	buckets [nBuckets]*bucket // index of known nodes by distance
+	nursery []*Node           // bootstrap nodes
+
+	net  transport
+	self *Node // metadata of the local node
+}
+
+// transport is implemented by the UDP transport.
+// it is an interface so we can test without opening lots of UDP
+// sockets and without generating a private key.
+type transport interface {
+	ping(*Node) error
+	findnode(e *Node, target NodeID) ([]*Node, error)
+	close()
+}
+
+// bucket contains nodes, ordered by their last activity.
+type bucket struct {
+	lastLookup time.Time
+	entries    []*Node
+}
+
+// Node represents node metadata that is stored in the table.
+type Node struct {
+	Addr *net.UDPAddr
+	ID   NodeID
+
+	active time.Time
+}
+
+type rpcNode struct {
+	IP   string
+	Port uint16
+	ID   NodeID
+}
+
+func (n Node) EncodeRLP(w io.Writer) error {
+	return rlp.Encode(w, rpcNode{IP: n.Addr.IP.String(), Port: uint16(n.Addr.Port), ID: n.ID})
+}
+func (n *Node) DecodeRLP(s *rlp.Stream) (err error) {
+	var ext rpcNode
+	if err = s.Decode(&ext); err == nil {
+		n.Addr = &net.UDPAddr{IP: net.ParseIP(ext.IP), Port: int(ext.Port)}
+		n.ID = ext.ID
+	}
+	return err
+}
+
+func newTable(t transport, ourID NodeID, ourAddr *net.UDPAddr) *Table {
+	tab := &Table{net: t, self: &Node{ID: ourID, Addr: ourAddr}}
+	for i := range tab.buckets {
+		tab.buckets[i] = &bucket{}
+	}
+	return tab
+}
+
+// Bootstrap sets the bootstrap nodes. These nodes are used to connect
+// to the network if the table is empty. Bootstrap will also attempt to
+// fill the table by performing random lookup operations on the
+// network.
+func (tab *Table) Bootstrap(nodes []Node) {
+	tab.mutex.Lock()
+	// TODO: maybe filter nodes with bad fields (nil, etc.) to avoid strange crashes
+	tab.nursery = make([]*Node, 0, len(nodes))
+	for _, n := range nodes {
+		cpy := n
+		tab.nursery = append(tab.nursery, &cpy)
+	}
+	tab.mutex.Unlock()
+	tab.refresh()
+}
+
+// Lookup performs a network search for nodes close
+// to the given target. It approaches the target by querying
+// nodes that are closer to it on each iteration.
+func (tab *Table) Lookup(target NodeID) []*Node {
+	var (
+		asked          = make(map[NodeID]bool)
+		seen           = make(map[NodeID]bool)
+		reply          = make(chan []*Node, alpha)
+		pendingQueries = 0
+	)
+	// don't query further if we hit the target.
+	// unlikely to happen often in practice.
+	asked[target] = true
+
+	tab.mutex.Lock()
+	// update last lookup stamp (for refresh logic)
+	tab.buckets[logdist(tab.self.ID, target)].lastLookup = time.Now()
+	// generate initial result set
+	result := tab.closest(target, bucketSize)
+	tab.mutex.Unlock()
+
+	for {
+		// ask the closest nodes that we haven't asked yet
+		for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
+			n := result.entries[i]
+			if !asked[n.ID] {
+				asked[n.ID] = true
+				pendingQueries++
+				go func() {
+					result, _ := tab.net.findnode(n, target)
+					reply <- result
+				}()
+			}
+		}
+		if pendingQueries == 0 {
+			// we have asked all closest nodes, stop the search
+			break
+		}
+
+		// wait for the next reply
+		for _, n := range <-reply {
+			cn := n
+			if !seen[n.ID] {
+				seen[n.ID] = true
+				result.push(cn, bucketSize)
+			}
+		}
+		pendingQueries--
+	}
+	return result.entries
+}
+
+// refresh performs a lookup for a random target to keep buckets full.
+func (tab *Table) refresh() {
+	ld := -1 // logdist of chosen bucket
+	tab.mutex.Lock()
+	for i, b := range tab.buckets {
+		if i > 0 && b.lastLookup.Before(time.Now().Add(-1*time.Hour)) {
+			ld = i
+			break
+		}
+	}
+	tab.mutex.Unlock()
+
+	result := tab.Lookup(randomID(tab.self.ID, ld))
+	if len(result) == 0 {
+		// bootstrap the table with a self lookup
+		tab.mutex.Lock()
+		tab.add(tab.nursery)
+		tab.mutex.Unlock()
+		tab.Lookup(tab.self.ID)
+		// TODO: the Kademlia paper says that we're supposed to perform
+		// random lookups in all buckets further away than our closest neighbor.
+	}
+}
+
+// closest returns the n nodes in the table that are closest to the
+// given id. The caller must hold tab.mutex.
+func (tab *Table) closest(target NodeID, nresults int) *nodesByDistance {
+	// This is a very wasteful way to find the closest nodes but
+	// obviously correct. I believe that tree-based buckets would make
+	// this easier to implement efficiently.
+	close := &nodesByDistance{target: target}
+	for _, b := range tab.buckets {
+		for _, n := range b.entries {
+			close.push(n, nresults)
+		}
+	}
+	return close
+}
+
+func (tab *Table) len() (n int) {
+	for _, b := range tab.buckets {
+		n += len(b.entries)
+	}
+	return n
+}
+
+// bumpOrAdd updates the activity timestamp for the given node and
+// attempts to insert the node into a bucket. The returned Node might
+// not be part of the table. The caller must hold tab.mutex.
+func (tab *Table) bumpOrAdd(node NodeID, from *net.UDPAddr) (n *Node) {
+	b := tab.buckets[logdist(tab.self.ID, node)]
+	if n = b.bump(node); n == nil {
+		n = &Node{ID: node, Addr: from, active: time.Now()}
+		if len(b.entries) == bucketSize {
+			tab.pingReplace(n, b)
+		} else {
+			b.entries = append(b.entries, n)
+		}
+	}
+	return n
+}
+
+func (tab *Table) pingReplace(n *Node, b *bucket) {
+	old := b.entries[bucketSize-1]
+	go func() {
+		if err := tab.net.ping(old); err == nil {
+			// it responded, we don't need to replace it.
+			return
+		}
+		// it didn't respond, replace the node if it is still the oldest node.
+		tab.mutex.Lock()
+		if len(b.entries) > 0 && b.entries[len(b.entries)-1] == old {
+			// slide down other entries and put the new one in front.
+			copy(b.entries[1:], b.entries)
+			b.entries[0] = n
+		}
+		tab.mutex.Unlock()
+	}()
+}
+
+// bump updates the activity timestamp for the given node.
+// The caller must hold tab.mutex.
+func (tab *Table) bump(node NodeID) {
+	tab.buckets[logdist(tab.self.ID, node)].bump(node)
+}
+
+// add puts the entries into the table if their corresponding
+// bucket is not full. The caller must hold tab.mutex.
+func (tab *Table) add(entries []*Node) {
+outer:
+	for _, n := range entries {
+		if n == nil || n.ID == tab.self.ID {
+			// skip bad entries. The RLP decoder returns nil for empty
+			// input lists.
+			continue
+		}
+		bucket := tab.buckets[logdist(tab.self.ID, n.ID)]
+		for i := range bucket.entries {
+			if bucket.entries[i].ID == n.ID {
+				// already in bucket
+				continue outer
+			}
+		}
+		if len(bucket.entries) < bucketSize {
+			bucket.entries = append(bucket.entries, n)
+		}
+	}
+}
+
+func (b *bucket) bump(id NodeID) *Node {
+	for i, n := range b.entries {
+		if n.ID == id {
+			n.active = time.Now()
+			// move it to the front
+			copy(b.entries[1:], b.entries[:i+1])
+			b.entries[0] = n
+			return n
+		}
+	}
+	return nil
+}
+
+// nodesByDistance is a list of nodes, ordered by
+// distance to target.
+type nodesByDistance struct {
+	entries []*Node
+	target  NodeID
+}
+
+// push adds the given node to the list, keeping the total size below maxElems.
+func (h *nodesByDistance) push(n *Node, maxElems int) {
+	ix := sort.Search(len(h.entries), func(i int) bool {
+		return distcmp(h.target, h.entries[i].ID, n.ID) > 0
+	})
+	if len(h.entries) < maxElems {
+		h.entries = append(h.entries, n)
+	}
+	if ix == len(h.entries) {
+		// farther away than all nodes we already have.
+		// if there was room for it, the node is now the last element.
+	} else {
+		// slide existing entries down to make room
+		// this will overwrite the entry we just appended.
+		copy(h.entries[ix+1:], h.entries[ix:])
+		h.entries[ix] = n
+	}
+}
+
+// NodeID is a unique identifier for each node.
+// The node identifier is a marshaled elliptic curve public key.
+type NodeID [512 / 8]byte
+
+// NodeID prints as a long hexadecimal number.
+func (n NodeID) String() string {
+	return fmt.Sprintf("%#x", n[:])
+}
+
+// The Go syntax representation of a NodeID is a call to HexID.
+func (n NodeID) GoString() string {
+	return fmt.Sprintf("HexID(\"%#x\")", n[:])
+}
+
+// HexID converts a hex string to a NodeID.
+// The string may be prefixed with 0x.
+func HexID(in string) NodeID {
+	if strings.HasPrefix(in, "0x") {
+		in = in[2:]
+	}
+	var id NodeID
+	b, err := hex.DecodeString(in)
+	if err != nil {
+		panic(err)
+	} else if len(b) != len(id) {
+		panic("wrong length")
+	}
+	copy(id[:], b)
+	return id
+}
+
+func newNodeID(priv *ecdsa.PrivateKey) (id NodeID) {
+	pubkey := elliptic.Marshal(priv.Curve, priv.X, priv.Y)
+	if len(pubkey)-1 != len(id) {
+		panic(fmt.Errorf("invalid key: need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pubkey)))
+	}
+	copy(id[:], pubkey[1:])
+	return id
+}
+
+// recoverNodeID computes the public key used to sign the
+// given hash from the signature.
+func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
+	pubkey, err := secp256k1.RecoverPubkey(hash, sig)
+	if err != nil {
+		return id, err
+	}
+	if len(pubkey)-1 != len(id) {
+		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
+	}
+	for i := range id {
+		id[i] = pubkey[i+1]
+	}
+	return id, nil
+}
+
+// distcmp compares the distances a->target and b->target.
+// Returns -1 if a is closer to target, 1 if b is closer to target
+// and 0 if they are equal.
+func distcmp(target, a, b NodeID) int {
+	for i := range target {
+		da := a[i] ^ target[i]
+		db := b[i] ^ target[i]
+		if da > db {
+			return 1
+		} else if da < db {
+			return -1
+		}
+	}
+	return 0
+}
+
+// table of leading zero counts for bytes [0..255]
+var lzcount = [256]int{
+	8, 7, 6, 6, 5, 5, 5, 5,
+	4, 4, 4, 4, 4, 4, 4, 4,
+	3, 3, 3, 3, 3, 3, 3, 3,
+	3, 3, 3, 3, 3, 3, 3, 3,
+	2, 2, 2, 2, 2, 2, 2, 2,
+	2, 2, 2, 2, 2, 2, 2, 2,
+	2, 2, 2, 2, 2, 2, 2, 2,
+	2, 2, 2, 2, 2, 2, 2, 2,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	1, 1, 1, 1, 1, 1, 1, 1,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+	0, 0, 0, 0, 0, 0, 0, 0,
+}
+
+// logdist returns the logarithmic distance between a and b, log2(a ^ b).
+func logdist(a, b NodeID) int {
+	lz := 0
+	for i := range a {
+		x := a[i] ^ b[i]
+		if x == 0 {
+			lz += 8
+		} else {
+			lz += lzcount[x]
+			break
+		}
+	}
+	return len(a)*8 - lz
+}
+
+// randomID returns a random NodeID such that logdist(a, b) == n
+func randomID(a NodeID, n int) (b NodeID) {
+	if n == 0 {
+		return a
+	}
+	// flip bit at position n, fill the rest with random bits
+	b = a
+	pos := len(a) - n/8 - 1
+	bit := byte(0x01) << (byte(n%8) - 1)
+	if bit == 0 {
+		pos++
+		bit = 0x80
+	}
+	b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
+	for i := pos + 1; i < len(a); i++ {
+		b[i] = byte(rand.Intn(255))
+	}
+	return b
+}
diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go
new file mode 100644
index 000000000..88563fe65
--- /dev/null
+++ b/p2p/discover/table_test.go
@@ -0,0 +1,403 @@
+package discover
+
+import (
+	"crypto/ecdsa"
+	"errors"
+	"fmt"
+	"math/big"
+	"math/rand"
+	"net"
+	"reflect"
+	"testing"
+	"testing/quick"
+	"time"
+
+	"github.com/ethereum/go-ethereum/crypto"
+)
+
+var (
+	quickrand = rand.New(rand.NewSource(time.Now().Unix()))
+	quickcfg  = &quick.Config{MaxCount: 5000, Rand: quickrand}
+)
+
+func TestHexID(t *testing.T) {
+	ref := NodeID{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188}
+	id1 := HexID("0x000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+	id2 := HexID("000000000000000000000000000000000000000000000000000000000000000000000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc")
+
+	if id1 != ref {
+		t.Errorf("wrong id1\ngot  %v\nwant %v", id1[:], ref[:])
+	}
+	if id2 != ref {
+		t.Errorf("wrong id2\ngot  %v\nwant %v", id2[:], ref[:])
+	}
+}
+
+func TestNodeID_recover(t *testing.T) {
+	prv := newkey()
+	hash := make([]byte, 32)
+	sig, err := crypto.Sign(hash, prv)
+	if err != nil {
+		t.Fatalf("signing error: %v", err)
+	}
+
+	pub := newNodeID(prv)
+	recpub, err := recoverNodeID(hash, sig)
+	if err != nil {
+		t.Fatalf("recovery error: %v", err)
+	}
+	if pub != recpub {
+		t.Errorf("recovered wrong pubkey:\ngot:  %v\nwant: %v", recpub, pub)
+	}
+}
+
+func TestNodeID_distcmp(t *testing.T) {
+	distcmpBig := func(target, a, b NodeID) int {
+		tbig := new(big.Int).SetBytes(target[:])
+		abig := new(big.Int).SetBytes(a[:])
+		bbig := new(big.Int).SetBytes(b[:])
+		return new(big.Int).Xor(tbig, abig).Cmp(new(big.Int).Xor(tbig, bbig))
+	}
+	if err := quick.CheckEqual(distcmp, distcmpBig, quickcfg); err != nil {
+		t.Error(err)
+	}
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_distcmpEqual(t *testing.T) {
+	base := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+	x := NodeID{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
+	if distcmp(base, x, x) != 0 {
+		t.Errorf("distcmp(base, x, x) != 0")
+	}
+}
+
+func TestNodeID_logdist(t *testing.T) {
+	logdistBig := func(a, b NodeID) int {
+		abig, bbig := new(big.Int).SetBytes(a[:]), new(big.Int).SetBytes(b[:])
+		return new(big.Int).Xor(abig, bbig).BitLen()
+	}
+	if err := quick.CheckEqual(logdist, logdistBig, quickcfg); err != nil {
+		t.Error(err)
+	}
+}
+
+// the random tests is likely to miss the case where they're equal.
+func TestNodeID_logdistEqual(t *testing.T) {
+	x := NodeID{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
+	if logdist(x, x) != 0 {
+		t.Errorf("logdist(x, x) != 0")
+	}
+}
+
+func TestNodeID_randomID(t *testing.T) {
+	// we don't use quick.Check here because its output isn't
+	// very helpful when the test fails.
+	for i := 0; i < quickcfg.MaxCount; i++ {
+		a := gen(NodeID{}, quickrand).(NodeID)
+		dist := quickrand.Intn(len(NodeID{}) * 8)
+		result := randomID(a, dist)
+		actualdist := logdist(result, a)
+
+		if dist != actualdist {
+			t.Log("a:     ", a)
+			t.Log("result:", result)
+			t.Fatalf("#%d: distance of result is %d, want %d", i, actualdist, dist)
+		}
+	}
+}
+
+func (NodeID) Generate(rand *rand.Rand, size int) reflect.Value {
+	var id NodeID
+	m := rand.Intn(len(id))
+	for i := len(id) - 1; i > m; i-- {
+		id[i] = byte(rand.Uint32())
+	}
+	return reflect.ValueOf(id)
+}
+
+func TestTable_bumpOrAddPingReplace(t *testing.T) {
+	pingC := make(pingC)
+	tab := newTable(pingC, NodeID{}, &net.UDPAddr{})
+	last := fillBucket(tab, 200)
+
+	// this bumpOrAdd should not replace the last node
+	// because the node replies to ping.
+	new := tab.bumpOrAdd(randomID(tab.self.ID, 200), nil)
+
+	pinged := <-pingC
+	if pinged != last.ID {
+		t.Fatalf("pinged wrong node: %v\nwant %v", pinged, last.ID)
+	}
+
+	tab.mutex.Lock()
+	defer tab.mutex.Unlock()
+	if l := len(tab.buckets[200].entries); l != bucketSize {
+		t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+	}
+	if !contains(tab.buckets[200].entries, last.ID) {
+		t.Error("last entry was removed")
+	}
+	if contains(tab.buckets[200].entries, new.ID) {
+		t.Error("new entry was added")
+	}
+}
+
+func TestTable_bumpOrAddPingTimeout(t *testing.T) {
+	tab := newTable(pingC(nil), NodeID{}, &net.UDPAddr{})
+	last := fillBucket(tab, 200)
+
+	// this bumpOrAdd should replace the last node
+	// because the node does not reply to ping.
+	new := tab.bumpOrAdd(randomID(tab.self.ID, 200), nil)
+
+	// wait for async bucket update. damn. this needs to go away.
+	time.Sleep(2 * time.Millisecond)
+
+	tab.mutex.Lock()
+	defer tab.mutex.Unlock()
+	if l := len(tab.buckets[200].entries); l != bucketSize {
+		t.Errorf("wrong bucket size after bumpOrAdd: got %d, want %d", bucketSize, l)
+	}
+	if contains(tab.buckets[200].entries, last.ID) {
+		t.Error("last entry was not removed")
+	}
+	if !contains(tab.buckets[200].entries, new.ID) {
+		t.Error("new entry was not added")
+	}
+}
+
+func fillBucket(tab *Table, ld int) (last *Node) {
+	b := tab.buckets[ld]
+	for len(b.entries) < bucketSize {
+		b.entries = append(b.entries, &Node{ID: randomID(tab.self.ID, ld)})
+	}
+	return b.entries[bucketSize-1]
+}
+
+type pingC chan NodeID
+
+func (t pingC) findnode(n *Node, target NodeID) ([]*Node, error) {
+	panic("findnode called on pingRecorder")
+}
+func (t pingC) close() {
+	panic("close called on pingRecorder")
+}
+func (t pingC) ping(n *Node) error {
+	if t == nil {
+		return errTimeout
+	}
+	t <- n.ID
+	return nil
+}
+
+func TestTable_bump(t *testing.T) {
+	tab := newTable(nil, NodeID{}, &net.UDPAddr{})
+
+	// add an old entry and two recent ones
+	oldactive := time.Now().Add(-2 * time.Minute)
+	old := &Node{ID: randomID(tab.self.ID, 200), active: oldactive}
+	others := []*Node{
+		&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+		&Node{ID: randomID(tab.self.ID, 200), active: time.Now()},
+	}
+	tab.add(append(others, old))
+	if tab.buckets[200].entries[0] == old {
+		t.Fatal("old entry is at front of bucket")
+	}
+
+	// bumping the old entry should move it to the front
+	tab.bump(old.ID)
+	if old.active == oldactive {
+		t.Error("activity timestamp not updated")
+	}
+	if tab.buckets[200].entries[0] != old {
+		t.Errorf("bumped entry did not move to the front of bucket")
+	}
+}
+
+func TestTable_closest(t *testing.T) {
+	t.Parallel()
+
+	test := func(test *closeTest) bool {
+		// for any node table, Target and N
+		tab := newTable(nil, test.Self, &net.UDPAddr{})
+		tab.add(test.All)
+
+		// check that doClosest(Target, N) returns nodes
+		result := tab.closest(test.Target, test.N).entries
+		if hasDuplicates(result) {
+			t.Errorf("result contains duplicates")
+			return false
+		}
+		if !sortedByDistanceTo(test.Target, result) {
+			t.Errorf("result is not sorted by distance to target")
+			return false
+		}
+
+		// check that the number of results is min(N, tablen)
+		wantN := test.N
+		if tlen := tab.len(); tlen < test.N {
+			wantN = tlen
+		}
+		if len(result) != wantN {
+			t.Errorf("wrong number of nodes: got %d, want %d", len(result), wantN)
+			return false
+		} else if len(result) == 0 {
+			return true // no need to check distance
+		}
+
+		// check that the result nodes have minimum distance to target.
+		for _, b := range tab.buckets {
+			for _, n := range b.entries {
+				if contains(result, n.ID) {
+					continue // don't run the check below for nodes in result
+				}
+				farthestResult := result[len(result)-1].ID
+				if distcmp(test.Target, n.ID, farthestResult) < 0 {
+					t.Errorf("table contains node that is closer to target but it's not in result")
+					t.Logf("  Target:          %v", test.Target)
+					t.Logf("  Farthest Result: %v", farthestResult)
+					t.Logf("  ID:              %v", n.ID)
+					return false
+				}
+			}
+		}
+		return true
+	}
+	if err := quick.Check(test, quickcfg); err != nil {
+		t.Error(err)
+	}
+}
+
+type closeTest struct {
+	Self   NodeID
+	Target NodeID
+	All    []*Node
+	N      int
+}
+
+func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value {
+	t := &closeTest{
+		Self:   gen(NodeID{}, rand).(NodeID),
+		Target: gen(NodeID{}, rand).(NodeID),
+		N:      rand.Intn(bucketSize),
+	}
+	for _, id := range gen([]NodeID{}, rand).([]NodeID) {
+		t.All = append(t.All, &Node{ID: id})
+	}
+	return reflect.ValueOf(t)
+}
+
+func TestTable_Lookup(t *testing.T) {
+	self := gen(NodeID{}, quickrand).(NodeID)
+	target := randomID(self, 200)
+	transport := findnodeOracle{t, target}
+	tab := newTable(transport, self, &net.UDPAddr{})
+
+	// lookup on empty table returns no nodes
+	if results := tab.Lookup(target); len(results) > 0 {
+		t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results)
+	}
+	// seed table with initial node (otherwise lookup will terminate immediately)
+	tab.bumpOrAdd(randomID(target, 200), &net.UDPAddr{Port: 200})
+
+	results := tab.Lookup(target)
+	t.Logf("results:")
+	for _, e := range results {
+		t.Logf("  ld=%d, %v", logdist(target, e.ID), e.ID)
+	}
+	if len(results) != bucketSize {
+		t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize)
+	}
+	if hasDuplicates(results) {
+		t.Errorf("result set contains duplicate entries")
+	}
+	if !sortedByDistanceTo(target, results) {
+		t.Errorf("result set not sorted by distance to target")
+	}
+	if !contains(results, target) {
+		t.Errorf("result set does not contain target")
+	}
+}
+
+// findnode on this transport always returns at least one node
+// that is one bucket closer to the target.
+type findnodeOracle struct {
+	t      *testing.T
+	target NodeID
+}
+
+func (t findnodeOracle) findnode(n *Node, target NodeID) ([]*Node, error) {
+	t.t.Logf("findnode query at dist %d", n.Addr.Port)
+	// current log distance is encoded in port number
+	var result []*Node
+	switch port := n.Addr.Port; port {
+	case 0:
+		panic("query to node at distance 0")
+	case 1:
+		result = append(result, &Node{ID: t.target, Addr: &net.UDPAddr{Port: 0}})
+	default:
+		// TODO: add more randomness to distances
+		port--
+		for i := 0; i < bucketSize; i++ {
+			result = append(result, &Node{ID: randomID(t.target, port), Addr: &net.UDPAddr{Port: port}})
+		}
+	}
+	return result, nil
+}
+
+func (t findnodeOracle) close() {}
+
+func (t findnodeOracle) ping(n *Node) error {
+	return errors.New("ping is not supported by this transport")
+}
+
+func hasDuplicates(slice []*Node) bool {
+	seen := make(map[NodeID]bool)
+	for _, e := range slice {
+		if seen[e.ID] {
+			return true
+		}
+		seen[e.ID] = true
+	}
+	return false
+}
+
+func sortedByDistanceTo(distbase NodeID, slice []*Node) bool {
+	var last NodeID
+	for i, e := range slice {
+		if i > 0 && distcmp(distbase, e.ID, last) < 0 {
+			return false
+		}
+		last = e.ID
+	}
+	return true
+}
+
+func contains(ns []*Node, id NodeID) bool {
+	for _, n := range ns {
+		if n.ID == id {
+			return true
+		}
+	}
+	return false
+}
+
+// gen wraps quick.Value so it's easier to use.
+// it generates a random value of the given value's type.
+func gen(typ interface{}, rand *rand.Rand) interface{} {
+	v, ok := quick.Value(reflect.TypeOf(typ), rand)
+	if !ok {
+		panic(fmt.Sprintf("couldn't generate random value of type %T", typ))
+	}
+	return v.Interface()
+}
+
+func newkey() *ecdsa.PrivateKey {
+	key, err := crypto.GenerateKey()
+	if err != nil {
+		panic("couldn't generate key: " + err.Error())
+	}
+	return key
+}
diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go
new file mode 100644
index 000000000..ec1f62dac
--- /dev/null
+++ b/p2p/discover/udp.go
@@ -0,0 +1,422 @@
+package discover
+
+import (
+	"bytes"
+	"crypto/ecdsa"
+	"errors"
+	"fmt"
+	"net"
+	"time"
+
+	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/logger"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+var log = logger.NewLogger("P2P Discovery")
+
+// Errors
+var (
+	errPacketTooSmall = errors.New("too small")
+	errBadHash        = errors.New("bad hash")
+	errExpired        = errors.New("expired")
+	errTimeout        = errors.New("RPC timeout")
+	errClosed         = errors.New("socket closed")
+)
+
+// Timeouts
+const (
+	respTimeout = 300 * time.Millisecond
+	sendTimeout = 300 * time.Millisecond
+	expiration  = 3 * time.Second
+
+	refreshInterval = 1 * time.Hour
+)
+
+// RPC packet types
+const (
+	pingPacket = iota + 1 // zero is 'reserved'
+	pongPacket
+	findnodePacket
+	neighborsPacket
+)
+
+// RPC request structures
+type (
+	ping struct {
+		IP         string // our IP
+		Port       uint16 // our port
+		Expiration uint64
+	}
+
+	// reply to Ping
+	pong struct {
+		ReplyTok   []byte
+		Expiration uint64
+	}
+
+	findnode struct {
+		// Id to look up. The responding node will send back nodes
+		// closest to the target.
+		Target     NodeID
+		Expiration uint64
+	}
+
+	// reply to findnode
+	neighbors struct {
+		Nodes      []*Node
+		Expiration uint64
+	}
+)
+
+// udp implements the RPC protocol.
+type udp struct {
+	conn       *net.UDPConn
+	priv       *ecdsa.PrivateKey
+	addpending chan *pending
+	replies    chan reply
+	closing    chan struct{}
+
+	*Table
+}
+
+// pending represents a pending reply.
+//
+// some implementations of the protocol wish to send more than one
+// reply packet to findnode. in general, any neighbors packet cannot
+// be matched up with a specific findnode packet.
+//
+// our implementation handles this by storing a callback function for
+// each pending reply. incoming packets from a node are dispatched
+// to all the callback functions for that node.
+type pending struct {
+	// these fields must match in the reply.
+	from  NodeID
+	ptype byte
+
+	// time when the request must complete
+	deadline time.Time
+
+	// callback is called when a matching reply arrives. if it returns
+	// true, the callback is removed from the pending reply queue.
+	// if it returns false, the reply is considered incomplete and
+	// the callback will be invoked again for the next matching reply.
+	callback func(resp interface{}) (done bool)
+
+	// errc receives nil when the callback indicates completion or an
+	// error if no further reply is received within the timeout.
+	errc chan<- error
+}
+
+type reply struct {
+	from  NodeID
+	ptype byte
+	data  interface{}
+}
+
+// ListenUDP returns a new table that listens for UDP packets on laddr.
+func ListenUDP(priv *ecdsa.PrivateKey, laddr string) (*Table, error) {
+	net, realaddr, err := listen(priv, laddr)
+	if err != nil {
+		return nil, err
+	}
+	net.Table = newTable(net, newNodeID(priv), realaddr)
+	log.DebugDetailf("Listening on %v, my ID %x\n", realaddr, net.self.ID[:])
+	return net.Table, nil
+}
+
+func listen(priv *ecdsa.PrivateKey, laddr string) (*udp, *net.UDPAddr, error) {
+	addr, err := net.ResolveUDPAddr("udp", laddr)
+	if err != nil {
+		return nil, nil, err
+	}
+	conn, err := net.ListenUDP("udp", addr)
+	if err != nil {
+		return nil, nil, err
+	}
+	realaddr := conn.LocalAddr().(*net.UDPAddr)
+
+	udp := &udp{
+		conn:       conn,
+		priv:       priv,
+		closing:    make(chan struct{}),
+		addpending: make(chan *pending),
+		replies:    make(chan reply),
+	}
+	go udp.loop()
+	go udp.readLoop()
+	return udp, realaddr, nil
+}
+
+func (t *udp) close() {
+	close(t.closing)
+	t.conn.Close()
+	// TODO: wait for the loops to end.
+}
+
+// ping sends a ping message to the given node and waits for a reply.
+func (t *udp) ping(e *Node) error {
+	// TODO: maybe check for ReplyTo field in callback to measure RTT
+	errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
+	t.send(e, pingPacket, ping{
+		IP:         t.self.Addr.String(),
+		Port:       uint16(t.self.Addr.Port),
+		Expiration: uint64(time.Now().Add(expiration).Unix()),
+	})
+	return <-errc
+}
+
+// findnode sends a findnode request to the given node and waits until
+// the node has sent up to k neighbors.
+func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
+	nodes := make([]*Node, 0, bucketSize)
+	nreceived := 0
+	errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
+		reply := r.(*neighbors)
+		for i := 0; i < len(reply.Nodes); i++ {
+			nreceived++
+			n := reply.Nodes[i]
+			if validAddr(n.Addr) && n.ID != t.self.ID {
+				nodes = append(nodes, n)
+			}
+		}
+		return nreceived == bucketSize
+	})
+
+	t.send(to, findnodePacket, findnode{
+		Target:     target,
+		Expiration: uint64(time.Now().Add(expiration).Unix()),
+	})
+	err := <-errc
+	return nodes, err
+}
+
+func validAddr(a *net.UDPAddr) bool {
+	return !a.IP.IsMulticast() && !a.IP.IsUnspecified() && a.Port != 0
+}
+
+// pending adds a reply callback to the pending reply queue.
+// see the documentation of type pending for a detailed explanation.
+func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
+	ch := make(chan error, 1)
+	p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
+	select {
+	case t.addpending <- p:
+		// loop will handle it
+	case <-t.closing:
+		ch <- errClosed
+	}
+	return ch
+}
+
+// loop runs in its own goroutin. it keeps track of
+// the refresh timer and the pending reply queue.
+func (t *udp) loop() {
+	var (
+		pending      []*pending
+		nextDeadline time.Time
+		timeout      = time.NewTimer(0)
+		refresh      = time.NewTicker(refreshInterval)
+	)
+	<-timeout.C // ignore first timeout
+	defer refresh.Stop()
+	defer timeout.Stop()
+
+	rearmTimeout := func() {
+		if len(pending) == 0 || nextDeadline == pending[0].deadline {
+			return
+		}
+		nextDeadline = pending[0].deadline
+		timeout.Reset(nextDeadline.Sub(time.Now()))
+	}
+
+	for {
+		select {
+		case <-refresh.C:
+			go t.refresh()
+
+		case <-t.closing:
+			for _, p := range pending {
+				p.errc <- errClosed
+			}
+			return
+
+		case p := <-t.addpending:
+			p.deadline = time.Now().Add(respTimeout)
+			pending = append(pending, p)
+			rearmTimeout()
+
+		case reply := <-t.replies:
+			// run matching callbacks, remove if they return false.
+			for i, p := range pending {
+				if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
+					p.errc <- nil
+					copy(pending[i:], pending[i+1:])
+					pending = pending[:len(pending)-1]
+					i--
+				}
+			}
+			rearmTimeout()
+
+		case now := <-timeout.C:
+			// notify and remove callbacks whose deadline is in the past.
+			i := 0
+			for ; i < len(pending) && now.After(pending[i].deadline); i++ {
+				pending[i].errc <- errTimeout
+			}
+			if i > 0 {
+				copy(pending, pending[i:])
+				pending = pending[:len(pending)-i]
+			}
+			rearmTimeout()
+		}
+	}
+}
+
+const (
+	macSize  = 256 / 8
+	sigSize  = 520 / 8
+	headSize = macSize + sigSize // space of packet frame data
+)
+
+var headSpace = make([]byte, headSize)
+
+func (t *udp) send(to *Node, ptype byte, req interface{}) error {
+	b := new(bytes.Buffer)
+	b.Write(headSpace)
+	b.WriteByte(ptype)
+	if err := rlp.Encode(b, req); err != nil {
+		log.Errorln("error encoding packet:", err)
+		return err
+	}
+
+	packet := b.Bytes()
+	sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
+	if err != nil {
+		log.Errorln("could not sign packet:", err)
+		return err
+	}
+	copy(packet[macSize:], sig)
+	// add the hash to the front. Note: this doesn't protect the
+	// packet in any way. Our public key will be part of this hash in
+	// the future.
+	copy(packet, crypto.Sha3(packet[macSize:]))
+
+	log.DebugDetailf(">>> %v %T %v\n", to.Addr, req, req)
+	if _, err = t.conn.WriteToUDP(packet, to.Addr); err != nil {
+		log.DebugDetailln("UDP send failed:", err)
+	}
+	return err
+}
+
+// readLoop runs in its own goroutine. it handles incoming UDP packets.
+func (t *udp) readLoop() {
+	defer t.conn.Close()
+	buf := make([]byte, 4096) // TODO: good buffer size
+	for {
+		nbytes, from, err := t.conn.ReadFromUDP(buf)
+		if err != nil {
+			return
+		}
+		if err := t.packetIn(from, buf[:nbytes]); err != nil {
+			log.Debugf("Bad packet from %v: %v\n", from, err)
+		}
+	}
+}
+
+func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
+	if len(buf) < headSize+1 {
+		return errPacketTooSmall
+	}
+	hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
+	shouldhash := crypto.Sha3(buf[macSize:])
+	if !bytes.Equal(hash, shouldhash) {
+		return errBadHash
+	}
+	fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
+	if err != nil {
+		return err
+	}
+
+	var req interface {
+		handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
+	}
+	switch ptype := sigdata[0]; ptype {
+	case pingPacket:
+		req = new(ping)
+	case pongPacket:
+		req = new(pong)
+	case findnodePacket:
+		req = new(findnode)
+	case neighborsPacket:
+		req = new(neighbors)
+	default:
+		return fmt.Errorf("unknown type: %d", ptype)
+	}
+	if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
+		return err
+	}
+	log.DebugDetailf("<<< %v %T %v\n", from, req, req)
+	return req.handle(t, from, fromID, hash)
+}
+
+func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+	if expired(req.Expiration) {
+		return errExpired
+	}
+	t.mutex.Lock()
+	// Note: we're ignoring the provided IP/Port right now.
+	e := t.bumpOrAdd(fromID, from)
+	t.mutex.Unlock()
+
+	t.send(e, pongPacket, pong{
+		ReplyTok:   mac,
+		Expiration: uint64(time.Now().Add(expiration).Unix()),
+	})
+	return nil
+}
+
+func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+	if expired(req.Expiration) {
+		return errExpired
+	}
+	t.mutex.Lock()
+	t.bump(fromID)
+	t.mutex.Unlock()
+
+	t.replies <- reply{fromID, pongPacket, req}
+	return nil
+}
+
+func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+	if expired(req.Expiration) {
+		return errExpired
+	}
+	t.mutex.Lock()
+	e := t.bumpOrAdd(fromID, from)
+	closest := t.closest(req.Target, bucketSize).entries
+	t.mutex.Unlock()
+
+	t.send(e, neighborsPacket, neighbors{
+		Nodes:      closest,
+		Expiration: uint64(time.Now().Add(expiration).Unix()),
+	})
+	return nil
+}
+
+func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
+	if expired(req.Expiration) {
+		return errExpired
+	}
+	t.mutex.Lock()
+	t.bump(fromID)
+	t.add(req.Nodes)
+	t.mutex.Unlock()
+
+	t.replies <- reply{fromID, neighborsPacket, req}
+	return nil
+}
+
+func expired(ts uint64) bool {
+	return time.Unix(int64(ts), 0).Before(time.Now())
+}
diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go
new file mode 100644
index 000000000..f2ab2b661
--- /dev/null
+++ b/p2p/discover/udp_test.go
@@ -0,0 +1,156 @@
+package discover
+
+import (
+	logpkg "log"
+	"net"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/ethereum/go-ethereum/logger"
+)
+
+func init() {
+	logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, logpkg.LstdFlags, logger.DebugLevel))
+}
+
+func TestUDP_ping(t *testing.T) {
+	t.Parallel()
+
+	n1, _ := ListenUDP(newkey(), "127.0.0.1:0")
+	n2, _ := ListenUDP(newkey(), "127.0.0.1:0")
+	defer n1.net.close()
+	defer n2.net.close()
+
+	if err := n1.net.ping(n2.self); err != nil {
+		t.Fatalf("ping error: %v", err)
+	}
+	if find(n2, n1.self.ID) == nil {
+		t.Errorf("node 2 does not contain id of node 1")
+	}
+	if e := find(n1, n2.self.ID); e != nil {
+		t.Errorf("node 1 does contains id of node 2: %v", e)
+	}
+}
+
+func find(tab *Table, id NodeID) *Node {
+	for _, b := range tab.buckets {
+		for _, e := range b.entries {
+			if e.ID == id {
+				return e
+			}
+		}
+	}
+	return nil
+}
+
+func TestUDP_findnode(t *testing.T) {
+	t.Parallel()
+
+	n1, _ := ListenUDP(newkey(), "127.0.0.1:0")
+	n2, _ := ListenUDP(newkey(), "127.0.0.1:0")
+	defer n1.net.close()
+	defer n2.net.close()
+
+	entry := &Node{ID: NodeID{1}, Addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 15}}
+	n2.add([]*Node{entry})
+
+	target := randomID(n1.self.ID, 100)
+	result, _ := n1.net.findnode(n2.self, target)
+	if len(result) != 1 {
+		t.Fatalf("wrong number of results: got %d, want 1", len(result))
+	}
+	if result[0].ID != entry.ID {
+		t.Errorf("wrong result: got %v, want %v", result[0], entry)
+	}
+}
+
+func TestUDP_replytimeout(t *testing.T) {
+	t.Parallel()
+
+	// reserve a port so we don't talk to an existing service by accident
+	addr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
+	fd, err := net.ListenUDP("udp", addr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer fd.Close()
+
+	n1, _ := ListenUDP(newkey(), "127.0.0.1:0")
+	defer n1.net.close()
+	n2 := n1.bumpOrAdd(randomID(n1.self.ID, 10), fd.LocalAddr().(*net.UDPAddr))
+
+	if err := n1.net.ping(n2); err != errTimeout {
+		t.Error("expected timeout error, got", err)
+	}
+
+	if result, err := n1.net.findnode(n2, n1.self.ID); err != errTimeout {
+		t.Error("expected timeout error, got", err)
+	} else if len(result) > 0 {
+		t.Error("expected empty result, got", result)
+	}
+}
+
+func TestUDP_findnodeMultiReply(t *testing.T) {
+	t.Parallel()
+
+	n1, _ := ListenUDP(newkey(), "127.0.0.1:0")
+	n2, _ := ListenUDP(newkey(), "127.0.0.1:0")
+	udp2 := n2.net.(*udp)
+	defer n1.net.close()
+	defer n2.net.close()
+
+	nodes := make([]*Node, bucketSize)
+	for i := range nodes {
+		nodes[i] = &Node{
+			Addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: i + 1},
+			ID:   randomID(n2.self.ID, i+1),
+		}
+	}
+
+	// ask N2 for neighbors. it will send an empty reply back.
+	// the request will wait for up to bucketSize replies.
+	resultC := make(chan []*Node)
+	go func() {
+		ns, err := n1.net.findnode(n2.self, n1.self.ID)
+		if err != nil {
+			t.Error("findnode error:", err)
+		}
+		resultC <- ns
+	}()
+
+	// send a few more neighbors packets to N1.
+	// it should collect those.
+	for end := 0; end < len(nodes); {
+		off := end
+		if end = end + 5; end > len(nodes) {
+			end = len(nodes)
+		}
+		udp2.send(n1.self, neighborsPacket, neighbors{
+			Nodes:      nodes[off:end],
+			Expiration: uint64(time.Now().Add(10 * time.Second).Unix()),
+		})
+	}
+
+	// check that they are all returned. we cannot just check for
+	// equality because they might not be returned in the order they
+	// were sent.
+	result := <-resultC
+	if hasDuplicates(result) {
+		t.Error("result slice contains duplicates")
+	}
+	if len(result) != len(nodes) {
+		t.Errorf("wrong number of nodes returned: got %d, want %d", len(result), len(nodes))
+	}
+	matched := make(map[NodeID]bool)
+	for _, n := range result {
+		for _, expn := range nodes {
+			if n.ID == expn.ID { // && bytes.Equal(n.Addr.IP, expn.Addr.IP) && n.Addr.Port == expn.Addr.Port {
+				matched[n.ID] = true
+			}
+		}
+	}
+	if len(matched) != len(nodes) {
+		t.Errorf("wrong number of matching nodes: got %d, want %d", len(matched), len(nodes))
+	}
+}
-- 
GitLab