You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
2.6 KiB

package discord
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"git.lalonde.me/matth/AltVRBot/pkg/discord/mux"
"git.lalonde.me/matth/AltVRBot/pkg/utils"
"github.com/bwmarrin/discordgo"
)
// Discord bot
type Discord struct {
token string
sID string
cID string
avatarHash []byte
Emojis []*discordgo.Emoji
Session *discordgo.Session
Router *mux.Mux
User *discordgo.User
}
var (
allowedChannelTypes = []discordgo.ChannelType{discordgo.ChannelTypeGuildText, discordgo.ChannelTypeDM, discordgo.ChannelTypeGroupDM}
)
// New instanciates a new discord bot
func New(token, sID, cID string) (*Discord, error) {
dg := &Discord{
Router: mux.New(),
token: token,
sID: sID,
cID: cID,
}
// Create a new Discord session using the provided bot token.
session, err := discordgo.New("Bot " + token)
if err != nil {
return dg, fmt.Errorf("Failed to create Discord session: %s", err)
}
// Verify a Token was provided
if session.Token == "" {
return dg, errors.New("You must provide a Discord authentication token")
}
dg.Session = session
c, err := dg.Session.Channel(cID)
if err != nil {
log.Fatalf("Invalid channel ID: %+v\n", err)
}
if !utils.SliceContainsChannelType(allowedChannelTypes, c.Type) {
log.Fatalf("Cannot join channel, invalid channel type")
}
dg.addHandlers()
// Open a websocket connection to Discord
err = dg.Session.Open()
if err != nil {
log.Printf("error opening connection to Discord, %s\n", err)
os.Exit(1)
}
dg.User, _ = dg.Session.User("@me")
dg.Emojis, _ = dg.Session.GuildEmojis(sID)
return dg, nil
}
// Close terminates the discord session
func (dg *Discord) Close() {
dg.Session.Close()
}
// UpdateAvatar Updates the bot user avatar if it has changed
func (dg *Discord) UpdateAvatar(url string) error {
if url == "" {
return nil
}
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("Error retrieving the file, %s", err)
}
defer resp.Body.Close()
img, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("Error reading the response, %s", err)
}
h := sha256.Sum256(img)
if bytes.Compare(h[:], dg.avatarHash) != 0 {
base64img := base64.StdEncoding.EncodeToString(img)
contentType := http.DetectContentType(img)
if base64img != "" {
avatar := fmt.Sprintf("data:%s;base64,%s", contentType, base64img)
_, err = dg.Session.UserUpdate("", "", "", avatar, "")
if err != nil {
return err
}
}
}
return nil
}
func (dg *Discord) addHandlers() {
dg.Session.AddHandler(dg.Router.OnMessageCreate)
dg.Router.Route("help", "Display this message.", dg.Router.Help)
}