package websocket import ( "context" "encoding/json" "fmt" "net/url" "sync" "time" "golang.org/x/net/websocket" ) type WebSocketClient struct { URL *url.URL conn *websocket.Conn wg sync.WaitGroup ctx context.Context cancel context.CancelFunc onOpen func(ws *WebSocketClient) onClose func(ws *WebSocketClient) onMessage func(ws *WebSocketClient, payload []byte) } func NewWebSocketClient( onOpen func(ws *WebSocketClient), onClose func(ws *WebSocketClient), onMessage func(ws *WebSocketClient, payload []byte), ) *WebSocketClient { ctx, cancel := context.WithCancel(context.Background()) return &WebSocketClient{ ctx: ctx, cancel: cancel, onOpen: onOpen, onClose: onClose, onMessage: onMessage, } } func (ws *WebSocketClient) Connect( webSocketUrl string, attempts int, interval time.Duration, ) (err error) { ws.Disconnect() ws.ctx, ws.cancel = context.WithCancel(context.Background()) ws.URL, err = url.Parse(webSocketUrl) if err != nil { return err } // attempt retry for range attempts { origin := ws.URL.Scheme + "://" + ws.URL.Host ws.conn, err = websocket.Dial(ws.URL.String(), "", origin) if err == nil { break } time.Sleep(interval) } if err != nil { return fmt.Errorf("websocket dial falied: %v", err) } if ws.onOpen != nil { ws.onOpen(ws) } // Message Handler ws.wg.Add(1) go func() { defer ws.wg.Done() for { select { case <-ws.ctx.Done(): return default: var payload []byte if err := websocket.Message.Receive(ws.conn, &payload); err != nil { if ws.onClose != nil { ws.onClose(ws) } return } ws.onMessage(ws, payload) } } }() return nil } func (ws *WebSocketClient) Disconnect() { if ws.conn != nil { ws.conn.Close() } ws.cancel() ws.wg.Wait() if ws.onClose != nil { ws.onClose(ws) } } func (ws *WebSocketClient) SendJSON(v any) error { bytes, err := json.Marshal(v) if err != nil { return fmt.Errorf("failed to marshal json: %v", err) } return websocket.Message.Send(ws.conn, string(bytes)) }