good morning!!!!

Skip to content
Snippets Groups Projects
Commit d10e8e2f authored by Garet Halliday's avatar Garet Halliday
Browse files

close subscription when conn closes

parent 62d02fd3
Branches
Tags
1 merge request!36close subscription when conn closes
Pipeline #33406 passed
...@@ -90,6 +90,17 @@ func (c *WrapClient) Subscribe(ctx context.Context, namespace string, channel an ...@@ -90,6 +90,17 @@ func (c *WrapClient) Subscribe(ctx context.Context, namespace string, channel an
readErr: make(chan error, 1), readErr: make(chan error, 1),
} }
go func() {
defer func() {
_ = sub.Unsubscribe()
}()
select {
case <-c.Closed():
case <-ctx.Done():
sub.err(ctx.Err())
}
}()
c.mu.Lock() c.mu.Lock()
c.subs[sub.id] = sub c.subs[sub.id] = sub
c.mu.Unlock() c.mu.Unlock()
......
...@@ -2,6 +2,7 @@ package subscription ...@@ -2,6 +2,7 @@ package subscription
import ( import (
"context" "context"
"encoding/json"
"net/http/httptest" "net/http/httptest"
_ "net/http/pprof" _ "net/http/pprof"
"strings" "strings"
...@@ -15,49 +16,78 @@ import ( ...@@ -15,49 +16,78 @@ import (
"gfx.cafe/open/jrpc/pkg/server" "gfx.cafe/open/jrpc/pkg/server"
) )
func TestSubscription(t *testing.T) { func newRouter(t *testing.T) jmux.Router {
const count = 100
engine := NewEngine() engine := NewEngine()
r := jmux.NewRouter() r := jmux.NewRouter()
r.Use(engine.Middleware()) r.Use(engine.Middleware())
r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
err := w.Send(r.Params, nil)
if err != nil {
t.Error(err)
}
})
// extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying
// to unsubscribe
r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
notifier, ok := NotifierFromContext(r.Context()) notifier, ok := NotifierFromContext(r.Context())
if !ok { if !ok {
_ = w.Send(nil, ErrNotificationsUnsupported) err := w.Send(nil, ErrNotificationsUnsupported)
if err != nil {
t.Error(err)
}
return return
} }
var count int
_ = json.Unmarshal(r.Params, &count)
go func() { go func() {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
for i := 0; i < count; i++ { for idx := 0; count == 0 || idx < count; idx++ {
if err := notifier.Notify(i); err != nil { select {
panic(err) case <-r.Context().Done():
return
case <-notifier.Err():
return
default:
}
err := notifier.Notify(idx)
if err != nil {
t.Error(err)
} }
} }
}() }()
}) })
return r
}
func newServer(t *testing.T) (Conn, func()) {
r := newRouter(t)
srv := server.NewServer(r) srv := server.NewServer(r)
defer srv.Shutdown(context.Background())
handler := codecs.WebsocketHandler(srv, []string{"*"}) handler := codecs.WebsocketHandler(srv, []string{"*"})
httpSrv := httptest.NewServer(handler) httpSrv := httptest.NewServer(handler)
defer httpSrv.Close()
wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
cl, err := UpgradeConn(jrpc.Dial(wsURL)) cl, err := UpgradeConn(jrpc.Dial(wsURL))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return nil, nil
} }
defer func() {
if err = cl.Close(); err != nil { return cl, func() {
t.Error(err) _ = cl.Close()
httpSrv.Close()
srv.Shutdown(context.Background())
} }
}() }
func TestSubscription(t *testing.T) {
const count = 100
cl, done := newServer(t)
defer done()
ch := make(chan int, count) ch := make(chan int, count)
sub, err := cl.Subscribe(context.Background(), "test", ch, nil) sub, err := cl.Subscribe(context.Background(), "test", ch, count)
defer func() { defer func() {
if err = sub.Unsubscribe(); err != nil { if err = sub.Unsubscribe(); err != nil {
t.Error(err) t.Error(err)
...@@ -73,46 +103,11 @@ func TestSubscription(t *testing.T) { ...@@ -73,46 +103,11 @@ func TestSubscription(t *testing.T) {
} }
func TestUnsubscribeNoRead(t *testing.T) { func TestUnsubscribeNoRead(t *testing.T) {
engine := NewEngine() cl, done := newServer(t)
r := jmux.NewRouter() defer done()
r.Use(engine.Middleware())
r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
notifier, ok := NotifierFromContext(r.Context())
if !ok {
_ = w.Send(nil, ErrNotificationsUnsupported)
return
}
go func() {
time.Sleep(10 * time.Millisecond)
for i := 0; i < 10; i++ {
if err := notifier.Notify(i); err != nil {
panic(err)
}
}
}()
})
srv := server.NewServer(r)
defer srv.Shutdown(context.Background())
handler := codecs.WebsocketHandler(srv, []string{"*"})
httpSrv := httptest.NewServer(handler)
defer httpSrv.Close()
wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
cl, err := UpgradeConn(jrpc.Dial(wsURL))
if err != nil {
t.Error(err)
return
}
defer func() {
if err = cl.Close(); err != nil {
t.Error(err)
}
}()
ch := make(chan int) ch := make(chan int)
sub, err := cl.Subscribe(context.Background(), "test", ch, nil) sub, err := cl.Subscribe(context.Background(), "test", ch, 10)
time.Sleep(time.Second) time.Sleep(time.Second)
if err = sub.Unsubscribe(); err != nil { if err = sub.Unsubscribe(); err != nil {
t.Error(err) t.Error(err)
...@@ -121,66 +116,12 @@ func TestUnsubscribeNoRead(t *testing.T) { ...@@ -121,66 +116,12 @@ func TestUnsubscribeNoRead(t *testing.T) {
} }
func TestWrapClient(t *testing.T) { func TestWrapClient(t *testing.T) {
engine := NewEngine() cl, done := newServer(t)
r := jmux.NewRouter() defer done()
r.Use(engine.Middleware())
r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
err := w.Send(r.Params, nil)
if err != nil {
t.Error(err)
}
})
// extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying
// to unsubscribe
r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
notifier, ok := NotifierFromContext(r.Context())
if !ok {
err := w.Send(nil, ErrNotificationsUnsupported)
if err != nil {
t.Error(err)
}
return
}
go func() {
time.Sleep(10 * time.Millisecond)
idx := 0
for {
select {
case <-r.Context().Done():
return
case <-notifier.Err():
return
default:
}
err := notifier.Notify(idx)
if err != nil {
t.Error(err)
}
idx += 1
}
}()
})
srv := server.NewServer(r)
defer srv.Shutdown(context.Background())
handler := codecs.WebsocketHandler(srv, []string{"*"})
httpSrv := httptest.NewServer(handler)
defer httpSrv.Close()
wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
cl, err := UpgradeConn(jrpc.Dial(wsURL))
if err != nil {
t.Error(err)
return
}
defer func() {
if err = cl.Close(); err != nil {
t.Error(err)
}
}()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
var res string var res string
if err = cl.Do(context.Background(), &res, "echo", "test"); err != nil { if err := cl.Do(context.Background(), &res, "echo", "test"); err != nil {
t.Error(err) t.Error(err)
return return
} }
...@@ -190,8 +131,7 @@ func TestWrapClient(t *testing.T) { ...@@ -190,8 +131,7 @@ func TestWrapClient(t *testing.T) {
} }
ch := make(chan int, 101) ch := make(chan int, 101)
var sub ClientSubscription sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
sub, err = cl.Subscribe(context.Background(), "test", ch, nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
...@@ -220,3 +160,35 @@ func TestWrapClient(t *testing.T) { ...@@ -220,3 +160,35 @@ func TestWrapClient(t *testing.T) {
}() }()
} }
} }
func TestCloseClient(t *testing.T) {
cl, done := newServer(t)
defer done()
ch := make(chan int)
sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
if err != nil {
t.Error(err)
return
}
go func() {
if err := cl.Close(); err != nil {
t.Error(err)
}
}()
for {
select {
case err, ok := <-sub.Err():
if ok {
t.Errorf("sub errored: %v", err)
}
return
case _, ok := <-ch:
if !ok {
return
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment