skybase/main.go

265 lines
5.7 KiB
Go

package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"unicode"
"unicode/utf8"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/session"
"github.com/gorilla/websocket"
"github.com/joho/godotenv"
)
const (
writeWait = 10 * time.Second
pingPeriod = 50 * time.Second
)
var upgrader = websocket.Upgrader{}
type mtag string
const (
mtagOnDiscordChat mtag = "on_discord_chat"
mtagOnMinecraftChat mtag = "on_minecraft_chat"
mtagNothing mtag = "nothing"
)
type message struct {
Tag mtag
OnDiscordChat *onDiscordChatMessage
OnMinecraftChat *onMinecraftChatMessage
}
type onDiscordChatMessage struct {
SenderNickname string
SenderUsername string
Content string
}
type onMinecraftChatMessage struct {
Sender string
Content string
}
type skyBase struct {
mu sync.RWMutex
luaOutboxes []chan message
goInbox chan message
}
func (sb *skyBase) luaDispatch(m message) {
sb.mu.RLock()
defer sb.mu.RUnlock()
for _, ch := range sb.luaOutboxes {
ch <- m
}
}
var base *skyBase
func handleIndex(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello from skybase"))
}
func handleChat(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
outbox := make(chan message, 10)
base.mu.Lock()
base.luaOutboxes = append(base.luaOutboxes, outbox)
base.mu.Unlock()
defer func() {
ws.Close()
base.mu.Lock()
for i, ch := range base.luaOutboxes {
if ch != outbox {
continue
}
base.luaOutboxes[i] = base.luaOutboxes[len(base.luaOutboxes)-1]
base.luaOutboxes = base.luaOutboxes[:len(base.luaOutboxes)-1]
close(ch)
}
base.mu.Unlock()
}()
pingTicker := time.NewTicker(pingPeriod)
defer pingTicker.Stop()
ctx, cancel := context.WithCancel(context.Background())
go func() {
defer cancel()
for {
_, data, err := ws.ReadMessage()
if err != nil {
return
}
var msg message
if err = json.Unmarshal(data, &msg); err != nil {
return
}
base.goInbox <- msg
}
}()
for {
select {
case <-pingTicker.C:
ws.SetWriteDeadline(time.Now().Add(writeWait))
if err = ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
return
}
case msg := <-outbox:
data, err := json.Marshal(msg)
if err != nil {
return
}
ws.SetWriteDeadline(time.Now().Add(writeWait))
if err = ws.WriteMessage(websocket.TextMessage, data); err != nil {
return
}
case <-ctx.Done():
return
}
}
}
func main() {
godotenv.Load()
// load env
token := os.Getenv("BOT_TOKEN")
if token == "" {
log.Fatal("missing bot token")
}
channelID, err := strconv.ParseUint(os.Getenv("CHANNEL_ID"), 10, 64)
if err != nil {
log.Fatal("missing channel ID")
}
httpListen := os.Getenv("HTTP_LISTEN")
if httpListen == "" {
log.Fatal("missing http listen address")
}
// init base
base = &skyBase{
luaOutboxes: []chan message{},
goInbox: make(chan message, 10),
}
// init session
id := gateway.DefaultIdentifier("Bot " + token)
id.Presence = &gateway.UpdatePresenceCommand{
Status: discord.DoNotDisturbStatus,
Activities: []discord.Activity{{
Type: discord.CustomActivity,
Name: "whatever",
State: "PQJDfhkjldalsljkfhasd",
}},
}
s := session.NewWithIdentifier(id)
s.AddHandler(func(c *gateway.ReadyEvent) {
u, err := s.Me()
if err != nil {
log.Fatal(err)
}
log.Println("started as", u.Username)
channelID := discord.ChannelID(channelID)
for msg := range base.goInbox {
switch msg.Tag {
case mtagOnMinecraftChat:
p := msg.OnMinecraftChat
if p == nil {
continue
}
if p.Content == "" {
continue
}
filtered := strings.ReplaceAll(p.Content, "`", "\\`")
filtered = strings.ReplaceAll(filtered, "|", "\\|")
filtered = strings.ReplaceAll(filtered, "*", "\\*")
filtered = strings.ReplaceAll(filtered, "_", "\\_")
filtered = strings.ReplaceAll(filtered, "#", "\\#")
filtered = strings.ReplaceAll(filtered, "[", "\\[")
filtered = strings.ReplaceAll(filtered, "-", "\\-")
filtered = strings.ReplaceAll(filtered, ".", "\\.")
filtered = strings.ReplaceAll(filtered, ">", "\\>")
log.Printf("<%s> %s\n", p.Sender, p.Content)
r, _ := utf8.DecodeRuneInString(p.Content)
isAlphanumeric := unicode.IsLetter(r) || unicode.IsNumber(r)
if !isAlphanumeric {
continue
}
s.SendMessageComplex(channelID, api.SendMessageData{
Content: fmt.Sprintf("<%s> %s", p.Sender, filtered),
AllowedMentions: &api.AllowedMentions{},
Flags: discord.SuppressEmbeds | discord.SuppressNotifications,
})
case mtagNothing:
default:
return
}
}
})
s.AddHandler(func(c *gateway.MessageCreateEvent) {
if c.Author.Bot {
return
}
log.Printf("[@%s] %s\n", c.Author.DisplayName, c.Content)
base.luaDispatch(message{
Tag: mtagOnDiscordChat,
OnDiscordChat: &onDiscordChatMessage{
SenderUsername: c.Author.Username,
SenderNickname: c.Author.DisplayName,
Content: c.Content,
},
})
})
s.AddIntents(gateway.IntentGuildMessages)
// init http
http.HandleFunc("/", handleIndex)
http.HandleFunc("/chat", handleChat)
// begin
go func() {
if err := s.Connect(context.Background()); err != nil {
log.Fatal(err)
}
defer s.Close()
}()
go func() {
http.ListenAndServe(httpListen, nil)
}()
// wait
select {}
}