137 lines
2.4 KiB
Go
137 lines
2.4 KiB
Go
package websocket
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
"github.com/coder/websocket/wsjson"
|
|
)
|
|
|
|
type WebSocketClient struct {
|
|
URL *url.URL
|
|
conn *websocket.Conn
|
|
wg sync.WaitGroup
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
onOpen func(ws *WebSocketClient)
|
|
onClose func(ws *WebSocketClient)
|
|
onMessage func(ws *WebSocketClient, payload []byte)
|
|
}
|
|
|
|
func NewWebSocketClient(
|
|
ctx context.Context,
|
|
onOpen func(ws *WebSocketClient),
|
|
onClose func(ws *WebSocketClient),
|
|
onMessage func(ws *WebSocketClient, payload []byte),
|
|
) *WebSocketClient {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
return &WebSocketClient{
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
onOpen: onOpen,
|
|
onClose: onClose,
|
|
onMessage: onMessage,
|
|
}
|
|
}
|
|
|
|
func (ws *WebSocketClient) Connect(
|
|
webSocketUrl string,
|
|
attempts int,
|
|
interval time.Duration,
|
|
) (err error) {
|
|
if ws.conn != nil {
|
|
return errors.New("connection already open")
|
|
}
|
|
|
|
ws.URL, err = url.Parse(webSocketUrl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// attempt retry
|
|
for range attempts {
|
|
origin := ws.URL.Scheme + "://" + ws.URL.Host
|
|
header := make(http.Header)
|
|
header.Add("Origin", origin)
|
|
ws.conn, _, err = websocket.Dial(ws.ctx, ws.URL.String(), &websocket.DialOptions{HTTPHeader: header})
|
|
if err == nil {
|
|
break
|
|
}
|
|
time.Sleep(interval)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("websocket dial falied: %v", err)
|
|
}
|
|
|
|
if ws.onOpen != nil {
|
|
ws.onOpen(ws)
|
|
}
|
|
|
|
// Message Handler
|
|
ws.wg.Add(1)
|
|
go func() {
|
|
defer ws.wg.Done()
|
|
for {
|
|
select {
|
|
case <-ws.ctx.Done():
|
|
return
|
|
default:
|
|
var payload []byte
|
|
_, payload, err := ws.conn.Read(ws.ctx)
|
|
if err != nil {
|
|
ws.cancel()
|
|
if ws.onClose != nil {
|
|
ws.onClose(ws)
|
|
}
|
|
return
|
|
}
|
|
ws.onMessage(ws, payload)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Ping Sender
|
|
ws.wg.Add(1)
|
|
go func() {
|
|
defer ws.wg.Done()
|
|
pingTicker := time.NewTicker(time.Second * 10)
|
|
for {
|
|
select {
|
|
case <-ws.ctx.Done():
|
|
return
|
|
case <-pingTicker.C:
|
|
ws.conn.Ping(ws.ctx)
|
|
}
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ws *WebSocketClient) Disconnect() {
|
|
if ws.conn != nil {
|
|
ws.conn.CloseNow()
|
|
}
|
|
|
|
ws.cancel()
|
|
ws.wg.Wait()
|
|
|
|
if ws.onClose != nil {
|
|
ws.onClose(ws)
|
|
}
|
|
}
|
|
|
|
func (ws *WebSocketClient) SendJSON(v any) error {
|
|
return wsjson.Write(ws.ctx, ws.conn, v)
|
|
}
|
|
|
|
func (ws *WebSocketClient) SendPing() error {
|
|
return ws.conn.Ping(ws.ctx)
|
|
}
|