good morning!!!!

Skip to content
Snippets Groups Projects
Commit 1d72aaa0 authored by Viktor Trón's avatar Viktor Trón
Browse files

simplify account unlocking

parent 9f6016e8
No related branches found
No related tags found
No related merge requests found
...@@ -49,11 +49,6 @@ var ( ...@@ -49,11 +49,6 @@ var (
ErrNoKeys = errors.New("no keys in store") ErrNoKeys = errors.New("no keys in store")
) )
const (
// Default unlock duration (in seconds) when an account is unlocked from the console
DefaultAccountUnlockDuration = 300
)
type Account struct { type Account struct {
Address common.Address Address common.Address
} }
...@@ -114,28 +109,58 @@ func (am *Manager) Sign(a Account, toSign []byte) (signature []byte, err error) ...@@ -114,28 +109,58 @@ func (am *Manager) Sign(a Account, toSign []byte) (signature []byte, err error)
return signature, err return signature, err
} }
// TimedUnlock unlocks the account with the given address. // unlock indefinitely
// When timeout has passed, the account will be locked again. func (am *Manager) Unlock(addr common.Address, keyAuth string) error {
return am.TimedUnlock(addr, keyAuth, 0)
}
// Unlock unlocks the account with the given address. The account
// stays unlocked for the duration of timeout
// it timeout is 0 the account is unlocked for the entire session
func (am *Manager) TimedUnlock(addr common.Address, keyAuth string, timeout time.Duration) error { func (am *Manager) TimedUnlock(addr common.Address, keyAuth string, timeout time.Duration) error {
key, err := am.keyStore.GetKey(addr, keyAuth) key, err := am.keyStore.GetKey(addr, keyAuth)
if err != nil { if err != nil {
return err return err
} }
u := am.addUnlocked(addr, key) var u *unlocked
go am.dropLater(addr, u, timeout) am.mutex.Lock()
defer am.mutex.Unlock()
var found bool
u, found = am.unlocked[addr]
if found {
// terminate dropLater for this key to avoid unexpected drops.
if u.abort != nil {
close(u.abort)
}
}
if timeout > 0 {
u = &unlocked{Key: key, abort: make(chan struct{})}
go am.expire(addr, u, timeout)
} else {
u = &unlocked{Key: key}
}
am.unlocked[addr] = u
return nil return nil
} }
// Unlock unlocks the account with the given address. The account func (am *Manager) expire(addr common.Address, u *unlocked, timeout time.Duration) {
// stays unlocked until the program exits or until a TimedUnlock t := time.NewTimer(timeout)
// timeout (started after the call to Unlock) expires. defer t.Stop()
func (am *Manager) Unlock(addr common.Address, keyAuth string) error { select {
key, err := am.keyStore.GetKey(addr, keyAuth) case <-u.abort:
if err != nil { // just quit
return err case <-t.C:
am.mutex.Lock()
// only drop if it's still the same key instance that dropLater
// was launched with. we can check that using pointer equality
// because the map stores a new pointer every time the key is
// unlocked.
if am.unlocked[addr] == u {
zeroKey(u.PrivateKey)
delete(am.unlocked, addr)
}
am.mutex.Unlock()
} }
am.addUnlocked(addr, key)
return nil
} }
func (am *Manager) NewAccount(auth string) (Account, error) { func (am *Manager) NewAccount(auth string) (Account, error) {
...@@ -162,43 +187,6 @@ func (am *Manager) Accounts() ([]Account, error) { ...@@ -162,43 +187,6 @@ func (am *Manager) Accounts() ([]Account, error) {
return accounts, err return accounts, err
} }
func (am *Manager) addUnlocked(addr common.Address, key *crypto.Key) *unlocked {
u := &unlocked{Key: key, abort: make(chan struct{})}
am.mutex.Lock()
prev, found := am.unlocked[addr]
if found {
// terminate dropLater for this key to avoid unexpected drops.
close(prev.abort)
// the key is zeroed here instead of in dropLater because
// there might not actually be a dropLater running for this
// key, i.e. when Unlock was used.
zeroKey(prev.PrivateKey)
}
am.unlocked[addr] = u
am.mutex.Unlock()
return u
}
func (am *Manager) dropLater(addr common.Address, u *unlocked, timeout time.Duration) {
t := time.NewTimer(timeout)
defer t.Stop()
select {
case <-u.abort:
// just quit
case <-t.C:
am.mutex.Lock()
// only drop if it's still the same key instance that dropLater
// was launched with. we can check that using pointer equality
// because the map stores a new pointer every time the key is
// unlocked.
if am.unlocked[addr] == u {
zeroKey(u.PrivateKey)
delete(am.unlocked, addr)
}
am.mutex.Unlock()
}
}
// zeroKey zeroes a private key in memory. // zeroKey zeroes a private key in memory.
func zeroKey(k *ecdsa.PrivateKey) { func zeroKey(k *ecdsa.PrivateKey) {
b := k.D.Bits() b := k.D.Bits()
......
...@@ -18,7 +18,7 @@ func TestSign(t *testing.T) { ...@@ -18,7 +18,7 @@ func TestSign(t *testing.T) {
pass := "" // not used but required by API pass := "" // not used but required by API
a1, err := am.NewAccount(pass) a1, err := am.NewAccount(pass)
toSign := randentropy.GetEntropyCSPRNG(32) toSign := randentropy.GetEntropyCSPRNG(32)
am.Unlock(a1.Address, "") am.Unlock(a1.Address, "", 0)
_, err = am.Sign(a1, toSign) _, err = am.Sign(a1, toSign)
if err != nil { if err != nil {
...@@ -58,6 +58,47 @@ func TestTimedUnlock(t *testing.T) { ...@@ -58,6 +58,47 @@ func TestTimedUnlock(t *testing.T) {
if err != ErrLocked { if err != ErrLocked {
t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err) t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err)
} }
}
func TestOverrideUnlock(t *testing.T) {
dir, ks := tmpKeyStore(t, crypto.NewKeyStorePassphrase)
defer os.RemoveAll(dir)
am := NewManager(ks)
pass := "foo"
a1, err := am.NewAccount(pass)
toSign := randentropy.GetEntropyCSPRNG(32)
// Unlock indefinitely
if err = am.Unlock(a1.Address, pass); err != nil {
t.Fatal(err)
}
// Signing without passphrase works because account is temp unlocked
_, err = am.Sign(a1, toSign)
if err != nil {
t.Fatal("Signing shouldn't return an error after unlocking, got ", err)
}
// reset unlock to a shorter period, invalidates the previous unlock
if err = am.TimedUnlock(a1.Address, pass, 100*time.Millisecond); err != nil {
t.Fatal(err)
}
// Signing without passphrase still works because account is temp unlocked
_, err = am.Sign(a1, toSign)
if err != nil {
t.Fatal("Signing shouldn't return an error after unlocking, got ", err)
}
// Signing fails again after automatic locking
time.Sleep(150 * time.Millisecond)
_, err = am.Sign(a1, toSign)
if err != ErrLocked {
t.Fatal("Signing should've failed with ErrLocked timeout expired, got ", err)
}
} }
func tmpKeyStore(t *testing.T, new func(string) crypto.KeyStore2) (string, crypto.KeyStore2) { func tmpKeyStore(t *testing.T, new func(string) crypto.KeyStore2) (string, crypto.KeyStore2) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment