diff --git a/api/api.go b/api/api.go index 63621b7..4847d1e 100644 --- a/api/api.go +++ b/api/api.go @@ -43,6 +43,7 @@ func StartServing( postgres *db.PostgresClient, redis *db.RedisClient, clickhouse *db.ClickhouseClient, + nats *utils.NatsClient, metrics *utils.Metrics, ) error { if config.API.Host == "" || config.API.Port == "" { @@ -54,7 +55,7 @@ func StartServing( } api.router.Use(middleware.LogRequest(metrics)) - api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse)) + api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse, nats)) api.router.Use(middleware.NewRateLimiter(100, 1*time.Minute, redis)) authenticator := middleware.NewAuthenticator(config.Twitch.ClientSecret, GenericResponse) diff --git a/api/middleware/authorize.go b/api/middleware/authorize.go new file mode 100644 index 0000000..64c9506 --- /dev/null +++ b/api/middleware/authorize.go @@ -0,0 +1,46 @@ +package middleware + +import ( + "net/http" + "slices" + + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" +) + +// GetTwitchPlatformID returns the Twitch platform user ID from a user's connections, or empty string if not found. +func GetTwitchPlatformID(user *common.User) string { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + return conn.UserID + } + } + + return "" +} + +// IsChannelAuthorized returns true if the user is an admin, the broadcaster of the given channel, +// or a channel ambassador. It is used as a shared auth check across channel-scoped write routes. +func IsChannelAuthorized( + request *http.Request, + user *common.User, + channelID string, + postgres *db.PostgresClient, +) bool { + if user.Level >= int(common.ADMIN) { + return true + } + + twitchID := GetTwitchPlatformID(user) + + if twitchID == channelID { + return true + } + + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil { + return false + } + + return slices.Contains(ambassadors, twitchID) +} diff --git a/api/middleware/context.go b/api/middleware/context.go index 488e46a..a773997 100644 --- a/api/middleware/context.go +++ b/api/middleware/context.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/utils" ) // ErrMissingContext is returned when a database client is not found in the request context. @@ -18,6 +19,7 @@ const ( PostgresKey contextKey = "postgres" RedisKey contextKey = "redis" ClickhouseKey contextKey = "clickhouse" + NatsKey contextKey = "nats" ) // InjectDatabases returns a middleware that injects DB clients into the request context. @@ -25,12 +27,14 @@ func InjectDatabases( postgres *db.PostgresClient, redis *db.RedisClient, clickhouse *db.ClickhouseClient, + nats *utils.NatsClient, ) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), PostgresKey, postgres) ctx = context.WithValue(ctx, RedisKey, redis) ctx = context.WithValue(ctx, ClickhouseKey, clickhouse) + ctx = context.WithValue(ctx, NatsKey, nats) next.ServeHTTP(w, r.WithContext(ctx)) }) diff --git a/api/routes/del/command_settings.go b/api/routes/del/command_settings.go new file mode 100644 index 0000000..ef603fc --- /dev/null +++ b/api/routes/del/command_settings.go @@ -0,0 +1,96 @@ +// Package del contains routes for http.MethodDelete requests. +package del + +import ( + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/command-settings", + Method: http.MethodDelete, + Handler: deleteCommandSettings, + UseAuth: true, + }) +} + +func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channelID := request.URL.Query().Get("id") + if channelID == "" { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + channelID = conn.UserID + + break + } + } + } + + command := request.URL.Query().Get("command") + + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "channel id is required"}}, + }, start) + + return + } + + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + // command is required — this endpoint resets a single command override back to defaults. + if command == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Missing required field: command"}}, + }, start) + + return + } + + if err := postgres.ResetCommandSettings(request.Context(), channelID, command); err != nil { + logger.Error.Printf("Error resetting command settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to reset command settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} diff --git a/api/routes/get/ambassadors.go b/api/routes/get/ambassadors.go new file mode 100644 index 0000000..13001bd --- /dev/null +++ b/api/routes/get/ambassadors.go @@ -0,0 +1,82 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +// AmbassadorsResponse is the response type for GET /channel/ambassadors. +type AmbassadorsResponse = common.GenericResponse[string] + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/ambassadors", + Method: http.MethodGet, + Handler: getAmbassadorsHandler, + UseAuth: true, + }) +} + +func getAmbassadorsHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, ok := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !ok { + api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channelID := resolveChannelID(request, user) + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Could not resolve channel ID"}}, + }, start) + + return + } + + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil { + logger.Error.Printf("Error fetching ambassadors: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch ambassadors"}}, + }, start) + + return + } + + if ambassadors == nil { + ambassadors = []string{} + } + + api.GenericResponse(writer, http.StatusOK, AmbassadorsResponse{ + Data: &ambassadors, + }, start) +} diff --git a/api/routes/get/auth.go b/api/routes/get/auth.go new file mode 100644 index 0000000..7e165e5 --- /dev/null +++ b/api/routes/get/auth.go @@ -0,0 +1,1091 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" + "github.com/Potat-Industries/potat-api/common/utils" + "github.com/google/uuid" +) + +//nolint:gosec +const ( + discordAuthorizeURL = "https://discord.com/api/oauth2/authorize" + discordTokenURL = "https://discord.com/api/oauth2/token" + discordMeURL = "https://discord.com/api/users/@me" + + spotifyAuthorizeURL = "https://accounts.spotify.com/authorize" + spotifyTokenURL = "https://accounts.spotify.com/api/token" + spotifyMeURL = "https://api.spotify.com/v1/me" + spotifyScopes = "streaming user-read-recently-played user-top-read " + + "user-read-playback-position user-read-playback-state " + + "user-read-currently-playing user-modify-playback-state user-follow-read" + + kickAuthorizeURL = "https://id.kick.com/oauth/authorize" + kickTokenURL = "https://id.kick.com/oauth/token" + kickScopes = "user:read channel:read channel:write chat:write events:subscribe" + + anilistAuthorizeURL = "https://anilist.co/api/v2/oauth/authorize" + anilistTokenURL = "https://anilist.co/api/v2/oauth/token" + anilistGQLURL = "https://graphql.anilist.co" + + traktAuthorizeURL = "https://api.trakt.tv/oauth/authorize" + traktTokenURL = "https://api.trakt.tv/oauth/token" + traktMeURL = "https://api.trakt.tv/users/me" + + ffzAuthorizeURL = "https://api.frankerfacez.com/auth/authorize" + ffzTokenURL = "https://api.frankerfacez.com/auth/token" + ffzScopes = "collection_edit" + + steamOpenIDURL = "https://steamcommunity.com/openid/login" + steamOpenIDNS = "http://specs.openid.net/auth/2.0" + steamFriendOffset = int64(76561197960265728) + + oauthStateTTL = 60 * time.Second + oauthHTTPTimeout = 10 * time.Second +) + +// oauthStates maps state strings to true with a TTL for replay-attack prevention. +var oauthStates sync.Map //nolint:gochecknoglobals + +// kickPKCEVerifiers maps OAuth state strings to PKCE code verifiers, with the same TTL as oauthStates. +var kickPKCEVerifiers sync.Map //nolint:gochecknoglobals + +func newOAuthState(userID int) string { + nonce := uuid.New().String() + state := fmt.Sprintf("%s:%d", nonce, userID) + oauthStates.Store(state, true) + time.AfterFunc(oauthStateTTL, func() { + oauthStates.Delete(state) + }) + + return state +} + +// newKickOAuthState creates an OAuth state and associated PKCE code verifier. +// State format: {nonce}:{userID}. +// The PKCE code verifier is stored server-side and must be retrieved on callback. +func newKickOAuthState(userID int) (state, codeVerifier, codeChallenge string) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + codeVerifier = uuid.New().String() + } else { + codeVerifier = hex.EncodeToString(buf) + } + + sum := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(sum[:]) + + nonce := uuid.New().String() + state = fmt.Sprintf("%s:%d", nonce, userID) + + // Store replay-protection state and PKCE verifier with the same TTL. + oauthStates.Store(state, true) + kickPKCEVerifiers.Store(state, codeVerifier) + time.AfterFunc(oauthStateTTL, func() { + oauthStates.Delete(state) + kickPKCEVerifiers.Delete(state) + }) + + return state, codeVerifier, codeChallenge +} + +// consumeOAuthState validates a state string, deletes it, and returns the embedded user ID. +func consumeOAuthState(state string) (int, bool) { + if _, ok := oauthStates.LoadAndDelete(state); !ok { + return 0, false + } + + parts := strings.SplitN(state, ":", 3) //nolint:mnd + if len(parts) < 2 { //nolint:mnd + return 0, false + } + + var userID int + if _, err := fmt.Sscanf(parts[1], "%d", &userID); err != nil { + return 0, false + } + + return userID, true +} + +// extractKickCodeVerifier retrieves the server-side PKCE code verifier for a Kick state. +func extractKickCodeVerifier(state string) string { + v, ok := kickPKCEVerifiers.Load(state) + if !ok { + return "" + } + + s, _ := v.(string) + + return s +} + +// authUser retrieves the authenticated user from the request context. +func authUser(request *http.Request) (*common.User, bool) { + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + + return user, ok && user != nil +} + +// oauthPostMessage builds the postMessage HTML used to close popups. +func oauthPostMessage(payload map[string]any) string { + data, _ := json.Marshal(payload) //nolint:errchkjson + + return fmt.Sprintf(``, string(data)) +} + +// oauthErrorHTML builds a minimal error postMessage HTML. +func oauthErrorHTML(message string) string { + return oauthPostMessage(map[string]any{"error": message}) +} + +// oauthSuccessHTML builds a minimal success postMessage HTML. +func oauthSuccessHTML(platform string) string { + return oauthPostMessage(map[string]any{"platform": platform, "ok": true}) +} + +// sendHTML writes an HTML response. +func sendHTML(writer http.ResponseWriter, code int, body string) { + writer.Header().Set("Content-Type", "text/html") + writer.WriteHeader(code) + + if _, err := writer.Write([]byte(body)); err != nil { + logger.Warn.Println("Failed to write HTML response:", err) + } +} + +// --------------------------------------------------------------------------- +// init — register all auth routes +// --------------------------------------------------------------------------- + +func init() { //nolint:funlen + // Discord + api.SetRoute(api.Route{ + Path: "/auth/discord/authorize", Method: http.MethodGet, Handler: discordAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/discord", Method: http.MethodGet, Handler: discordCallbackHandler, UseAuth: false}) + api.SetRoute(api.Route{ + Path: "/auth/discord/join", Method: http.MethodGet, Handler: discordJoinHandler, UseAuth: false, + }) + + // Spotify + api.SetRoute(api.Route{ + Path: "/auth/spotify/authorize", Method: http.MethodGet, Handler: spotifyAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/spotify", Method: http.MethodGet, Handler: spotifyCallbackHandler, UseAuth: false}) + + // Kick + api.SetRoute(api.Route{ + Path: "/auth/kick/authorize", Method: http.MethodGet, Handler: kickAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/kick", Method: http.MethodGet, Handler: kickCallbackHandler, UseAuth: false}) + + // Anilist + api.SetRoute(api.Route{ + Path: "/auth/anilist/authorize", Method: http.MethodGet, Handler: anilistAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/anilist", Method: http.MethodGet, Handler: anilistCallbackHandler, UseAuth: false}) + + // Trakt + api.SetRoute(api.Route{ + Path: "/auth/trakt/authorize", Method: http.MethodGet, Handler: traktAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/trakt", Method: http.MethodGet, Handler: traktCallbackHandler, UseAuth: false}) + + // FFZ — public, direct redirect + api.SetRoute(api.Route{ + Path: "/auth/ffz/authorize", Method: http.MethodGet, Handler: ffzAuthorizeHandler, UseAuth: false, + }) + api.SetRoute(api.Route{Path: "/auth/ffz", Method: http.MethodGet, Handler: ffzCallbackHandler, UseAuth: false}) + + // Steam — OpenID 2.0 + api.SetRoute(api.Route{ + Path: "/auth/steam/authorize", Method: http.MethodGet, Handler: steamAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/steam", Method: http.MethodGet, Handler: steamCallbackHandler, UseAuth: false}) +} + +// --------------------------------------------------------------------------- +// Discord +// --------------------------------------------------------------------------- + +func discordAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Discord.OAuthURI, "/") + "/auth/discord" + + params := url.Values{ + "client_id": {config.Discord.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {"identify"}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", discordAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func discordCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Discord.OAuthURI, "/") + "/auth/discord" + + tok, err := exchangeFormToken(request.Context(), discordTokenURL, url.Values{ + "client_id": {config.Discord.ClientID}, + "client_secret": {config.Discord.ClientSecret}, + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + }) + if err != nil { + logger.Error.Println("Discord token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + discordID, err := fetchDiscordUserID(request.Context(), tok.AccessToken) + if err != nil { + logger.Error.Println("Discord user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Discord user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + _ = userID // stored in state; could be used to link to internal user row + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), discordID, common.DISCORD, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Discord OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("discord")) +} + +func discordJoinHandler(writer http.ResponseWriter, request *http.Request) { + config := utils.LoadConfig() + + params := url.Values{ + "client_id": {config.Discord.ClientID}, + "permissions": {"414464674880"}, + "scope": {"bot identify messages.read guilds.members.read"}, + } + + target := fmt.Sprintf("https://discord.com/oauth2/authorize?%s", params.Encode()) + http.Redirect(writer, request, target, http.StatusFound) +} + +// --------------------------------------------------------------------------- +// Spotify +// --------------------------------------------------------------------------- + +func spotifyAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Spotify.OAuthURI, "/") + "/auth/spotify" + + params := url.Values{ + "client_id": {config.Spotify.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {spotifyScopes}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", spotifyAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func spotifyCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Spotify.OAuthURI, "/") + "/auth/spotify" + + tok, err := exchangeFormToken(request.Context(), spotifyTokenURL, url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + }, withBasicAuth(config.Spotify.ClientID, config.Spotify.ClientSecret)) + if err != nil { + logger.Error.Println("Spotify token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + spotifyID, err := fetchJSONField( + request.Context(), spotifyMeURL, + map[string]string{"Authorization": "Bearer " + tok.AccessToken}, + "id", + ) + if err != nil { + logger.Error.Println("Spotify me fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Spotify user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), spotifyID, common.SPOTIFY, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Spotify OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("spotify")) +} + +// --------------------------------------------------------------------------- +// Kick +// --------------------------------------------------------------------------- + +func kickAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state, _, codeChallenge := newKickOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Kick.OAuthURI, "/") + "/auth/kick" + + params := url.Values{ + "client_id": {config.Kick.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {kickScopes}, + "state": {state}, + "code_challenge": {codeChallenge}, + "code_challenge_method": {"S256"}, + } + + target := fmt.Sprintf("%s?%s", kickAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func kickCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + codeVerifier := extractKickCodeVerifier(state) + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Kick.OAuthURI, "/") + "/auth/kick" + + tok, err := exchangeFormToken(request.Context(), kickTokenURL, url.Values{ + "client_id": {config.Kick.ClientID}, + "client_secret": {config.Kick.ClientSecret}, + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "code_verifier": {codeVerifier}, + }) + if err != nil { + logger.Error.Println("Kick token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + kickID, err := fetchJSONField( + request.Context(), "https://id.kick.com/oauth/user-info", + map[string]string{"Authorization": "Bearer " + tok.AccessToken}, + "sub", + ) + if err != nil { + logger.Error.Println("Kick user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Kick user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), kickID, common.KICK, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Kick OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("kick")) +} + +// --------------------------------------------------------------------------- +// AniList +// --------------------------------------------------------------------------- + +func anilistAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Anilist.OAuthURI, "/") + "/auth/anilist" + + params := url.Values{ + "client_id": {config.Anilist.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", anilistAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func anilistCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Anilist.OAuthURI, "/") + "/auth/anilist" + + tok, err := exchangeJSONToken(request.Context(), anilistTokenURL, map[string]any{ + "grant_type": "authorization_code", + "client_id": config.Anilist.ClientID, + "client_secret": config.Anilist.ClientSecret, + "redirect_uri": redirectURI, + "code": code, + }) + if err != nil { + logger.Error.Println("Anilist token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + anilistID, err := fetchAnilistUserID(request.Context(), tok.AccessToken) + if err != nil { + logger.Error.Println("Anilist user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Anilist user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), anilistID, common.ANILIST, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Anilist OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("anilist")) +} + +// --------------------------------------------------------------------------- +// Trakt +// --------------------------------------------------------------------------- + +func traktAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Trakt.OAuthURI, "/") + "/auth/trakt" + + params := url.Values{ + "client_id": {config.Trakt.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", traktAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func traktCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Trakt.OAuthURI, "/") + "/auth/trakt" + + tok, err := exchangeJSONToken(request.Context(), traktTokenURL, map[string]any{ + "code": code, + "client_id": config.Trakt.ClientID, + "client_secret": config.Trakt.ClientSecret, + "redirect_uri": redirectURI, + "grant_type": "authorization_code", + }) + if err != nil { + logger.Error.Println("Trakt token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + traktID, err := fetchJSONField( + request.Context(), traktMeURL, + map[string]string{ + "Authorization": "Bearer " + tok.AccessToken, + "trakt-api-version": "2", + "trakt-api-key": config.Trakt.ClientID, + }, + "ids.slug", + ) + if err != nil { + logger.Error.Println("Trakt user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Trakt user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), traktID, common.TRAKT, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Trakt OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("trakt")) +} + +// --------------------------------------------------------------------------- +// FFZ — no JWT, no state, direct redirect +// --------------------------------------------------------------------------- + +func ffzAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.FFZ.OAuthURI, "/") + "/auth/ffz" + + params := url.Values{ + "client_id": {config.FFZ.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {ffzScopes}, + } + + http.Redirect(writer, request, fmt.Sprintf("%s?%s", ffzAuthorizeURL, params.Encode()), http.StatusFound) +} + +func ffzCallbackHandler(writer http.ResponseWriter, request *http.Request) { + code := request.URL.Query().Get("code") + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.FFZ.OAuthURI, "/") + "/auth/ffz" + + tok, err := exchangeFormToken(request.Context(), ffzTokenURL, url.Values{ + "client_id": {config.FFZ.ClientID}, + "client_secret": {config.FFZ.ClientSecret}, + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + }) + if err != nil { + logger.Error.Println("FFZ token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("FFZ token exchange failed")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), config.FFZ.ID, common.FFZ, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert FFZ OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("ffz")) +} + +// --------------------------------------------------------------------------- +// Steam — OpenID 2.0 +// --------------------------------------------------------------------------- + +func steamAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + // Derive the base URL from the incoming request to ensure Steam OpenID + // callbacks use the correct host and scheme for this API instance. + scheme := "https" + if forwarded := request.Header.Get("X-Forwarded-Proto"); forwarded != "" { + // Use the first value if multiple are provided. + scheme = strings.Split(forwarded, ",")[0] + } else if request.TLS == nil { + scheme = "http" + } + baseURL := fmt.Sprintf("%s://%s/", scheme, request.Host) + returnTo := fmt.Sprintf("%sauth/steam?user_id=%d", baseURL, user.ID) + realm := baseURL + + params := url.Values{ + "openid.ns": {steamOpenIDNS}, + "openid.mode": {"checkid_setup"}, + "openid.return_to": {returnTo}, + "openid.realm": {realm}, + "openid.identity": {"http://specs.openid.net/auth/2.0/identifier_select"}, + "openid.claimed_id": {"http://specs.openid.net/auth/2.0/identifier_select"}, + } + + target := fmt.Sprintf("%s?%s", steamOpenIDURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func steamCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + + // Re-verify the OpenID assertion with Steam + verifyParams := url.Values{} + maps.Copy(verifyParams, query) + verifyParams.Set("openid.mode", "check_authentication") + + ctx := request.Context() + client := &http.Client{ + Timeout: 5 * time.Second, + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + steamOpenIDURL, + strings.NewReader(verifyParams.Encode()), + ) + if err != nil { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Steam verification failed")) + + return + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam verification failed")) + + return + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam verification failed")) + + return + } + + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "is_valid:true") { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam assertion invalid")) + + return + } + + claimedID := query.Get("openid.claimed_id") + // claimed_id ends with the 64-bit SteamID, e.g. https://steamcommunity.com/openid/id/76561198... + parts := strings.Split(claimedID, "/") + steamID64Str := parts[len(parts)-1] + + var steamID64 int64 + if _, err := fmt.Sscanf(steamID64Str, "%d", &steamID64); err != nil { + sendHTML(writer, http.StatusBadRequest, oauthErrorHTML("Invalid Steam ID")) + + return + } + steamFriendCode := fmt.Sprintf("%d", steamID64-steamFriendOffset) + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + if err := postgres.UpsertOAuthToken( + request.Context(), steamID64Str, common.STEAM, + "", "", []string{}, 0, + ); err != nil { + logger.Warn.Println("Failed to upsert Steam token:", err) + } + + sendHTML(writer, http.StatusOK, oauthPostMessage(map[string]any{ + "platform": "steam", + "ok": true, + "steam_id": steamID64Str, + "friend_code": steamFriendCode, + })) +} + +// --------------------------------------------------------------------------- +// OAuth helpers +// --------------------------------------------------------------------------- + +type simpleTokenResponse struct { //nolint:govet + AccessToken string `json:"access_token"` //nolint:gosec + RefreshToken string `json:"refresh_token"` //nolint:gosec + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type exchangeOption func(*http.Request) + +func withBasicAuth(clientID, clientSecret string) exchangeOption { + return func(req *http.Request) { + req.SetBasicAuth(clientID, clientSecret) + } +} + +// exchangeFormToken does a standard application/x-www-form-urlencoded code exchange. +func exchangeFormToken( + ctx context.Context, + tokenURL string, + params url.Values, + opts ...exchangeOption, +) (*simpleTokenResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(params.Encode())) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + for _, opt := range opts { + opt(req) + } + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return nil, err + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode >= 400 { //nolint:mnd + b, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(b)) //nolint:err113 + } + + var tok simpleTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tok); err != nil { + return nil, err + } + + return &tok, nil +} + +// exchangeJSONToken does a code exchange with application/json body. +func exchangeJSONToken(ctx context.Context, tokenURL string, body map[string]any) (*simpleTokenResponse, error) { + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return nil, err + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode >= 400 { //nolint:mnd + b, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(b)) //nolint:err113 + } + + var tok simpleTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tok); err != nil { + return nil, err + } + + return &tok, nil +} + +// fetchJSONField makes an authenticated request and returns a top-level string field. +// Nested fields can be accessed with dot notation (e.g. "ids.slug"). +func fetchJSONField( + ctx context.Context, + endpoint string, + headers map[string]string, + field string, +) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return "", err + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return "", err + } + defer resp.Body.Close() //nolint:errcheck + + var data map[string]any + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return "", err + } + + parts := strings.SplitN(field, ".", 2) //nolint:mnd + val, exists := data[parts[0]] + + if !exists { + return "", fmt.Errorf("field %q not found in response", parts[0]) //nolint:err113 + } + + if len(parts) == 2 { //nolint:mnd + nested, ok := val.(map[string]any) + if !ok { + return "", fmt.Errorf("field %q is not an object", parts[0]) //nolint:err113 + } + + val, exists = nested[parts[1]] + if !exists { + return "", fmt.Errorf("field %q not found in nested object", parts[1]) //nolint:err113 + } + } + + return fmt.Sprintf("%v", val), nil +} + +// fetchDiscordUserID calls /users/@me and returns the Discord user ID. +func fetchDiscordUserID(ctx context.Context, accessToken string) (string, error) { + return fetchJSONField( + ctx, discordMeURL, + map[string]string{"Authorization": "Bearer " + accessToken}, + "id", + ) +} + +// fetchAnilistUserID queries the AniList GraphQL API for the authenticated user's ID. +func fetchAnilistUserID(ctx context.Context, accessToken string) (string, error) { + query := `{"query":"{Viewer{id}}"}` + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, anilistGQLURL, strings.NewReader(query)) + if err != nil { + return "", err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return "", err + } + defer resp.Body.Close() //nolint:errcheck + + var result struct { + Data struct { + Viewer struct { + ID int `json:"id"` + } `json:"Viewer"` + } `json:"data"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + return fmt.Sprintf("%d", result.Data.Viewer.ID), nil +} diff --git a/api/routes/get/channels.go b/api/routes/get/channels.go new file mode 100644 index 0000000..17f94f3 --- /dev/null +++ b/api/routes/get/channels.go @@ -0,0 +1,57 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +// ChannelsResponse is the response type for the GET /channels endpoint. +type ChannelsResponse = common.GenericResponse[common.ChannelListItem] + +func init() { + api.SetRoute(api.Route{ + Path: "/channels", + Method: http.MethodGet, + Handler: getChannelsHandler, + UseAuth: false, + }) +} + +func getChannelsHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + postgres, ok := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !ok { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, ChannelsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channels, err := postgres.GetAllChannels(request.Context()) + if err != nil { + logger.Error.Printf("Error fetching channels: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, ChannelsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch channels"}}, + }, start) + + return + } + + if channels == nil { + channels = []common.ChannelListItem{} + } + + api.GenericResponse(writer, http.StatusOK, ChannelsResponse{ + Data: &channels, + }, start) +} diff --git a/api/routes/get/command_settings.go b/api/routes/get/command_settings.go new file mode 100644 index 0000000..4253aad --- /dev/null +++ b/api/routes/get/command_settings.go @@ -0,0 +1,98 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" + "github.com/gorilla/mux" +) + +// CommandSettingsResponse is the response type for GET /channel/command-settings. +type CommandSettingsResponse = common.GenericResponse[common.CommandSettings] + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/command-settings", + Method: http.MethodGet, + Handler: getCommandSettingsHandler, + UseAuth: true, + }) +} + +// resolveChannelID returns the channel ID from ?id= or defaults to the authenticated user's Twitch ID. +func resolveChannelID(request *http.Request, user *common.User) string { + if id := request.URL.Query().Get("id"); id != "" { + return id + } + + // Check path variable as fallback (e.g. /channel/:id/...) + if id := mux.Vars(request)["id"]; id != "" { + return id + } + + return middleware.GetTwitchPlatformID(user) +} + +func getCommandSettingsHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channelID := resolveChannelID(request, user) + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "channel id is required"}}, + }, start) + + return + } + + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + settings, err := postgres.GetCommandSettings(request.Context(), channelID) + if err != nil { + logger.Error.Printf("Error fetching command settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch command settings"}}, + }, start) + + return + } + + if settings == nil { + settings = []common.CommandSettings{} + } + + api.GenericResponse(writer, http.StatusOK, CommandSettingsResponse{ + Data: &settings, + }, start) +} diff --git a/api/routes/get/emotes.go b/api/routes/get/emotes.go new file mode 100644 index 0000000..9e7b06a --- /dev/null +++ b/api/routes/get/emotes.go @@ -0,0 +1,269 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "encoding/base64" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" + "github.com/gorilla/mux" +) + +const ( + defaultEmoteLimit = 100 + maxEmoteLimit = 300 + maxEmoteOffset = 10_000 +) + +func init() { + api.SetRoute(api.Route{ + Path: "/emotes/stats", + Method: http.MethodGet, + Handler: getEmoteStats, + UseAuth: false, + }) + api.SetRoute(api.Route{ + Path: "/emotes/history/{login}", + Method: http.MethodGet, + Handler: getEmoteHistory, + UseAuth: false, + }) +} + +// periodToHours converts a period string to a number of hours. 0 means no time filter. +func periodToHours(period string) int { + switch strings.ToLower(period) { + case "hour": + return 1 + case "day": + return 24 + case "week": + return 168 + case "month": + return 720 + default: + return 0 + } +} + +// normaliseProvider maps user-facing provider names to Clickhouse enum values. +func normaliseProvider(provider string) []string { + switch strings.ToUpper(provider) { + case "7TV", "STV": + return []string{"STV"} + case "FFZ": + return []string{"FFZ"} + case "BTTV": + return []string{"BTTV"} + case "TWITCH": + return []string{"TWITCH"} + case "EMOJI": + return []string{"EMOJI"} + case "ALL", "": + return []string{} + default: + // For unknown/unsupported providers, apply no provider filter. + return []string{} + } +} + +func encodeCursor(offset int) string { + return base64.StdEncoding.EncodeToString([]byte(strconv.Itoa(offset))) +} + +func decodeCursor(cursor string) (int, error) { + b, err := base64.StdEncoding.DecodeString(cursor) + if err != nil { + return 0, err + } + + return strconv.Atoi(string(b)) +} + +func getEmoteStats(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + clickhouse, ok := request.Context().Value(middleware.ClickhouseKey).(*db.ClickhouseClient) + if !ok { + logger.Error.Println("Clickhouse client not found in context") + writeEmoteError(writer, http.StatusInternalServerError, start) + + return + } + + query := request.URL.Query() + + channelID := query.Get("id") + login := query.Get("login") + + // If login was provided, resolve to a channel ID via Postgres + if channelID == "" && login != "" { + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + writeEmoteError(writer, http.StatusInternalServerError, start) + + return + } + + ch, err := postgres.GetChannelByName(request.Context(), login, common.TWITCH) + if err != nil { + writeEmoteError(writer, http.StatusNotFound, start) + + return + } + + channelID = ch.ChannelID + } + + if channelID == "" { + writeEmoteError(writer, http.StatusBadRequest, start) + + return + } + + limit := defaultEmoteLimit + if v := query.Get("limit"); v == "" { + if v = query.Get("first"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= maxEmoteLimit { + limit = n + } + } + } else if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= maxEmoteLimit { + limit = n + } + + offset := 0 + if cursor := query.Get("after"); cursor != "" { + if o, err := decodeCursor(cursor); err == nil && o >= 0 && o <= maxEmoteOffset { + offset = o + } else { + writeEmoteError(writer, http.StatusBadRequest, start) + + return + } + } + + opts := db.EmoteStatsOptions{ + ChannelID: channelID, + PeriodHours: periodToHours(query.Get("period")), + Providers: normaliseProvider(query.Get("provider")), + Order: query.Get("order"), + Limit: limit + 1, // fetch one extra to determine hasNextPage + Offset: offset, + } + + stats, err := clickhouse.GetEmoteStats(request.Context(), opts) + if err != nil { + logger.Error.Printf("Error fetching emote stats: %v", err) + writeEmoteError(writer, http.StatusInternalServerError, start) + + return + } + + if stats == nil { + stats = []common.EmoteStat{} + } + + hasNextPage := len(stats) > limit + if hasNextPage { + stats = stats[:limit] + } + + var nextCursor string + if hasNextPage { + nextCursor = encodeCursor(offset + limit) + } + + elapsed := time.Since(start).Seconds() + api.GenericResponse(writer, http.StatusOK, common.EmoteStatsResponse{ + Data: &stats, + Pagination: common.PageInfo{ + HasNextPage: hasNextPage, + Cursor: nextCursor, + }, + StatusCode: http.StatusOK, + Duration: elapsed, + }, start) +} + +func getEmoteHistory(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + login := mux.Vars(request)["login"] + if login == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "login is required"}}, + }, start) + + return + } + + clickhouse, ok := request.Context().Value(middleware.ClickhouseKey).(*db.ClickhouseClient) + if !ok { + logger.Error.Println("Clickhouse client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + // Resolve login → channel ID + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + ch, err := postgres.GetChannelByName(request.Context(), login, common.TWITCH) + if err != nil { + api.GenericResponse(writer, http.StatusNotFound, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Channel not found"}}, + }, start) + + return + } + + limit := defaultEmoteLimit + if v := request.URL.Query().Get("limit"); v != "" { + if n, parseErr := strconv.Atoi(v); parseErr == nil && n > 0 && n <= maxEmoteLimit { + limit = n + } + } + + entries, err := clickhouse.GetEmoteHistory(request.Context(), "", ch.ChannelID, limit) + if err != nil { + logger.Error.Printf("Error fetching emote history: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch emote history"}}, + }, start) + + return + } + + if entries == nil { + entries = []common.EmoteHistoryEntry{} + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[common.EmoteHistoryEntry]{ + Data: &entries, + }, start) +} + +func writeEmoteError(writer http.ResponseWriter, code int, start time.Time) { + api.GenericResponse(writer, code, common.EmoteStatsResponse{ + Data: &[]common.EmoteStat{}, + StatusCode: code, + Duration: time.Since(start).Seconds(), + }, start) +} diff --git a/api/routes/get/help.go b/api/routes/get/help.go index bea0775..e1047ae 100644 --- a/api/routes/get/help.go +++ b/api/routes/get/help.go @@ -113,7 +113,18 @@ func getCommandsHandler(writer http.ResponseWriter, request *http.Request) { } writer.Header().Set("X-Cache-Hit", "MISS") - response, err := utils.BridgeRequest( + nats, ok := request.Context().Value(middleware.NatsKey).(*utils.NatsClient) + if !ok || nats == nil { + logger.Error.Println("NATS client not found in context") + api.GenericResponse(writer, http.StatusServiceUnavailable, HelpResponse{ + Data: &[]common.Command{}, + Errors: &[]common.ErrorMessage{{Message: "Service unavailable"}}, + }, start) + + return + } + + response, err := nats.BridgeRequest( 5*time.Second, "get-commands", ) diff --git a/api/routes/patch/settings.go b/api/routes/patch/settings.go new file mode 100644 index 0000000..0e84643 --- /dev/null +++ b/api/routes/patch/settings.go @@ -0,0 +1,161 @@ +// Package patch contains routes for http.MethodPatch requests. +package patch + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +func init() { + api.SetRoute(api.Route{ + Path: "/users/me/settings", + Method: http.MethodPatch, + Handler: patchUserSettings, + UseAuth: true, + }) + api.SetRoute(api.Route{ + Path: "/channels/me/settings", + Method: http.MethodPatch, + Handler: patchChannelSettings, + UseAuth: true, + }) +} + +func patchUserSettings(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + input := user.Settings + if err := json.NewDecoder(request.Body).Decode(&input); err != nil { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + if err := postgres.UpdateUserSettings(request.Context(), user.ID, input); err != nil { + logger.Error.Printf("Error updating user settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to update settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} + +func patchChannelSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + // Resolve channel ID and platform from query params or fall back to the users own Twitch channel. + channelID := request.URL.Query().Get("id") + platform := request.URL.Query().Get("platform") + if platform == "" { + platform = string(common.TWITCH) + } + if channelID == "" { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + channelID = conn.UserID + + break + } + } + } + + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "channel id is required"}}, + }, start) + + return + } + + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + existingSettings, err := postgres.GetChannelSettingsByID(request.Context(), channelID, platform) + if err != nil { + logger.Error.Printf("Error fetching channel settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + input := existingSettings + if err := json.NewDecoder(request.Body).Decode(&input); err != nil { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, + }, start) + + return + } + + if err := postgres.UpdateChannelSettings(request.Context(), channelID, platform, input); err != nil { + logger.Error.Printf("Error updating channel settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to update settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} diff --git a/api/routes/put/command_settings.go b/api/routes/put/command_settings.go new file mode 100644 index 0000000..9afcc99 --- /dev/null +++ b/api/routes/put/command_settings.go @@ -0,0 +1,101 @@ +// Package put contains routes for http.MethodPut requests. +package put + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/command-settings", + Method: http.MethodPut, + Handler: putCommandSettings, + UseAuth: true, + }) +} + +func putCommandSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + var input common.CommandSettings + if err := json.NewDecoder(request.Body).Decode(&input); err != nil { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, + }, start) + + return + } + + channelID := input.ChannelID + if channelID == "" { + channelID = request.URL.Query().Get("id") + } + + if channelID == "" { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + channelID = conn.UserID + + break + } + } + } + + if channelID == "" || input.Command == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "channel_id and command are required"}}, + }, start) + + return + } + + input.ChannelID = channelID + + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + if err := postgres.UpsertCommandSettings(request.Context(), input); err != nil { + logger.Error.Printf("Error upserting command settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to update command settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} diff --git a/common/config.go b/common/config.go index 2d6a012..85fbfd3 100644 --- a/common/config.go +++ b/common/config.go @@ -3,18 +3,27 @@ package common // Config holds the configuration for the application, including database and service settings. type Config struct { - Postgres SQLConfig `json:"postgres"` - Clickhouse SQLConfig `json:"clickhouse"` - Twitch TwitchConfig `json:"twitch"` - Redis RedisConfig `json:"redis"` - API APIConfig `json:"api"` - Socket APIConfig `json:"socket"` - Redirects APIConfig `json:"redirects"` - Uploader APIConfig `json:"uploader"` - Prometheus APIConfig `json:"prometheus"` - Haste HasteConfig `json:"haste"` - Nats BoolConfig `json:"nats"` - Loops BoolConfig `json:"loops"` + Postgres SQLConfig `json:"postgres"` + Clickhouse SQLConfig `json:"clickhouse"` + Twitch TwitchConfig `json:"twitch"` + Discord DiscordConfig `json:"discord"` + Spotify SpotifyConfig `json:"spotify"` + Kick KickConfig `json:"kick"` + Anilist AnilistConfig `json:"anilist"` + Trakt TraktConfig `json:"trakt"` + FFZ FFZConfig `json:"ffz"` + STV STVConfig `json:"stv"` + BTTV BTTVConfig `json:"bttv"` + Misc MiscConfig `json:"misc"` + Redis RedisConfig `json:"redis"` + API APIConfig `json:"api"` + Socket APIConfig `json:"socket"` + Redirects APIConfig `json:"redirects"` + Uploader APIConfig `json:"uploader"` + Prometheus APIConfig `json:"prometheus"` + Haste HasteConfig `json:"haste"` + Nats BoolConfig `json:"nats"` + Loops BoolConfig `json:"loops"` } // TwitchConfig holds the configuration for Twitch API integration. @@ -24,6 +33,77 @@ type TwitchConfig struct { OauthURI string `json:"oauth_uri"` } +// DiscordConfig holds the configuration for Discord OAuth and bot integration. +type DiscordConfig struct { + ID string `json:"id"` + OAuth string `json:"oauth"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// SpotifyConfig holds the configuration for Spotify OAuth integration. +type SpotifyConfig struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// KickConfig holds the configuration for Kick OAuth integration (PKCE flow). +type KickConfig struct { + ID string `json:"id"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` + OAuth string `json:"oauth"` +} + +// AnilistConfig holds the configuration for AniList OAuth integration. +type AnilistConfig struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// TraktConfig holds the configuration for Trakt OAuth integration. +type TraktConfig struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// FFZConfig holds the configuration for FrankerFaceZ OAuth integration. +type FFZConfig struct { + ID string `json:"id"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + Token string `json:"token"` + Refresh string `json:"refresh"` + OAuthURI string `json:"oauth_uri"` +} + +// STVConfig holds the configuration for 7TV bot token. +type STVConfig struct { + Token string `json:"token"` + ID string `json:"id"` +} + +// BTTVConfig holds the configuration for BTTV bot token. +type BTTVConfig struct { + Token string `json:"token"` + ID string `json:"id"` +} + +// SteamConfig holds the Steam Web API key. +type SteamConfig struct { + Key string `json:"key"` +} + +// MiscConfig holds miscellaneous third-party service configurations. +type MiscConfig struct { + Steam SteamConfig `json:"steam"` +} + // BoolConfig holds the configuration for the simple enablable service. type BoolConfig struct { Enabled bool `json:"enabled"` diff --git a/common/db/clickhouse.go b/common/db/clickhouse.go index 298a14f..a135693 100644 --- a/common/db/clickhouse.go +++ b/common/db/clickhouse.go @@ -2,7 +2,10 @@ package db import ( + "context" "fmt" + "strings" + "time" "github.com/ClickHouse/clickhouse-go/v2" "github.com/ClickHouse/clickhouse-go/v2/lib/driver" @@ -45,3 +48,151 @@ func InitClickhouse(config common.Config) (*ClickhouseClient, error) { return &ClickhouseClient{conn}, nil } + +// EmoteStatsOptions holds the query filters for GetEmoteStats. +type EmoteStatsOptions struct { + // ChannelID filters results to a specific channel (uses emote_usage table when UserID is empty). + ChannelID string + // UserID filters results to a specific user (uses user_emote_usage table). + UserID string + // Order is "ASC" or "DESC". + Order string + // Providers is the normalised list of provider enum values (e.g. "STV", "FFZ"). Empty = all. + Providers []string + // PeriodHours is the time window in hours. 0 means no time filter ("all"). + PeriodHours int + // Limit is the maximum number of rows to return. + Limit int + // Offset is the number of rows to skip (cursor-based pagination). + Offset int +} + +// GetEmoteStats queries aggregated emote usage from Clickhouse with cursor-based pagination. +func (db *ClickhouseClient) GetEmoteStats( //nolint:cyclop + ctx context.Context, + opts EmoteStatsOptions, +) ([]common.EmoteStat, error) { + table := "potatbotat.emote_usage" + if opts.UserID != "" { + table = "potatbotat.user_emote_usage" + } + + var sb strings.Builder + args := make([]any, 0, 8) + + fmt.Fprintf(&sb, + "SELECT emote_id, emote_name, emote_alias, provider, sum(count) AS count FROM %s FINAL WHERE 1=1", + table, + ) + + if opts.ChannelID != "" { + args = append(args, opts.ChannelID) + sb.WriteString(" AND channel_id = ?") + } + + if opts.UserID != "" { + args = append(args, opts.UserID) + sb.WriteString(" AND user_id = ?") + } + + if opts.PeriodHours > 0 { + cutoff := time.Now().Add(-time.Duration(opts.PeriodHours) * time.Hour) + args = append(args, cutoff) + sb.WriteString(" AND used_at >= ?") + } + + if len(opts.Providers) > 0 { + for _, p := range opts.Providers { + args = append(args, p) + } + placeholders := strings.TrimSuffix(strings.Repeat("?,", len(opts.Providers)), ",") + fmt.Fprintf(&sb, " AND provider IN (%s)", placeholders) + } + + sb.WriteString(" GROUP BY emote_id, emote_name, emote_alias, provider") + + order := "DESC" + if strings.EqualFold(opts.Order, "asc") { + order = "ASC" + } + fmt.Fprintf(&sb, " ORDER BY count %s", order) + + limit := max(1, min(opts.Limit, 300)) + 1 + fmt.Fprintf(&sb, " LIMIT %d", limit) + + if opts.Offset > 0 { + fmt.Fprintf(&sb, " OFFSET %d", opts.Offset) + } + + rows, err := db.Query(ctx, sb.String(), args...) + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck + + var stats []common.EmoteStat + for rows.Next() { + var s common.EmoteStat + if err := rows.Scan(&s.EmoteID, &s.EmoteName, &s.EmoteAlias, &s.Provider, &s.Count); err != nil { + return nil, err + } + stats = append(stats, s) + } + + return stats, rows.Err() +} + +// GetEmoteHistory queries the most recent per-user emote usage records from Clickhouse. +func (db *ClickhouseClient) GetEmoteHistory( + ctx context.Context, + userID string, + channelID string, + limit int, +) ([]common.EmoteHistoryEntry, error) { + if limit <= 0 || limit > 300 { + limit = 100 + } + + var sb strings.Builder + args := make([]any, 0, 3) + + sb.WriteString(` + SELECT emote_id, emote_name, emote_alias, provider, channel_id, user_id, sum(count) AS count, max(used_at) AS used_at + FROM potatbotat.user_emote_usage FINAL + WHERE 1=1 + `) + + if userID != "" { + args = append(args, userID) + sb.WriteString(" AND user_id = ?") + } + + if channelID != "" { + args = append(args, channelID) + sb.WriteString(" AND channel_id = ?") + } + + sb.WriteString(" GROUP BY emote_id, emote_name, emote_alias, provider, channel_id, user_id") + sb.WriteString(" ORDER BY used_at DESC") + fmt.Fprintf(&sb, " LIMIT %d", limit) + + rows, err := db.Query(ctx, sb.String(), args...) + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck + + var entries []common.EmoteHistoryEntry + for rows.Next() { + var e common.EmoteHistoryEntry + if err := rows.Scan( + &e.EmoteID, &e.EmoteName, &e.EmoteAlias, &e.Provider, + &e.ChannelID, &e.UserID, &e.Count, &e.UsedAt, + ); err != nil { + return nil, err + } + entries = append(entries, e) + } + + return entries, rows.Err() +} diff --git a/common/db/loops.go b/common/db/loops.go index 712ad4d..90e03d4 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -1,4 +1,5 @@ // Package db provides database clients and functions to retrieve or update data. + package db import ( @@ -21,29 +22,39 @@ import ( ) var ( - errNoRows = fmt.Errorf("no rows returned for database size query") + errNoRows = fmt.Errorf("no rows returned for database size query") + errMissingRefreshToken = errors.New("missing refresh token") ) const dumpPath = "./dump" // StartLoops initializes schedules and loops for various tasks. + func StartLoops( ctx context.Context, + config common.Config, + natsClient *utils.NatsClient, + postgres *PostgresClient, + clickhouse *ClickhouseClient, + redis *RedisClient, ) { if !config.Loops.Enabled { return } + cronManager := cron.New() var err error + _, err = cronManager.AddFunc("@hourly", func() { go updateHourlyUsage(ctx, postgres) + go validateTokens(ctx, config, postgres) }) if err != nil { @@ -51,6 +62,7 @@ func StartLoops( return } + _, err = cronManager.AddFunc("@daily", func() { updateDailyUsage(ctx, postgres) }) @@ -59,6 +71,7 @@ func StartLoops( return } + _, err = cronManager.AddFunc("@weekly", func() { updateWeeklyUsage(ctx, postgres) }) @@ -67,6 +80,7 @@ func StartLoops( return } + _, err = cronManager.AddFunc("0 */2 * * *", func() { refreshAllHelixTokens(ctx, config, postgres) }) @@ -75,10 +89,14 @@ func StartLoops( return } + _, err = cronManager.AddFunc("*/30 * * * *", func() { updateColorView(ctx, clickhouse) + updateActiveBadgeView(ctx, clickhouse) + updateOwnedBadgeView(ctx, clickhouse) + updateUserOwnedBadgeView(ctx, clickhouse) }) if err != nil { @@ -86,8 +104,10 @@ func StartLoops( return } + _, err = cronManager.AddFunc("0 */12 * * *", func() { go backupPostgres(ctx, postgres, natsClient, config) + // go optimizeClickhouse(ctx, config, clickhouse) }) if err != nil { @@ -99,13 +119,16 @@ func StartLoops( cronManager.Start() go decrementDuels(ctx, redis) + go deleteOldUploads(ctx, postgres) + go updateAggregateTable(ctx, postgres) } func decrementDuels(ctx context.Context, redis *RedisClient) { for { time.Sleep(30 * time.Minute) + logger.Info.Println("Decrementing duels") keys, err := redis.Scan(ctx, "duelUse:*", 100, 0) @@ -120,16 +143,27 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { } luaScript := ` + local decrementedKeys = 0 + for _, key in ipairs(KEYS) do + local num = redis.call("DECR", key) + decrementedKeys = decrementedKeys + 1 + if num <= 0 then + redis.call("DEL", key) + end + end + + return decrementedKeys + ` value, err := redis.Eval(ctx, luaScript, keys).Result() @@ -144,12 +178,17 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { func deleteOldUploads(ctx context.Context, postgres *PostgresClient) { for { time.Sleep(24 * time.Hour) + logger.Info.Println("Deleting old uploads") query := ` + DELETE FROM file_store + WHERE (created_at < NOW() - INTERVAL '30 days' AND expires_at IS NULL) + OR (expires_at IS NOT NULL AND expires_at < NOW()); + ` _, err := postgres.Exec(ctx, query) @@ -166,12 +205,19 @@ func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { time.Sleep(5 * time.Minute) query := ` + INSERT INTO channel_command_usage (channel_id, channel_usage) + SELECT channel_id, SUM(channel_usage) AS channel_usage + FROM command_settings + GROUP BY channel_id + ON CONFLICT (channel_id) DO UPDATE + SET channel_usage = EXCLUDED.channel_usage; + ` _, err := postgres.Exec(ctx, query) @@ -185,6 +231,7 @@ func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { func updateHourlyUsage(ctx context.Context, postgres *PostgresClient) { logger.Info.Println("Updating hourly usage") + query := `UPDATE gpt_usage SET hourly_usage = 0;` _, err := postgres.Exec(ctx, query) @@ -197,6 +244,7 @@ func updateHourlyUsage(ctx context.Context, postgres *PostgresClient) { func updateDailyUsage(ctx context.Context, postgres *PostgresClient) { logger.Info.Println("Updating daily usage") + query := `UPDATE gpt_usage SET daily_usage = 0` _, err := postgres.Exec(ctx, query) @@ -209,6 +257,7 @@ func updateDailyUsage(ctx context.Context, postgres *PostgresClient) { func updateWeeklyUsage(ctx context.Context, postgres *PostgresClient) { logger.Info.Println("Updating weekly usage") + query := `UPDATE gpt_usage SET weekly_usage = 0` _, err := postgres.Exec(ctx, query) @@ -223,17 +272,29 @@ func updateColorView(ctx context.Context, clickhouse *ClickhouseClient) { logger.Info.Println("Updating color view") query := ` + INSERT INTO potatbotat.twitch_color_stats + SELECT + color, + COUNT(DISTINCT user_id) AS user_count, + (COUNT(DISTINCT user_id) * 100.0) / ( + SELECT COUNT(user_id) + FROM potatbotat.twitch_colors + ) AS percentage, + ROW_NUMBER() OVER (ORDER BY COUNT(DISTINCT user_id) DESC) AS rank + FROM potatbotat.twitch_colors FINAL + GROUP BY color; + ` err := clickhouse.Exec(ctx, query) @@ -253,14 +314,23 @@ func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { } query := ` + INSERT INTO potatbotat.twitch_active_badge_stats + SELECT + badge, + count(user_id) AS user_count, + version + FROM potatbotat.twitch_badges + WHERE badge NOT IN ('', 'NOBADGE') + GROUP BY (badge, version); + ` err = clickhouse.Exec(ctx, query) @@ -273,19 +343,30 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { logger.Info.Println("Updating owned badge view") // Insert owned badges from active table first + // prepare := ` + // INSERT INTO potatbotat.twitch_owned_badges + // SELECT + // badge, + // user_id, + // version + // FROM potatbotat.twitch_badges + // WHERE badge NOT IN ('', 'NOBADGE') + // ` // err := clickhouse.Exec(ctx, prepare) + // if err != nil { // logger.Error.Println("Error preparing badge view ", err) + // } err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_stats;`) @@ -296,14 +377,23 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { } query := ` + INSERT INTO potatbotat.twitch_owned_badge_stats + SELECT + badge, + count(user_id) AS user_count, + version + FROM potatbotat.twitch_owned_badges + WHERE badge NOT IN ('', 'NOBADGE') + GROUP BY (badge, version); + ` err = clickhouse.Exec(ctx, query) @@ -323,17 +413,29 @@ func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) } query := ` + INSERT INTO potatbotat.twitch_owned_badge_user_stats + SELECT + user_id, + count(badge) AS badge_count, + groupArrayDistinct(badge) AS badges, + now64(3) + FROM potatbotat.twitch_owned_badges FINAL + WHERE badge NOT IN ('', 'NOBADGE') + GROUP BY user_id + HAVING uniqExact(badge) >= 5 + ORDER BY badge_count DESC; + ` err = clickhouse.Exec(ctx, query) @@ -344,40 +446,71 @@ func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) func upsertOAuthToken( ctx context.Context, + postgres *PostgresClient, + oauth *common.GenericOAUTHResponse, + con common.PlatformOauth, ) error { query := ` + INSERT INTO connection_oauth ( + platform_id, + access_token, + refresh_token, + scope, + expires_in, + added_at, + platform + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (platform_id, platform) + DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + scope = EXCLUDED.scope, + expires_in = EXCLUDED.expires_in, + added_at = EXCLUDED.added_at; + ` _, err := postgres.Exec( + ctx, + query, + con.PlatformID, + oauth.AccessToken, + oauth.RefreshToken, + oauth.Scope, + oauth.ExpiresIn, + time.Now(), + common.TWITCH, ) @@ -386,16 +519,21 @@ func upsertOAuthToken( func refreshOrDelete( ctx context.Context, + config common.Config, + postgres *PostgresClient, + con common.PlatformOauth, ) (bool, error) { var err error + if con.RefreshToken == "" { return false, errMissingRefreshToken } refreshResult, err := utils.RefreshHelixToken(ctx, config, con.RefreshToken) + if err != nil || refreshResult == nil { return false, err } @@ -403,6 +541,7 @@ func refreshOrDelete( err = upsertOAuthToken(ctx, postgres, refreshResult, con) if err != nil { logger.Error.Println( + "Error updating token for user_id", con.PlatformID, ":", err, ) @@ -416,12 +555,19 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre logger.Info.Println("Validating Twitch tokens ") query := ` + SELECT + access_token, + platform_id, + refresh_token + FROM connection_oauth + WHERE platform = 'TWITCH'; + ` rows, err := postgres.Query(ctx, query) @@ -430,9 +576,11 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre return } + defer rows.Close() validated, deleted := 0, 0 + for rows.Next() { var con common.PlatformOauth @@ -454,6 +602,7 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre ok, err := refreshOrDelete(ctx, config, postgres, con) if err != nil { logger.Error.Println("Error refreshing token ", err) + deleted++ continue @@ -467,14 +616,18 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre continue } + validated++ time.Sleep(200 * time.Millisecond) } logger.Info.Printf( + "Validated %d helix tokens, and deleted %d expired tokens", + validated, + deleted, ) } @@ -483,16 +636,27 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * logger.Info.Println("Refreshing all Twitch tokens") query := ` + SELECT + platform, + platform_id, + access_token, + refresh_token, + expires_in, + added_at, + scope + FROM connection_oauth + WHERE platform = 'TWITCH'; + ` rows, err := postgres.Query(ctx, query) @@ -501,19 +665,28 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * return } + defer rows.Close() refreshed, failed := 0, 0 + for rows.Next() { var con common.PlatformOauth err := rows.Scan( + &con.Platform, + &con.PlatformID, + &con.AccessToken, + &con.RefreshToken, + &con.ExpiresIn, + &con.AddedAt, + &con.Scope, ) if err != nil { @@ -525,6 +698,7 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * ok, err := refreshOrDelete(ctx, config, postgres, con) if err != nil { logger.Error.Println("Error refreshing token ", err) + failed++ continue @@ -542,8 +716,11 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * } logger.Info.Printf( + "Refreshed %d helix tokens, %d failed and were expunged", + refreshed, + failed, ) } @@ -551,11 +728,13 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * func sortFiles(files []string) func(i, j int) bool { return func(i, j int) bool { fileI, errI := os.Stat(filepath.Join(dumpPath, files[i])) + if errI != nil { return false } fileJ, errJ := os.Stat(filepath.Join(dumpPath, files[j])) + if errJ != nil { return true } @@ -566,6 +745,7 @@ func sortFiles(files []string) func(i, j int) bool { func deleteOldDumps(files []string, maxSize int) { logger.Debug.Printf("Checking for old dump files, current count: %d, max: %d", len(files), maxSize) + if len(files) <= maxSize { return } @@ -573,6 +753,7 @@ func deleteOldDumps(files []string, maxSize int) { sort.Slice(files, sortFiles(files)) filesToDelete := files[:len(files)-maxSize+1] + for _, file := range filesToDelete { err := os.Remove(file) if err != nil { @@ -585,8 +766,11 @@ func deleteOldDumps(files []string, maxSize int) { func backupPostgres( ctx context.Context, + postgres *PostgresClient, + natsClient *utils.NatsClient, + config common.Config, ) { logger.Debug.Println("Backing up Postgres") @@ -607,18 +791,27 @@ func backupPostgres( deleteOldDumps(files, 10) filePath := filepath.Join( + dumpPath, + fmt.Sprintf("data_%d.sql.zst", time.Now().Unix()), ) - //nolint:gosec,noctx - cmd := exec.Command("sh", "-c", fmt.Sprintf( + //nolint:gosec + cmd := exec.CommandContext(ctx, "sh", "-c", fmt.Sprintf( + "PGPASSWORD=%s pg_dump -d %s -U %s -h %s | zstd -3 --threads=%d > %s", + config.Postgres.Password, + config.Postgres.Database, + config.Postgres.User, + config.Postgres.Host, + runtime.NumCPU(), + filePath, )) @@ -633,14 +826,17 @@ func backupPostgres( }() var stderr bytes.Buffer + cmd.Stderr = &stderr start := time.Now() + if err = cmd.Run(); err != nil { logger.Error.Println("Failed to execute pg_dump:", err, stderr.String()) return } + duration := time.Since(start) stat, err := os.Stat(filePath) @@ -660,32 +856,42 @@ func backupPostgres( } message := fmt.Sprintf( + "Database back-up successful in %s - DB size: %s - Backup size: %.2f GB", + utils.Humanize(duration, 2), + dbSize, + backupSize, ) jsonMessage, err := json.Marshal(message) if err != nil { logger.Error.Println("Failed to JSON stringify message:", err) + logger.Info.Println(message) return } - err = natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", jsonMessage) - if err != nil { - logger.Error.Println("Failed to publish to queue:", err) - - return + if natsClient != nil { + publishBackupNotification(natsClient, jsonMessage) } logger.Info.Println(message) } +func publishBackupNotification(natsClient *utils.NatsClient, message []byte) { + err := natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", message) + if err != nil { + logger.Error.Println("Failed to publish to queue:", err) + } +} + func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName string) (string, error) { query := `SELECT pg_size_pretty(pg_database_size($1)) AS size` + rows, err := postgres.Query(ctx, query, dbName) if err != nil { return "", err @@ -695,6 +901,7 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin if rows.Next() { var size string + if err := rows.Scan(&size); err != nil { return "", err } @@ -707,6 +914,7 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin // func optimizeClickhouse(ctx context.Context, config common.Config, clickhouse *ClickhouseClient) { // // offset any concurrent crons + // time.Sleep(5 * time.Minute) // logger.Info.Println("Optimizing Clickhouse tables") @@ -715,36 +923,49 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin // logger.Error.Println("Clickhouse database is not configured") // return + // } // query := `SELECT table FROM system.tables WHERE database = ?` // rows, err := clickhouse.Query(ctx, query, config.Clickhouse.Database) + // if err != nil { // logger.Error.Println("Failed to query Clickhouse tables:", err) // return + // } // for rows.Next() { // var table string + // if err := rows.Scan(&table); err != nil { // logger.Error.Println("Failed to scan Clickhouse table:", err) // continue + // } // query := fmt.Sprintf("OPTIMIZE TABLE %s.%s FINAL", config.Clickhouse.Database, table) + // if err := clickhouse.Exec(ctx, query); err != nil { // logger.Error.Println("Failed to optimize Clickhouse table:", err) + // } // logger.Info.Printf( + // "Optimized Clickhouse table %s.%s", + // config.Clickhouse.Database, + // table, + // ) // time.Sleep(5 * time.Second) + // } + // } diff --git a/common/db/postgres.go b/common/db/postgres.go index af55871..a266fc0 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -1,4 +1,5 @@ // Package db provides database clients and functions to retrieve or update data. + package db import ( @@ -16,25 +17,33 @@ import ( ) // PostgresClient is a wrapper around the pgxpool.Pool to manage database connections and queries. + type PostgresClient struct { *pgxpool.Pool } // LoaderKey is used to identify a user or channel in the database. + type LoaderKey struct { - ID *int - UserID *string + ID *int + + UserID *string + Username *string + Platform *string } // ErrPostgresNoRows is an alias for pgx.ErrNoRows to handle cases where no rows are returned from a query. + var ( ErrPostgresNoRows = pgx.ErrNoRows - errInvalidType = fmt.Errorf("invalid channel type") + + errInvalidType = fmt.Errorf("invalid channel type") ) // InitPostgres initializes a new Postgres client with the provided configuration. + func InitPostgres(ctx context.Context, config common.Config) (*PostgresClient, error) { dbConfig, err := loadConfig(config) if err != nil { @@ -51,31 +60,41 @@ func InitPostgres(ctx context.Context, config common.Config) (*PostgresClient, e func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unparam user := config.Postgres.User + if user == "" { user = "postgres" } host := config.Postgres.Host + if host == "" { host = "localhost" //nolint:goconst } port := config.Postgres.Port + if port == "" { port = "5432" } database := config.Postgres.Database + if database == "" { database = "postgres" } constring := fmt.Sprintf( + "postgres://%s:%s@%s:%s/%s", + user, + config.Postgres.Password, + host, + port, + database, ) @@ -87,16 +106,22 @@ func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unpara } dbConfig.MaxConns = 32 + dbConfig.MinConns = 4 + dbConfig.MaxConnIdleTime = 1 * time.Minute + dbConfig.MaxConnLifetime = 30 * time.Minute + dbConfig.HealthCheckPeriod = 5 * time.Minute + dbConfig.ConnConfig.ConnectTimeout = 10 * time.Second return dbConfig, nil } // CheckTableExists checks if a table exists in the database and creates it if it doesn't. + func (db *PostgresClient) CheckTableExists(ctx context.Context, createTable string) { _, err := db.Pool.Exec(ctx, createTable) if err != nil { @@ -105,35 +130,58 @@ func (db *PostgresClient) CheckTableExists(ctx context.Context, createTable stri } // Ping checks the connection to the database. + func (db *PostgresClient) Ping(ctx context.Context) error { return db.Pool.Ping(ctx) } // GetUserByName retrieves a user by their username from the database. + func (db *PostgresClient) GetUserByName(ctx context.Context, username string) (*common.User, error) { query := ` + SELECT + users.user_id, + users.username, + users.display, + users.first_seen, + users.level, + users.settings, + json_agg(uc) AS connections + FROM users + LEFT JOIN user_connections uc ON users.user_id = uc.user_id + WHERE users.username = $1 + GROUP BY users.user_id; + ` var user common.User + err := db.Pool.QueryRow(ctx, query, username).Scan( + &user.ID, + &user.Username, + &user.Display, + &user.FirstSeen, + &user.Level, + &user.Settings, + &user.Connections, ) if err != nil { @@ -144,30 +192,52 @@ func (db *PostgresClient) GetUserByName(ctx context.Context, username string) (* } // GetUserByInternalID retrieves a user by their internal ID from the database. + func (db *PostgresClient) GetUserByInternalID(ctx context.Context, id int) (*common.User, error) { query := ` + SELECT + u.user_id, + username, + display, + first_seen, + level, + settings, + json_agg(uc) as connections + FROM users u + JOIN user_connections uc ON u.user_id = uc.user_id + WHERE u.user_id = $1 + GROUP BY u.user_id; + ` var user common.User + err := db.Pool.QueryRow(ctx, query, id).Scan( + &user.ID, + &user.Username, + &user.Display, + &user.FirstSeen, + &user.Level, + &user.Settings, + &user.Connections, ) if err != nil { @@ -178,16 +248,26 @@ func (db *PostgresClient) GetUserByInternalID(ctx context.Context, id int) (*com } // GetChannelBlocks retrieves all blocks for a given channel from the database. + func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string) *[]common.Block { query := ` + SELECT + user_id + block_id, + channel_id, + block_type, + block_data + FROM blocks + WHERE channel_id = $1 + ` rows, err := db.Pool.Query(ctx, query, channelID) @@ -198,13 +278,20 @@ func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string defer rows.Close() var blocks []common.Block + for rows.Next() { var block common.Block + err := rows.Scan( + &block.ID, + &block.BlockedUserID, + &block.ChannelID, + &block.BlockType, + &block.CommandName, ) if err != nil { @@ -218,33 +305,60 @@ func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string } // GetChannelCommands retrieves all custom channel commands for a given channel from the database. + func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID string) *[]common.ChannelCommand { query := ` + SELECT + command_id, + user_id, + channel_id, + name, + user_trigger_ids, + user_ignore_ids, + trigger, + response, + run_command, + active, + active_online, + active_offline, + reply, + whisper, + announce, + cooldown, + delay, + use_count, + created, + modified, + platform, + help + FROM custom_channel_commands + WHERE channel_id = $1 + ` rows, err := db.Pool.Query(ctx, query, channelID) @@ -255,30 +369,54 @@ func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID stri defer rows.Close() var commands []common.ChannelCommand + for rows.Next() { var command common.ChannelCommand + err := rows.Scan( + &command.CommandID, + &command.UserID, + &command.ChannelID, + &command.Name, + &command.UserTriggerIDs, + &command.UserIgnoreIDs, + &command.Trigger, + &command.Response, + &command.RunCommand, + &command.Active, + &command.ActiveOnline, + &command.ActiveOffline, + &command.Reply, + &command.Whisper, + &command.Announce, + &command.Cooldown, + &command.Delay, + &command.UseCount, + &command.Created, + &command.Modified, + &command.Platform, + &command.Help, ) if err != nil { @@ -292,65 +430,105 @@ func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID stri } // GetChannelByName retrieves a channel by its username and platform from the database. + func (db *PostgresClient) GetChannelByName( ctx context.Context, + username string, + platform common.Platforms, ) (*common.Channel, error) { return db.getChannelByType(ctx, username, platform, "NAME") } // GetChannelByID retrieves a channel by its ID and platform from the database. + func (db *PostgresClient) GetChannelByID( ctx context.Context, + channelID string, + platform common.Platforms, ) (*common.Channel, error) { return db.getChannelByType(ctx, channelID, platform, "ID") } func (db *PostgresClient) getChannelByType( //nolint:cyclop + ctx context.Context, + value string, + platform common.Platforms, + chanType string, ) (*common.Channel, error) { query := ` + SELECT + c.channel_id, + c.username, + c.joined_at, + c.added_by, + c.platform, + c.settings, + c.editors, + c.ambassadors, + c.meta, + c.state + FROM channels c + ` switch chanType { case "ID": + query += `WHERE c.channel_id = $1 ` + case "NAME": + query += `WHERE c.username = $1 ` + default: + return nil, errInvalidType } + query += `AND platform = $2;` var channel common.Channel + err := db.Pool.QueryRow(ctx, query, value, platform).Scan( + &channel.ChannelID, + &channel.Username, + &channel.JoinedAt, + &channel.AddedBy, + &channel.Platform, + &channel.Settings, + &channel.Editors, + &channel.Ambassadors, + &channel.Meta, + &channel.State, ) if err != nil { @@ -362,18 +540,24 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop wg.Add(2) var commands *[]common.ChannelCommand + go func() { defer wg.Done() + cmds := db.GetChannelCommands(ctx, channel.ChannelID) + if cmds != nil { commands = cmds } }() var blocks []common.Block + go func() { defer wg.Done() + bs := db.GetChannelBlocks(ctx, channel.ChannelID) + if bs != nil { blocks = *bs } @@ -389,17 +573,23 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop if len(blocks) > 0 { channel.Blocks = common.FilteredBlocks{ - Users: &[]common.Block{}, + Users: &[]common.Block{}, + Commands: &[]common.Block{}, } for _, block := range blocks { switch block.BlockType { case common.UserBlock: + *channel.Blocks.Users = append(*channel.Blocks.Users, block) + case common.CommandBlock: + *channel.Blocks.Commands = append(*channel.Blocks.Commands, block) + case common.GlobalBlock: + continue } } @@ -411,85 +601,160 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop } // GetPotatoData retrieves potato data for a user from the database. + func (db *PostgresClient) GetPotatoData(ctx context.Context, username string) (*common.PotatoData, error) { query := ` + SELECT + p.user_id, + p.potato_count, + p.potato_prestige, + p.potato_rank, + p.tax_multiplier, + p.first_seen, + p.stole_from, + p.stole_amount, + p.trampled_by, + a.average_response_time, + a.eat_count, + a.harvest_count, + a.stolen_count, + a.theft_count, + a.trampled_count, + a.trample_count, + a.cdr_count, + a.quiz_count, + a.quiz_complete_count, + a.guard_buy_count, + a.fertilizer_buy_count, + a.cdr_buy_count, + a.new_quiz_buy_count, + a.gamble_win_count, + a.gamble_loss_count, + a.gamble_wins_total, + a.gamble_losses_total, + a.duel_win_count, + a.duel_loss_count, + a.duel_wins_amount, + a.duel_losses_amount, + a.duel_caught_losses, + a.average_response_count, + s.not_verbose + FROM ( SELECT user_id FROM users WHERE username = $1 ) u + INNER JOIN potatoes p ON p.user_id = u.user_id + INNER JOIN potato_analytics a ON u.user_id = a.user_id + INNER JOIN potato_settings s ON u.user_id = s.user_id; + ` var data common.PotatoData err := db.Pool.QueryRow(ctx, query, username).Scan( + &data.ID, + &data.PotatoCount, + &data.PotatoPrestige, + &data.PotatoRank, + &data.TaxMultiplier, + &data.FirstSeen, + &data.StoleFrom, + &data.StoleAmount, + &data.TrampledBy, + &data.AverageResponseTime, + &data.EatCount, + &data.HarvestCount, + &data.StolenCount, + &data.TheftCount, + &data.TrampledCount, + &data.TrampleCount, + &data.CDRCount, + &data.QuizCount, + &data.QuizCompleteCount, + &data.GuardBuyCount, + &data.FertilizerBuyCount, + &data.CDRBuyCount, + &data.NewQuizBuyCount, + &data.GambleWinCount, + &data.GambleLossCount, + &data.GambleWinsTotal, + &data.GambleLossesTotal, + &data.DuelWinCount, + &data.DuelLossCount, + &data.DuelWinsAmount, + &data.DuelLossesAmount, + &data.DuelCaughtLosses, + &data.AverageResponseCount, + &data.NotVerbose, ) if err != nil { @@ -500,21 +765,34 @@ func (db *PostgresClient) GetPotatoData(ctx context.Context, username string) (* } // BatchUserConections retrieves user connections for a batch of user IDs from the database. + func (db *PostgresClient) BatchUserConections( ctx context.Context, + ids []int, ) *map[int][]common.UserConnection { query := ` + SELECT + user_id, + platform_id, + platform_username, + platform_display, + platform_pfp, + platform, + platform_metadata + FROM user_connections + WHERE user_id = ANY($1::INT[]) + ` rows, err := db.Pool.Query(ctx, query, ids) @@ -525,15 +803,24 @@ func (db *PostgresClient) BatchUserConections( defer rows.Close() users := make(map[int][]common.UserConnection) + for rows.Next() { var connection common.UserConnection + err := rows.Scan( + &connection.ID, + &connection.UserID, + &connection.Username, + &connection.Display, + &connection.PFP, + &connection.Platform, + &connection.Meta, ) if err != nil { @@ -547,10 +834,12 @@ func (db *PostgresClient) BatchUserConections( } // GetRedirectByKey retrieves a URL redirect from the database by its key. + func (db *PostgresClient) GetRedirectByKey(ctx context.Context, key string) (string, error) { query := `SELECT url FROM url_redirects WHERE key = $1` var url string + err := db.Pool.QueryRow(ctx, query, key).Scan(&url) if err != nil { return "", err @@ -560,10 +849,12 @@ func (db *PostgresClient) GetRedirectByKey(ctx context.Context, key string) (str } // GetKeyByRedirect retrieves the key associated with a given URL redirect from the database. + func (db *PostgresClient) GetKeyByRedirect(ctx context.Context, url string) (string, error) { query := `SELECT key FROM url_redirects WHERE url = $1` var key string + err := db.Pool.QueryRow(ctx, query, url).Scan(&key) if err != nil { return "", err @@ -573,6 +864,7 @@ func (db *PostgresClient) GetKeyByRedirect(ctx context.Context, url string) (str } // RedirectExists checks if a URL redirect exists in the database by its key. + func (db *PostgresClient) RedirectExists(ctx context.Context, key string) bool { query := `SELECT EXISTS(SELECT 1 FROM url_redirects WHERE key = $1)` @@ -587,6 +879,7 @@ func (db *PostgresClient) RedirectExists(ctx context.Context, key string) bool { } // NewRedirect inserts a new URL redirect into the database. + func (db *PostgresClient) NewRedirect(ctx context.Context, key, url string) error { query := `INSERT INTO url_redirects (key, url) VALUES ($1, $2)` @@ -596,12 +889,18 @@ func (db *PostgresClient) NewRedirect(ctx context.Context, key, url string) erro } // GetHaste retrieves a hastebin text document from the database by its key. + func (db *PostgresClient) GetHaste(ctx context.Context, key string) (string, error) { query := ` + UPDATE haste + SET access_count = access_count + 1 + WHERE key = $1 + RETURNING convert_from(zstd_decompress(content::bytea), 'utf-8') AS text; + ` var text string @@ -615,16 +914,24 @@ func (db *PostgresClient) GetHaste(ctx context.Context, key string) (string, err } // NewHaste inserts a new compressed hastebin text document into the database. + func (db *PostgresClient) NewHaste( ctx context.Context, + key string, + text []byte, + source string, ) error { query := ` + INSERT INTO haste (key, content, source) + VALUES ($1, zstd_compress($2, null, 8), $3) + ON CONFLICT (key) DO NOTHING; + ` _, err := db.Pool.Exec(ctx, query, encode(key), text, source) @@ -634,25 +941,37 @@ func (db *PostgresClient) NewHaste( func encode(data string) string { hash := md5.New() //nolint:gosec + hash.Write([]byte(data)) return hex.EncodeToString(hash.Sum(nil)) } // NewUpload inserts a new file into the database and returns the creation timestamp. + func (db *PostgresClient) NewUpload( ctx context.Context, + key string, + file []byte, + name string, + mimeType string, ) (bool, *time.Time) { query := ` + INSERT INTO file_store (file, file_name, mime_type, key) + VALUES ($1, $2, $3, $4) + RETURNING created_at; + ` + var createdAt time.Time + err := db.Pool.QueryRow(ctx, query, file, name, mimeType, key).Scan(&createdAt) if err != nil { logger.Error.Println("Error scanning upload", err) @@ -664,25 +983,38 @@ func (db *PostgresClient) NewUpload( } // GetFileByKey retrieves a file from the database by its key. + func (db *PostgresClient) GetFileByKey( ctx context.Context, + key string, ) ([]byte, string, *string, *time.Time, error) { query := ` + SELECT file, mime_type, file_name, created_at + FROM file_store + WHERE key = $1 + ` var content []byte + var mimeType string + var fileName *string + var createdAt time.Time err := db.Pool.QueryRow(ctx, query, key).Scan( + &content, + &mimeType, + &fileName, + &createdAt, ) if err != nil { @@ -693,13 +1025,18 @@ func (db *PostgresClient) GetFileByKey( } // DeleteFileByKey deletes a file from the database by its key. + func (db *PostgresClient) DeleteFileByKey( ctx context.Context, + key string, ) bool { query := ` + DELETE FROM file_store + WHERE key = $1 + ` _, err := db.Pool.Exec(ctx, query, key) @@ -708,17 +1045,24 @@ func (db *PostgresClient) DeleteFileByKey( } // GetUploadCreatedAt retrieves the creation timestamp of an upload by its key. + func (db *PostgresClient) GetUploadCreatedAt( ctx context.Context, + key string, ) (*time.Time, error) { query := ` + SELECT created_at + FROM file_store + WHERE key = $1 + ` var createdAt time.Time + err := db.Pool.QueryRow(ctx, query, key).Scan(&createdAt) if err != nil { return nil, err @@ -726,3 +1070,414 @@ func (db *PostgresClient) GetUploadCreatedAt( return &createdAt, nil } + +// GetAllChannels retrieves all channels with state = 'JOINED', ordered by username. + +func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelListItem, error) { + query := ` + + SELECT channel_id, username, platform, state + + FROM channels + + WHERE state = 'JOINED' + + ORDER BY username + + ` + + rows, err := db.Pool.Query(ctx, query) + if err != nil { + return nil, err + } + + defer rows.Close() + + var channels []common.ChannelListItem + + for rows.Next() { + var ch common.ChannelListItem + + if err := rows.Scan(&ch.ChannelID, &ch.Username, &ch.Platform, &ch.State); err != nil { + return nil, err + } + + channels = append(channels, ch) + } + + return channels, rows.Err() +} + +// UpdateUserSettings replaces the settings JSONB column for a user. + +func (db *PostgresClient) UpdateUserSettings(ctx context.Context, userID int, settings common.UserSettings) error { + query := `UPDATE users SET settings = $1 WHERE user_id = $2` + + _, err := db.Pool.Exec(ctx, query, settings, userID) + + return err +} + +// UpdateChannelSettings replaces the settings JSONB column for a channel. + +func (db *PostgresClient) UpdateChannelSettings( + ctx context.Context, + + channelID string, + + platform string, + + settings common.ChannelSettings, +) error { + query := `UPDATE channels SET settings = $1 WHERE channel_id = $2 AND platform = $3` + + _, err := db.Pool.Exec(ctx, query, settings, channelID, platform) + + return err +} + +// GetChannelSettingsByID retrieves only the settings column for a channel by its ID. + +func (db *PostgresClient) GetChannelSettingsByID( + ctx context.Context, + + channelID string, + + platform string, +) (common.ChannelSettings, error) { + var settings common.ChannelSettings + + err := db.Pool.QueryRow( + + ctx, + + `SELECT settings FROM channels WHERE channel_id = $1 AND platform = $2`, + + channelID, + + platform, + ).Scan(&settings) + + return settings, err +} + +// GetCommandSettings retrieves all command settings rows for a given channel. + +func (db *PostgresClient) GetCommandSettings( + ctx context.Context, + + channelID string, +) ([]common.CommandSettings, error) { + query := ` + + SELECT + + channel_id, + + command, + + permission, + + users_blacklisted, + + users_whitelisted, + + custom_cooldown, + + channel_usage, + + is_enabled, + + offline_only, + + silent_errors, + + allow_bots, + + platform, + + ambassador_granted + + FROM command_settings + + WHERE channel_id = $1 + + ORDER BY command + + ` + + rows, err := db.Pool.Query(ctx, query, channelID) + if err != nil { + return nil, err + } + + defer rows.Close() + + var results []common.CommandSettings + + for rows.Next() { + var cs common.CommandSettings + + if err := rows.Scan( + + &cs.ChannelID, + + &cs.Command, + + &cs.Permission, + + &cs.UsersBlacklisted, + + &cs.UsersWhitelisted, + + &cs.CustomCooldown, + + &cs.ChannelUsage, + + &cs.IsEnabled, + + &cs.OfflineOnly, + + &cs.SilentErrors, + + &cs.AllowBots, + + &cs.Platform, + + &cs.AmbassadorGranted, + ); err != nil { + return nil, err + } + + results = append(results, cs) + } + + return results, rows.Err() +} + +// UpsertCommandSettings inserts or updates a single command's settings row. + +func (db *PostgresClient) UpsertCommandSettings(ctx context.Context, cs common.CommandSettings) error { + query := ` + + INSERT INTO command_settings ( + + channel_id, command, permission, users_blacklisted, users_whitelisted, + + custom_cooldown, is_enabled, offline_only, silent_errors, allow_bots, + + platform, ambassador_granted + + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + + ON CONFLICT (channel_id, command, platform) DO UPDATE SET + + permission = EXCLUDED.permission, + + users_blacklisted = EXCLUDED.users_blacklisted, + + users_whitelisted = EXCLUDED.users_whitelisted, + + custom_cooldown = EXCLUDED.custom_cooldown, + + is_enabled = EXCLUDED.is_enabled, + + offline_only = EXCLUDED.offline_only, + + silent_errors = EXCLUDED.silent_errors, + + allow_bots = EXCLUDED.allow_bots, + + platform = EXCLUDED.platform, + + ambassador_granted = EXCLUDED.ambassador_granted + + ` + + platform := cs.Platform + + if platform == "" { + platform = "TWITCH" + } + + _, err := db.Pool.Exec( + + ctx, query, + + cs.ChannelID, cs.Command, cs.Permission, + + cs.UsersBlacklisted, cs.UsersWhitelisted, + + cs.CustomCooldown, cs.IsEnabled, cs.OfflineOnly, + + cs.SilentErrors, cs.AllowBots, platform, cs.AmbassadorGranted, + ) + + return err +} + +// ResetCommandSettings resets a single command's overrides back to their defaults. + +func (db *PostgresClient) ResetCommandSettings(ctx context.Context, channelID, command string) error { + query := ` + + UPDATE command_settings SET + + is_enabled = TRUE, + + offline_only = NULL, + + custom_cooldown = NULL, + + silent_errors = FALSE, + + users_whitelisted = NULL, + + users_blacklisted = NULL, + + allow_bots = NULL, + + permission = NULL, + + ambassador_granted = FALSE + + WHERE channel_id = $1 AND command = $2 + + ` + + _, err := db.Pool.Exec(ctx, query, channelID, command) + + return err +} + +// GetChannelAmbassadors returns the ambassadors slice for a channel, used for auth checks. + +func (db *PostgresClient) GetChannelAmbassadors( + ctx context.Context, + + channelID string, + + platform common.Platforms, +) ([]string, error) { + query := `SELECT ambassadors FROM channels WHERE channel_id = $1 AND platform = $2` + + var ambassadors []string + + err := db.Pool.QueryRow(ctx, query, channelID, platform).Scan(&ambassadors) + if err != nil { + return nil, err + } + + return ambassadors, nil +} + +// GetUserReminders retrieves all pending reminders for a user on a given platform. + +func (db *PostgresClient) GetUserReminders( + ctx context.Context, + + userID string, + + platform common.Platforms, +) ([]common.Reminder, error) { + query := ` + + SELECT + + reminder_id, user_id, recipient_id, channel_id, + + message, ready_at, set_at, afk_withheld, status, platform, sent_at, type + + FROM reminders + + WHERE recipient_id = $1 AND platform = $2 + + ORDER BY set_at DESC + + ` + + rows, err := db.Pool.Query(ctx, query, userID, platform) + if err != nil { + return nil, err + } + + defer rows.Close() + + var reminders []common.Reminder + + for rows.Next() { + var r common.Reminder + + if err := rows.Scan( + + &r.ReminderID, &r.UserID, &r.RecipientID, &r.ChannelID, + + &r.Message, &r.ReadyAt, &r.SetAt, &r.AfkWithheld, &r.Status, &r.Platform, &r.SentAt, &r.Type, + ); err != nil { + return nil, err + } + + reminders = append(reminders, r) + } + + return reminders, rows.Err() +} + +// DeleteReminder hard-deletes a reminder by ID, verifying the owner platform ID. + +func (db *PostgresClient) DeleteReminder(ctx context.Context, reminderID int, recipientID string) error { + query := `DELETE FROM reminders WHERE reminder_id = $1 AND recipient_id = $2` + + _, err := db.Pool.Exec(ctx, query, reminderID, recipientID) + + return err +} + +// UpsertOAuthToken stores or refreshes a platform OAuth token for a given user. + +func (db *PostgresClient) UpsertOAuthToken( + ctx context.Context, + + platformID string, + + platform common.Platforms, + + accessToken string, + + refreshToken string, + + scope []string, + + expiresIn int, +) error { + query := ` + + INSERT INTO connection_oauth ( + + platform_id, access_token, refresh_token, scope, expires_in, added_at, platform + + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + + ON CONFLICT (platform_id, platform) DO UPDATE SET + + access_token = EXCLUDED.access_token, + + refresh_token = EXCLUDED.refresh_token, + + scope = EXCLUDED.scope, + + expires_in = EXCLUDED.expires_in, + + added_at = EXCLUDED.added_at + + ` + + _, err := db.Pool.Exec( + + ctx, query, + + platformID, accessToken, refreshToken, scope, expiresIn, time.Now(), platform, + ) + + return err +} diff --git a/common/types.go b/common/types.go index caab010..362a5f9 100644 --- a/common/types.go +++ b/common/types.go @@ -1,4 +1,5 @@ // Package common provides common types and structures used across the application. + package common import ( @@ -7,396 +8,766 @@ import ( ) // Platforms represents the different platforms supported by the bot. + type Platforms string //nolint:revive + const ( - TWITCH Platforms = "TWITCH" + TWITCH Platforms = "TWITCH" + DISCORD Platforms = "DISCORD" - KICK Platforms = "KICK" - STV Platforms = "STV" + + KICK Platforms = "KICK" + + STV Platforms = "STV" + + SPOTIFY Platforms = "SPOTIFY" + + ANILIST Platforms = "ANILIST" + + TRAKT Platforms = "TRAKT" + + FFZ Platforms = "FFZ" + + BTTV Platforms = "BTTV" + + STEAM Platforms = "STEAM" ) // PermissionLevel represents the permission level of a user interally with the api and bot. + type PermissionLevel uint8 //nolint:revive + const ( - DEVELOPER PermissionLevel = 4 - ADMIN PermissionLevel = 3 - MOD PermissionLevel = 2 - USER PermissionLevel = 1 + DEVELOPER PermissionLevel = 4 + + ADMIN PermissionLevel = 3 + + MOD PermissionLevel = 2 + + USER PermissionLevel = 1 + BLACKLISTED PermissionLevel = 0 ) // User represents a user on a platform, including their first seen date, username, display name,. + type User struct { - FirstSeen time.Time `json:"first_seen"` - Username string `json:"username"` - Display string `json:"display"` - Settings UserSettings `json:"settings"` + FirstSeen time.Time `json:"first_seen"` + + Username string `json:"username"` + + Display string `json:"display"` + + Settings UserSettings `json:"settings"` + Connections []UserConnection `json:"connections,omitempty"` - ID int `json:"user_id"` - Level int `json:"level"` + + ID int `json:"user_id"` + + Level int `json:"level"` } // UserSettings represents the settings for a user on a platform, including language preferences,. + type UserSettings struct { - Language string `json:"language"` - IgnoreDropped bool `json:"ignore_dropped"` - ColorResponses bool `json:"color_responses"` - NoReply bool `json:"no_reply"` - IsBot bool `json:"is_bot"` - IsSelfBot bool `json:"is_selfbot"` + Language string `json:"language"` + + IgnoreDropped bool `json:"ignore_dropped"` + + ColorResponses bool `json:"color_responses"` + + NoReply bool `json:"no_reply"` + + IsBot bool `json:"is_bot"` + + IsSelfBot bool `json:"is_selfbot"` } // UserMeta represents the metadata associated with a user on a platform. + type UserMeta = json.RawMessage // UserConnection represents a connection between a user and a platform. + type UserConnection struct { Platform Platforms `json:"platform"` - Username string `json:"platform_username"` - Display string `json:"platform_display"` - UserID string `json:"platform_id"` - PFP string `json:"platform_pfp"` - Meta UserMeta `json:"platform_metadata"` - ID int `json:"user_id"` + + Username string `json:"platform_username"` + + Display string `json:"platform_display"` + + UserID string `json:"platform_id"` + + PFP string `json:"platform_pfp"` + + Meta UserMeta `json:"platform_metadata"` + + ID int `json:"user_id"` } // KickChannelMeta represents the metadata for a channel on Kick, including the channel ID, chatroom ID, and user ID. + type KickChannelMeta struct { - ChannelID string `json:"channel_id"` + ChannelID string `json:"channel_id"` + ChatroomID string `json:"chatroom_id"` - UserID string `json:"user_id"` + + UserID string `json:"user_id"` } // TwitchChannelMeta represents the metadata for a channel on Twitch, including whether the bot and channel are banned. + type TwitchChannelMeta struct { - BotBanned bool `json:"bot_banned"` + BotBanned bool `json:"bot_banned"` + TwitchBanned bool `json:"twitch_banned"` } // TwitchUserMeta represents the metadata for a user on Twitch, including their color and roles. + type TwitchUserMeta struct { - Color string `json:"color,omitempty"` + Color string `json:"color,omitempty"` + Roles TwitchRoles `json:"roles,omitzero"` } // StvUserMeta represents the metadata for a user on 7TV, including their paint ID and roles. + type StvUserMeta struct { - PaintID string `json:"paint_id,omitempty"` - Roles []string `json:"roles,omitempty"` + PaintID string `json:"paint_id,omitempty"` + + Roles []string `json:"roles,omitempty"` } // TwitchRoles represents the roles a user can have on Twitch, such as staff, partner, or affiliate. + type TwitchRoles struct { - IsStaff bool `json:"isStaff,omitempty"` - IsPartner bool `json:"isPartner,omitempty"` + IsStaff bool `json:"isStaff,omitempty"` + + IsPartner bool `json:"isPartner,omitempty"` + IsAffiliate bool `json:"isAffiliate,omitempty"` } // Channel represents a channel on a platform, including its blocks, settings, commands, and other metadata. + type Channel struct { - Blocks FilteredBlocks `json:"blocks,omitzero"` - JoinedAt *time.Time `json:"joined_at,omitempty"` - Meta map[string]any `json:"meta"` - Commands *[]ChannelCommand `json:"commands,omitempty"` - ChannelID string `json:"channel_id"` - Username string `json:"username"` - Platform Platforms `json:"platform"` - State string `json:"state"` - AddedBy []AddedByData `json:"added_by,omitempty"` - Editors []string `json:"editors"` - Ambassadors []string `json:"ambassadors"` - Settings ChannelSettings `json:"settings"` + Blocks FilteredBlocks `json:"blocks,omitzero"` + + JoinedAt *time.Time `json:"joined_at,omitempty"` + + Meta map[string]any `json:"meta"` + + Commands *[]ChannelCommand `json:"commands,omitempty"` + + ChannelID string `json:"channel_id"` + + Username string `json:"username"` + + Platform Platforms `json:"platform"` + + State string `json:"state"` + + AddedBy []AddedByData `json:"added_by,omitempty"` + + Editors []string `json:"editors"` + + Ambassadors []string `json:"ambassadors"` + + Settings ChannelSettings `json:"settings"` } // FilteredBlocks represents the blocks of users and commands in a channel. + type FilteredBlocks struct { - Users *[]Block `json:"users"` + Users *[]Block `json:"users"` + Commands *[]Block `json:"commands"` } // Block represents a block connection between a user and a command or another user. + type Block struct { - ChannelID string `json:"channel_id"` - BlockType BlockType `json:"block_type"` - CommandName string `json:"command_name,omitempty"` - ID int `json:"user_id"` - BlockedUserID int `json:"blocked_user_id"` + ChannelID string `json:"channel_id"` + + BlockType BlockType `json:"block_type"` + + CommandName string `json:"command_name,omitempty"` + + ID int `json:"user_id"` + + BlockedUserID int `json:"blocked_user_id"` } // BlockType represents the type of block connection that can be made between a user or a command. + type BlockType string //nolint:revive + const ( - UserBlock BlockType = "USER" + UserBlock BlockType = "USER" + CommandBlock BlockType = "COMMAND" - GlobalBlock BlockType = "GLOBAL" + + GlobalBlock BlockType = "GLOBAL" ) // AddedByData represents the data of a user who added a channel. + type AddedByData struct { - AddedAt time.Time `json:"addedAt"` - Username string `json:"username"` - ID string `json:"id"` + AddedAt time.Time `json:"addedAt"` + + Username string `json:"username"` + + ID string `json:"id"` } // ChannelSettings represents the settings for a channel, including bot settings, cooldowns, language, + // permission levels, and other configurations. + type ChannelSettings struct { - PajBot *string `json:"paj_bot,omitempty"` - UserCooldown *int `json:"user_cooldown,omitempty"` - ChannelCooldown *int `json:"channel_cooldown,omitempty"` - Language string `json:"language"` - Permission string `json:"permission"` - Prefix string `json:"prefix"` - UsersBlacklisted []string `json:"users_blacklisted"` - NoReply bool `json:"no_reply"` - FirstMsgResponses bool `json:"first_msg_responses"` - WhisperOnly bool `json:"whisper_only"` - OfflineOnly bool `json:"offline_only"` - ForceLanguage bool `json:"force_language"` - SilentErrors bool `json:"silent_errors"` - ColorResponses bool `json:"color_responses"` + PajBot *string `json:"paj_bot,omitempty"` + + UserCooldown *int `json:"user_cooldown,omitempty"` + + ChannelCooldown *int `json:"channel_cooldown,omitempty"` + + OnlinePermission *string `json:"online_permission,omitempty"` + + EmoteStreakResponse *string `json:"emote_streak_response,omitempty"` + + PyramidResponse *string `json:"pyramid_response,omitempty"` + + OnlineSilentErrors *bool `json:"online_silent_errors,omitempty"` + + OnlineWhisperOnly *bool `json:"online_whisper_only,omitempty"` + + AllowBotEmoteTracking *bool `json:"allow_bot_emote_tracking,omitempty"` + + IgnoreDropped *bool `json:"ignore_dropped,omitempty"` + + NoLinks *bool `json:"no_links,omitempty"` + + ForcePyramidNotVerbose *bool `json:"force_pyramid_not_verbose,omitempty"` + + Language string `json:"language"` + + Permission string `json:"permission"` + + Prefix string `json:"prefix"` + + UsersBlacklisted []string `json:"users_blacklisted"` + + NoReply bool `json:"no_reply"` + + FirstMsgResponses bool `json:"first_msg_responses"` + + WhisperOnly bool `json:"whisper_only"` + + OfflineOnly bool `json:"offline_only"` + + ForceLanguage bool `json:"force_language"` + + SilentErrors bool `json:"silent_errors"` + + ColorResponses bool `json:"color_responses"` } // CommandSettings represents the settings for a command in a channel, including its permissions, cooldowns, + // and usage limits. -type CommandSettings struct { - ChannelID string `json:"channel_id"` - Command string `json:"command"` - Permission string `json:"permission"` - UsersBlacklisted []string `json:"users_blacklisted"` - UsersWhitelisted []string `json:"users_whitelisted"` - CustomCooldown int `json:"custom_cooldown"` - ChannelUsage int `json:"channel_usage"` - IsEnabled bool `json:"is_enabled"` - OfflineOnly bool `json:"offline_only"` - SilentErrors bool `json:"silent_errors"` - AllowBots bool `json:"allow_bots"` + +type CommandSettings struct { //nolint:govet + + ChannelID string `json:"channel_id"` + + Command string `json:"command"` + + Permission *string `json:"permission,omitempty"` + + UsersBlacklisted []string `json:"users_blacklisted,omitempty"` + + UsersWhitelisted []string `json:"users_whitelisted,omitempty"` + + CustomCooldown *int `json:"custom_cooldown,omitempty"` + + ChannelUsage int `json:"channel_usage"` + + IsEnabled bool `json:"is_enabled"` + + OfflineOnly *bool `json:"offline_only,omitempty"` + + SilentErrors bool `json:"silent_errors"` + + AllowBots *bool `json:"allow_bots,omitempty"` + + Platform string `json:"platform"` + + AmbassadorGranted bool `json:"ambassador_granted"` } // PlatformOauth represents the OAuth token and metadata for a user on a specific platform, including. + type PlatformOauth struct { - AddedAt time.Time `json:"added_at"` - PlatformID string `json:"platform_id"` - AccessToken string `json:"access_token"` //nolint:gosec - RefreshToken string `json:"refresh_token"` //nolint:gosec - Platform Platforms `json:"platform"` - Scope []string `json:"scope"` - ExpiresIn int `json:"expires_in"` + AddedAt time.Time `json:"added_at"` + + PlatformID string `json:"platform_id"` + + AccessToken string `json:"access_token"` //nolint:gosec + + RefreshToken string `json:"refresh_token"` //nolint:gosec + + Platform Platforms `json:"platform"` + + Scope []string `json:"scope"` + + ExpiresIn int `json:"expires_in"` } // ChannelCommand represents a single custom command, including its properties and settings. + type ChannelCommand struct { - Created time.Time `json:"created"` - Modified time.Time `json:"modified"` - Help *string `json:"help"` - Name *string `json:"name"` - RunCommand *string `json:"run_command"` - ChannelID string `json:"channel_id"` - Trigger string `json:"trigger"` - Response string `json:"response"` - Platform string `json:"platform"` - UserTriggerIDs []string `json:"user_trigger_ids"` - UserIgnoreIDs []string `json:"user_ignore_ids"` - Cooldown int `json:"cooldown"` - UserID int `json:"user_id"` - Delay int `json:"delay"` - UseCount int `json:"use_count"` - CommandID int `json:"command_id"` - Reply bool `json:"reply"` - Whisper bool `json:"whisper"` - Announce bool `json:"announce"` - ActiveOffline bool `json:"active_offline"` - ActiveOnline bool `json:"active_online"` - Active bool `json:"active"` + Created time.Time `json:"created"` + + Modified time.Time `json:"modified"` + + Help *string `json:"help"` + + Name *string `json:"name"` + + RunCommand *string `json:"run_command"` + + ChannelID string `json:"channel_id"` + + Trigger string `json:"trigger"` + + Response string `json:"response"` + + Platform string `json:"platform"` + + UserTriggerIDs []string `json:"user_trigger_ids"` + + UserIgnoreIDs []string `json:"user_ignore_ids"` + + Cooldown int `json:"cooldown"` + + UserID int `json:"user_id"` + + Delay int `json:"delay"` + + UseCount int `json:"use_count"` + + CommandID int `json:"command_id"` + + Reply bool `json:"reply"` + + Whisper bool `json:"whisper"` + + Announce bool `json:"announce"` + + ActiveOffline bool `json:"active_offline"` + + ActiveOnline bool `json:"active_online"` + + Active bool `json:"active"` } // Potatoes represents the structure of potato-related data for a user. + type Potatoes struct { - StoleFrom *string `json:"stole_from"` - StoleAmount *int `json:"stole_amount"` - TrampledBy *string `json:"trampled_by"` - FirstSeen string `json:"first_seen"` - ID int `json:"user_id"` - PotatoCount int `json:"potato_count"` - PotatoPrestige int `json:"potato_prestige"` - PotatoRank int `json:"potato_rank"` - TaxMultiplier int `json:"tax_multiplier"` + StoleFrom *string `json:"stole_from"` + + StoleAmount *int `json:"stole_amount"` + + TrampledBy *string `json:"trampled_by"` + + FirstSeen string `json:"first_seen"` + + ID int `json:"user_id"` + + PotatoCount int `json:"potato_count"` + + PotatoPrestige int `json:"potato_prestige"` + + PotatoRank int `json:"potato_rank"` + + TaxMultiplier int `json:"tax_multiplier"` } // PotatoAnalytics contains various statistics related to potato interactions. + type PotatoAnalytics struct { - AverageResponseTime string `json:"average_response_time"` - EatCount int `json:"eat_count"` - HarvestCount int `json:"harvest_count"` - StolenCount int `json:"stolen_count"` - TheftCount int `json:"theft_count"` - TrampledCount int `json:"trampled_count"` - TrampleCount int `json:"trample_count"` - CDRCount int `json:"cdr_count"` - QuizCount int `json:"quiz_count"` - QuizCompleteCount int `json:"quiz_complete_count"` - GuardBuyCount int `json:"guard_buy_count"` - FertilizerBuyCount int `json:"fertilizer_buy_count"` - CDRBuyCount int `json:"cdr_buy_count"` - NewQuizBuyCount int `json:"new_quiz_buy_count"` - GambleWinCount int `json:"gamble_win_count"` - GambleLossCount int `json:"gamble_loss_count"` - GambleWinsTotal int `json:"gamble_wins_total"` - GambleLossesTotal int `json:"gamble_losses_total"` - DuelWinCount int `json:"duel_win_count"` - DuelLossCount int `json:"duel_loss_count"` - DuelWinsAmount int `json:"duel_wins_amount"` - DuelLossesAmount int `json:"duel_losses_amount"` - DuelCaughtLosses int `json:"duel_caught_losses"` - AverageResponseCount int `json:"average_response_count"` + AverageResponseTime string `json:"average_response_time"` + + EatCount int `json:"eat_count"` + + HarvestCount int `json:"harvest_count"` + + StolenCount int `json:"stolen_count"` + + TheftCount int `json:"theft_count"` + + TrampledCount int `json:"trampled_count"` + + TrampleCount int `json:"trample_count"` + + CDRCount int `json:"cdr_count"` + + QuizCount int `json:"quiz_count"` + + QuizCompleteCount int `json:"quiz_complete_count"` + + GuardBuyCount int `json:"guard_buy_count"` + + FertilizerBuyCount int `json:"fertilizer_buy_count"` + + CDRBuyCount int `json:"cdr_buy_count"` + + NewQuizBuyCount int `json:"new_quiz_buy_count"` + + GambleWinCount int `json:"gamble_win_count"` + + GambleLossCount int `json:"gamble_loss_count"` + + GambleWinsTotal int `json:"gamble_wins_total"` + + GambleLossesTotal int `json:"gamble_losses_total"` + + DuelWinCount int `json:"duel_win_count"` + + DuelLossCount int `json:"duel_loss_count"` + + DuelWinsAmount int `json:"duel_wins_amount"` + + DuelLossesAmount int `json:"duel_losses_amount"` + + DuelCaughtLosses int `json:"duel_caught_losses"` + + AverageResponseCount int `json:"average_response_count"` } // PotatoSettings represents the settings related to potato interactions, such as verbosity. + type PotatoSettings struct { NotVerbose bool `json:"not_verbose"` } // PotatoData combines the Potatoes, PotatoAnalytics, and PotatoSettings structures into a single structure. + type PotatoData struct { Potatoes + PotatoAnalytics + PotatoSettings } // Redirect represents a URL redirect structure, typically used for OAuth flows. + type Redirect struct { Key string `json:"key"` + URL string `json:"url"` } // ErrorMessage represents a structure for error messages returned in API responses. + type ErrorMessage struct { Message string `json:"message"` } // GenericResponse represents a generic API response structure, which can include data and errors. + type GenericResponse[T any] struct { - Data *[]T `json:"data"` + Data *[]T `json:"data"` + Errors *[]ErrorMessage `json:"errors,omitempty"` } // TwitchValidation represents the structure of a Twitch OAuth validation response. + type TwitchValidation struct { - ClientID string `json:"client_id"` - Login string `json:"login"` - UserID string `json:"user_id"` - Scopes []string `json:"scopes"` - ExpiresIn int `json:"expires_in"` - StatusCode int `json:"status_code"` + ClientID string `json:"client_id"` + + Login string `json:"login"` + + UserID string `json:"user_id"` + + Scopes []string `json:"scopes"` + + ExpiresIn int `json:"expires_in"` + + StatusCode int `json:"status_code"` } // GenericOAUTHResponse represents a generic OAuth response structure. + type GenericOAUTHResponse struct { - AccessToken string `json:"access_token"` //nolint:gosec - RefreshToken string `json:"refresh_token"` //nolint:gosec - TokenType string `json:"token_type"` - Scope []string `json:"scope"` - ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` //nolint:gosec + + RefreshToken string `json:"refresh_token"` //nolint:gosec + + TokenType string `json:"token_type"` + + Scope []string `json:"scope"` + + ExpiresIn int `json:"expires_in"` } // Command represents chat command, with its permissions, description, and other details. + type Command struct { - Conditions CommandConditions `json:"conditions"` - BotRequires BotCommandRequirements `json:"botRequires"` - Description string `json:"description"` - DetailedDescription string `json:"detailedDescription,omitempty"` - Title string `json:"title"` - Usage string `json:"usage"` - Category CommandCategories `json:"category"` - Name string `json:"name"` - UserRequires UserRequires `json:"userRequires"` - Aliases []string `json:"aliases"` - Flags []FlagDetails `json:"flags"` - Cooldown int `json:"cooldown"` - Level PermissionLevel `json:"level"` + Conditions CommandConditions `json:"conditions"` + + BotRequires BotCommandRequirements `json:"botRequires"` + + Description string `json:"description"` + + DetailedDescription string `json:"detailedDescription,omitempty"` + + Title string `json:"title"` + + Usage string `json:"usage"` + + Category CommandCategories `json:"category"` + + Name string `json:"name"` + + UserRequires UserRequires `json:"userRequires"` + + Aliases []string `json:"aliases"` + + Flags []FlagDetails `json:"flags"` + + Cooldown int `json:"cooldown"` + + Level PermissionLevel `json:"level"` } // CommandCategories represents the category of a command. + type CommandCategories string //nolint:revive + const ( Development CommandCategories = "development" - Deprecated CommandCategories = "deprecated" - Moderation CommandCategories = "moderation" - Utilities CommandCategories = "utilities" - Unlisted CommandCategories = "unlisted" - Settings CommandCategories = "settings" - Stream CommandCategories = "stream" - Potato CommandCategories = "potato" - Emotes CommandCategories = "emotes" - Anime CommandCategories = "anime" - Music CommandCategories = "music" - Spam CommandCategories = "spam" - Misc CommandCategories = "misc" - Fun CommandCategories = "fun" + + Deprecated CommandCategories = "deprecated" + + Moderation CommandCategories = "moderation" + + Utilities CommandCategories = "utilities" + + Unlisted CommandCategories = "unlisted" + + Settings CommandCategories = "settings" + + Stream CommandCategories = "stream" + + Potato CommandCategories = "potato" + + Emotes CommandCategories = "emotes" + + Anime CommandCategories = "anime" + + Music CommandCategories = "music" + + Spam CommandCategories = "spam" + + Misc CommandCategories = "misc" + + Fun CommandCategories = "fun" ) // FlagDetails represents the details of a command flag, including its requirements, usage, and validation function. + type FlagDetails struct { - UserRequires *UserRequires `json:"user_requires,omitempty"` - Usage *string `json:"usage,omitempty"` - Multi *bool `json:"multi,omitempty"` - Check func(params Flags, flag FlagDetails) (FlagCheckResult, error) `json:"-"` - Name string `json:"name"` - Type string `json:"type"` - Description string `json:"description"` - Aliases []string `json:"aliases,omitempty"` - Level PermissionLevel `json:"level"` - Required bool `json:"required"` + UserRequires *UserRequires `json:"user_requires,omitempty"` + + Usage *string `json:"usage,omitempty"` + + Multi *bool `json:"multi,omitempty"` + + Check func(params Flags, flag FlagDetails) (FlagCheckResult, error) `json:"-"` + + Name string `json:"name"` + + Type string `json:"type"` + + Description string `json:"description"` + + Aliases []string `json:"aliases,omitempty"` + + Level PermissionLevel `json:"level"` + + Required bool `json:"required"` } // FlagCheckResult represents the result of a flag check, indicating whether the flag is valid, + // and any associated error or requirement. + type FlagCheckResult struct { MustBe *string `json:"must_be,omitempty"` - Error *string `json:"error,omitempty"` - Valid bool `json:"valid"` + + Error *string `json:"error,omitempty"` + + Valid bool `json:"valid"` } // UserRequires represents the required user permission level for a command or flag. + type UserRequires string //nolint:revive + const ( - None UserRequires = "NONE" - Subscriber UserRequires = "SUBSCRIBER" - VIP UserRequires = "VIP" - Mod UserRequires = "MOD" - Ambassador UserRequires = "AMBASSADOR" + None UserRequires = "NONE" + + Subscriber UserRequires = "SUBSCRIBER" + + VIP UserRequires = "VIP" + + Mod UserRequires = "MOD" + + Ambassador UserRequires = "AMBASSADOR" + Broadcaster UserRequires = "BROADCASTER" ) // Flags represents a map of flags, where each flag is identified by a string key and can hold any type of value. + type Flags map[string]any // CommandConditions represents the conditions under which a command can be executed. + type CommandConditions struct { - Ryan *bool `json:"ryan,omitempty"` - OfflineOnly *bool `json:"offlineOnly,omitempty"` - Whisperable *bool `json:"whisperable,omitempty"` - IgnoreBots *bool `json:"ignoreBots,omitempty"` - IsBlockable *bool `json:"isBlockable,omitempty"` + Ryan *bool `json:"ryan,omitempty"` + + OfflineOnly *bool `json:"offlineOnly,omitempty"` + + Whisperable *bool `json:"whisperable,omitempty"` + + IgnoreBots *bool `json:"ignoreBots,omitempty"` + + IsBlockable *bool `json:"isBlockable,omitempty"` + IsNotPipable *bool `json:"isNotPipable,omitempty"` } // BotCommandRequirements represents the requirements for a bot to execute a command,. + type BotCommandRequirements string //nolint:revive + const ( BotNone BotCommandRequirements = "NONE" - BotVIP BotCommandRequirements = "VIP" - BotMod BotCommandRequirements = "MOD" + + BotVIP BotCommandRequirements = "VIP" + + BotMod BotCommandRequirements = "MOD" ) + +// ChannelListItem is a lightweight channel representation used for list endpoints (e.g. nav search). + +type ChannelListItem struct { + ChannelID string `json:"channel_id"` + + Username string `json:"username"` + + Platform Platforms `json:"platform"` + + State string `json:"state"` +} + +// EmoteStat represents an aggregated emote usage entry from Clickhouse. + +type EmoteStat struct { + EmoteID string `json:"emote_id"` + + EmoteName string `json:"emote_name"` + + EmoteAlias string `json:"emote_alias"` + + Provider string `json:"provider"` + + Count int64 `json:"count"` +} + +// EmoteHistoryEntry represents a single per-user emote usage record from Clickhouse. + +type EmoteHistoryEntry struct { //nolint:govet + + EmoteID string `json:"emote_id"` + + EmoteName string `json:"emote_name"` + + EmoteAlias string `json:"emote_alias"` + + Provider string `json:"provider"` + + ChannelID string `json:"channel_id"` + + UserID string `json:"user_id"` + + Count int64 `json:"count"` + + UsedAt time.Time `json:"used_at"` +} + +// PageInfo contains cursor-based pagination metadata. + +type PageInfo struct { //nolint:govet + + HasNextPage bool `json:"hasNextPage"` + + Cursor string `json:"cursor"` +} + +// EmoteStatsResponse is the paginated response shape for GET /emotes/stats. + +type EmoteStatsResponse struct { + Data *[]EmoteStat `json:"data"` + + Pagination PageInfo `json:"pagination"` + + StatusCode int `json:"statusCode"` + + Duration float64 `json:"duration"` +} + +// Reminder represents a scheduled or on-next-seen reminder message. + +type Reminder struct { + SetAt time.Time `json:"set_at"` + + SentAt *time.Time `json:"sent_at,omitempty"` + + ReadyAt *time.Time `json:"ready_at,omitempty"` + + UserID string `json:"user_id"` + + RecipientID string `json:"recipient_id"` + + ChannelID string `json:"channel_id"` + + Message string `json:"message"` + + Status string `json:"status"` + + Platform string `json:"platform"` + + Type string `json:"type"` + + ReminderID int `json:"reminder_id"` + + AfkWithheld bool `json:"afk_withheld"` +} diff --git a/common/utils/broker.go b/common/utils/broker.go index 7b4304d..ad3d805 100644 --- a/common/utils/broker.go +++ b/common/utils/broker.go @@ -139,16 +139,15 @@ func (n *NatsClient) handleMessage(message *nats.Msg) { } // BridgeRequest sends a request to the NATS server and waits for a response. -func BridgeRequest( +func (n *NatsClient) BridgeRequest( ttl time.Duration, request string, ) ([]byte, error) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - return nil, fmt.Errorf("failed to publish request: %w", err) + if n == nil || n.Client == nil { + return nil, errNatsNotConnected } - response, err := nc.Request( + response, err := n.Client.Request( "github.com/Potat-Industries/potat-api.job-request", []byte(request), ttl, diff --git a/exampleconfig.json b/exampleconfig.json index 6abf843..47319cb 100644 --- a/exampleconfig.json +++ b/exampleconfig.json @@ -4,6 +4,56 @@ "client_secret": "asdf1234", "oauth_uri": "https://api.potat.industries" }, + "discord": { + "id": "", + "oauth": "", + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "spotify": { + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "kick": { + "id": "", + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/", + "oauth": "" + }, + "anilist": { + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "trakt": { + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "ffz": { + "id": "", + "client_id": "", + "client_secret": "", + "token": "", + "refresh": "", + "oauth_uri": "https://your-domain.com/" + }, + "stv": { + "token": "", + "id": "" + }, + "bttv": { + "token": "", + "id": "" + }, + "misc": { + "steam": { + "key": "" + } + }, "loops": { "enabled": true }, diff --git a/main.go b/main.go index 9e293a6..a77fdf0 100644 --- a/main.go +++ b/main.go @@ -15,8 +15,11 @@ import ( "time" "github.com/Potat-Industries/potat-api/api" + _ "github.com/Potat-Industries/potat-api/api/routes/del" _ "github.com/Potat-Industries/potat-api/api/routes/get" + _ "github.com/Potat-Industries/potat-api/api/routes/patch" _ "github.com/Potat-Industries/potat-api/api/routes/post" + _ "github.com/Potat-Industries/potat-api/api/routes/put" "github.com/Potat-Industries/potat-api/common" "github.com/Potat-Industries/potat-api/common/db" "github.com/Potat-Industries/potat-api/common/logger" @@ -107,7 +110,7 @@ func main() { //nolint:cyclop apiChan := make(chan error) if config.API.Enabled { go func() { - apiChan <- api.StartServing(*config, postgres, redis, clickhouse, metrics) + apiChan <- api.StartServing(*config, postgres, redis, clickhouse, nats, metrics) }() } @@ -220,8 +223,9 @@ func initClickhouse(ctx context.Context, config common.Config) *db.ClickhouseCli func initNats(ctx context.Context) *utils.NatsClient { nats, err := utils.CreateNatsBroker(ctx) if err != nil { - logger.Error.Panicf("Failed to connect to RabbitMQ: %v", err) + logger.Error.Panicf("Failed to connect to NATS: %v", err) } + logger.Info.Println("NATS initialized") return nats } diff --git a/redirects/redirects_test.go b/redirects/redirects_test.go index c2c40f8..4003c01 100644 --- a/redirects/redirects_test.go +++ b/redirects/redirects_test.go @@ -11,21 +11,28 @@ func TestRedirects__CheckProtocolFormatAfterProtocolReformat(t *testing.T) { redirector := redirects{} tests := []struct { - input string + input string + expected string }{ {"https://google.com", "https://google.com"}, + {"http://google.com", "https://google.com"}, + {"//google.com", "https://google.com"}, + {"google.com", "https://google.com"}, } for _, tc := range tests { t.Run(tc.input, func(t *testing.T) { cleanedURL := redirector.cleanRedirectProtocolSoLinksActuallyWork(tc.input) + assert.Truef( + t, strings.HasPrefix(cleanedURL, "https://"), "Expected cleaned URL to start with 'https://', got %q", cleanedURL, ) + assert.Equal(t, tc.expected, cleanedURL) }) } diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..9b87f51 --- /dev/null +++ b/schema.sql @@ -0,0 +1,259 @@ +-- potat-api schema +-- Run against the database before starting the API: +-- docker exec -i potat-api-postgres-1 psql -U potat -d potat < schema.sql +-- +-- NOTE: the haste service requires the pgzstd Postgres extension +-- (https://github.com/grahamedgecombe/pgzstd). The standard postgres:16 +-- Docker image does not include it, so leave haste.enabled=false in +-- config.docker.json unless you build a custom image with the extension. + +-- --------------------------------------------------------------------------- +-- Core user tables (required for /login and all user-facing routes) +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS users ( + user_id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + display TEXT NOT NULL DEFAULT '', + first_seen TIMESTAMPTZ NOT NULL DEFAULT NOW(), + level INT NOT NULL DEFAULT 1, + settings JSONB NOT NULL DEFAULT '{}' +); + +CREATE TABLE IF NOT EXISTS user_connections ( + user_id INT NOT NULL REFERENCES users(user_id) ON DELETE CASCADE, + platform_id TEXT NOT NULL, + platform_username TEXT NOT NULL DEFAULT '', + platform_display TEXT NOT NULL DEFAULT '', + platform_pfp TEXT NOT NULL DEFAULT '', + platform TEXT NOT NULL, + platform_metadata JSONB NOT NULL DEFAULT '{}', + PRIMARY KEY (user_id, platform, platform_id) +); + +CREATE TABLE IF NOT EXISTS connection_oauth ( + platform_id TEXT NOT NULL, + platform TEXT NOT NULL, + access_token TEXT NOT NULL, + refresh_token TEXT NOT NULL DEFAULT '', + scope TEXT[] NOT NULL DEFAULT '{}', + expires_in INT NOT NULL DEFAULT 0, + added_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (platform_id, platform) +); + +-- --------------------------------------------------------------------------- +-- Channel tables +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS channels ( + channel_id TEXT NOT NULL, + username TEXT NOT NULL, + joined_at TIMESTAMPTZ, + added_by JSONB NOT NULL DEFAULT '[]', + platform TEXT NOT NULL DEFAULT 'TWITCH', + settings JSONB NOT NULL DEFAULT '{}', + editors TEXT[] NOT NULL DEFAULT '{}', + ambassadors TEXT[] NOT NULL DEFAULT '{}', + meta JSONB NOT NULL DEFAULT '{}', + state TEXT NOT NULL DEFAULT 'JOINED', + PRIMARY KEY (channel_id, platform) +); + +CREATE TABLE IF NOT EXISTS blocks ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL, + block_id INT NOT NULL DEFAULT 0, + channel_id TEXT NOT NULL, + block_type TEXT NOT NULL DEFAULT 'USER', + block_data TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE IF NOT EXISTS custom_channel_commands ( + command_id SERIAL PRIMARY KEY, + user_id INT NOT NULL, + channel_id TEXT NOT NULL, + name TEXT, + user_trigger_ids TEXT[] NOT NULL DEFAULT '{}', + user_ignore_ids TEXT[] NOT NULL DEFAULT '{}', + trigger TEXT NOT NULL, + response TEXT NOT NULL DEFAULT '', + run_command TEXT, + active BOOLEAN NOT NULL DEFAULT TRUE, + active_online BOOLEAN NOT NULL DEFAULT TRUE, + active_offline BOOLEAN NOT NULL DEFAULT TRUE, + reply BOOLEAN NOT NULL DEFAULT FALSE, + whisper BOOLEAN NOT NULL DEFAULT FALSE, + announce BOOLEAN NOT NULL DEFAULT FALSE, + cooldown INT NOT NULL DEFAULT 5, + delay INT NOT NULL DEFAULT 0, + use_count INT NOT NULL DEFAULT 0, + created TIMESTAMPTZ NOT NULL DEFAULT NOW(), + modified TIMESTAMPTZ NOT NULL DEFAULT NOW(), + platform TEXT NOT NULL DEFAULT 'TWITCH', + help TEXT +); + +CREATE TABLE IF NOT EXISTS command_settings ( + channel_id TEXT NOT NULL, + command TEXT NOT NULL, + permission TEXT NOT NULL DEFAULT 'NONE', + users_blacklisted TEXT[] NOT NULL DEFAULT '{}', + users_whitelisted TEXT[] NOT NULL DEFAULT '{}', + custom_cooldown INT NOT NULL DEFAULT 0, + channel_usage INT NOT NULL DEFAULT 0, + is_enabled BOOLEAN NOT NULL DEFAULT TRUE, + offline_only BOOLEAN NOT NULL DEFAULT FALSE, + silent_errors BOOLEAN NOT NULL DEFAULT FALSE, + allow_bots BOOLEAN NOT NULL DEFAULT FALSE, + platform TEXT NOT NULL DEFAULT 'TWITCH', + PRIMARY KEY (channel_id, command, platform) +); + +CREATE TABLE IF NOT EXISTS channel_command_usage ( + channel_id TEXT PRIMARY KEY, + channel_usage BIGINT NOT NULL DEFAULT 0 +); + +-- --------------------------------------------------------------------------- +-- Potato game tables +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS potatoes ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + potato_count INT NOT NULL DEFAULT 0, + potato_prestige INT NOT NULL DEFAULT 0, + potato_rank INT NOT NULL DEFAULT 0, + tax_multiplier INT NOT NULL DEFAULT 0, + first_seen TEXT NOT NULL DEFAULT '', + stole_from TEXT, + stole_amount INT, + trampled_by TEXT +); + +CREATE TABLE IF NOT EXISTS potato_analytics ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + average_response_time TEXT NOT NULL DEFAULT '0', + eat_count INT NOT NULL DEFAULT 0, + harvest_count INT NOT NULL DEFAULT 0, + stolen_count INT NOT NULL DEFAULT 0, + theft_count INT NOT NULL DEFAULT 0, + trampled_count INT NOT NULL DEFAULT 0, + trample_count INT NOT NULL DEFAULT 0, + cdr_count INT NOT NULL DEFAULT 0, + quiz_count INT NOT NULL DEFAULT 0, + quiz_complete_count INT NOT NULL DEFAULT 0, + guard_buy_count INT NOT NULL DEFAULT 0, + fertilizer_buy_count INT NOT NULL DEFAULT 0, + cdr_buy_count INT NOT NULL DEFAULT 0, + new_quiz_buy_count INT NOT NULL DEFAULT 0, + gamble_win_count INT NOT NULL DEFAULT 0, + gamble_loss_count INT NOT NULL DEFAULT 0, + gamble_wins_total INT NOT NULL DEFAULT 0, + gamble_losses_total INT NOT NULL DEFAULT 0, + duel_win_count INT NOT NULL DEFAULT 0, + duel_loss_count INT NOT NULL DEFAULT 0, + duel_wins_amount INT NOT NULL DEFAULT 0, + duel_losses_amount INT NOT NULL DEFAULT 0, + duel_caught_losses INT NOT NULL DEFAULT 0, + average_response_count INT NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS potato_settings ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + not_verbose BOOLEAN NOT NULL DEFAULT FALSE +); + +-- --------------------------------------------------------------------------- +-- GPT usage table (reset by hourly/daily/weekly cron loops) +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS gpt_usage ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + hourly_usage INT NOT NULL DEFAULT 0, + daily_usage INT NOT NULL DEFAULT 0, + weekly_usage INT NOT NULL DEFAULT 0 +); + +-- --------------------------------------------------------------------------- +-- Service tables (auto-created on startup when the service is enabled, +-- included here so they can be pre-created for a fresh database) +-- --------------------------------------------------------------------------- + +-- Requires pgzstd extension — omit if not using the haste service. +-- CREATE EXTENSION IF NOT EXISTS pgzstd; +-- CREATE TABLE IF NOT EXISTS haste ( +-- key CHAR(32) UNIQUE NOT NULL, +-- content BYTEA NOT NULL, +-- access_count INT NOT NULL DEFAULT 1, +-- source TEXT NOT NULL DEFAULT 'potatbotat', +-- timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW() +-- ); + +CREATE TABLE IF NOT EXISTS url_redirects ( + key VARCHAR(9) PRIMARY KEY, + url VARCHAR(500) NOT NULL +); + +CREATE TABLE IF NOT EXISTS file_store ( + key VARCHAR(50) PRIMARY KEY, + file BYTEA NOT NULL, + file_name VARCHAR(50), + mime_type VARCHAR(50) NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- --------------------------------------------------------------------------- +-- command_settings — add nullable columns and new fields +-- (ALTER TABLE is idempotent via IF NOT EXISTS) +-- --------------------------------------------------------------------------- + +ALTER TABLE command_settings + ALTER COLUMN permission DROP NOT NULL, + ALTER COLUMN offline_only DROP NOT NULL, + ALTER COLUMN custom_cooldown DROP NOT NULL, + ALTER COLUMN users_blacklisted DROP NOT NULL, + ALTER COLUMN users_whitelisted DROP NOT NULL, + ALTER COLUMN allow_bots DROP NOT NULL; + +ALTER TABLE command_settings + ADD COLUMN IF NOT EXISTS platform TEXT NOT NULL DEFAULT 'TWITCH', + ADD COLUMN IF NOT EXISTS ambassador_granted BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TABLE command_settings DROP CONSTRAINT IF EXISTS command_settings_pkey; +ALTER TABLE command_settings ADD PRIMARY KEY (channel_id, command, platform); + +ALTER TABLE channels DROP CONSTRAINT IF EXISTS channels_pkey; +ALTER TABLE channels ADD PRIMARY KEY (channel_id, platform); + +-- --------------------------------------------------------------------------- +-- Reminders +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS reminders ( + reminder_id SERIAL PRIMARY KEY, + user_id VARCHAR(64) NOT NULL, + recipient_id VARCHAR(64) NOT NULL, + channel_id VARCHAR(64) NOT NULL, + message VARCHAR(500) NOT NULL DEFAULT '', + ready_at TIMESTAMPTZ, + set_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + afk_withheld BOOLEAN NOT NULL DEFAULT FALSE, + status TEXT NOT NULL DEFAULT 'PENDING', + platform TEXT NOT NULL DEFAULT 'TWITCH', + sent_at TIMESTAMPTZ, + type TEXT NOT NULL DEFAULT 'NEXT' +); + +CREATE INDEX IF NOT EXISTS idx_reminders_recipient ON reminders (recipient_id, platform, status); + +-- --------------------------------------------------------------------------- +-- Website JWT sessions +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS website_jwt ( + token TEXT NOT NULL, + user_id INT NOT NULL REFERENCES users(user_id) ON DELETE CASCADE, + PRIMARY KEY (token) +);