diff --git a/example.go b/example.go index 790718f..d1bc9dc 100644 --- a/example.go +++ b/example.go @@ -16,21 +16,13 @@ const ( func main() { ws := websocket.NewWebSocketClient( // onOpen - func(ws *websocket.WebSocketClient, isReconnecting bool) { - if isReconnecting { - log.Println("Reconnected") - } else { - log.Println("Connected") - } + func(ws *websocket.WebSocketClient) { + log.Println("Connected") }, // onClose - func(ws *websocket.WebSocketClient, isReconnecting bool) { - if isReconnecting { - log.Println("Reconnecting...") - } else { - log.Println("Disconnected") - } + func(ws *websocket.WebSocketClient) { + log.Println("Disconnected") }, // onMessage @@ -40,7 +32,7 @@ func main() { ) // Connect to server - if err := ws.Connect(WEBSOCKET_URL, ATTEMPTS, INTERVAL, false); err != nil { + if err := ws.Connect(WEBSOCKET_URL, ATTEMPTS, INTERVAL); err != nil { log.Println("Failed to connect:", err) } @@ -56,11 +48,11 @@ func main() { time.Sleep(2 * time.Second) // Reconnecting - ws.Reconnect(WEBSOCKET_URL, ATTEMPTS, INTERVAL) + ws.Connect(WEBSOCKET_URL, ATTEMPTS, INTERVAL) sendMessage() time.Sleep(2 * time.Second) - ws.Disconnect(false) + ws.Disconnect() log.Println("Done") } diff --git a/websocket/client.go b/websocket/client.go index a31e33e..d5cf2d8 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -17,14 +17,14 @@ type WebSocketClient struct { wg sync.WaitGroup ctx context.Context cancel context.CancelFunc - onOpen func(ws *WebSocketClient, isReconnecting bool) - onClose func(ws *WebSocketClient, isReconnecting bool) + onOpen func(ws *WebSocketClient) + onClose func(ws *WebSocketClient) onMessage func(ws *WebSocketClient, payload []byte) } func NewWebSocketClient( - onOpen func(ws *WebSocketClient, isReconnecting bool), - onClose func(ws *WebSocketClient, isReconnecting bool), + onOpen func(ws *WebSocketClient), + onClose func(ws *WebSocketClient), onMessage func(ws *WebSocketClient, payload []byte), ) *WebSocketClient { ctx, cancel := context.WithCancel(context.Background()) @@ -41,8 +41,10 @@ func (ws *WebSocketClient) Connect( webSocketUrl string, attempts int, interval time.Duration, - isReconnecting bool, ) (err error) { + ws.Disconnect() + ws.ctx, ws.cancel = context.WithCancel(context.Background()) + ws.URL, err = url.Parse(webSocketUrl) if err != nil { return err @@ -62,7 +64,7 @@ func (ws *WebSocketClient) Connect( } if ws.onOpen != nil { - ws.onOpen(ws, isReconnecting) + ws.onOpen(ws) } // Message Handler @@ -72,14 +74,12 @@ func (ws *WebSocketClient) Connect( for { select { case <-ws.ctx.Done(): - // log.Println("websocket receive cancel") return default: var payload []byte if err := websocket.Message.Receive(ws.conn, &payload); err != nil { - // log.Println("receive error", err) if ws.onClose != nil { - ws.onClose(ws, isReconnecting) + ws.onClose(ws) } return } @@ -91,7 +91,7 @@ func (ws *WebSocketClient) Connect( return nil } -func (ws *WebSocketClient) Disconnect(isReconnecting bool) { +func (ws *WebSocketClient) Disconnect() { if ws.conn != nil { ws.conn.Close() } @@ -100,16 +100,10 @@ func (ws *WebSocketClient) Disconnect(isReconnecting bool) { ws.wg.Wait() if ws.onClose != nil { - ws.onClose(ws, isReconnecting) + ws.onClose(ws) } } -func (ws *WebSocketClient) Reconnect(url string, attempts int, interval time.Duration) error { - ws.Disconnect(true) - ws.ctx, ws.cancel = context.WithCancel(context.Background()) - return ws.Connect(url, attempts, interval, true) -} - func (ws *WebSocketClient) SendJSON(v any) error { bytes, err := json.Marshal(v) if err != nil {