diff --git a/websocket.go b/websocket.go index 24347ab..ee0ba8d 100644 --- a/websocket.go +++ b/websocket.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "reflect" - "strings" "time" "github.com/cenkalti/backoff/v4" @@ -58,28 +57,25 @@ func NewWebSocket(baseURL string) (*WSClient, error) { if err != nil { out.errCh <- err - // gorilla/websocket eats the error type, so I have to check - // the string itself - if strings.HasSuffix(err.Error(), "connection timed out") { - err = backoff.RetryNotify( - func() error { - conn, _, err = keepaliveDialer().Dial(u.JoinPath("ws").String(), nil) - if err != nil { - out.errCh <- err - return err - } - out.conn = conn - out.recHandler(out) - return nil - }, - backoff.NewExponentialBackOff(), - func(err error, d time.Duration) { - out.errCh <- fmt.Errorf("reconnect backoff (%s): %w", d, err) - }, - ) - if err != nil { - out.errCh <- err - } + conn.Close() + err = backoff.RetryNotify( + func() error { + conn, _, err = keepaliveDialer().Dial(u.JoinPath("ws").String(), nil) + if err != nil { + out.errCh <- err + return err + } + out.conn = conn + out.recHandler(out) + return nil + }, + backoff.NewExponentialBackOff(), + func(err error, d time.Duration) { + out.errCh <- fmt.Errorf("reconnect backoff (%s): %w", d, err) + }, + ) + if err != nil { + out.errCh <- err } continue @@ -201,7 +197,7 @@ func keepaliveDialer() *websocket.Dialer { return nil, err } - err = conn.SetKeepAlivePeriod(time.Second) + err = conn.SetKeepAlivePeriod(10 * time.Second) if err != nil { return nil, err }