good morning!!!!

Skip to content
Snippets Groups Projects
Commit d10e8e2f authored by Garet Halliday's avatar Garet Halliday
Browse files

close subscription when conn closes

parent 62d02fd3
No related branches found
No related tags found
1 merge request!36close subscription when conn closes
Pipeline #33406 passed with stage
in 4 minutes and 8 seconds
......@@ -90,6 +90,17 @@ func (c *WrapClient) Subscribe(ctx context.Context, namespace string, channel an
readErr: make(chan error, 1),
}
go func() {
defer func() {
_ = sub.Unsubscribe()
}()
select {
case <-c.Closed():
case <-ctx.Done():
sub.err(ctx.Err())
}
}()
c.mu.Lock()
c.subs[sub.id] = sub
c.mu.Unlock()
......
......@@ -2,6 +2,7 @@ package subscription
import (
"context"
"encoding/json"
"net/http/httptest"
_ "net/http/pprof"
"strings"
......@@ -15,49 +16,78 @@ import (
"gfx.cafe/open/jrpc/pkg/server"
)
func TestSubscription(t *testing.T) {
const count = 100
func newRouter(t *testing.T) jmux.Router {
engine := NewEngine()
r := jmux.NewRouter()
r.Use(engine.Middleware())
r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
err := w.Send(r.Params, nil)
if err != nil {
t.Error(err)
}
})
// extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying
// to unsubscribe
r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
notifier, ok := NotifierFromContext(r.Context())
if !ok {
_ = w.Send(nil, ErrNotificationsUnsupported)
err := w.Send(nil, ErrNotificationsUnsupported)
if err != nil {
t.Error(err)
}
return
}
var count int
_ = json.Unmarshal(r.Params, &count)
go func() {
time.Sleep(10 * time.Millisecond)
for i := 0; i < count; i++ {
if err := notifier.Notify(i); err != nil {
panic(err)
for idx := 0; count == 0 || idx < count; idx++ {
select {
case <-r.Context().Done():
return
case <-notifier.Err():
return
default:
}
err := notifier.Notify(idx)
if err != nil {
t.Error(err)
}
}
}()
})
return r
}
func newServer(t *testing.T) (Conn, func()) {
r := newRouter(t)
srv := server.NewServer(r)
defer srv.Shutdown(context.Background())
handler := codecs.WebsocketHandler(srv, []string{"*"})
httpSrv := httptest.NewServer(handler)
defer httpSrv.Close()
wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
cl, err := UpgradeConn(jrpc.Dial(wsURL))
if err != nil {
t.Error(err)
return
return nil, nil
}
defer func() {
if err = cl.Close(); err != nil {
t.Error(err)
}
}()
return cl, func() {
_ = cl.Close()
httpSrv.Close()
srv.Shutdown(context.Background())
}
}
func TestSubscription(t *testing.T) {
const count = 100
cl, done := newServer(t)
defer done()
ch := make(chan int, count)
sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
sub, err := cl.Subscribe(context.Background(), "test", ch, count)
defer func() {
if err = sub.Unsubscribe(); err != nil {
t.Error(err)
......@@ -73,46 +103,11 @@ func TestSubscription(t *testing.T) {
}
func TestUnsubscribeNoRead(t *testing.T) {
engine := NewEngine()
r := jmux.NewRouter()
r.Use(engine.Middleware())
r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
notifier, ok := NotifierFromContext(r.Context())
if !ok {
_ = w.Send(nil, ErrNotificationsUnsupported)
return
}
go func() {
time.Sleep(10 * time.Millisecond)
for i := 0; i < 10; i++ {
if err := notifier.Notify(i); err != nil {
panic(err)
}
}
}()
})
srv := server.NewServer(r)
defer srv.Shutdown(context.Background())
handler := codecs.WebsocketHandler(srv, []string{"*"})
httpSrv := httptest.NewServer(handler)
defer httpSrv.Close()
wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
cl, err := UpgradeConn(jrpc.Dial(wsURL))
if err != nil {
t.Error(err)
return
}
defer func() {
if err = cl.Close(); err != nil {
t.Error(err)
}
}()
cl, done := newServer(t)
defer done()
ch := make(chan int)
sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
sub, err := cl.Subscribe(context.Background(), "test", ch, 10)
time.Sleep(time.Second)
if err = sub.Unsubscribe(); err != nil {
t.Error(err)
......@@ -121,66 +116,12 @@ func TestUnsubscribeNoRead(t *testing.T) {
}
func TestWrapClient(t *testing.T) {
engine := NewEngine()
r := jmux.NewRouter()
r.Use(engine.Middleware())
r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
err := w.Send(r.Params, nil)
if err != nil {
t.Error(err)
}
})
// extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying
// to unsubscribe
r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
notifier, ok := NotifierFromContext(r.Context())
if !ok {
err := w.Send(nil, ErrNotificationsUnsupported)
if err != nil {
t.Error(err)
}
return
}
go func() {
time.Sleep(10 * time.Millisecond)
idx := 0
for {
select {
case <-r.Context().Done():
return
case <-notifier.Err():
return
default:
}
err := notifier.Notify(idx)
if err != nil {
t.Error(err)
}
idx += 1
}
}()
})
srv := server.NewServer(r)
defer srv.Shutdown(context.Background())
handler := codecs.WebsocketHandler(srv, []string{"*"})
httpSrv := httptest.NewServer(handler)
defer httpSrv.Close()
wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
cl, err := UpgradeConn(jrpc.Dial(wsURL))
if err != nil {
t.Error(err)
return
}
defer func() {
if err = cl.Close(); err != nil {
t.Error(err)
}
}()
cl, done := newServer(t)
defer done()
for i := 0; i < 10; i++ {
var res string
if err = cl.Do(context.Background(), &res, "echo", "test"); err != nil {
if err := cl.Do(context.Background(), &res, "echo", "test"); err != nil {
t.Error(err)
return
}
......@@ -190,8 +131,7 @@ func TestWrapClient(t *testing.T) {
}
ch := make(chan int, 101)
var sub ClientSubscription
sub, err = cl.Subscribe(context.Background(), "test", ch, nil)
sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
if err != nil {
t.Error(err)
return
......@@ -220,3 +160,35 @@ func TestWrapClient(t *testing.T) {
}()
}
}
func TestCloseClient(t *testing.T) {
cl, done := newServer(t)
defer done()
ch := make(chan int)
sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
if err != nil {
t.Error(err)
return
}
go func() {
if err := cl.Close(); err != nil {
t.Error(err)
}
}()
for {
select {
case err, ok := <-sub.Err():
if ok {
t.Errorf("sub errored: %v", err)
}
return
case _, ok := <-ch:
if !ok {
return
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment