From d760ca48c438b5744775843c74dc4a59d925c80d Mon Sep 17 00:00:00 2001 From: tlos Date: Sat, 14 Mar 2026 00:02:25 -0400 Subject: [PATCH 01/20] feat: add platform config structs --- common/config.go | 110 +++++++++++++++++++++++++++++++++------ common/types.go | 126 ++++++++++++++++++++++++++++++++++++--------- exampleconfig.json | 50 ++++++++++++++++++ 3 files changed, 246 insertions(+), 40 deletions(-) diff --git a/common/config.go b/common/config.go index 28b485b..85fbfd3 100644 --- a/common/config.go +++ b/common/config.go @@ -3,27 +3,107 @@ package common // Config holds the configuration for the application, including database and service settings. type Config struct { - Postgres SQLConfig `json:"postgres"` - Clickhouse SQLConfig `json:"clickhouse"` - Twitch TwitchConfig `json:"twitch"` - Redis RedisConfig `json:"redis"` - API APIConfig `json:"api"` - Socket APIConfig `json:"socket"` - Redirects APIConfig `json:"redirects"` - Uploader APIConfig `json:"uploader"` - Prometheus APIConfig `json:"prometheus"` - Haste HasteConfig `json:"haste"` - Nats BoolConfig `json:"nats"` - Loops BoolConfig `json:"loops"` + Postgres SQLConfig `json:"postgres"` + Clickhouse SQLConfig `json:"clickhouse"` + Twitch TwitchConfig `json:"twitch"` + Discord DiscordConfig `json:"discord"` + Spotify SpotifyConfig `json:"spotify"` + Kick KickConfig `json:"kick"` + Anilist AnilistConfig `json:"anilist"` + Trakt TraktConfig `json:"trakt"` + FFZ FFZConfig `json:"ffz"` + STV STVConfig `json:"stv"` + BTTV BTTVConfig `json:"bttv"` + Misc MiscConfig `json:"misc"` + Redis RedisConfig `json:"redis"` + API APIConfig `json:"api"` + Socket APIConfig `json:"socket"` + Redirects APIConfig `json:"redirects"` + Uploader APIConfig `json:"uploader"` + Prometheus APIConfig `json:"prometheus"` + Haste HasteConfig `json:"haste"` + Nats BoolConfig `json:"nats"` + Loops BoolConfig `json:"loops"` } // TwitchConfig holds the configuration for Twitch API integration. type TwitchConfig struct { ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` + ClientSecret string `json:"client_secret"` //nolint:gosec OauthURI string `json:"oauth_uri"` } +// DiscordConfig holds the configuration for Discord OAuth and bot integration. +type DiscordConfig struct { + ID string `json:"id"` + OAuth string `json:"oauth"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// SpotifyConfig holds the configuration for Spotify OAuth integration. +type SpotifyConfig struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// KickConfig holds the configuration for Kick OAuth integration (PKCE flow). +type KickConfig struct { + ID string `json:"id"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` + OAuth string `json:"oauth"` +} + +// AnilistConfig holds the configuration for AniList OAuth integration. +type AnilistConfig struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// TraktConfig holds the configuration for Trakt OAuth integration. +type TraktConfig struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + OAuthURI string `json:"oauth_uri"` +} + +// FFZConfig holds the configuration for FrankerFaceZ OAuth integration. +type FFZConfig struct { + ID string `json:"id"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` //nolint:gosec + Token string `json:"token"` + Refresh string `json:"refresh"` + OAuthURI string `json:"oauth_uri"` +} + +// STVConfig holds the configuration for 7TV bot token. +type STVConfig struct { + Token string `json:"token"` + ID string `json:"id"` +} + +// BTTVConfig holds the configuration for BTTV bot token. +type BTTVConfig struct { + Token string `json:"token"` + ID string `json:"id"` +} + +// SteamConfig holds the Steam Web API key. +type SteamConfig struct { + Key string `json:"key"` +} + +// MiscConfig holds miscellaneous third-party service configurations. +type MiscConfig struct { + Steam SteamConfig `json:"steam"` +} + // BoolConfig holds the configuration for the simple enablable service. type BoolConfig struct { Enabled bool `json:"enabled"` @@ -33,7 +113,7 @@ type BoolConfig struct { type APIConfig struct { Host string `json:"host"` Port string `json:"port"` - AuthKey string `json:"authkey,omitempty"` + AuthKey string `json:"authkey,omitempty"` //nolint:gosec Enabled bool `json:"enabled"` } @@ -51,7 +131,7 @@ type SQLConfig struct { Host string `json:"host"` Port string `json:"port"` User string `json:"user"` - Password string `json:"password"` + Password string `json:"password"` //nolint:gosec Database string `json:"database"` } diff --git a/common/types.go b/common/types.go index 2dfbb42..3a25506 100644 --- a/common/types.go +++ b/common/types.go @@ -15,6 +15,12 @@ const ( DISCORD Platforms = "DISCORD" KICK Platforms = "KICK" STV Platforms = "STV" + SPOTIFY Platforms = "SPOTIFY" + ANILIST Platforms = "ANILIST" + TRAKT Platforms = "TRAKT" + FFZ Platforms = "FFZ" + BTTV Platforms = "BTTV" + STEAM Platforms = "STEAM" ) // PermissionLevel represents the permission level of a user interally with the api and bot. @@ -147,36 +153,47 @@ type AddedByData struct { // ChannelSettings represents the settings for a channel, including bot settings, cooldowns, language, // permission levels, and other configurations. type ChannelSettings struct { - PajBot *string `json:"paj_bot,omitempty"` - UserCooldown *int `json:"user_cooldown,omitempty"` - ChannelCooldown *int `json:"channel_cooldown,omitempty"` - Language string `json:"language"` - Permission string `json:"permission"` - Prefix string `json:"prefix"` - UsersBlacklisted []string `json:"users_blacklisted"` - NoReply bool `json:"no_reply"` - FirstMsgResponses bool `json:"first_msg_responses"` - WhisperOnly bool `json:"whisper_only"` - OfflineOnly bool `json:"offline_only"` - ForceLanguage bool `json:"force_language"` - SilentErrors bool `json:"silent_errors"` - ColorResponses bool `json:"color_responses"` + PajBot *string `json:"paj_bot,omitempty"` + UserCooldown *int `json:"user_cooldown,omitempty"` + ChannelCooldown *int `json:"channel_cooldown,omitempty"` + OnlinePermission *string `json:"online_permission,omitempty"` + EmoteStreakResponse *string `json:"emote_streak_response,omitempty"` + PyramidResponse *string `json:"pyramid_response,omitempty"` + OnlineSilentErrors *bool `json:"online_silent_errors,omitempty"` + OnlineWhisperOnly *bool `json:"online_whisper_only,omitempty"` + AllowBotEmoteTracking *bool `json:"allow_bot_emote_tracking,omitempty"` + IgnoreDropped *bool `json:"ignore_dropped,omitempty"` + NoLinks *bool `json:"no_links,omitempty"` + ForcePyramidNotVerbose *bool `json:"force_potato_not_verbose,omitempty"` + Language string `json:"language"` + Permission string `json:"permission"` + Prefix string `json:"prefix"` + UsersBlacklisted []string `json:"users_blacklisted"` + NoReply bool `json:"no_reply"` + FirstMsgResponses bool `json:"first_msg_responses"` + WhisperOnly bool `json:"whisper_only"` + OfflineOnly bool `json:"offline_only"` + ForceLanguage bool `json:"force_language"` + SilentErrors bool `json:"silent_errors"` + ColorResponses bool `json:"color_responses"` } // CommandSettings represents the settings for a command in a channel, including its permissions, cooldowns, // and usage limits. type CommandSettings struct { - ChannelID string `json:"channel_id"` - Command string `json:"command"` - Permission string `json:"permission"` - UsersBlacklisted []string `json:"users_blacklisted"` - UsersWhitelisted []string `json:"users_whitelisted"` - CustomCooldown int `json:"custom_cooldown"` - ChannelUsage int `json:"channel_usage"` - IsEnabled bool `json:"is_enabled"` - OfflineOnly bool `json:"offline_only"` - SilentErrors bool `json:"silent_errors"` - AllowBots bool `json:"allow_bots"` + ChannelID string `json:"channel_id"` + Command string `json:"command"` + Permission *string `json:"permission,omitempty"` + UsersBlacklisted []string `json:"users_blacklisted,omitempty"` + UsersWhitelisted []string `json:"users_whitelisted,omitempty"` + CustomCooldown *int `json:"custom_cooldown,omitempty"` + ChannelUsage int `json:"channel_usage"` + IsEnabled bool `json:"is_enabled"` + OfflineOnly *bool `json:"offline_only,omitempty"` + SilentErrors bool `json:"silent_errors"` + AllowBots *bool `json:"allow_bots,omitempty"` + Platform string `json:"platform"` + AmbassadorGranted bool `json:"ambassador_granted"` } // PlatformOauth represents the OAuth token and metadata for a user on a specific platform, including. @@ -400,3 +417,62 @@ const ( BotVIP BotCommandRequirements = "VIP" BotMod BotCommandRequirements = "MOD" ) + +// ChannelListItem is a lightweight channel representation used for list endpoints (e.g. nav search). +type ChannelListItem struct { + ChannelID string `json:"channel_id"` + Username string `json:"username"` + Platform Platforms `json:"platform"` + State string `json:"state"` +} + +// EmoteStat represents an aggregated emote usage entry from Clickhouse. +type EmoteStat struct { + EmoteID string `json:"emote_id"` + EmoteName string `json:"emote_name"` + EmoteAlias string `json:"emote_alias"` + Provider string `json:"provider"` + Count int64 `json:"count"` +} + +// EmoteHistoryEntry represents a single per-user emote usage record from Clickhouse. +type EmoteHistoryEntry struct { + EmoteID string `json:"emote_id"` + EmoteName string `json:"emote_name"` + EmoteAlias string `json:"emote_alias"` + Provider string `json:"provider"` + ChannelID string `json:"channel_id"` + UserID string `json:"user_id"` + Count int64 `json:"count"` + UsedAt time.Time `json:"used_at"` +} + +// PageInfo contains cursor-based pagination metadata. +type PageInfo struct { + HasNextPage bool `json:"hasNextPage"` + Cursor string `json:"cursor"` +} + +// EmoteStatsResponse is the paginated response shape for GET /emotes/stats. +type EmoteStatsResponse struct { + Data *[]EmoteStat `json:"data"` + Pagination PageInfo `json:"pagination"` + StatusCode int `json:"statusCode"` + Duration float64 `json:"duration"` +} + +// Reminder represents a scheduled or on-next-seen reminder message. +type Reminder struct { + SetAt time.Time `json:"set_at"` + SentAt *time.Time `json:"sent_at,omitempty"` + ReadyAt *time.Time `json:"ready_at,omitempty"` + UserID string `json:"user_id"` + RecipientID string `json:"recipient_id"` + ChannelID string `json:"channel_id"` + Message string `json:"message"` + Status string `json:"status"` + Platform string `json:"platform"` + Type string `json:"type"` + ReminderID int `json:"reminder_id"` + AfkWithheld bool `json:"afk_withheld"` +} diff --git a/exampleconfig.json b/exampleconfig.json index 6abf843..47319cb 100644 --- a/exampleconfig.json +++ b/exampleconfig.json @@ -4,6 +4,56 @@ "client_secret": "asdf1234", "oauth_uri": "https://api.potat.industries" }, + "discord": { + "id": "", + "oauth": "", + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "spotify": { + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "kick": { + "id": "", + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/", + "oauth": "" + }, + "anilist": { + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "trakt": { + "client_id": "", + "client_secret": "", + "oauth_uri": "https://your-domain.com/" + }, + "ffz": { + "id": "", + "client_id": "", + "client_secret": "", + "token": "", + "refresh": "", + "oauth_uri": "https://your-domain.com/" + }, + "stv": { + "token": "", + "id": "" + }, + "bttv": { + "token": "", + "id": "" + }, + "misc": { + "steam": { + "key": "" + } + }, "loops": { "enabled": true }, From 96d9d29a473545fa2ad67760181f5694c5fcc147 Mon Sep 17 00:00:00 2001 From: tlos Date: Sun, 15 Mar 2026 02:05:09 -0400 Subject: [PATCH 02/20] feat: add data routes and db methods --- api/routes/get/channels.go | 57 +++++++ api/routes/get/command_settings.go | 124 ++++++++++++++ api/routes/get/emotes.go | 263 +++++++++++++++++++++++++++++ common/db/clickhouse.go | 155 +++++++++++++++++ common/db/postgres.go | 251 +++++++++++++++++++++++++++ 5 files changed, 850 insertions(+) create mode 100644 api/routes/get/channels.go create mode 100644 api/routes/get/command_settings.go create mode 100644 api/routes/get/emotes.go diff --git a/api/routes/get/channels.go b/api/routes/get/channels.go new file mode 100644 index 0000000..17f94f3 --- /dev/null +++ b/api/routes/get/channels.go @@ -0,0 +1,57 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +// ChannelsResponse is the response type for the GET /channels endpoint. +type ChannelsResponse = common.GenericResponse[common.ChannelListItem] + +func init() { + api.SetRoute(api.Route{ + Path: "/channels", + Method: http.MethodGet, + Handler: getChannelsHandler, + UseAuth: false, + }) +} + +func getChannelsHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + postgres, ok := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !ok { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, ChannelsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channels, err := postgres.GetAllChannels(request.Context()) + if err != nil { + logger.Error.Printf("Error fetching channels: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, ChannelsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch channels"}}, + }, start) + + return + } + + if channels == nil { + channels = []common.ChannelListItem{} + } + + api.GenericResponse(writer, http.StatusOK, ChannelsResponse{ + Data: &channels, + }, start) +} diff --git a/api/routes/get/command_settings.go b/api/routes/get/command_settings.go new file mode 100644 index 0000000..3b99c58 --- /dev/null +++ b/api/routes/get/command_settings.go @@ -0,0 +1,124 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "net/http" + "slices" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" + "github.com/gorilla/mux" +) + +// CommandSettingsResponse is the response type for GET /channel/command-settings. +type CommandSettingsResponse = common.GenericResponse[common.CommandSettings] + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/command-settings", + Method: http.MethodGet, + Handler: getCommandSettingsHandler, + UseAuth: true, + }) +} + +// isChannelAuthorized returns true if the user is admin, the broadcaster, or a channel ambassador. +func isChannelAuthorized( + request *http.Request, + user *common.User, + channelID string, + postgres *db.PostgresClient, +) bool { + if int(common.ADMIN) <= user.Level { + return true + } + + twitchID := getTwitchPlatformID(user) + + if twitchID == channelID { + return true + } + + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil { + return false + } + + return slices.Contains(ambassadors, twitchID) +} + +// resolveChannelID returns the channel ID from ?id= or defaults to the authenticated user's Twitch ID. +func resolveChannelID(request *http.Request, user *common.User) string { + if id := request.URL.Query().Get("id"); id != "" { + return id + } + + // Check path variable as fallback (e.g. /channel/:id/...) + if id := mux.Vars(request)["id"]; id != "" { + return id + } + + return getTwitchPlatformID(user) +} + +func getCommandSettingsHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channelID := resolveChannelID(request, user) + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "channel id is required"}}, + }, start) + + return + } + + if !isChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + settings, err := postgres.GetCommandSettings(request.Context(), channelID) + if err != nil { + logger.Error.Printf("Error fetching command settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, CommandSettingsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch command settings"}}, + }, start) + + return + } + + if settings == nil { + settings = []common.CommandSettings{} + } + + api.GenericResponse(writer, http.StatusOK, CommandSettingsResponse{ + Data: &settings, + }, start) +} diff --git a/api/routes/get/emotes.go b/api/routes/get/emotes.go new file mode 100644 index 0000000..c17b79f --- /dev/null +++ b/api/routes/get/emotes.go @@ -0,0 +1,263 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "encoding/base64" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" + "github.com/gorilla/mux" +) + +const ( + defaultEmoteLimit = 100 + maxEmoteLimit = 300 +) + +func init() { + api.SetRoute(api.Route{ + Path: "/emotes/stats", + Method: http.MethodGet, + Handler: getEmoteStats, + UseAuth: false, + }) + api.SetRoute(api.Route{ + Path: "/emotes/history/{login}", + Method: http.MethodGet, + Handler: getEmoteHistory, + UseAuth: false, + }) +} + +// periodToHours converts a period string to a number of hours. 0 means no time filter. +func periodToHours(period string) int { + switch strings.ToLower(period) { + case "hour": + return 1 + case "day": + return 24 + case "week": + return 168 + case "month": + return 720 + default: + return 0 + } +} + +// normaliseProvider maps user-facing provider names to Clickhouse enum values. +func normaliseProvider(provider string) []string { + switch strings.ToUpper(provider) { + case "7TV", "STV": + return []string{"STV"} + case "FFZ": + return []string{"FFZ"} + case "BTTV": + return []string{"BTTV"} + case "TWITCH": + return []string{"TWITCH"} + case "EMOJI": + return []string{"EMOJI"} + case "ALL", "": + return []string{} + default: + return []string{"STV"} + } +} + +func encodeCursor(offset int) string { + return base64.StdEncoding.EncodeToString([]byte(strconv.Itoa(offset))) +} + +func decodeCursor(cursor string) (int, error) { + b, err := base64.StdEncoding.DecodeString(cursor) + if err != nil { + return 0, err + } + + return strconv.Atoi(string(b)) +} + +func getEmoteStats(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + clickhouse, ok := request.Context().Value(middleware.ClickhouseKey).(*db.ClickhouseClient) + if !ok { + logger.Error.Println("Clickhouse client not found in context") + writeEmoteError(writer, http.StatusInternalServerError, start) + + return + } + + query := request.URL.Query() + + channelID := query.Get("id") + login := query.Get("login") + + // If login was provided, resolve to a channel ID via Postgres + if channelID == "" && login != "" { + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + writeEmoteError(writer, http.StatusInternalServerError, start) + + return + } + + ch, err := postgres.GetChannelByName(request.Context(), login, common.TWITCH) + if err != nil { + writeEmoteError(writer, http.StatusNotFound, start) + + return + } + + channelID = ch.ChannelID + } + + if channelID == "" { + writeEmoteError(writer, http.StatusBadRequest, start) + + return + } + + limit := defaultEmoteLimit + if v := query.Get("limit"); v == "" { + if v = query.Get("first"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= maxEmoteLimit { + limit = n + } + } + } else if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= maxEmoteLimit { + limit = n + } + + offset := 0 + if cursor := query.Get("after"); cursor != "" { + if o, err := decodeCursor(cursor); err == nil { + offset = o + } + } + + opts := db.EmoteStatsOptions{ + ChannelID: channelID, + PeriodHours: periodToHours(query.Get("period")), + Providers: normaliseProvider(query.Get("provider")), + Order: query.Get("order"), + Limit: limit + 1, // fetch one extra to determine hasNextPage + Offset: offset, + } + + stats, err := clickhouse.GetEmoteStats(request.Context(), opts) + if err != nil { + logger.Error.Printf("Error fetching emote stats: %v", err) + writeEmoteError(writer, http.StatusInternalServerError, start) + + return + } + + if stats == nil { + stats = []common.EmoteStat{} + } + + hasNextPage := len(stats) > limit + if hasNextPage { + stats = stats[:limit] + } + + var nextCursor string + if hasNextPage { + nextCursor = encodeCursor(offset + limit) + } + + elapsed := time.Since(start).Seconds() + api.GenericResponse(writer, http.StatusOK, common.EmoteStatsResponse{ + Data: &stats, + Pagination: common.PageInfo{ + HasNextPage: hasNextPage, + Cursor: nextCursor, + }, + StatusCode: http.StatusOK, + Duration: elapsed, + }, start) +} + +func getEmoteHistory(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + login := mux.Vars(request)["login"] + if login == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "login is required"}}, + }, start) + + return + } + + clickhouse, ok := request.Context().Value(middleware.ClickhouseKey).(*db.ClickhouseClient) + if !ok { + logger.Error.Println("Clickhouse client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + // Resolve login → channel ID + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + ch, err := postgres.GetChannelByName(request.Context(), login, common.TWITCH) + if err != nil { + api.GenericResponse(writer, http.StatusNotFound, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Channel not found"}}, + }, start) + + return + } + + limit := defaultEmoteLimit + if v := request.URL.Query().Get("limit"); v != "" { + if n, parseErr := strconv.Atoi(v); parseErr == nil && n > 0 && n <= maxEmoteLimit { + limit = n + } + } + + entries, err := clickhouse.GetEmoteHistory(request.Context(), "", ch.ChannelID, limit) + if err != nil { + logger.Error.Printf("Error fetching emote history: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[common.EmoteHistoryEntry]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch emote history"}}, + }, start) + + return + } + + if entries == nil { + entries = []common.EmoteHistoryEntry{} + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[common.EmoteHistoryEntry]{ + Data: &entries, + }, start) +} + +func writeEmoteError(writer http.ResponseWriter, code int, start time.Time) { + api.GenericResponse(writer, code, common.EmoteStatsResponse{ + Data: &[]common.EmoteStat{}, + StatusCode: code, + Duration: time.Since(start).Seconds(), + }, start) +} diff --git a/common/db/clickhouse.go b/common/db/clickhouse.go index 298a14f..3b3ec40 100644 --- a/common/db/clickhouse.go +++ b/common/db/clickhouse.go @@ -2,7 +2,10 @@ package db import ( + "context" "fmt" + "strings" + "time" "github.com/ClickHouse/clickhouse-go/v2" "github.com/ClickHouse/clickhouse-go/v2/lib/driver" @@ -45,3 +48,155 @@ func InitClickhouse(config common.Config) (*ClickhouseClient, error) { return &ClickhouseClient{conn}, nil } + +// EmoteStatsOptions holds the query filters for GetEmoteStats. +type EmoteStatsOptions struct { + // ChannelID filters results to a specific channel (uses emote_usage table when UserID is empty). + ChannelID string + // UserID filters results to a specific user (uses user_emote_usage table). + UserID string + // Order is "ASC" or "DESC". + Order string + // Providers is the normalised list of provider enum values (e.g. "STV", "FFZ"). Empty = all. + Providers []string + // PeriodHours is the time window in hours. 0 means no time filter ("all"). + PeriodHours int + // Limit is the maximum number of rows to return. + Limit int + // Offset is the number of rows to skip (cursor-based pagination). + Offset int +} + +// GetEmoteStats queries aggregated emote usage from Clickhouse with cursor-based pagination. +func (db *ClickhouseClient) GetEmoteStats( //nolint:cyclop + ctx context.Context, + opts EmoteStatsOptions, +) ([]common.EmoteStat, error) { + table := "potatbotat.emote_usage" + if opts.UserID != "" { + table = "potatbotat.user_emote_usage" + } + + var sb strings.Builder + args := make([]any, 0, 8) + + fmt.Fprintf(&sb, + "SELECT emote_id, emote_name, emote_alias, provider, sum(count) AS count FROM %s FINAL WHERE 1=1", + table, + ) + + if opts.ChannelID != "" { + args = append(args, opts.ChannelID) + fmt.Fprintf(&sb, " AND channel_id = $%d", len(args)) + } + + if opts.UserID != "" { + args = append(args, opts.UserID) + fmt.Fprintf(&sb, " AND user_id = $%d", len(args)) + } + + if opts.PeriodHours > 0 { + cutoff := time.Now().Add(-time.Duration(opts.PeriodHours) * time.Hour) + args = append(args, cutoff) + fmt.Fprintf(&sb, " AND used_at >= $%d", len(args)) + } + + if len(opts.Providers) > 0 { + placeholders := make([]string, len(opts.Providers)) + for i, p := range opts.Providers { + args = append(args, p) + placeholders[i] = fmt.Sprintf("$%d", len(args)) + } + fmt.Fprintf(&sb, " AND provider IN (%s)", strings.Join(placeholders, ",")) + } + + sb.WriteString(" GROUP BY emote_id, emote_name, emote_alias, provider") + + order := "DESC" + if strings.EqualFold(opts.Order, "asc") { + order = "ASC" + } + fmt.Fprintf(&sb, " ORDER BY count %s", order) + + limit := opts.Limit + if limit <= 0 || limit > 300 { + limit = 100 + } + fmt.Fprintf(&sb, " LIMIT %d", limit) + + if opts.Offset > 0 { + fmt.Fprintf(&sb, " OFFSET %d", opts.Offset) + } + + rows, err := db.Query(ctx, sb.String(), args...) + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck + + var stats []common.EmoteStat + for rows.Next() { + var s common.EmoteStat + if err := rows.Scan(&s.EmoteID, &s.EmoteName, &s.EmoteAlias, &s.Provider, &s.Count); err != nil { + return nil, err + } + stats = append(stats, s) + } + + return stats, rows.Err() +} + +// GetEmoteHistory queries the most recent per-user emote usage records from Clickhouse. +func (db *ClickhouseClient) GetEmoteHistory( + ctx context.Context, + userID string, + channelID string, + limit int, +) ([]common.EmoteHistoryEntry, error) { + if limit <= 0 || limit > 300 { + limit = 100 + } + + var sb strings.Builder + args := make([]any, 0, 3) + + sb.WriteString(` + SELECT emote_id, emote_name, emote_alias, provider, channel_id, user_id, sum(count) AS count, max(used_at) AS used_at + FROM potatbotat.user_emote_usage FINAL + WHERE 1=1 + `) + + if userID != "" { + args = append(args, userID) + fmt.Fprintf(&sb, " AND user_id = $%d", len(args)) + } + + if channelID != "" { + args = append(args, channelID) + fmt.Fprintf(&sb, " AND channel_id = $%d", len(args)) + } + + sb.WriteString(" GROUP BY emote_id, emote_name, emote_alias, provider, channel_id, user_id") + sb.WriteString(" ORDER BY used_at DESC") + fmt.Fprintf(&sb, " LIMIT %d", limit) + + rows, err := db.Query(ctx, sb.String(), args...) + if err != nil { + return nil, err + } + defer rows.Close() //nolint:errcheck + + var entries []common.EmoteHistoryEntry + for rows.Next() { + var e common.EmoteHistoryEntry + if err := rows.Scan( + &e.EmoteID, &e.EmoteName, &e.EmoteAlias, &e.Provider, + &e.ChannelID, &e.UserID, &e.Count, &e.UsedAt, + ); err != nil { + return nil, err + } + entries = append(entries, e) + } + + return entries, rows.Err() +} diff --git a/common/db/postgres.go b/common/db/postgres.go index da40ff6..8a9d65e 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -399,6 +399,8 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop *channel.Blocks.Users = append(*channel.Blocks.Users, block) case common.CommandBlock: *channel.Blocks.Commands = append(*channel.Blocks.Commands, block) + case common.GlobalBlock: + // global blocks are not categorized into users/commands } } } else { @@ -724,3 +726,252 @@ func (db *PostgresClient) GetUploadCreatedAt( return &createdAt, nil } + +// GetAllChannels retrieves all channels with state = 'JOINED', ordered by username. +func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelListItem, error) { + query := ` + SELECT channel_id, username, platform, state + FROM channels + WHERE state = 'JOINED' + ORDER BY username + ` + + rows, err := db.Pool.Query(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + var channels []common.ChannelListItem + for rows.Next() { + var ch common.ChannelListItem + if err := rows.Scan(&ch.ChannelID, &ch.Username, &ch.Platform, &ch.State); err != nil { + return nil, err + } + channels = append(channels, ch) + } + + return channels, rows.Err() +} + +// UpdateUserSettings replaces the settings JSONB column for a user. +func (db *PostgresClient) UpdateUserSettings(ctx context.Context, userID int, settings common.UserSettings) error { + query := `UPDATE users SET settings = $1 WHERE user_id = $2` + _, err := db.Pool.Exec(ctx, query, settings, userID) + + return err +} + +// UpdateChannelSettings replaces the settings JSONB column for a channel. +func (db *PostgresClient) UpdateChannelSettings( + ctx context.Context, + channelID string, + settings common.ChannelSettings, +) error { + query := `UPDATE channels SET settings = $1 WHERE channel_id = $2` + _, err := db.Pool.Exec(ctx, query, settings, channelID) + + return err +} + +// GetCommandSettings retrieves all command settings rows for a given channel. +func (db *PostgresClient) GetCommandSettings( + ctx context.Context, + channelID string, +) ([]common.CommandSettings, error) { + query := ` + SELECT + channel_id, + command, + permission, + users_blacklisted, + users_whitelisted, + custom_cooldown, + channel_usage, + is_enabled, + offline_only, + silent_errors, + allow_bots, + platform, + ambassador_granted + FROM command_settings + WHERE channel_id = $1 + ORDER BY command + ` + + rows, err := db.Pool.Query(ctx, query, channelID) + if err != nil { + return nil, err + } + defer rows.Close() + + var results []common.CommandSettings + for rows.Next() { + var cs common.CommandSettings + if err := rows.Scan( + &cs.ChannelID, + &cs.Command, + &cs.Permission, + &cs.UsersBlacklisted, + &cs.UsersWhitelisted, + &cs.CustomCooldown, + &cs.ChannelUsage, + &cs.IsEnabled, + &cs.OfflineOnly, + &cs.SilentErrors, + &cs.AllowBots, + &cs.Platform, + &cs.AmbassadorGranted, + ); err != nil { + return nil, err + } + results = append(results, cs) + } + + return results, rows.Err() +} + +// UpsertCommandSettings inserts or updates a single command's settings row. +func (db *PostgresClient) UpsertCommandSettings(ctx context.Context, cs common.CommandSettings) error { + query := ` + INSERT INTO command_settings ( + channel_id, command, permission, users_blacklisted, users_whitelisted, + custom_cooldown, is_enabled, offline_only, silent_errors, allow_bots, + platform, ambassador_granted + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (channel_id, command) DO UPDATE SET + permission = EXCLUDED.permission, + users_blacklisted = EXCLUDED.users_blacklisted, + users_whitelisted = EXCLUDED.users_whitelisted, + custom_cooldown = EXCLUDED.custom_cooldown, + is_enabled = EXCLUDED.is_enabled, + offline_only = EXCLUDED.offline_only, + silent_errors = EXCLUDED.silent_errors, + allow_bots = EXCLUDED.allow_bots, + platform = EXCLUDED.platform, + ambassador_granted = EXCLUDED.ambassador_granted + ` + + platform := cs.Platform + if platform == "" { + platform = "TWITCH" + } + + _, err := db.Pool.Exec( + ctx, query, + cs.ChannelID, cs.Command, cs.Permission, + cs.UsersBlacklisted, cs.UsersWhitelisted, + cs.CustomCooldown, cs.IsEnabled, cs.OfflineOnly, + cs.SilentErrors, cs.AllowBots, platform, cs.AmbassadorGranted, + ) + + return err +} + +// ResetCommandSettings resets a single command's overrides back to their defaults. +func (db *PostgresClient) ResetCommandSettings(ctx context.Context, channelID, command string) error { + query := ` + UPDATE command_settings SET + is_enabled = TRUE, + offline_only = NULL, + custom_cooldown = NULL, + silent_errors = FALSE, + users_whitelisted = NULL, + users_blacklisted = NULL, + allow_bots = NULL + WHERE channel_id = $1 AND command = $2 + ` + _, err := db.Pool.Exec(ctx, query, channelID, command) + + return err +} + +// GetChannelAmbassadors returns the ambassadors slice for a channel, used for auth checks. +func (db *PostgresClient) GetChannelAmbassadors( + ctx context.Context, + channelID string, + platform common.Platforms, +) ([]string, error) { + query := `SELECT ambassadors FROM channels WHERE channel_id = $1 AND platform = $2` + + var ambassadors []string + err := db.Pool.QueryRow(ctx, query, channelID, platform).Scan(&ambassadors) + if err != nil { + return nil, err + } + + return ambassadors, nil +} + +// GetUserReminders retrieves all pending reminders for a user on a given platform. +func (db *PostgresClient) GetUserReminders( + ctx context.Context, + userID string, + platform common.Platforms, +) ([]common.Reminder, error) { + query := ` + SELECT + reminder_id, user_id, recipient_id, channel_id, + message, ready_at, set_at, afk_withheld, status, platform, sent_at, type + FROM reminders + WHERE recipient_id = $1 AND platform = $2 + ORDER BY set_at DESC + ` + + rows, err := db.Pool.Query(ctx, query, userID, platform) + if err != nil { + return nil, err + } + defer rows.Close() + + var reminders []common.Reminder + for rows.Next() { + var r common.Reminder + if err := rows.Scan( + &r.ReminderID, &r.UserID, &r.RecipientID, &r.ChannelID, + &r.Message, &r.ReadyAt, &r.SetAt, &r.AfkWithheld, &r.Status, &r.Platform, &r.SentAt, &r.Type, + ); err != nil { + return nil, err + } + reminders = append(reminders, r) + } + + return reminders, rows.Err() +} + +// DeleteReminder hard-deletes a reminder by ID, verifying the owner platform ID. +func (db *PostgresClient) DeleteReminder(ctx context.Context, reminderID int, recipientID string) error { + query := `DELETE FROM reminders WHERE reminder_id = $1 AND recipient_id = $2` + _, err := db.Pool.Exec(ctx, query, reminderID, recipientID) + + return err +} + +// UpsertOAuthToken stores or refreshes a platform OAuth token for a given user. +func (db *PostgresClient) UpsertOAuthToken( + ctx context.Context, + platformID string, + platform common.Platforms, + accessToken string, + refreshToken string, + scope []string, + expiresIn int, +) error { + query := ` + INSERT INTO connection_oauth ( + platform_id, access_token, refresh_token, scope, expires_in, added_at, platform + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (platform_id, platform) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + scope = EXCLUDED.scope, + expires_in = EXCLUDED.expires_in, + added_at = EXCLUDED.added_at + ` + _, err := db.Pool.Exec( + ctx, query, + platformID, accessToken, refreshToken, scope, expiresIn, time.Now(), platform, + ) + + return err +} From a42728ddf00a0202d308d0b7843e2f99befaec83 Mon Sep 17 00:00:00 2001 From: tlos Date: Sun, 15 Mar 2026 22:45:02 -0400 Subject: [PATCH 03/20] feat: add settings and command write routes --- api/routes/del/command_settings.go | 128 +++++++++++++++++++++++ api/routes/patch/settings.go | 161 +++++++++++++++++++++++++++++ api/routes/put/command_settings.go | 133 ++++++++++++++++++++++++ main.go | 3 + 4 files changed, 425 insertions(+) create mode 100644 api/routes/del/command_settings.go create mode 100644 api/routes/patch/settings.go create mode 100644 api/routes/put/command_settings.go diff --git a/api/routes/del/command_settings.go b/api/routes/del/command_settings.go new file mode 100644 index 0000000..785420f --- /dev/null +++ b/api/routes/del/command_settings.go @@ -0,0 +1,128 @@ +// Package del contains routes for http.MethodDelete requests. +package del + +import ( + "net/http" + "slices" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/command-settings", + Method: http.MethodDelete, + Handler: deleteCommandSettings, + UseAuth: true, + }) +} + +func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channelID := request.URL.Query().Get("id") + if channelID == "" { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + channelID = conn.UserID + + break + } + } + } + + command := request.URL.Query().Get("command") + + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "channel id is required"}}, + }, start) + + return + } + + if !isDelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + // Without a specific command name, this is a no-op (not yet implemented upstream). + if command == "" { + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) + + return + } + + if err := postgres.ResetCommandSettings(request.Context(), channelID, command); err != nil { + logger.Error.Printf("Error resetting command settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to reset command settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} + +func isDelAuthorized( + request *http.Request, + user *common.User, + channelID string, + postgres *db.PostgresClient, +) bool { + if user.Level >= int(common.ADMIN) { + return true + } + + var twitchID string + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + twitchID = conn.UserID + + break + } + } + + if twitchID == channelID { + return true + } + + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil { + return false + } + + return slices.Contains(ambassadors, twitchID) +} diff --git a/api/routes/patch/settings.go b/api/routes/patch/settings.go new file mode 100644 index 0000000..606dd67 --- /dev/null +++ b/api/routes/patch/settings.go @@ -0,0 +1,161 @@ +// Package patch contains routes for http.MethodPatch requests. +package patch + +import ( + "encoding/json" + "net/http" + "slices" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +func init() { + api.SetRoute(api.Route{ + Path: "/users/me/settings", + Method: http.MethodPatch, + Handler: patchUserSettings, + UseAuth: true, + }) + api.SetRoute(api.Route{ + Path: "/channels/me/settings", + Method: http.MethodPatch, + Handler: patchChannelSettings, + UseAuth: true, + }) +} + +func patchUserSettings(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + var input common.UserSettings + if err := json.NewDecoder(request.Body).Decode(&input); err != nil { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + if err := postgres.UpdateUserSettings(request.Context(), user.ID, input); err != nil { + logger.Error.Printf("Error updating user settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to update settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} + +func patchChannelSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + // Resolve channel ID from ?id= or fall back to the user's own Twitch channel ID. + channelID := request.URL.Query().Get("id") + if channelID == "" { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + channelID = conn.UserID + + break + } + } + } + + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "channel id is required"}}, + }, start) + + return + } + + // Only the broadcaster or an ambassador (or admin) may update settings. + var twitchID string + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + twitchID = conn.UserID + + break + } + } + + if twitchID != channelID && user.Level < int(common.ADMIN) { + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil || !slices.Contains(ambassadors, twitchID) { + api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + } + + var input common.ChannelSettings + if err := json.NewDecoder(request.Body).Decode(&input); err != nil { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, + }, start) + + return + } + + if err := postgres.UpdateChannelSettings(request.Context(), channelID, input); err != nil { + logger.Error.Printf("Error updating channel settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to update settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} diff --git a/api/routes/put/command_settings.go b/api/routes/put/command_settings.go new file mode 100644 index 0000000..ea430c2 --- /dev/null +++ b/api/routes/put/command_settings.go @@ -0,0 +1,133 @@ +// Package put contains routes for http.MethodPut requests. +package put + +import ( + "encoding/json" + "net/http" + "slices" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/command-settings", + Method: http.MethodPut, + Handler: putCommandSettings, + UseAuth: true, + }) +} + +func putCommandSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + logger.Error.Println("Postgres client not found in context") + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + var input common.CommandSettings + if err := json.NewDecoder(request.Body).Decode(&input); err != nil { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, + }, start) + + return + } + + channelID := input.ChannelID + if channelID == "" { + channelID = request.URL.Query().Get("id") + } + + if channelID == "" { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + channelID = conn.UserID + + break + } + } + } + + if channelID == "" || input.Command == "" { + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "channel_id and command are required"}}, + }, start) + + return + } + + input.ChannelID = channelID + + if !isCmdAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + if err := postgres.UpsertCommandSettings(request.Context(), input); err != nil { + logger.Error.Printf("Error upserting command settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Failed to update command settings"}}, + }, start) + + return + } + + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ + Data: &[]any{}, + }, start) +} + +func isCmdAuthorized( + request *http.Request, + user *common.User, + channelID string, + postgres *db.PostgresClient, +) bool { + if user.Level >= int(common.ADMIN) { + return true + } + + var twitchID string + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + twitchID = conn.UserID + + break + } + } + + if twitchID == channelID { + return true + } + + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil { + return false + } + + return slices.Contains(ambassadors, twitchID) +} diff --git a/main.go b/main.go index 9e293a6..36314b7 100644 --- a/main.go +++ b/main.go @@ -15,8 +15,11 @@ import ( "time" "github.com/Potat-Industries/potat-api/api" + _ "github.com/Potat-Industries/potat-api/api/routes/del" _ "github.com/Potat-Industries/potat-api/api/routes/get" + _ "github.com/Potat-Industries/potat-api/api/routes/patch" _ "github.com/Potat-Industries/potat-api/api/routes/post" + _ "github.com/Potat-Industries/potat-api/api/routes/put" "github.com/Potat-Industries/potat-api/common" "github.com/Potat-Industries/potat-api/common/db" "github.com/Potat-Industries/potat-api/common/logger" From 7a334d0ac7948149a7de4ad66b8fc082376c5f8f Mon Sep 17 00:00:00 2001 From: tlos Date: Tue, 17 Mar 2026 00:55:19 -0400 Subject: [PATCH 04/20] feat: add oauth platform auth routes --- api/routes/get/auth.go | 1056 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1056 insertions(+) create mode 100644 api/routes/get/auth.go diff --git a/api/routes/get/auth.go b/api/routes/get/auth.go new file mode 100644 index 0000000..334b8c9 --- /dev/null +++ b/api/routes/get/auth.go @@ -0,0 +1,1056 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" + "github.com/Potat-Industries/potat-api/common/utils" + "github.com/google/uuid" +) + +//nolint:gosec +const ( + discordAuthorizeURL = "https://discord.com/api/oauth2/authorize" + discordTokenURL = "https://discord.com/api/oauth2/token" + discordMeURL = "https://discord.com/api/users/@me" + + spotifyAuthorizeURL = "https://accounts.spotify.com/authorize" + spotifyTokenURL = "https://accounts.spotify.com/api/token" + spotifyMeURL = "https://api.spotify.com/v1/me" + spotifyScopes = "streaming user-read-recently-played user-top-read " + + "user-read-playback-position user-read-playback-state " + + "user-read-currently-playing user-modify-playback-state user-follow-read" + + kickAuthorizeURL = "https://id.kick.com/oauth/authorize" + kickTokenURL = "https://id.kick.com/oauth/token" + kickScopes = "user:read channel:read channel:write chat:write events:subscribe" + + anilistAuthorizeURL = "https://anilist.co/api/v2/oauth/authorize" + anilistTokenURL = "https://anilist.co/api/v2/oauth/token" + anilistGQLURL = "https://graphql.anilist.co" + + traktAuthorizeURL = "https://api.trakt.tv/oauth/authorize" + traktTokenURL = "https://api.trakt.tv/oauth/token" + traktMeURL = "https://api.trakt.tv/users/me" + + ffzAuthorizeURL = "https://api.frankerfacez.com/auth/authorize" + ffzTokenURL = "https://api.frankerfacez.com/auth/token" + ffzScopes = "collection_edit" + + steamOpenIDURL = "https://steamcommunity.com/openid/login" + steamOpenIDNS = "http://specs.openid.net/auth/2.0" + steamFriendOffset = int64(76561197960265728) + + oauthStateTTL = 60 * time.Second + oauthHTTPTimeout = 10 * time.Second +) + +// oauthStates maps state strings to true with a TTL for replay-attack prevention. +var oauthStates sync.Map //nolint:gochecknoglobals + +func newOAuthState(userID int) string { + nonce := uuid.New().String() + state := fmt.Sprintf("%s:%d", nonce, userID) + oauthStates.Store(state, true) + go func(s string) { + time.Sleep(oauthStateTTL) + oauthStates.Delete(s) + }(state) + + return state +} + +// newKickOAuthState creates a state with an embedded PKCE code verifier. +// State format: {nonce}:{userID}:{codeVerifier}. +func newKickOAuthState(userID int) (state, codeVerifier, codeChallenge string) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + codeVerifier = uuid.New().String() + } else { + codeVerifier = hex.EncodeToString(buf) + } + + sum := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(sum[:]) + + nonce := uuid.New().String() + state = fmt.Sprintf("%s:%d:%s", nonce, userID, codeVerifier) + oauthStates.Store(state, true) + go func(s string) { + time.Sleep(oauthStateTTL) + oauthStates.Delete(s) + }(state) + + return state, codeVerifier, codeChallenge +} + +// consumeOAuthState validates a state string, deletes it, and returns the embedded user ID. +func consumeOAuthState(state string) (int, bool) { + if _, ok := oauthStates.LoadAndDelete(state); !ok { + return 0, false + } + + parts := strings.SplitN(state, ":", 3) //nolint:mnd + if len(parts) < 2 { //nolint:mnd + return 0, false + } + + var userID int + if _, err := fmt.Sscanf(parts[1], "%d", &userID); err != nil { + return 0, false + } + + return userID, true +} + +// extractKickCodeVerifier returns the codeVerifier embedded in a Kick state string. +func extractKickCodeVerifier(state string) string { + parts := strings.SplitN(state, ":", 3) //nolint:mnd + if len(parts) < 3 { //nolint:mnd + return "" + } + + return parts[2] +} + +// authUser retrieves the authenticated user from the request context. +func authUser(request *http.Request) (*common.User, bool) { + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + + return user, ok && user != nil +} + +// getTwitchPlatformID finds the Twitch platform ID from a user's connections. +func getTwitchPlatformID(user *common.User) string { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + return conn.UserID + } + } + + return "" +} + +// oauthPostMessage builds the postMessage HTML used to close popups. +func oauthPostMessage(payload map[string]any) string { + data, _ := json.Marshal(payload) //nolint:errchkjson + + return fmt.Sprintf(``, string(data)) +} + +// oauthErrorHTML builds a minimal error postMessage HTML. +func oauthErrorHTML(message string) string { + return oauthPostMessage(map[string]any{"error": message}) +} + +// oauthSuccessHTML builds a minimal success postMessage HTML. +func oauthSuccessHTML(platform string) string { + return oauthPostMessage(map[string]any{"platform": platform, "ok": true}) +} + +// sendHTML writes an HTML response. +func sendHTML(writer http.ResponseWriter, code int, body string) { + writer.Header().Set("Content-Type", "text/html") + writer.WriteHeader(code) + + if _, err := writer.Write([]byte(body)); err != nil { + logger.Warn.Println("Failed to write HTML response:", err) + } +} + +// --------------------------------------------------------------------------- +// init — register all auth routes +// --------------------------------------------------------------------------- + +func init() { //nolint:funlen + // Discord + api.SetRoute(api.Route{ + Path: "/auth/discord/authorize", Method: http.MethodGet, Handler: discordAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/discord", Method: http.MethodGet, Handler: discordCallbackHandler, UseAuth: false}) + api.SetRoute(api.Route{ + Path: "/auth/discord/join", Method: http.MethodGet, Handler: discordJoinHandler, UseAuth: false, + }) + + // Spotify + api.SetRoute(api.Route{ + Path: "/auth/spotify/authorize", Method: http.MethodGet, Handler: spotifyAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/spotify", Method: http.MethodGet, Handler: spotifyCallbackHandler, UseAuth: false}) + + // Kick + api.SetRoute(api.Route{ + Path: "/auth/kick/authorize", Method: http.MethodGet, Handler: kickAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/kick", Method: http.MethodGet, Handler: kickCallbackHandler, UseAuth: false}) + + // Anilist + api.SetRoute(api.Route{ + Path: "/auth/anilist/authorize", Method: http.MethodGet, Handler: anilistAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/anilist", Method: http.MethodGet, Handler: anilistCallbackHandler, UseAuth: false}) + + // Trakt + api.SetRoute(api.Route{ + Path: "/auth/trakt/authorize", Method: http.MethodGet, Handler: traktAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/trakt", Method: http.MethodGet, Handler: traktCallbackHandler, UseAuth: false}) + + // FFZ — public, direct redirect + api.SetRoute(api.Route{ + Path: "/auth/ffz/authorize", Method: http.MethodGet, Handler: ffzAuthorizeHandler, UseAuth: false, + }) + api.SetRoute(api.Route{Path: "/auth/ffz", Method: http.MethodGet, Handler: ffzCallbackHandler, UseAuth: false}) + + // Steam — OpenID 2.0 + api.SetRoute(api.Route{ + Path: "/auth/steam/authorize", Method: http.MethodGet, Handler: steamAuthorizeHandler, UseAuth: true, + }) + api.SetRoute(api.Route{Path: "/auth/steam", Method: http.MethodGet, Handler: steamCallbackHandler, UseAuth: false}) +} + +// --------------------------------------------------------------------------- +// Discord +// --------------------------------------------------------------------------- + +func discordAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Discord.OAuthURI, "/") + "/auth/discord" + + params := url.Values{ + "client_id": {config.Discord.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {"identify"}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", discordAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func discordCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Discord.OAuthURI, "/") + "/auth/discord" + + tok, err := exchangeFormToken(request.Context(), discordTokenURL, url.Values{ + "client_id": {config.Discord.ClientID}, + "client_secret": {config.Discord.ClientSecret}, + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + }) + if err != nil { + logger.Error.Println("Discord token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + discordID, err := fetchDiscordUserID(request.Context(), tok.AccessToken) + if err != nil { + logger.Error.Println("Discord user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Discord user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + _ = userID // stored in state; could be used to link to internal user row + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), discordID, common.DISCORD, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Discord OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("discord")) +} + +func discordJoinHandler(writer http.ResponseWriter, request *http.Request) { + config := utils.LoadConfig() + + params := url.Values{ + "client_id": {config.Discord.ClientID}, + "permissions": {"414464674880"}, + "scope": {"bot identify messages.read guilds.members.read"}, + } + + target := fmt.Sprintf("https://discord.com/oauth2/authorize?%s", params.Encode()) + http.Redirect(writer, request, target, http.StatusFound) +} + +// --------------------------------------------------------------------------- +// Spotify +// --------------------------------------------------------------------------- + +func spotifyAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Spotify.OAuthURI, "/") + "/auth/spotify" + + params := url.Values{ + "client_id": {config.Spotify.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {spotifyScopes}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", spotifyAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func spotifyCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Spotify.OAuthURI, "/") + "/auth/spotify" + + tok, err := exchangeFormToken(request.Context(), spotifyTokenURL, url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + }, withBasicAuth(config.Spotify.ClientID, config.Spotify.ClientSecret)) + if err != nil { + logger.Error.Println("Spotify token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + spotifyID, err := fetchJSONField( + request.Context(), spotifyMeURL, + map[string]string{"Authorization": "Bearer " + tok.AccessToken}, + "id", + ) + if err != nil { + logger.Error.Println("Spotify me fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Spotify user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), spotifyID, common.SPOTIFY, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Spotify OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("spotify")) +} + +// --------------------------------------------------------------------------- +// Kick +// --------------------------------------------------------------------------- + +func kickAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state, _, codeChallenge := newKickOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Kick.OAuthURI, "/") + "/auth/kick" + + params := url.Values{ + "client_id": {config.Kick.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {kickScopes}, + "state": {state}, + "code_challenge": {codeChallenge}, + "code_challenge_method": {"S256"}, + } + + target := fmt.Sprintf("%s?%s", kickAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func kickCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + codeVerifier := extractKickCodeVerifier(state) + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Kick.OAuthURI, "/") + "/auth/kick" + + tok, err := exchangeFormToken(request.Context(), kickTokenURL, url.Values{ + "client_id": {config.Kick.ClientID}, + "client_secret": {config.Kick.ClientSecret}, + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "code_verifier": {codeVerifier}, + }) + if err != nil { + logger.Error.Println("Kick token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + kickID, err := fetchJSONField( + request.Context(), "https://id.kick.com/oauth/user-info", + map[string]string{"Authorization": "Bearer " + tok.AccessToken}, + "sub", + ) + if err != nil { + logger.Error.Println("Kick user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Kick user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), kickID, common.KICK, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Kick OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("kick")) +} + +// --------------------------------------------------------------------------- +// AniList +// --------------------------------------------------------------------------- + +func anilistAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Anilist.OAuthURI, "/") + "/auth/anilist" + + params := url.Values{ + "client_id": {config.Anilist.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", anilistAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func anilistCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Anilist.OAuthURI, "/") + "/auth/anilist" + + tok, err := exchangeJSONToken(request.Context(), anilistTokenURL, map[string]any{ + "grant_type": "authorization_code", + "client_id": config.Anilist.ClientID, + "client_secret": config.Anilist.ClientSecret, + "redirect_uri": redirectURI, + "code": code, + }) + if err != nil { + logger.Error.Println("Anilist token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + anilistID, err := fetchAnilistUserID(request.Context(), tok.AccessToken) + if err != nil { + logger.Error.Println("Anilist user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Anilist user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), anilistID, common.ANILIST, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Anilist OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("anilist")) +} + +// --------------------------------------------------------------------------- +// Trakt +// --------------------------------------------------------------------------- + +func traktAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + state := newOAuthState(user.ID) + redirectURI := strings.TrimRight(config.Trakt.OAuthURI, "/") + "/auth/trakt" + + params := url.Values{ + "client_id": {config.Trakt.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "state": {state}, + } + + target := fmt.Sprintf("%s?%s", traktAuthorizeURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func traktCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + code := query.Get("code") + state := query.Get("state") + + userID, ok := consumeOAuthState(state) + if !ok { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Invalid or expired state")) + + return + } + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.Trakt.OAuthURI, "/") + "/auth/trakt" + + tok, err := exchangeJSONToken(request.Context(), traktTokenURL, map[string]any{ + "code": code, + "client_id": config.Trakt.ClientID, + "client_secret": config.Trakt.ClientSecret, + "redirect_uri": redirectURI, + "grant_type": "authorization_code", + }) + if err != nil { + logger.Error.Println("Trakt token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Token exchange failed")) + + return + } + + traktID, err := fetchJSONField( + request.Context(), traktMeURL, + map[string]string{ + "Authorization": "Bearer " + tok.AccessToken, + "trakt-api-version": "2", + "trakt-api-key": config.Trakt.ClientID, + }, + "ids.slug", + ) + if err != nil { + logger.Error.Println("Trakt user fetch failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Failed to fetch Trakt user")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK || userID == 0 { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), traktID, common.TRAKT, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert Trakt OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("trakt")) +} + +// --------------------------------------------------------------------------- +// FFZ — no JWT, no state, direct redirect +// --------------------------------------------------------------------------- + +func ffzAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.FFZ.OAuthURI, "/") + "/auth/ffz" + + params := url.Values{ + "client_id": {config.FFZ.ClientID}, + "redirect_uri": {redirectURI}, + "response_type": {"code"}, + "scope": {ffzScopes}, + } + + http.Redirect(writer, request, fmt.Sprintf("%s?%s", ffzAuthorizeURL, params.Encode()), http.StatusFound) +} + +func ffzCallbackHandler(writer http.ResponseWriter, request *http.Request) { + code := request.URL.Query().Get("code") + + config := utils.LoadConfig() + redirectURI := strings.TrimRight(config.FFZ.OAuthURI, "/") + "/auth/ffz" + + tok, err := exchangeFormToken(request.Context(), ffzTokenURL, url.Values{ + "client_id": {config.FFZ.ClientID}, + "client_secret": {config.FFZ.ClientSecret}, + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + }) + if err != nil { + logger.Error.Println("FFZ token exchange failed:", err) + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("FFZ token exchange failed")) + + return + } + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + scope := strings.Fields(tok.Scope) + if err := postgres.UpsertOAuthToken( + request.Context(), config.FFZ.ID, common.FFZ, + tok.AccessToken, tok.RefreshToken, scope, tok.ExpiresIn, + ); err != nil { + logger.Warn.Println("Failed to upsert FFZ OAuth token:", err) + } + + sendHTML(writer, http.StatusOK, oauthSuccessHTML("ffz")) +} + +// --------------------------------------------------------------------------- +// Steam — OpenID 2.0 +// --------------------------------------------------------------------------- + +func steamAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := authUser(request) + if !ok { + api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[string]{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + config := utils.LoadConfig() + returnTo := fmt.Sprintf("%sauth/steam?user_id=%d", + strings.TrimRight(config.Anilist.OAuthURI, "/")+"/", + user.ID, + ) + realm := strings.TrimRight(config.Anilist.OAuthURI, "/") + "/" + + params := url.Values{ + "openid.ns": {steamOpenIDNS}, + "openid.mode": {"checkid_setup"}, + "openid.return_to": {returnTo}, + "openid.realm": {realm}, + "openid.identity": {"http://specs.openid.net/auth/2.0/identifier_select"}, + "openid.claimed_id": {"http://specs.openid.net/auth/2.0/identifier_select"}, + } + + target := fmt.Sprintf("%s?%s", steamOpenIDURL, params.Encode()) + api.GenericResponse(writer, http.StatusOK, common.GenericResponse[string]{ + Data: &[]string{target}, + }, start) +} + +func steamCallbackHandler(writer http.ResponseWriter, request *http.Request) { + query := request.URL.Query() + + // Re-verify the OpenID assertion with Steam + verifyParams := url.Values{} + maps.Copy(verifyParams, query) + verifyParams.Set("openid.mode", "check_authentication") + + resp, err := http.PostForm(steamOpenIDURL, verifyParams) //nolint:noctx + if err != nil || resp.StatusCode != http.StatusOK { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam verification failed")) + + return + } + defer resp.Body.Close() //nolint:errcheck + + body, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(body), "is_valid:true") { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam assertion invalid")) + + return + } + + claimedID := query.Get("openid.claimed_id") + // claimed_id ends with the 64-bit SteamID, e.g. https://steamcommunity.com/openid/id/76561198... + parts := strings.Split(claimedID, "/") + steamID64Str := parts[len(parts)-1] + + var steamID64 int64 + if _, err := fmt.Sscanf(steamID64Str, "%d", &steamID64); err != nil { + sendHTML(writer, http.StatusBadRequest, oauthErrorHTML("Invalid Steam ID")) + + return + } + steamFriendCode := fmt.Sprintf("%d", steamID64-steamFriendOffset) + + postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !pgOK { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Internal Server Error")) + + return + } + + if err := postgres.UpsertOAuthToken( + request.Context(), steamID64Str, common.STEAM, + "", "", []string{}, 0, + ); err != nil { + logger.Warn.Println("Failed to upsert Steam token:", err) + } + + sendHTML(writer, http.StatusOK, oauthPostMessage(map[string]any{ + "platform": "steam", + "ok": true, + "steam_id": steamID64Str, + "friend_code": steamFriendCode, + })) +} + +// --------------------------------------------------------------------------- +// OAuth helpers +// --------------------------------------------------------------------------- + +type simpleTokenResponse struct { //nolint:govet + AccessToken string `json:"access_token"` //nolint:gosec + RefreshToken string `json:"refresh_token"` //nolint:gosec + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type exchangeOption func(*http.Request) + +func withBasicAuth(clientID, clientSecret string) exchangeOption { + return func(req *http.Request) { + req.SetBasicAuth(clientID, clientSecret) + } +} + +// exchangeFormToken does a standard application/x-www-form-urlencoded code exchange. +func exchangeFormToken( + ctx context.Context, + tokenURL string, + params url.Values, + opts ...exchangeOption, +) (*simpleTokenResponse, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(params.Encode())) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + for _, opt := range opts { + opt(req) + } + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return nil, err + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode >= 400 { //nolint:mnd + b, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(b)) //nolint:err113 + } + + var tok simpleTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tok); err != nil { + return nil, err + } + + return &tok, nil +} + +// exchangeJSONToken does a code exchange with application/json body. +func exchangeJSONToken(ctx context.Context, tokenURL string, body map[string]any) (*simpleTokenResponse, error) { + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return nil, err + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode >= 400 { //nolint:mnd + b, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("token exchange returned %d: %s", resp.StatusCode, string(b)) //nolint:err113 + } + + var tok simpleTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tok); err != nil { + return nil, err + } + + return &tok, nil +} + +// fetchJSONField makes an authenticated request and returns a top-level string field. +// Nested fields can be accessed with dot notation (e.g. "ids.slug"). +func fetchJSONField( + ctx context.Context, + endpoint string, + headers map[string]string, + field string, +) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return "", err + } + + for k, v := range headers { + req.Header.Set(k, v) + } + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return "", err + } + defer resp.Body.Close() //nolint:errcheck + + var data map[string]any + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return "", err + } + + parts := strings.SplitN(field, ".", 2) //nolint:mnd + val, exists := data[parts[0]] + + if !exists { + return "", fmt.Errorf("field %q not found in response", parts[0]) //nolint:err113 + } + + if len(parts) == 2 { //nolint:mnd + nested, ok := val.(map[string]any) + if !ok { + return "", fmt.Errorf("field %q is not an object", parts[0]) //nolint:err113 + } + + val, exists = nested[parts[1]] + if !exists { + return "", fmt.Errorf("field %q not found in nested object", parts[1]) //nolint:err113 + } + } + + return fmt.Sprintf("%v", val), nil +} + +// fetchDiscordUserID calls /users/@me and returns the Discord user ID. +func fetchDiscordUserID(ctx context.Context, accessToken string) (string, error) { + return fetchJSONField( + ctx, discordMeURL, + map[string]string{"Authorization": "Bearer " + accessToken}, + "id", + ) +} + +// fetchAnilistUserID queries the AniList GraphQL API for the authenticated user's ID. +func fetchAnilistUserID(ctx context.Context, accessToken string) (string, error) { + query := `{"query":"{Viewer{id}}"}` + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, anilistGQLURL, strings.NewReader(query)) + if err != nil { + return "", err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + + client := &http.Client{Timeout: oauthHTTPTimeout} + + resp, err := client.Do(req) //nolint:gosec + if err != nil { + return "", err + } + defer resp.Body.Close() //nolint:errcheck + + var result struct { + Data struct { + Viewer struct { + ID int `json:"id"` + } `json:"Viewer"` + } `json:"data"` + } + + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + return fmt.Sprintf("%d", result.Data.Viewer.ID), nil +} From 3502d41a26a1d480689535847e0edce54544d8b9 Mon Sep 17 00:00:00 2001 From: tlos Date: Tue, 17 Mar 2026 12:01:36 -0400 Subject: [PATCH 05/20] fix: lint in db package --- common/db/loops.go | 2 +- common/db/redis.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/db/loops.go b/common/db/loops.go index acf0060..ababd1f 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -612,7 +612,7 @@ func backupPostgres( ) //nolint:gosec - cmd := exec.Command("sh", "-c", fmt.Sprintf( + cmd := exec.CommandContext(ctx, "sh", "-c", fmt.Sprintf( "PGPASSWORD=%s pg_dump -d %s -U %s -h %s | zstd -3 --threads=%d > %s", config.Postgres.Password, config.Postgres.Database, diff --git a/common/db/redis.go b/common/db/redis.go index ec49bf9..941ad85 100644 --- a/common/db/redis.go +++ b/common/db/redis.go @@ -21,7 +21,7 @@ var ErrRedisNil = redis.Nil func InitRedis(config common.Config) (*RedisClient, error) { host := config.Redis.Host if host == "" { - host = "localhost" + host = "localhost" //nolint:goconst } port := config.Redis.Port From 0d05f92a5197fef7f32ad1052c3eaff04aaac6f0 Mon Sep 17 00:00:00 2001 From: tlos Date: Tue, 17 Mar 2026 12:57:58 -0400 Subject: [PATCH 06/20] fix: suppress lint errors in existing files --- api/routes/get/login.go | 2 +- api/routes/post/redirects.go | 4 ++-- common/db/postgres.go | 2 +- common/types.go | 14 +++++++------- common/utils/requests.go | 2 +- haste/haste.go | 4 ++-- redirects/redirects_test.go | 4 ++-- uploader/uploader.go | 4 ++-- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/api/routes/get/login.go b/api/routes/get/login.go index 09ff368..a0b3964 100644 --- a/api/routes/get/login.go +++ b/api/routes/get/login.go @@ -117,7 +117,7 @@ func twitchLoginHandler(writer http.ResponseWriter, request *http.Request) { //n } // Excahnge code for access token - tokenResp, err := client.Do(req) + tokenResp, err := client.Do(req) //nolint:gosec if err != nil { http.Error(writer, "Failed to get access token", http.StatusInternalServerError) diff --git a/api/routes/post/redirects.go b/api/routes/post/redirects.go index 3ae6281..1a6b4a0 100644 --- a/api/routes/post/redirects.go +++ b/api/routes/post/redirects.go @@ -48,7 +48,7 @@ func createRedirect(writer http.ResponseWriter, request *http.Request) { //nolin key, err := postgres.GetKeyByRedirect(request.Context(), input.URL) if err == nil && key != "" { response := fmt.Sprintf("https://%s/%s", request.Host, key) - _, err = writer.Write([]byte(response)) + _, err = writer.Write([]byte(response)) //nolint:gosec if err != nil { logger.Error.Printf("Failed to write response: %v", err) } @@ -72,7 +72,7 @@ func createRedirect(writer http.ResponseWriter, request *http.Request) { //nolin } response := fmt.Sprintf("https://%s/%s", request.Host, key) - if _, err = writer.Write([]byte(response)); err != nil { + if _, err = writer.Write([]byte(response)); err != nil { //nolint:gosec logger.Error.Printf("Failed to write response: %v", err) } } diff --git a/common/db/postgres.go b/common/db/postgres.go index 8a9d65e..c781c36 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -49,7 +49,7 @@ func InitPostgres(ctx context.Context, config common.Config) (*PostgresClient, e return &PostgresClient{pool}, nil } -func loadConfig(config common.Config) (*pgxpool.Config, error) { +func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unparam user := config.Postgres.User if user == "" { user = "postgres" diff --git a/common/types.go b/common/types.go index 3a25506..36356ba 100644 --- a/common/types.go +++ b/common/types.go @@ -180,7 +180,7 @@ type ChannelSettings struct { // CommandSettings represents the settings for a command in a channel, including its permissions, cooldowns, // and usage limits. -type CommandSettings struct { +type CommandSettings struct { //nolint:govet ChannelID string `json:"channel_id"` Command string `json:"command"` Permission *string `json:"permission,omitempty"` @@ -200,8 +200,8 @@ type CommandSettings struct { type PlatformOauth struct { AddedAt time.Time `json:"added_at"` PlatformID string `json:"platform_id"` - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` //nolint:gosec + RefreshToken string `json:"refresh_token"` //nolint:gosec Platform Platforms `json:"platform"` Scope []string `json:"scope"` ExpiresIn int `json:"expires_in"` @@ -315,8 +315,8 @@ type TwitchValidation struct { // GenericOAUTHResponse represents a generic OAuth response structure. type GenericOAUTHResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` //nolint:gosec + RefreshToken string `json:"refresh_token"` //nolint:gosec TokenType string `json:"token_type"` Scope []string `json:"scope"` ExpiresIn int `json:"expires_in"` @@ -436,7 +436,7 @@ type EmoteStat struct { } // EmoteHistoryEntry represents a single per-user emote usage record from Clickhouse. -type EmoteHistoryEntry struct { +type EmoteHistoryEntry struct { //nolint:govet EmoteID string `json:"emote_id"` EmoteName string `json:"emote_name"` EmoteAlias string `json:"emote_alias"` @@ -448,7 +448,7 @@ type EmoteHistoryEntry struct { } // PageInfo contains cursor-based pagination metadata. -type PageInfo struct { +type PageInfo struct { //nolint:govet HasNextPage bool `json:"hasNextPage"` Cursor string `json:"cursor"` } diff --git a/common/utils/requests.go b/common/utils/requests.go index 37f9d44..35c5b54 100644 --- a/common/utils/requests.go +++ b/common/utils/requests.go @@ -69,7 +69,7 @@ func makeRequest( Timeout: time.Second * 10, } - res, err := client.Do(req) + res, err := client.Do(req) //nolint:gosec if err != nil { return nil, err } diff --git a/haste/haste.go b/haste/haste.go index 2f01cab..cb9c07d 100644 --- a/haste/haste.go +++ b/haste/haste.go @@ -202,7 +202,7 @@ func (h *hastebin) handleGetRaw(writer http.ResponseWriter, request *http.Reques writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("X-Cache-Hit", "HIT") writer.WriteHeader(http.StatusOK) - _, err = writer.Write([]byte(cache)) + _, err = writer.Write([]byte(cache)) //nolint:gosec if err != nil { logger.Warn.Println("Failed to write document: ", err) } @@ -224,7 +224,7 @@ func (h *hastebin) handleGetRaw(writer http.ResponseWriter, request *http.Reques writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("X-Cache-Hit", "MISS") writer.WriteHeader(http.StatusOK) - _, err = writer.Write([]byte(data)) + _, err = writer.Write([]byte(data)) //nolint:gosec if err != nil { logger.Warn.Println("Failed to write document: ", err) } diff --git a/redirects/redirects_test.go b/redirects/redirects_test.go index a4990d8..8642eda 100644 --- a/redirects/redirects_test.go +++ b/redirects/redirects_test.go @@ -22,10 +22,10 @@ func TestRedirects__CheckProtocolFormatAfterProtocolReformat(t *testing.T) { t.Run(tc.input, func(t *testing.T) { cleanedURL := redirector.cleanRedirectProtocolSoLinksActuallyWork(tc.input) if !strings.HasPrefix(cleanedURL, "https://") { - t.Errorf("Expected cleaned URL to start with 'https://', got %q", cleanedURL) + t.Errorf("Expected cleaned URL to start with 'https://', got %q", cleanedURL) //nolint:forbidigo } if cleanedURL != tc.expected { - t.Errorf("Expected %q, got %q", tc.expected, cleanedURL) + t.Errorf("Expected %q, got %q", tc.expected, cleanedURL) //nolint:forbidigo } }) } diff --git a/uploader/uploader.go b/uploader/uploader.go index 0bce66a..86619dc 100644 --- a/uploader/uploader.go +++ b/uploader/uploader.go @@ -233,7 +233,7 @@ func (u *uploader) handleGet(writer http.ResponseWriter, request *http.Request) writer.Header().Set("Content-Type", contentType) writer.Header().Set("X-Cache-Hit", "HIT") writer.WriteHeader(http.StatusOK) - _, err = writer.Write(cache) + _, err = writer.Write(cache) //nolint:gosec if err != nil { logger.Warn.Printf("Failed to write document: %v", err) } @@ -262,7 +262,7 @@ func (u *uploader) handleGet(writer http.ResponseWriter, request *http.Request) writer.Header().Set("Content-Disposition", "inline; filename=\""+*name+"\"") } writer.Header().Set("Content-Type", mimeType) - _, err = writer.Write(data) + _, err = writer.Write(data) //nolint:gosec if err != nil { logger.Error.Printf("Error writing file: %v", err) http.Error(writer, "Failed to write file", http.StatusInternalServerError) From fcb9a535a09791790b8274e96ec6a604589e5e98 Mon Sep 17 00:00:00 2001 From: tlos Date: Wed, 18 Mar 2026 00:25:45 -0400 Subject: [PATCH 07/20] fix: address clanker comments --- api/routes/del/command_settings.go | 6 +- api/routes/get/auth.go | 92 +++++++++++++++++++++--------- api/routes/get/emotes.go | 9 ++- 3 files changed, 76 insertions(+), 31 deletions(-) diff --git a/api/routes/del/command_settings.go b/api/routes/del/command_settings.go index 785420f..c14bbd7 100644 --- a/api/routes/del/command_settings.go +++ b/api/routes/del/command_settings.go @@ -73,10 +73,10 @@ func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { return } - // Without a specific command name, this is a no-op (not yet implemented upstream). + // command is required — this endpoint resets a single command override back to defaults. if command == "" { - api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{ - Data: &[]any{}, + api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Missing required field: command"}}, }, start) return diff --git a/api/routes/get/auth.go b/api/routes/get/auth.go index 334b8c9..feb75c2 100644 --- a/api/routes/get/auth.go +++ b/api/routes/get/auth.go @@ -67,20 +67,23 @@ const ( // oauthStates maps state strings to true with a TTL for replay-attack prevention. var oauthStates sync.Map //nolint:gochecknoglobals +// kickPKCEVerifiers maps OAuth state strings to PKCE code verifiers, with the same TTL as oauthStates. +var kickPKCEVerifiers sync.Map //nolint:gochecknoglobals + func newOAuthState(userID int) string { nonce := uuid.New().String() state := fmt.Sprintf("%s:%d", nonce, userID) oauthStates.Store(state, true) - go func(s string) { - time.Sleep(oauthStateTTL) - oauthStates.Delete(s) - }(state) + time.AfterFunc(oauthStateTTL, func() { + oauthStates.Delete(state) + }) return state } -// newKickOAuthState creates a state with an embedded PKCE code verifier. -// State format: {nonce}:{userID}:{codeVerifier}. +// newKickOAuthState creates an OAuth state and associated PKCE code verifier. +// State format: {nonce}:{userID}. +// The PKCE code verifier is stored server-side and must be retrieved on callback. func newKickOAuthState(userID int) (state, codeVerifier, codeChallenge string) { buf := make([]byte, 32) if _, err := rand.Read(buf); err != nil { @@ -93,12 +96,15 @@ func newKickOAuthState(userID int) (state, codeVerifier, codeChallenge string) { codeChallenge = base64.RawURLEncoding.EncodeToString(sum[:]) nonce := uuid.New().String() - state = fmt.Sprintf("%s:%d:%s", nonce, userID, codeVerifier) + state = fmt.Sprintf("%s:%d", nonce, userID) + + // Store replay-protection state and PKCE verifier with the same TTL. oauthStates.Store(state, true) - go func(s string) { - time.Sleep(oauthStateTTL) - oauthStates.Delete(s) - }(state) + kickPKCEVerifiers.Store(state, codeVerifier) + time.AfterFunc(oauthStateTTL, func() { + oauthStates.Delete(state) + kickPKCEVerifiers.Delete(state) + }) return state, codeVerifier, codeChallenge } @@ -122,14 +128,16 @@ func consumeOAuthState(state string) (int, bool) { return userID, true } -// extractKickCodeVerifier returns the codeVerifier embedded in a Kick state string. +// extractKickCodeVerifier retrieves the server-side PKCE code verifier for a Kick state. func extractKickCodeVerifier(state string) string { - parts := strings.SplitN(state, ":", 3) //nolint:mnd - if len(parts) < 3 { //nolint:mnd + v, ok := kickPKCEVerifiers.Load(state) + if !ok { return "" } - return parts[2] + s, _ := v.(string) + + return s } // authUser retrieves the authenticated user from the request context. @@ -155,10 +163,17 @@ func oauthPostMessage(payload map[string]any) string { data, _ := json.Marshal(payload) //nolint:errchkjson return fmt.Sprintf(``, string(data)) } @@ -782,12 +797,18 @@ func steamAuthorizeHandler(writer http.ResponseWriter, request *http.Request) { return } - config := utils.LoadConfig() - returnTo := fmt.Sprintf("%sauth/steam?user_id=%d", - strings.TrimRight(config.Anilist.OAuthURI, "/")+"/", - user.ID, - ) - realm := strings.TrimRight(config.Anilist.OAuthURI, "/") + "/" + // Derive the base URL from the incoming request to ensure Steam OpenID + // callbacks use the correct host and scheme for this API instance. + scheme := "https" + if forwarded := request.Header.Get("X-Forwarded-Proto"); forwarded != "" { + // Use the first value if multiple are provided. + scheme = strings.Split(forwarded, ",")[0] + } else if request.TLS == nil { + scheme = "http" + } + baseURL := fmt.Sprintf("%s://%s/", scheme, request.Host) + returnTo := fmt.Sprintf("%sauth/steam?user_id=%d", baseURL, user.ID) + realm := baseURL params := url.Values{ "openid.ns": {steamOpenIDNS}, @@ -812,7 +833,26 @@ func steamCallbackHandler(writer http.ResponseWriter, request *http.Request) { maps.Copy(verifyParams, query) verifyParams.Set("openid.mode", "check_authentication") - resp, err := http.PostForm(steamOpenIDURL, verifyParams) //nolint:noctx + ctx := request.Context() + client := &http.Client{ + Timeout: 5 * time.Second, + } + + req, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + steamOpenIDURL, + strings.NewReader(verifyParams.Encode()), + ) + if err != nil { + sendHTML(writer, http.StatusInternalServerError, oauthErrorHTML("Steam verification failed")) + + return + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := client.Do(req) //nolint:gosec if err != nil || resp.StatusCode != http.StatusOK { sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam verification failed")) diff --git a/api/routes/get/emotes.go b/api/routes/get/emotes.go index c17b79f..8bcfd04 100644 --- a/api/routes/get/emotes.go +++ b/api/routes/get/emotes.go @@ -68,7 +68,8 @@ func normaliseProvider(provider string) []string { case "ALL", "": return []string{} default: - return []string{"STV"} + // For unknown/unsupported providers, apply no provider filter. + return []string{} } } @@ -139,8 +140,12 @@ func getEmoteStats(writer http.ResponseWriter, request *http.Request) { //nolint offset := 0 if cursor := query.Get("after"); cursor != "" { - if o, err := decodeCursor(cursor); err == nil { + if o, err := decodeCursor(cursor); err == nil && o >= 0 { offset = o + } else { + writeEmoteError(writer, http.StatusBadRequest, start) + + return } } From 8dd3e3e36abc376919748e28c94dbe23eb313be8 Mon Sep 17 00:00:00 2001 From: tlos Date: Sat, 2 May 2026 12:39:00 -0400 Subject: [PATCH 08/20] fix: stop leaking NATS connections per request --- api/api.go | 3 ++- api/middleware/context.go | 4 ++++ api/routes/get/help.go | 13 ++++++++++++- common/utils/broker.go | 9 ++++----- main.go | 2 +- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/api/api.go b/api/api.go index 63621b7..4847d1e 100644 --- a/api/api.go +++ b/api/api.go @@ -43,6 +43,7 @@ func StartServing( postgres *db.PostgresClient, redis *db.RedisClient, clickhouse *db.ClickhouseClient, + nats *utils.NatsClient, metrics *utils.Metrics, ) error { if config.API.Host == "" || config.API.Port == "" { @@ -54,7 +55,7 @@ func StartServing( } api.router.Use(middleware.LogRequest(metrics)) - api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse)) + api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse, nats)) api.router.Use(middleware.NewRateLimiter(100, 1*time.Minute, redis)) authenticator := middleware.NewAuthenticator(config.Twitch.ClientSecret, GenericResponse) diff --git a/api/middleware/context.go b/api/middleware/context.go index 488e46a..a773997 100644 --- a/api/middleware/context.go +++ b/api/middleware/context.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/utils" ) // ErrMissingContext is returned when a database client is not found in the request context. @@ -18,6 +19,7 @@ const ( PostgresKey contextKey = "postgres" RedisKey contextKey = "redis" ClickhouseKey contextKey = "clickhouse" + NatsKey contextKey = "nats" ) // InjectDatabases returns a middleware that injects DB clients into the request context. @@ -25,12 +27,14 @@ func InjectDatabases( postgres *db.PostgresClient, redis *db.RedisClient, clickhouse *db.ClickhouseClient, + nats *utils.NatsClient, ) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.WithValue(r.Context(), PostgresKey, postgres) ctx = context.WithValue(ctx, RedisKey, redis) ctx = context.WithValue(ctx, ClickhouseKey, clickhouse) + ctx = context.WithValue(ctx, NatsKey, nats) next.ServeHTTP(w, r.WithContext(ctx)) }) diff --git a/api/routes/get/help.go b/api/routes/get/help.go index bea0775..e1047ae 100644 --- a/api/routes/get/help.go +++ b/api/routes/get/help.go @@ -113,7 +113,18 @@ func getCommandsHandler(writer http.ResponseWriter, request *http.Request) { } writer.Header().Set("X-Cache-Hit", "MISS") - response, err := utils.BridgeRequest( + nats, ok := request.Context().Value(middleware.NatsKey).(*utils.NatsClient) + if !ok || nats == nil { + logger.Error.Println("NATS client not found in context") + api.GenericResponse(writer, http.StatusServiceUnavailable, HelpResponse{ + Data: &[]common.Command{}, + Errors: &[]common.ErrorMessage{{Message: "Service unavailable"}}, + }, start) + + return + } + + response, err := nats.BridgeRequest( 5*time.Second, "get-commands", ) diff --git a/common/utils/broker.go b/common/utils/broker.go index 7b4304d..ad3d805 100644 --- a/common/utils/broker.go +++ b/common/utils/broker.go @@ -139,16 +139,15 @@ func (n *NatsClient) handleMessage(message *nats.Msg) { } // BridgeRequest sends a request to the NATS server and waits for a response. -func BridgeRequest( +func (n *NatsClient) BridgeRequest( ttl time.Duration, request string, ) ([]byte, error) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - return nil, fmt.Errorf("failed to publish request: %w", err) + if n == nil || n.Client == nil { + return nil, errNatsNotConnected } - response, err := nc.Request( + response, err := n.Client.Request( "github.com/Potat-Industries/potat-api.job-request", []byte(request), ttl, diff --git a/main.go b/main.go index 36314b7..1d81628 100644 --- a/main.go +++ b/main.go @@ -110,7 +110,7 @@ func main() { //nolint:cyclop apiChan := make(chan error) if config.API.Enabled { go func() { - apiChan <- api.StartServing(*config, postgres, redis, clickhouse, metrics) + apiChan <- api.StartServing(*config, postgres, redis, clickhouse, nats, metrics) }() } From d5a5a721d9d1e2bbb48cbb469b338c04a9521c47 Mon Sep 17 00:00:00 2001 From: tlos Date: Sun, 3 May 2026 14:21:14 -0400 Subject: [PATCH 09/20] refactor: consolidate channel auth into shared middleware helper Extract IsChannelAuthorized and GetTwitchPlatformID into api/middleware/authorize.go, replacing four separate identical implementations across get, del, put, and patch route packages. --- api/middleware/authorize.go | 46 +++++++++++++++++ api/routes/del/command_settings.go | 32 +----------- api/routes/get/ambassadors.go | 82 ++++++++++++++++++++++++++++++ api/routes/get/auth.go | 11 ---- api/routes/get/command_settings.go | 30 +---------- api/routes/patch/settings.go | 24 ++------- api/routes/put/command_settings.go | 32 +----------- 7 files changed, 137 insertions(+), 120 deletions(-) create mode 100644 api/middleware/authorize.go create mode 100644 api/routes/get/ambassadors.go diff --git a/api/middleware/authorize.go b/api/middleware/authorize.go new file mode 100644 index 0000000..64c9506 --- /dev/null +++ b/api/middleware/authorize.go @@ -0,0 +1,46 @@ +package middleware + +import ( + "net/http" + "slices" + + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" +) + +// GetTwitchPlatformID returns the Twitch platform user ID from a user's connections, or empty string if not found. +func GetTwitchPlatformID(user *common.User) string { + for _, conn := range user.Connections { + if conn.Platform == common.TWITCH { + return conn.UserID + } + } + + return "" +} + +// IsChannelAuthorized returns true if the user is an admin, the broadcaster of the given channel, +// or a channel ambassador. It is used as a shared auth check across channel-scoped write routes. +func IsChannelAuthorized( + request *http.Request, + user *common.User, + channelID string, + postgres *db.PostgresClient, +) bool { + if user.Level >= int(common.ADMIN) { + return true + } + + twitchID := GetTwitchPlatformID(user) + + if twitchID == channelID { + return true + } + + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil { + return false + } + + return slices.Contains(ambassadors, twitchID) +} diff --git a/api/routes/del/command_settings.go b/api/routes/del/command_settings.go index c14bbd7..abe7286 100644 --- a/api/routes/del/command_settings.go +++ b/api/routes/del/command_settings.go @@ -3,7 +3,6 @@ package del import ( "net/http" - "slices" "time" "github.com/Potat-Industries/potat-api/api" @@ -65,7 +64,7 @@ func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { return } - if !isDelAuthorized(request, user, channelID, postgres) { + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, }, start) @@ -96,33 +95,4 @@ func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { }, start) } -func isDelAuthorized( - request *http.Request, - user *common.User, - channelID string, - postgres *db.PostgresClient, -) bool { - if user.Level >= int(common.ADMIN) { - return true - } - - var twitchID string - for _, conn := range user.Connections { - if conn.Platform == common.TWITCH { - twitchID = conn.UserID - - break - } - } - if twitchID == channelID { - return true - } - - ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) - if err != nil { - return false - } - - return slices.Contains(ambassadors, twitchID) -} diff --git a/api/routes/get/ambassadors.go b/api/routes/get/ambassadors.go new file mode 100644 index 0000000..13001bd --- /dev/null +++ b/api/routes/get/ambassadors.go @@ -0,0 +1,82 @@ +// Package get contains routes for http.MethodGet requests. +package get + +import ( + "net/http" + "time" + + "github.com/Potat-Industries/potat-api/api" + "github.com/Potat-Industries/potat-api/api/middleware" + "github.com/Potat-Industries/potat-api/common" + "github.com/Potat-Industries/potat-api/common/db" + "github.com/Potat-Industries/potat-api/common/logger" +) + +// AmbassadorsResponse is the response type for GET /channel/ambassadors. +type AmbassadorsResponse = common.GenericResponse[string] + +func init() { + api.SetRoute(api.Route{ + Path: "/channel/ambassadors", + Method: http.MethodGet, + Handler: getAmbassadorsHandler, + UseAuth: true, + }) +} + +func getAmbassadorsHandler(writer http.ResponseWriter, request *http.Request) { + start := time.Now() + + user, ok := request.Context().Value(middleware.AuthedUser).(*common.User) + if !ok || user == nil { + api.GenericResponse(writer, http.StatusUnauthorized, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}}, + }, start) + + return + } + + postgres, ok := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient) + if !ok { + api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + channelID := resolveChannelID(request, user) + if channelID == "" { + api.GenericResponse(writer, http.StatusBadRequest, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Could not resolve channel ID"}}, + }, start) + + return + } + + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) + + return + } + + ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) + if err != nil { + logger.Error.Printf("Error fetching ambassadors: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{ + Errors: &[]common.ErrorMessage{{Message: "Failed to fetch ambassadors"}}, + }, start) + + return + } + + if ambassadors == nil { + ambassadors = []string{} + } + + api.GenericResponse(writer, http.StatusOK, AmbassadorsResponse{ + Data: &ambassadors, + }, start) +} diff --git a/api/routes/get/auth.go b/api/routes/get/auth.go index feb75c2..be0a152 100644 --- a/api/routes/get/auth.go +++ b/api/routes/get/auth.go @@ -147,17 +147,6 @@ func authUser(request *http.Request) (*common.User, bool) { return user, ok && user != nil } -// getTwitchPlatformID finds the Twitch platform ID from a user's connections. -func getTwitchPlatformID(user *common.User) string { - for _, conn := range user.Connections { - if conn.Platform == common.TWITCH { - return conn.UserID - } - } - - return "" -} - // oauthPostMessage builds the postMessage HTML used to close popups. func oauthPostMessage(payload map[string]any) string { data, _ := json.Marshal(payload) //nolint:errchkjson diff --git a/api/routes/get/command_settings.go b/api/routes/get/command_settings.go index 3b99c58..4253aad 100644 --- a/api/routes/get/command_settings.go +++ b/api/routes/get/command_settings.go @@ -3,7 +3,6 @@ package get import ( "net/http" - "slices" "time" "github.com/Potat-Industries/potat-api/api" @@ -26,31 +25,6 @@ func init() { }) } -// isChannelAuthorized returns true if the user is admin, the broadcaster, or a channel ambassador. -func isChannelAuthorized( - request *http.Request, - user *common.User, - channelID string, - postgres *db.PostgresClient, -) bool { - if int(common.ADMIN) <= user.Level { - return true - } - - twitchID := getTwitchPlatformID(user) - - if twitchID == channelID { - return true - } - - ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) - if err != nil { - return false - } - - return slices.Contains(ambassadors, twitchID) -} - // resolveChannelID returns the channel ID from ?id= or defaults to the authenticated user's Twitch ID. func resolveChannelID(request *http.Request, user *common.User) string { if id := request.URL.Query().Get("id"); id != "" { @@ -62,7 +36,7 @@ func resolveChannelID(request *http.Request, user *common.User) string { return id } - return getTwitchPlatformID(user) + return middleware.GetTwitchPlatformID(user) } func getCommandSettingsHandler(writer http.ResponseWriter, request *http.Request) { @@ -96,7 +70,7 @@ func getCommandSettingsHandler(writer http.ResponseWriter, request *http.Request return } - if !isChannelAuthorized(request, user, channelID, postgres) { + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { api.GenericResponse(writer, http.StatusForbidden, CommandSettingsResponse{ Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, }, start) diff --git a/api/routes/patch/settings.go b/api/routes/patch/settings.go index 606dd67..9292b1a 100644 --- a/api/routes/patch/settings.go +++ b/api/routes/patch/settings.go @@ -4,7 +4,6 @@ package patch import ( "encoding/json" "net/http" - "slices" "time" "github.com/Potat-Industries/potat-api/api" @@ -116,25 +115,12 @@ func patchChannelSettings(writer http.ResponseWriter, request *http.Request) { / return } - // Only the broadcaster or an ambassador (or admin) may update settings. - var twitchID string - for _, conn := range user.Connections { - if conn.Platform == common.TWITCH { - twitchID = conn.UserID - - break - } - } - - if twitchID != channelID && user.Level < int(common.ADMIN) { - ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) - if err != nil || !slices.Contains(ambassadors, twitchID) { - api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ - Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, - }, start) + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { + api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, + }, start) - return - } + return } var input common.ChannelSettings diff --git a/api/routes/put/command_settings.go b/api/routes/put/command_settings.go index ea430c2..9883ddb 100644 --- a/api/routes/put/command_settings.go +++ b/api/routes/put/command_settings.go @@ -4,7 +4,6 @@ package put import ( "encoding/json" "net/http" - "slices" "time" "github.com/Potat-Industries/potat-api/api" @@ -79,7 +78,7 @@ func putCommandSettings(writer http.ResponseWriter, request *http.Request) { //n input.ChannelID = channelID - if !isCmdAuthorized(request, user, channelID, postgres) { + if !middleware.IsChannelAuthorized(request, user, channelID, postgres) { api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{ Errors: &[]common.ErrorMessage{{Message: "Forbidden"}}, }, start) @@ -101,33 +100,4 @@ func putCommandSettings(writer http.ResponseWriter, request *http.Request) { //n }, start) } -func isCmdAuthorized( - request *http.Request, - user *common.User, - channelID string, - postgres *db.PostgresClient, -) bool { - if user.Level >= int(common.ADMIN) { - return true - } - - var twitchID string - for _, conn := range user.Connections { - if conn.Platform == common.TWITCH { - twitchID = conn.UserID - - break - } - } - if twitchID == channelID { - return true - } - - ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH) - if err != nil { - return false - } - - return slices.Contains(ambassadors, twitchID) -} From c3178d0385199cd40fae7a921b9deb24ac353df2 Mon Sep 17 00:00:00 2001 From: tlos Date: Wed, 6 May 2026 22:43:01 -0400 Subject: [PATCH 10/20] fix NATS subject names and add initialized log --- common/db/loops.go | 2 +- common/utils/broker.go | 2 +- main.go | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/common/db/loops.go b/common/db/loops.go index ababd1f..ceddac9 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -674,7 +674,7 @@ func backupPostgres( return } - err = natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", jsonMessage) + err = natsClient.Publish("potat-api.postgres-backup", jsonMessage) if err != nil { logger.Error.Println("Failed to publish to queue:", err) diff --git a/common/utils/broker.go b/common/utils/broker.go index ad3d805..78c07b1 100644 --- a/common/utils/broker.go +++ b/common/utils/broker.go @@ -66,7 +66,7 @@ func (n *NatsClient) subNatsStream(ctx context.Context) error { return err } - err = n.Client.Publish("github.com/Potat-Industries/potat-api.connected", []byte(nil)) + err = n.Client.Publish("potat-api.connected", []byte(nil)) if err != nil { logger.Warn.Printf("Failed to publish connected message: %v", err) } diff --git a/main.go b/main.go index 1d81628..bf82c0e 100644 --- a/main.go +++ b/main.go @@ -225,6 +225,7 @@ func initNats(ctx context.Context) *utils.NatsClient { if err != nil { logger.Error.Panicf("Failed to connect to RabbitMQ: %v", err) } + logger.Info.Println("NATS initialized") return nats } From d4f72c2b992c463aa3f94d41ba5d4c1d241f1d66 Mon Sep 17 00:00:00 2001 From: tlos Date: Wed, 6 May 2026 22:49:35 -0400 Subject: [PATCH 11/20] remove bench flag causing gotestfmt panic --- .github/workflows/test.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 27b1b2d..6118406 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,11 +32,8 @@ jobs: - name: Run test run: | - TEST_BENCH_OPTION="-bench=." - set -euo pipefail go-acc -o cover.out ./... -- \ - ${TEST_BENCH_OPTION} \ -json \ -v -race 2>&1 | grep -v '^go: downloading' | tee /tmp/gotest.log | gotestfmt From 5dd255e6d3c0220b754ddd0486984bd0cb98ef48 Mon Sep 17 00:00:00 2001 From: tlos Date: Wed, 6 May 2026 22:55:49 -0400 Subject: [PATCH 12/20] fix: use native go test --- .github/workflows/test.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6118406..bfb825a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,9 +22,6 @@ jobs: with: go-version: 1.23 - - name: Setup go-acc - run: go install github.com/ory/go-acc@latest - - name: Setup gotestfmt uses: haveyoudebuggedit/gotestfmt-action@v2 with: @@ -33,9 +30,7 @@ jobs: - name: Run test run: | set -euo pipefail - go-acc -o cover.out ./... -- \ - -json \ - -v -race 2>&1 | grep -v '^go: downloading' | tee /tmp/gotest.log | gotestfmt + go test -coverprofile=cover.out -coverpkg=./... -json -v -race ./... | tee /tmp/gotest.log | gotestfmt - name: Upload test log uses: actions/upload-artifact@v4 From a06a16e27fc98b78fa45b1ee8d3c34689ae9d979 Mon Sep 17 00:00:00 2001 From: tlos Date: Wed, 6 May 2026 23:14:11 -0400 Subject: [PATCH 13/20] refactor: revert gotestfmt changes --- .github/workflows/test.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bfb825a..27b1b2d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,6 +22,9 @@ jobs: with: go-version: 1.23 + - name: Setup go-acc + run: go install github.com/ory/go-acc@latest + - name: Setup gotestfmt uses: haveyoudebuggedit/gotestfmt-action@v2 with: @@ -29,8 +32,13 @@ jobs: - name: Run test run: | + TEST_BENCH_OPTION="-bench=." + set -euo pipefail - go test -coverprofile=cover.out -coverpkg=./... -json -v -race ./... | tee /tmp/gotest.log | gotestfmt + go-acc -o cover.out ./... -- \ + ${TEST_BENCH_OPTION} \ + -json \ + -v -race 2>&1 | grep -v '^go: downloading' | tee /tmp/gotest.log | gotestfmt - name: Upload test log uses: actions/upload-artifact@v4 From f135244e5fcce28ba16fb07706480f65d75b195e Mon Sep 17 00:00:00 2001 From: tlos Date: Thu, 7 May 2026 22:19:30 -0400 Subject: [PATCH 14/20] fix: address copilot comments --- api/routes/get/auth.go | 8 +++++++- api/routes/get/emotes.go | 3 ++- api/routes/patch/settings.go | 14 ++++++++++++-- common/db/clickhouse.go | 17 ++++++++--------- common/db/loops.go | 10 ++++++---- common/db/postgres.go | 12 +++++++++++- common/types.go | 2 +- common/utils/broker.go | 2 +- main.go | 2 +- 9 files changed, 49 insertions(+), 21 deletions(-) diff --git a/api/routes/get/auth.go b/api/routes/get/auth.go index be0a152..7e165e5 100644 --- a/api/routes/get/auth.go +++ b/api/routes/get/auth.go @@ -842,13 +842,19 @@ func steamCallbackHandler(writer http.ResponseWriter, request *http.Request) { req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := client.Do(req) //nolint:gosec - if err != nil || resp.StatusCode != http.StatusOK { + if err != nil { sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam verification failed")) return } defer resp.Body.Close() //nolint:errcheck + if resp.StatusCode != http.StatusOK { + sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam verification failed")) + + return + } + body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), "is_valid:true") { sendHTML(writer, http.StatusForbidden, oauthErrorHTML("Steam assertion invalid")) diff --git a/api/routes/get/emotes.go b/api/routes/get/emotes.go index 8bcfd04..9e7b06a 100644 --- a/api/routes/get/emotes.go +++ b/api/routes/get/emotes.go @@ -19,6 +19,7 @@ import ( const ( defaultEmoteLimit = 100 maxEmoteLimit = 300 + maxEmoteOffset = 10_000 ) func init() { @@ -140,7 +141,7 @@ func getEmoteStats(writer http.ResponseWriter, request *http.Request) { //nolint offset := 0 if cursor := query.Get("after"); cursor != "" { - if o, err := decodeCursor(cursor); err == nil && o >= 0 { + if o, err := decodeCursor(cursor); err == nil && o >= 0 && o <= maxEmoteOffset { offset = o } else { writeEmoteError(writer, http.StatusBadRequest, start) diff --git a/api/routes/patch/settings.go b/api/routes/patch/settings.go index 9292b1a..43fa493 100644 --- a/api/routes/patch/settings.go +++ b/api/routes/patch/settings.go @@ -40,7 +40,7 @@ func patchUserSettings(writer http.ResponseWriter, request *http.Request) { return } - var input common.UserSettings + input := user.Settings if err := json.NewDecoder(request.Body).Decode(&input); err != nil { api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, @@ -123,7 +123,17 @@ func patchChannelSettings(writer http.ResponseWriter, request *http.Request) { / return } - var input common.ChannelSettings + existingSettings, err := postgres.GetChannelSettingsByID(request.Context(), channelID) + if err != nil { + logger.Error.Printf("Error fetching channel settings: %v", err) + api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ + Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}}, + }, start) + + return + } + + input := existingSettings if err := json.NewDecoder(request.Body).Decode(&input); err != nil { api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{ Errors: &[]common.ErrorMessage{{Message: "Invalid request body"}}, diff --git a/common/db/clickhouse.go b/common/db/clickhouse.go index 3b3ec40..1418bbe 100644 --- a/common/db/clickhouse.go +++ b/common/db/clickhouse.go @@ -87,27 +87,26 @@ func (db *ClickhouseClient) GetEmoteStats( //nolint:cyclop if opts.ChannelID != "" { args = append(args, opts.ChannelID) - fmt.Fprintf(&sb, " AND channel_id = $%d", len(args)) + sb.WriteString(" AND channel_id = ?") } if opts.UserID != "" { args = append(args, opts.UserID) - fmt.Fprintf(&sb, " AND user_id = $%d", len(args)) + sb.WriteString(" AND user_id = ?") } if opts.PeriodHours > 0 { cutoff := time.Now().Add(-time.Duration(opts.PeriodHours) * time.Hour) args = append(args, cutoff) - fmt.Fprintf(&sb, " AND used_at >= $%d", len(args)) + sb.WriteString(" AND used_at >= ?") } if len(opts.Providers) > 0 { - placeholders := make([]string, len(opts.Providers)) - for i, p := range opts.Providers { + for _, p := range opts.Providers { args = append(args, p) - placeholders[i] = fmt.Sprintf("$%d", len(args)) } - fmt.Fprintf(&sb, " AND provider IN (%s)", strings.Join(placeholders, ",")) + placeholders := strings.TrimSuffix(strings.Repeat("?,", len(opts.Providers)), ",") + fmt.Fprintf(&sb, " AND provider IN (%s)", placeholders) } sb.WriteString(" GROUP BY emote_id, emote_name, emote_alias, provider") @@ -168,12 +167,12 @@ func (db *ClickhouseClient) GetEmoteHistory( if userID != "" { args = append(args, userID) - fmt.Fprintf(&sb, " AND user_id = $%d", len(args)) + sb.WriteString(" AND user_id = ?") } if channelID != "" { args = append(args, channelID) - fmt.Fprintf(&sb, " AND channel_id = $%d", len(args)) + sb.WriteString(" AND channel_id = ?") } sb.WriteString(" GROUP BY emote_id, emote_name, emote_alias, provider, channel_id, user_id") diff --git a/common/db/loops.go b/common/db/loops.go index ceddac9..add7501 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -674,11 +674,13 @@ func backupPostgres( return } - err = natsClient.Publish("potat-api.postgres-backup", jsonMessage) - if err != nil { - logger.Error.Println("Failed to publish to queue:", err) + if natsClient != nil { + err = natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", jsonMessage) + if err != nil { + logger.Error.Println("Failed to publish to queue:", err) - return + return + } } logger.Info.Println(message) diff --git a/common/db/postgres.go b/common/db/postgres.go index c781c36..e325ccc 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -774,6 +774,14 @@ func (db *PostgresClient) UpdateChannelSettings( return err } +// GetChannelSettingsByID retrieves only the settings column for a channel by its ID. +func (db *PostgresClient) GetChannelSettingsByID(ctx context.Context, channelID string) (common.ChannelSettings, error) { + var settings common.ChannelSettings + err := db.Pool.QueryRow(ctx, `SELECT settings FROM channels WHERE channel_id = $1`, channelID).Scan(&settings) + + return settings, err +} + // GetCommandSettings retrieves all command settings rows for a given channel. func (db *PostgresClient) GetCommandSettings( ctx context.Context, @@ -878,7 +886,9 @@ func (db *PostgresClient) ResetCommandSettings(ctx context.Context, channelID, c silent_errors = FALSE, users_whitelisted = NULL, users_blacklisted = NULL, - allow_bots = NULL + allow_bots = NULL, + permission = NULL, + ambassador_granted = FALSE WHERE channel_id = $1 AND command = $2 ` _, err := db.Pool.Exec(ctx, query, channelID, command) diff --git a/common/types.go b/common/types.go index 36356ba..ce216f1 100644 --- a/common/types.go +++ b/common/types.go @@ -164,7 +164,7 @@ type ChannelSettings struct { AllowBotEmoteTracking *bool `json:"allow_bot_emote_tracking,omitempty"` IgnoreDropped *bool `json:"ignore_dropped,omitempty"` NoLinks *bool `json:"no_links,omitempty"` - ForcePyramidNotVerbose *bool `json:"force_potato_not_verbose,omitempty"` + ForcePyramidNotVerbose *bool `json:"force_pyramid_not_verbose,omitempty"` Language string `json:"language"` Permission string `json:"permission"` Prefix string `json:"prefix"` diff --git a/common/utils/broker.go b/common/utils/broker.go index 78c07b1..ad3d805 100644 --- a/common/utils/broker.go +++ b/common/utils/broker.go @@ -66,7 +66,7 @@ func (n *NatsClient) subNatsStream(ctx context.Context) error { return err } - err = n.Client.Publish("potat-api.connected", []byte(nil)) + err = n.Client.Publish("github.com/Potat-Industries/potat-api.connected", []byte(nil)) if err != nil { logger.Warn.Printf("Failed to publish connected message: %v", err) } diff --git a/main.go b/main.go index bf82c0e..a77fdf0 100644 --- a/main.go +++ b/main.go @@ -223,7 +223,7 @@ func initClickhouse(ctx context.Context, config common.Config) *db.ClickhouseCli func initNats(ctx context.Context) *utils.NatsClient { nats, err := utils.CreateNatsBroker(ctx) if err != nil { - logger.Error.Panicf("Failed to connect to RabbitMQ: %v", err) + logger.Error.Panicf("Failed to connect to NATS: %v", err) } logger.Info.Println("NATS initialized") From b4e1b5625aee181de9d8ec5e385e64dbe14213fb Mon Sep 17 00:00:00 2001 From: tlos Date: Thu, 7 May 2026 22:33:43 -0400 Subject: [PATCH 15/20] fix: lint --- api/routes/del/command_settings.go | 2 -- api/routes/put/command_settings.go | 2 -- common/db/loops.go | 14 ++++++++------ common/db/postgres.go | 5 ++++- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/api/routes/del/command_settings.go b/api/routes/del/command_settings.go index abe7286..ef603fc 100644 --- a/api/routes/del/command_settings.go +++ b/api/routes/del/command_settings.go @@ -94,5 +94,3 @@ func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { Data: &[]any{}, }, start) } - - diff --git a/api/routes/put/command_settings.go b/api/routes/put/command_settings.go index 9883ddb..9afcc99 100644 --- a/api/routes/put/command_settings.go +++ b/api/routes/put/command_settings.go @@ -99,5 +99,3 @@ func putCommandSettings(writer http.ResponseWriter, request *http.Request) { //n Data: &[]any{}, }, start) } - - diff --git a/common/db/loops.go b/common/db/loops.go index add7501..44f4783 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -675,17 +675,19 @@ func backupPostgres( } if natsClient != nil { - err = natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", jsonMessage) - if err != nil { - logger.Error.Println("Failed to publish to queue:", err) - - return - } + publishBackupNotification(natsClient, jsonMessage) } logger.Info.Println(message) } +func publishBackupNotification(natsClient *utils.NatsClient, message []byte) { + err := natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", message) + if err != nil { + logger.Error.Println("Failed to publish to queue:", err) + } +} + func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName string) (string, error) { query := `SELECT pg_size_pretty(pg_database_size($1)) AS size` rows, err := postgres.Query(ctx, query, dbName) diff --git a/common/db/postgres.go b/common/db/postgres.go index e325ccc..a384b52 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -775,7 +775,10 @@ func (db *PostgresClient) UpdateChannelSettings( } // GetChannelSettingsByID retrieves only the settings column for a channel by its ID. -func (db *PostgresClient) GetChannelSettingsByID(ctx context.Context, channelID string) (common.ChannelSettings, error) { +func (db *PostgresClient) GetChannelSettingsByID( + ctx context.Context, + channelID string, +) (common.ChannelSettings, error) { var settings common.ChannelSettings err := db.Pool.QueryRow(ctx, `SELECT settings FROM channels WHERE channel_id = $1`, channelID).Scan(&settings) From 15dde7ef1eaf13f1e46239d095293f52b1f747d2 Mon Sep 17 00:00:00 2001 From: tlos Date: Mon, 18 May 2026 22:27:58 -0400 Subject: [PATCH 16/20] fix: make command/channel settings platform aware --- api/routes/patch/settings.go | 10 +- common/db/clickhouse.go | 5 +- common/db/postgres.go | 10 +- schema.sql | 259 +++++++++++++++++++++++++++++++++++ 4 files changed, 273 insertions(+), 11 deletions(-) create mode 100644 schema.sql diff --git a/api/routes/patch/settings.go b/api/routes/patch/settings.go index 43fa493..0e84643 100644 --- a/api/routes/patch/settings.go +++ b/api/routes/patch/settings.go @@ -95,8 +95,12 @@ func patchChannelSettings(writer http.ResponseWriter, request *http.Request) { / return } - // Resolve channel ID from ?id= or fall back to the user's own Twitch channel ID. + // Resolve channel ID and platform from query params or fall back to the users own Twitch channel. channelID := request.URL.Query().Get("id") + platform := request.URL.Query().Get("platform") + if platform == "" { + platform = string(common.TWITCH) + } if channelID == "" { for _, conn := range user.Connections { if conn.Platform == common.TWITCH { @@ -123,7 +127,7 @@ func patchChannelSettings(writer http.ResponseWriter, request *http.Request) { / return } - existingSettings, err := postgres.GetChannelSettingsByID(request.Context(), channelID) + existingSettings, err := postgres.GetChannelSettingsByID(request.Context(), channelID, platform) if err != nil { logger.Error.Printf("Error fetching channel settings: %v", err) api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ @@ -142,7 +146,7 @@ func patchChannelSettings(writer http.ResponseWriter, request *http.Request) { / return } - if err := postgres.UpdateChannelSettings(request.Context(), channelID, input); err != nil { + if err := postgres.UpdateChannelSettings(request.Context(), channelID, platform, input); err != nil { logger.Error.Printf("Error updating channel settings: %v", err) api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{ Errors: &[]common.ErrorMessage{{Message: "Failed to update settings"}}, diff --git a/common/db/clickhouse.go b/common/db/clickhouse.go index 1418bbe..a135693 100644 --- a/common/db/clickhouse.go +++ b/common/db/clickhouse.go @@ -117,10 +117,7 @@ func (db *ClickhouseClient) GetEmoteStats( //nolint:cyclop } fmt.Fprintf(&sb, " ORDER BY count %s", order) - limit := opts.Limit - if limit <= 0 || limit > 300 { - limit = 100 - } + limit := max(1, min(opts.Limit, 300)) + 1 fmt.Fprintf(&sb, " LIMIT %d", limit) if opts.Offset > 0 { diff --git a/common/db/postgres.go b/common/db/postgres.go index 3920f5c..a5b8da1 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -766,10 +766,11 @@ func (db *PostgresClient) UpdateUserSettings(ctx context.Context, userID int, se func (db *PostgresClient) UpdateChannelSettings( ctx context.Context, channelID string, + platform string, settings common.ChannelSettings, ) error { - query := `UPDATE channels SET settings = $1 WHERE channel_id = $2` - _, err := db.Pool.Exec(ctx, query, settings, channelID) + query := `UPDATE channels SET settings = $1 WHERE channel_id = $2 AND platform = $3` + _, err := db.Pool.Exec(ctx, query, settings, channelID, platform) return err } @@ -778,9 +779,10 @@ func (db *PostgresClient) UpdateChannelSettings( func (db *PostgresClient) GetChannelSettingsByID( ctx context.Context, channelID string, + platform string, ) (common.ChannelSettings, error) { var settings common.ChannelSettings - err := db.Pool.QueryRow(ctx, `SELECT settings FROM channels WHERE channel_id = $1`, channelID).Scan(&settings) + err := db.Pool.QueryRow(ctx, `SELECT settings FROM channels WHERE channel_id = $1 AND platform = $2`, channelID, platform).Scan(&settings) return settings, err } @@ -850,7 +852,7 @@ func (db *PostgresClient) UpsertCommandSettings(ctx context.Context, cs common.C custom_cooldown, is_enabled, offline_only, silent_errors, allow_bots, platform, ambassador_granted ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) - ON CONFLICT (channel_id, command) DO UPDATE SET + ON CONFLICT (channel_id, command, platform) DO UPDATE SET permission = EXCLUDED.permission, users_blacklisted = EXCLUDED.users_blacklisted, users_whitelisted = EXCLUDED.users_whitelisted, diff --git a/schema.sql b/schema.sql new file mode 100644 index 0000000..9b87f51 --- /dev/null +++ b/schema.sql @@ -0,0 +1,259 @@ +-- potat-api schema +-- Run against the database before starting the API: +-- docker exec -i potat-api-postgres-1 psql -U potat -d potat < schema.sql +-- +-- NOTE: the haste service requires the pgzstd Postgres extension +-- (https://github.com/grahamedgecombe/pgzstd). The standard postgres:16 +-- Docker image does not include it, so leave haste.enabled=false in +-- config.docker.json unless you build a custom image with the extension. + +-- --------------------------------------------------------------------------- +-- Core user tables (required for /login and all user-facing routes) +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS users ( + user_id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + display TEXT NOT NULL DEFAULT '', + first_seen TIMESTAMPTZ NOT NULL DEFAULT NOW(), + level INT NOT NULL DEFAULT 1, + settings JSONB NOT NULL DEFAULT '{}' +); + +CREATE TABLE IF NOT EXISTS user_connections ( + user_id INT NOT NULL REFERENCES users(user_id) ON DELETE CASCADE, + platform_id TEXT NOT NULL, + platform_username TEXT NOT NULL DEFAULT '', + platform_display TEXT NOT NULL DEFAULT '', + platform_pfp TEXT NOT NULL DEFAULT '', + platform TEXT NOT NULL, + platform_metadata JSONB NOT NULL DEFAULT '{}', + PRIMARY KEY (user_id, platform, platform_id) +); + +CREATE TABLE IF NOT EXISTS connection_oauth ( + platform_id TEXT NOT NULL, + platform TEXT NOT NULL, + access_token TEXT NOT NULL, + refresh_token TEXT NOT NULL DEFAULT '', + scope TEXT[] NOT NULL DEFAULT '{}', + expires_in INT NOT NULL DEFAULT 0, + added_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (platform_id, platform) +); + +-- --------------------------------------------------------------------------- +-- Channel tables +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS channels ( + channel_id TEXT NOT NULL, + username TEXT NOT NULL, + joined_at TIMESTAMPTZ, + added_by JSONB NOT NULL DEFAULT '[]', + platform TEXT NOT NULL DEFAULT 'TWITCH', + settings JSONB NOT NULL DEFAULT '{}', + editors TEXT[] NOT NULL DEFAULT '{}', + ambassadors TEXT[] NOT NULL DEFAULT '{}', + meta JSONB NOT NULL DEFAULT '{}', + state TEXT NOT NULL DEFAULT 'JOINED', + PRIMARY KEY (channel_id, platform) +); + +CREATE TABLE IF NOT EXISTS blocks ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL, + block_id INT NOT NULL DEFAULT 0, + channel_id TEXT NOT NULL, + block_type TEXT NOT NULL DEFAULT 'USER', + block_data TEXT NOT NULL DEFAULT '' +); + +CREATE TABLE IF NOT EXISTS custom_channel_commands ( + command_id SERIAL PRIMARY KEY, + user_id INT NOT NULL, + channel_id TEXT NOT NULL, + name TEXT, + user_trigger_ids TEXT[] NOT NULL DEFAULT '{}', + user_ignore_ids TEXT[] NOT NULL DEFAULT '{}', + trigger TEXT NOT NULL, + response TEXT NOT NULL DEFAULT '', + run_command TEXT, + active BOOLEAN NOT NULL DEFAULT TRUE, + active_online BOOLEAN NOT NULL DEFAULT TRUE, + active_offline BOOLEAN NOT NULL DEFAULT TRUE, + reply BOOLEAN NOT NULL DEFAULT FALSE, + whisper BOOLEAN NOT NULL DEFAULT FALSE, + announce BOOLEAN NOT NULL DEFAULT FALSE, + cooldown INT NOT NULL DEFAULT 5, + delay INT NOT NULL DEFAULT 0, + use_count INT NOT NULL DEFAULT 0, + created TIMESTAMPTZ NOT NULL DEFAULT NOW(), + modified TIMESTAMPTZ NOT NULL DEFAULT NOW(), + platform TEXT NOT NULL DEFAULT 'TWITCH', + help TEXT +); + +CREATE TABLE IF NOT EXISTS command_settings ( + channel_id TEXT NOT NULL, + command TEXT NOT NULL, + permission TEXT NOT NULL DEFAULT 'NONE', + users_blacklisted TEXT[] NOT NULL DEFAULT '{}', + users_whitelisted TEXT[] NOT NULL DEFAULT '{}', + custom_cooldown INT NOT NULL DEFAULT 0, + channel_usage INT NOT NULL DEFAULT 0, + is_enabled BOOLEAN NOT NULL DEFAULT TRUE, + offline_only BOOLEAN NOT NULL DEFAULT FALSE, + silent_errors BOOLEAN NOT NULL DEFAULT FALSE, + allow_bots BOOLEAN NOT NULL DEFAULT FALSE, + platform TEXT NOT NULL DEFAULT 'TWITCH', + PRIMARY KEY (channel_id, command, platform) +); + +CREATE TABLE IF NOT EXISTS channel_command_usage ( + channel_id TEXT PRIMARY KEY, + channel_usage BIGINT NOT NULL DEFAULT 0 +); + +-- --------------------------------------------------------------------------- +-- Potato game tables +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS potatoes ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + potato_count INT NOT NULL DEFAULT 0, + potato_prestige INT NOT NULL DEFAULT 0, + potato_rank INT NOT NULL DEFAULT 0, + tax_multiplier INT NOT NULL DEFAULT 0, + first_seen TEXT NOT NULL DEFAULT '', + stole_from TEXT, + stole_amount INT, + trampled_by TEXT +); + +CREATE TABLE IF NOT EXISTS potato_analytics ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + average_response_time TEXT NOT NULL DEFAULT '0', + eat_count INT NOT NULL DEFAULT 0, + harvest_count INT NOT NULL DEFAULT 0, + stolen_count INT NOT NULL DEFAULT 0, + theft_count INT NOT NULL DEFAULT 0, + trampled_count INT NOT NULL DEFAULT 0, + trample_count INT NOT NULL DEFAULT 0, + cdr_count INT NOT NULL DEFAULT 0, + quiz_count INT NOT NULL DEFAULT 0, + quiz_complete_count INT NOT NULL DEFAULT 0, + guard_buy_count INT NOT NULL DEFAULT 0, + fertilizer_buy_count INT NOT NULL DEFAULT 0, + cdr_buy_count INT NOT NULL DEFAULT 0, + new_quiz_buy_count INT NOT NULL DEFAULT 0, + gamble_win_count INT NOT NULL DEFAULT 0, + gamble_loss_count INT NOT NULL DEFAULT 0, + gamble_wins_total INT NOT NULL DEFAULT 0, + gamble_losses_total INT NOT NULL DEFAULT 0, + duel_win_count INT NOT NULL DEFAULT 0, + duel_loss_count INT NOT NULL DEFAULT 0, + duel_wins_amount INT NOT NULL DEFAULT 0, + duel_losses_amount INT NOT NULL DEFAULT 0, + duel_caught_losses INT NOT NULL DEFAULT 0, + average_response_count INT NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS potato_settings ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + not_verbose BOOLEAN NOT NULL DEFAULT FALSE +); + +-- --------------------------------------------------------------------------- +-- GPT usage table (reset by hourly/daily/weekly cron loops) +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS gpt_usage ( + user_id INT PRIMARY KEY REFERENCES users(user_id) ON DELETE CASCADE, + hourly_usage INT NOT NULL DEFAULT 0, + daily_usage INT NOT NULL DEFAULT 0, + weekly_usage INT NOT NULL DEFAULT 0 +); + +-- --------------------------------------------------------------------------- +-- Service tables (auto-created on startup when the service is enabled, +-- included here so they can be pre-created for a fresh database) +-- --------------------------------------------------------------------------- + +-- Requires pgzstd extension — omit if not using the haste service. +-- CREATE EXTENSION IF NOT EXISTS pgzstd; +-- CREATE TABLE IF NOT EXISTS haste ( +-- key CHAR(32) UNIQUE NOT NULL, +-- content BYTEA NOT NULL, +-- access_count INT NOT NULL DEFAULT 1, +-- source TEXT NOT NULL DEFAULT 'potatbotat', +-- timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW() +-- ); + +CREATE TABLE IF NOT EXISTS url_redirects ( + key VARCHAR(9) PRIMARY KEY, + url VARCHAR(500) NOT NULL +); + +CREATE TABLE IF NOT EXISTS file_store ( + key VARCHAR(50) PRIMARY KEY, + file BYTEA NOT NULL, + file_name VARCHAR(50), + mime_type VARCHAR(50) NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- --------------------------------------------------------------------------- +-- command_settings — add nullable columns and new fields +-- (ALTER TABLE is idempotent via IF NOT EXISTS) +-- --------------------------------------------------------------------------- + +ALTER TABLE command_settings + ALTER COLUMN permission DROP NOT NULL, + ALTER COLUMN offline_only DROP NOT NULL, + ALTER COLUMN custom_cooldown DROP NOT NULL, + ALTER COLUMN users_blacklisted DROP NOT NULL, + ALTER COLUMN users_whitelisted DROP NOT NULL, + ALTER COLUMN allow_bots DROP NOT NULL; + +ALTER TABLE command_settings + ADD COLUMN IF NOT EXISTS platform TEXT NOT NULL DEFAULT 'TWITCH', + ADD COLUMN IF NOT EXISTS ambassador_granted BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TABLE command_settings DROP CONSTRAINT IF EXISTS command_settings_pkey; +ALTER TABLE command_settings ADD PRIMARY KEY (channel_id, command, platform); + +ALTER TABLE channels DROP CONSTRAINT IF EXISTS channels_pkey; +ALTER TABLE channels ADD PRIMARY KEY (channel_id, platform); + +-- --------------------------------------------------------------------------- +-- Reminders +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS reminders ( + reminder_id SERIAL PRIMARY KEY, + user_id VARCHAR(64) NOT NULL, + recipient_id VARCHAR(64) NOT NULL, + channel_id VARCHAR(64) NOT NULL, + message VARCHAR(500) NOT NULL DEFAULT '', + ready_at TIMESTAMPTZ, + set_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + afk_withheld BOOLEAN NOT NULL DEFAULT FALSE, + status TEXT NOT NULL DEFAULT 'PENDING', + platform TEXT NOT NULL DEFAULT 'TWITCH', + sent_at TIMESTAMPTZ, + type TEXT NOT NULL DEFAULT 'NEXT' +); + +CREATE INDEX IF NOT EXISTS idx_reminders_recipient ON reminders (recipient_id, platform, status); + +-- --------------------------------------------------------------------------- +-- Website JWT sessions +-- --------------------------------------------------------------------------- + +CREATE TABLE IF NOT EXISTS website_jwt ( + token TEXT NOT NULL, + user_id INT NOT NULL REFERENCES users(user_id) ON DELETE CASCADE, + PRIMARY KEY (token) +); From ae630b157948603f0530a43848a9a12761d4fc26 Mon Sep 17 00:00:00 2001 From: tlos Date: Mon, 18 May 2026 22:43:34 -0400 Subject: [PATCH 17/20] fix: lint --- common/db/loops.go | 446 +++++++++++++++++++- common/db/postgres.go | 726 +++++++++++++++++++++++++++++++- common/types.go | 803 ++++++++++++++++++++++++------------ redirects/redirects_test.go | 16 +- 4 files changed, 1730 insertions(+), 261 deletions(-) diff --git a/common/db/loops.go b/common/db/loops.go index 44f4783..4fa0fed 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -1,4 +1,5 @@ // Package db provides database clients and functions to retrieve or update data. + package db import ( @@ -21,734 +22,1177 @@ import ( ) var ( - errNoRows = fmt.Errorf("no rows returned for database size query") + errNoRows = fmt.Errorf("no rows returned for database size query") + errMissingRefreshToken = errors.New("missing refresh token") ) const dumpPath = "./dump" // StartLoops initializes schedules and loops for various tasks. + func StartLoops( + ctx context.Context, + config common.Config, + natsClient *utils.NatsClient, + postgres *PostgresClient, + clickhouse *ClickhouseClient, + redis *RedisClient, + ) { + if !config.Loops.Enabled { + return + } + cronManager := cron.New() var err error + _, err = cronManager.AddFunc("@hourly", func() { + go updateHourlyUsage(ctx, postgres) + go validateTokens(ctx, config, postgres) + }) + if err != nil { + logger.Error.Println("Failed initializing cron updateHourlyUsage", err) return + } + _, err = cronManager.AddFunc("@daily", func() { + updateDailyUsage(ctx, postgres) + }) + if err != nil { + logger.Error.Println("Failed initializing cron updateDailyUsage", err) return + } + _, err = cronManager.AddFunc("@weekly", func() { + updateWeeklyUsage(ctx, postgres) + }) + if err != nil { + logger.Error.Println("Failed initializing cron updateWeeklyUsage", err) return + } + _, err = cronManager.AddFunc("0 */2 * * *", func() { + refreshAllHelixTokens(ctx, config, postgres) + }) + if err != nil { + logger.Error.Println("Failed initializing cron refreshAllHelixTokens", err) return + } + _, err = cronManager.AddFunc("*/30 * * * *", func() { + updateColorView(ctx, clickhouse) + updateActiveBadgeView(ctx, clickhouse) + updateOwnedBadgeView(ctx, clickhouse) + updateUserOwnedBadgeView(ctx, clickhouse) + }) + if err != nil { + logger.Error.Println("Failed initializing cron clickhouse views", err) return + } + _, err = cronManager.AddFunc("0 */12 * * *", func() { + go backupPostgres(ctx, postgres, natsClient, config) + // go optimizeClickhouse(ctx, config, clickhouse) + }) + if err != nil { + logger.Error.Println("Failed initializing cron backupPostgres", err) return + } cronManager.Start() go decrementDuels(ctx, redis) + go deleteOldUploads(ctx, postgres) + go updateAggregateTable(ctx, postgres) + } func decrementDuels(ctx context.Context, redis *RedisClient) { + for { + time.Sleep(30 * time.Minute) + logger.Info.Println("Decrementing duels") keys, err := redis.Scan(ctx, "duelUse:*", 100, 0) + if err != nil { + logger.Error.Println("Failed scanning keys for duels", err) return + } if len(keys) == 0 { + return + } luaScript := ` + local decrementedKeys = 0 + for _, key in ipairs(KEYS) do + local num = redis.call("DECR", key) + decrementedKeys = decrementedKeys + 1 + if num <= 0 then + redis.call("DEL", key) + end + end + + return decrementedKeys + ` value, err := redis.Eval(ctx, luaScript, keys).Result() + if err != nil { + logger.Error.Println("Failed decrementing duels", err) + } logger.Info.Printf("Decremented %d duel keys", value) + } + } func deleteOldUploads(ctx context.Context, postgres *PostgresClient) { + for { + time.Sleep(24 * time.Hour) + logger.Info.Println("Deleting old uploads") query := ` + DELETE FROM file_store + WHERE (created_at < NOW() - INTERVAL '30 days' AND expires_at IS NULL) + OR (expires_at IS NOT NULL AND expires_at < NOW()); + ` _, err := postgres.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error deleting old uploads ", err) + } logger.Debug.Println("Deleted old uploads") + } + } func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { + for { + time.Sleep(5 * time.Minute) query := ` + INSERT INTO channel_command_usage (channel_id, channel_usage) + SELECT channel_id, SUM(channel_usage) AS channel_usage + FROM command_settings + GROUP BY channel_id + ON CONFLICT (channel_id) DO UPDATE + SET channel_usage = EXCLUDED.channel_usage; + ` _, err := postgres.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating aggregate table", err) + } logger.Info.Println("Updated aggregate table") + } + } func updateHourlyUsage(ctx context.Context, postgres *PostgresClient) { + logger.Info.Println("Updating hourly usage") + query := `UPDATE gpt_usage SET hourly_usage = 0;` _, err := postgres.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating hourly usage", err) + } logger.Info.Println("Updated hourly usage") + } func updateDailyUsage(ctx context.Context, postgres *PostgresClient) { + logger.Info.Println("Updating daily usage") + query := `UPDATE gpt_usage SET daily_usage = 0` _, err := postgres.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating daily usage", err) + } logger.Info.Println("Updated daily usage") + } func updateWeeklyUsage(ctx context.Context, postgres *PostgresClient) { + logger.Info.Println("Updating weekly usage") + query := `UPDATE gpt_usage SET weekly_usage = 0` _, err := postgres.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating weekly usage", err) + } logger.Info.Println("Updated weekly usage") + } func updateColorView(ctx context.Context, clickhouse *ClickhouseClient) { + logger.Info.Println("Updating color view") query := ` + INSERT INTO potatbotat.twitch_color_stats + SELECT + color, + COUNT(DISTINCT user_id) AS user_count, + (COUNT(DISTINCT user_id) * 100.0) / ( + SELECT COUNT(user_id) + FROM potatbotat.twitch_colors + ) AS percentage, + ROW_NUMBER() OVER (ORDER BY COUNT(DISTINCT user_id) DESC) AS rank + FROM potatbotat.twitch_colors FINAL + GROUP BY color; + ` err := clickhouse.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating color view ", err) + } + } func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { + logger.Info.Println("Updating active badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_active_badge_stats;`) + if err != nil { + logger.Error.Println("Error truncating badge stats table ", err) return + } query := ` + INSERT INTO potatbotat.twitch_active_badge_stats + SELECT + badge, + count(user_id) AS user_count, + version + FROM potatbotat.twitch_badges + WHERE badge NOT IN ('', 'NOBADGE') + GROUP BY (badge, version); + ` err = clickhouse.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating badge view ", err) + } + } func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { + logger.Info.Println("Updating owned badge view") // Insert owned badges from active table first + // prepare := ` + // INSERT INTO potatbotat.twitch_owned_badges + // SELECT + // badge, + // user_id, + // version + // FROM potatbotat.twitch_badges + // WHERE badge NOT IN ('', 'NOBADGE') + // ` // err := clickhouse.Exec(ctx, prepare) + // if err != nil { + // logger.Error.Println("Error preparing badge view ", err) + // } err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_stats;`) + if err != nil { + logger.Error.Println("Error truncating badge stats table ", err) return + } query := ` + INSERT INTO potatbotat.twitch_owned_badge_stats + SELECT + badge, + count(user_id) AS user_count, + version + FROM potatbotat.twitch_owned_badges + WHERE badge NOT IN ('', 'NOBADGE') + GROUP BY (badge, version); + ` err = clickhouse.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating badge view ", err) + } + } func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { + logger.Info.Println("Updating user owned badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_user_stats;`) + if err != nil { + logger.Error.Println("Error truncating badge stats table ", err) return + } query := ` + INSERT INTO potatbotat.twitch_owned_badge_user_stats + SELECT + user_id, + count(badge) AS badge_count, + groupArrayDistinct(badge) AS badges, + now64(3) + FROM potatbotat.twitch_owned_badges FINAL + WHERE badge NOT IN ('', 'NOBADGE') + GROUP BY user_id + HAVING uniqExact(badge) >= 5 + ORDER BY badge_count DESC; + ` err = clickhouse.Exec(ctx, query) + if err != nil { + logger.Error.Println("Error updating user owned badge view ", err) + } + } func upsertOAuthToken( + ctx context.Context, + postgres *PostgresClient, + oauth *common.GenericOAUTHResponse, + con common.PlatformOauth, + ) error { + query := ` + INSERT INTO connection_oauth ( + platform_id, + access_token, + refresh_token, + scope, + expires_in, + added_at, + platform + ) + VALUES + ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (platform_id, platform) + DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + scope = EXCLUDED.scope, + expires_in = EXCLUDED.expires_in, + added_at = EXCLUDED.added_at; + ` _, err := postgres.Exec( + ctx, + query, + con.PlatformID, + oauth.AccessToken, + oauth.RefreshToken, + oauth.Scope, + oauth.ExpiresIn, + time.Now(), + common.TWITCH, ) return err + } func refreshOrDelete( + ctx context.Context, + config common.Config, + postgres *PostgresClient, + con common.PlatformOauth, + ) (bool, error) { + var err error + if con.RefreshToken == "" { + return false, errMissingRefreshToken + } refreshResult, err := utils.RefreshHelixToken(ctx, config, con.RefreshToken) + if err != nil || refreshResult == nil { + return false, err + } err = upsertOAuthToken(ctx, postgres, refreshResult, con) + if err != nil { + logger.Error.Println( + "Error updating token for user_id", con.PlatformID, ":", err, ) return false, err + } return true, nil + } func validateTokens(ctx context.Context, config common.Config, postgres *PostgresClient) { + logger.Info.Println("Validating Twitch tokens ") query := ` + SELECT + access_token, + platform_id, + refresh_token + FROM connection_oauth + WHERE platform = 'TWITCH'; + ` rows, err := postgres.Query(ctx, query) + if err != nil { + logger.Error.Println("Error getting tokens ", err) return + } + defer rows.Close() validated, deleted := 0, 0 + for rows.Next() { + var con common.PlatformOauth err := rows.Scan(&con.AccessToken, &con.PlatformID, &con.RefreshToken) + if err != nil { + logger.Error.Println("Error scanning token: ", err) continue + } valid, _, err := utils.ValidateHelixToken(ctx, con.AccessToken, false) + if err != nil { + logger.Error.Println("Error validating token ", err) continue + } if !valid { + ok, err := refreshOrDelete(ctx, config, postgres, con) + if err != nil { + logger.Error.Println("Error refreshing token ", err) + deleted++ continue + } if ok { + validated++ + } else { + deleted++ + } continue + } + validated++ time.Sleep(200 * time.Millisecond) + } logger.Info.Printf( + "Validated %d helix tokens, and deleted %d expired tokens", + validated, + deleted, ) + } func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres *PostgresClient) { + logger.Info.Println("Refreshing all Twitch tokens") query := ` + SELECT + platform, + platform_id, + access_token, + refresh_token, + expires_in, + added_at, + scope + FROM connection_oauth + WHERE platform = 'TWITCH'; + ` rows, err := postgres.Query(ctx, query) + if err != nil { + logger.Error.Println("Error getting tokens ", err) return + } + defer rows.Close() refreshed, failed := 0, 0 + for rows.Next() { + var con common.PlatformOauth err := rows.Scan( + &con.Platform, + &con.PlatformID, + &con.AccessToken, + &con.RefreshToken, + &con.ExpiresIn, + &con.AddedAt, + &con.Scope, ) + if err != nil { + logger.Error.Println("Error scanning token: ", err) continue + } ok, err := refreshOrDelete(ctx, config, postgres, con) + if err != nil { + logger.Error.Println("Error refreshing token ", err) + failed++ continue + } if ok { + refreshed++ + } else { + failed++ + } time.Sleep(200 * time.Millisecond) continue + } logger.Info.Printf( + "Refreshed %d helix tokens, %d failed and were expunged", + refreshed, + failed, ) + } func sortFiles(files []string) func(i, j int) bool { + return func(i, j int) bool { + fileI, errI := os.Stat(filepath.Join(dumpPath, files[i])) + if errI != nil { + return false + } fileJ, errJ := os.Stat(filepath.Join(dumpPath, files[j])) + if errJ != nil { + return true + } return fileI.ModTime().Before(fileJ.ModTime()) + } + } func deleteOldDumps(files []string, maxSize int) { + logger.Debug.Printf("Checking for old dump files, current count: %d, max: %d", len(files), maxSize) + if len(files) <= maxSize { + return + } sort.Slice(files, sortFiles(files)) filesToDelete := files[:len(files)-maxSize+1] + for _, file := range filesToDelete { + err := os.Remove(file) + if err != nil { + logger.Error.Println("Failed deleting dump file ", err) + } + } logger.Info.Printf("Deleted %d old dump files", len(filesToDelete)) + } func backupPostgres( + ctx context.Context, + postgres *PostgresClient, + natsClient *utils.NatsClient, + config common.Config, + ) { + logger.Debug.Println("Backing up Postgres") if err := os.MkdirAll(dumpPath, 0o750); err != nil { + logger.Error.Println("Failed to create backup folder:", err) return + } files, err := filepath.Glob(filepath.Join(dumpPath, "*.sql.zst")) + if err != nil { + logger.Error.Println("Failed to list dump files:", err) return + } deleteOldDumps(files, 10) filePath := filepath.Join( + dumpPath, + fmt.Sprintf("data_%d.sql.zst", time.Now().Unix()), ) //nolint:gosec + cmd := exec.CommandContext(ctx, "sh", "-c", fmt.Sprintf( + "PGPASSWORD=%s pg_dump -d %s -U %s -h %s | zstd -3 --threads=%d > %s", + config.Postgres.Password, + config.Postgres.Database, + config.Postgres.User, + config.Postgres.Host, + runtime.NumCPU(), + filePath, )) defer func() { + if err = cmd.Process.Release(); err != nil { + logger.Error.Println("Failed to release pg_dump process:", err) if err = cmd.Process.Kill(); err != nil { + logger.Error.Fatalln("Failed to kill pg_dump process:", err) + } + } + }() var stderr bytes.Buffer + cmd.Stderr = &stderr start := time.Now() + if err = cmd.Run(); err != nil { + logger.Error.Println("Failed to execute pg_dump:", err, stderr.String()) return + } + duration := time.Since(start) stat, err := os.Stat(filePath) + if err != nil { + logger.Error.Println("Failed to get backup file size:", err) return + } backupSize := float64(stat.Size()) / (1024 * 1024 * 1024) dbSize, err := getDatabaseSize(ctx, postgres, config.Postgres.Database) + if err != nil { + logger.Error.Println("Failed to get database size:", err) return + } message := fmt.Sprintf( + "Database back-up successful in %s - DB size: %s - Backup size: %.2f GB", + utils.Humanize(duration, 2), + dbSize, + backupSize, ) jsonMessage, err := json.Marshal(message) + if err != nil { + logger.Error.Println("Failed to JSON stringify message:", err) + logger.Info.Println(message) return + } if natsClient != nil { + publishBackupNotification(natsClient, jsonMessage) + } logger.Info.Println(message) + } func publishBackupNotification(natsClient *utils.NatsClient, message []byte) { + err := natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", message) + if err != nil { + logger.Error.Println("Failed to publish to queue:", err) + } + } func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName string) (string, error) { + query := `SELECT pg_size_pretty(pg_database_size($1)) AS size` + rows, err := postgres.Query(ctx, query, dbName) + if err != nil { + return "", err + } defer rows.Close() if rows.Next() { + var size string + if err := rows.Scan(&size); err != nil { + return "", err + } return size, nil + } return "", errNoRows + } // func optimizeClickhouse(ctx context.Context, config common.Config, clickhouse *ClickhouseClient) { + // // offset any concurrent crons + // time.Sleep(5 * time.Minute) // logger.Info.Println("Optimizing Clickhouse tables") // if config.Clickhouse.Database == "" { + // logger.Error.Println("Clickhouse database is not configured") // return + // } // query := `SELECT table FROM system.tables WHERE database = ?` // rows, err := clickhouse.Query(ctx, query, config.Clickhouse.Database) + // if err != nil { + // logger.Error.Println("Failed to query Clickhouse tables:", err) // return + // } // for rows.Next() { + // var table string + // if err := rows.Scan(&table); err != nil { + // logger.Error.Println("Failed to scan Clickhouse table:", err) // continue + // } // query := fmt.Sprintf("OPTIMIZE TABLE %s.%s FINAL", config.Clickhouse.Database, table) + // if err := clickhouse.Exec(ctx, query); err != nil { + // logger.Error.Println("Failed to optimize Clickhouse table:", err) + // } // logger.Info.Printf( + // "Optimized Clickhouse table %s.%s", + // config.Clickhouse.Database, + // table, + // ) // time.Sleep(5 * time.Second) + // } + // } diff --git a/common/db/postgres.go b/common/db/postgres.go index a5b8da1..dbd7258 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -1,4 +1,5 @@ // Package db provides database clients and functions to retrieve or update data. + package db import ( @@ -16,345 +17,595 @@ import ( ) // PostgresClient is a wrapper around the pgxpool.Pool to manage database connections and queries. + type PostgresClient struct { *pgxpool.Pool } // LoaderKey is used to identify a user or channel in the database. + type LoaderKey struct { - ID *int - UserID *string + ID *int + + UserID *string + Username *string + Platform *string } // ErrPostgresNoRows is an alias for pgx.ErrNoRows to handle cases where no rows are returned from a query. + var ( ErrPostgresNoRows = pgx.ErrNoRows - errInvalidType = fmt.Errorf("invalid channel type") + + errInvalidType = fmt.Errorf("invalid channel type") ) // InitPostgres initializes a new Postgres client with the provided configuration. + func InitPostgres(ctx context.Context, config common.Config) (*PostgresClient, error) { + dbConfig, err := loadConfig(config) + if err != nil { + return nil, err + } pool, err := pgxpool.NewWithConfig(ctx, dbConfig) + if err != nil { + return nil, err + } return &PostgresClient{pool}, nil + } func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unparam + user := config.Postgres.User + if user == "" { + user = "postgres" + } host := config.Postgres.Host + if host == "" { + host = "localhost" //nolint:goconst + } port := config.Postgres.Port + if port == "" { + port = "5432" + } database := config.Postgres.Database + if database == "" { + database = "postgres" + } constring := fmt.Sprintf( + "postgres://%s:%s@%s:%s/%s", + user, + config.Postgres.Password, + host, + port, + database, ) dbConfig, err := pgxpool.ParseConfig(constring) + if err != nil { + logger.Error.Panicln("Error parsing database config", err) return nil, err + } dbConfig.MaxConns = 32 + dbConfig.MinConns = 4 + dbConfig.MaxConnIdleTime = 1 * time.Minute + dbConfig.MaxConnLifetime = 30 * time.Minute + dbConfig.HealthCheckPeriod = 5 * time.Minute + dbConfig.ConnConfig.ConnectTimeout = 10 * time.Second return dbConfig, nil + } // CheckTableExists checks if a table exists in the database and creates it if it doesn't. + func (db *PostgresClient) CheckTableExists(ctx context.Context, createTable string) { + _, err := db.Pool.Exec(ctx, createTable) + if err != nil { + logger.Error.Fatalf("Failed to create table: %v", err) + } + } // Ping checks the connection to the database. + func (db *PostgresClient) Ping(ctx context.Context) error { + return db.Pool.Ping(ctx) + } // GetUserByName retrieves a user by their username from the database. + func (db *PostgresClient) GetUserByName(ctx context.Context, username string) (*common.User, error) { + query := ` + SELECT + users.user_id, + users.username, + users.display, + users.first_seen, + users.level, + users.settings, + json_agg(uc) AS connections + FROM users + LEFT JOIN user_connections uc ON users.user_id = uc.user_id + WHERE users.username = $1 + GROUP BY users.user_id; + ` var user common.User + err := db.Pool.QueryRow(ctx, query, username).Scan( + &user.ID, + &user.Username, + &user.Display, + &user.FirstSeen, + &user.Level, + &user.Settings, + &user.Connections, ) + if err != nil { + return nil, err + } return &user, nil + } // GetUserByInternalID retrieves a user by their internal ID from the database. + func (db *PostgresClient) GetUserByInternalID(ctx context.Context, id int) (*common.User, error) { + query := ` + SELECT + u.user_id, + username, + display, + first_seen, + level, + settings, + json_agg(uc) as connections + FROM users u + JOIN user_connections uc ON u.user_id = uc.user_id + WHERE u.user_id = $1 + GROUP BY u.user_id; + ` var user common.User + err := db.Pool.QueryRow(ctx, query, id).Scan( + &user.ID, + &user.Username, + &user.Display, + &user.FirstSeen, + &user.Level, + &user.Settings, + &user.Connections, ) + if err != nil { + return nil, err + } return &user, nil + } // GetChannelBlocks retrieves all blocks for a given channel from the database. + func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string) *[]common.Block { + query := ` + SELECT + user_id + block_id, + channel_id, + block_type, + block_data + FROM blocks + WHERE channel_id = $1 + ` rows, err := db.Pool.Query(ctx, query, channelID) + if err != nil { + return nil + } defer rows.Close() var blocks []common.Block + for rows.Next() { + var block common.Block + err := rows.Scan( + &block.ID, + &block.BlockedUserID, + &block.ChannelID, + &block.BlockType, + &block.CommandName, ) + if err != nil { + return nil + } blocks = append(blocks, block) + } return &blocks + } // GetChannelCommands retrieves all custom channel commands for a given channel from the database. + func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID string) *[]common.ChannelCommand { + query := ` + SELECT + command_id, + user_id, + channel_id, + name, + user_trigger_ids, + user_ignore_ids, + trigger, + response, + run_command, + active, + active_online, + active_offline, + reply, + whisper, + announce, + cooldown, + delay, + use_count, + created, + modified, + platform, + help + FROM custom_channel_commands + WHERE channel_id = $1 + ` rows, err := db.Pool.Query(ctx, query, channelID) + if err != nil { + return nil + } defer rows.Close() var commands []common.ChannelCommand + for rows.Next() { + var command common.ChannelCommand + err := rows.Scan( + &command.CommandID, + &command.UserID, + &command.ChannelID, + &command.Name, + &command.UserTriggerIDs, + &command.UserIgnoreIDs, + &command.Trigger, + &command.Response, + &command.RunCommand, + &command.Active, + &command.ActiveOnline, + &command.ActiveOffline, + &command.Reply, + &command.Whisper, + &command.Announce, + &command.Cooldown, + &command.Delay, + &command.UseCount, + &command.Created, + &command.Modified, + &command.Platform, + &command.Help, ) + if err != nil { + return nil + } commands = append(commands, command) + } return &commands + } // GetChannelByName retrieves a channel by its username and platform from the database. + func (db *PostgresClient) GetChannelByName( + ctx context.Context, + username string, + platform common.Platforms, + ) (*common.Channel, error) { + return db.getChannelByType(ctx, username, platform, "NAME") + } // GetChannelByID retrieves a channel by its ID and platform from the database. + func (db *PostgresClient) GetChannelByID( + ctx context.Context, + channelID string, + platform common.Platforms, + ) (*common.Channel, error) { + return db.getChannelByType(ctx, channelID, platform, "ID") + } func (db *PostgresClient) getChannelByType( //nolint:cyclop + ctx context.Context, + value string, + platform common.Platforms, + chanType string, + ) (*common.Channel, error) { + query := ` + SELECT + c.channel_id, + c.username, + c.joined_at, + c.added_by, + c.platform, + c.settings, + c.editors, + c.ambassadors, + c.meta, + c.state + FROM channels c + ` switch chanType { + case "ID": + query += `WHERE c.channel_id = $1 ` + case "NAME": + query += `WHERE c.username = $1 ` + default: + return nil, errInvalidType + } + query += `AND platform = $2;` var channel common.Channel + err := db.Pool.QueryRow(ctx, query, value, platform).Scan( + &channel.ChannelID, + &channel.Username, + &channel.JoinedAt, + &channel.AddedBy, + &channel.Platform, + &channel.Settings, + &channel.Editors, + &channel.Ambassadors, + &channel.Meta, + &channel.State, ) + if err != nil { + return nil, err + } var wg sync.WaitGroup @@ -362,631 +613,1096 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop wg.Add(2) var commands *[]common.ChannelCommand + go func() { + defer wg.Done() + cmds := db.GetChannelCommands(ctx, channel.ChannelID) + if cmds != nil { + commands = cmds + } + }() var blocks []common.Block + go func() { + defer wg.Done() + bs := db.GetChannelBlocks(ctx, channel.ChannelID) + if bs != nil { + blocks = *bs + } + }() wg.Wait() if commands != nil { + channel.Commands = commands + } else { + channel.Commands = &[]common.ChannelCommand{} + } if len(blocks) > 0 { + channel.Blocks = common.FilteredBlocks{ - Users: &[]common.Block{}, + + Users: &[]common.Block{}, + Commands: &[]common.Block{}, } for _, block := range blocks { + switch block.BlockType { + case common.UserBlock: + *channel.Blocks.Users = append(*channel.Blocks.Users, block) + case common.CommandBlock: + *channel.Blocks.Commands = append(*channel.Blocks.Commands, block) + case common.GlobalBlock: + continue + } + } + } else { + channel.Blocks = common.FilteredBlocks{} + } return &channel, nil + } // GetPotatoData retrieves potato data for a user from the database. + func (db *PostgresClient) GetPotatoData(ctx context.Context, username string) (*common.PotatoData, error) { + query := ` + SELECT + p.user_id, + p.potato_count, + p.potato_prestige, + p.potato_rank, + p.tax_multiplier, + p.first_seen, + p.stole_from, + p.stole_amount, + p.trampled_by, + a.average_response_time, + a.eat_count, + a.harvest_count, + a.stolen_count, + a.theft_count, + a.trampled_count, + a.trample_count, + a.cdr_count, + a.quiz_count, + a.quiz_complete_count, + a.guard_buy_count, + a.fertilizer_buy_count, + a.cdr_buy_count, + a.new_quiz_buy_count, + a.gamble_win_count, + a.gamble_loss_count, + a.gamble_wins_total, + a.gamble_losses_total, + a.duel_win_count, + a.duel_loss_count, + a.duel_wins_amount, + a.duel_losses_amount, + a.duel_caught_losses, + a.average_response_count, + s.not_verbose + FROM ( SELECT user_id FROM users WHERE username = $1 ) u + INNER JOIN potatoes p ON p.user_id = u.user_id + INNER JOIN potato_analytics a ON u.user_id = a.user_id + INNER JOIN potato_settings s ON u.user_id = s.user_id; + ` var data common.PotatoData err := db.Pool.QueryRow(ctx, query, username).Scan( + &data.ID, + &data.PotatoCount, + &data.PotatoPrestige, + &data.PotatoRank, + &data.TaxMultiplier, + &data.FirstSeen, + &data.StoleFrom, + &data.StoleAmount, + &data.TrampledBy, + &data.AverageResponseTime, + &data.EatCount, + &data.HarvestCount, + &data.StolenCount, + &data.TheftCount, + &data.TrampledCount, + &data.TrampleCount, + &data.CDRCount, + &data.QuizCount, + &data.QuizCompleteCount, + &data.GuardBuyCount, + &data.FertilizerBuyCount, + &data.CDRBuyCount, + &data.NewQuizBuyCount, + &data.GambleWinCount, + &data.GambleLossCount, + &data.GambleWinsTotal, + &data.GambleLossesTotal, + &data.DuelWinCount, + &data.DuelLossCount, + &data.DuelWinsAmount, + &data.DuelLossesAmount, + &data.DuelCaughtLosses, + &data.AverageResponseCount, + &data.NotVerbose, ) + if err != nil { + return nil, err + } return &data, nil + } // BatchUserConections retrieves user connections for a batch of user IDs from the database. + func (db *PostgresClient) BatchUserConections( + ctx context.Context, + ids []int, + ) *map[int][]common.UserConnection { + query := ` + SELECT + user_id, + platform_id, + platform_username, + platform_display, + platform_pfp, + platform, + platform_metadata + FROM user_connections + WHERE user_id = ANY($1::INT[]) + ` rows, err := db.Pool.Query(ctx, query, ids) + if err != nil { + return nil + } defer rows.Close() users := make(map[int][]common.UserConnection) + for rows.Next() { + var connection common.UserConnection + err := rows.Scan( + &connection.ID, + &connection.UserID, + &connection.Username, + &connection.Display, + &connection.PFP, + &connection.Platform, + &connection.Meta, ) + if err != nil { + return nil + } users[connection.ID] = append(users[connection.ID], connection) + } return &users + } // GetRedirectByKey retrieves a URL redirect from the database by its key. + func (db *PostgresClient) GetRedirectByKey(ctx context.Context, key string) (string, error) { + query := `SELECT url FROM url_redirects WHERE key = $1` var url string + err := db.Pool.QueryRow(ctx, query, key).Scan(&url) + if err != nil { + return "", err + } return url, nil + } // GetKeyByRedirect retrieves the key associated with a given URL redirect from the database. + func (db *PostgresClient) GetKeyByRedirect(ctx context.Context, url string) (string, error) { + query := `SELECT key FROM url_redirects WHERE url = $1` var key string + err := db.Pool.QueryRow(ctx, query, url).Scan(&key) + if err != nil { + return "", err + } return key, nil + } // RedirectExists checks if a URL redirect exists in the database by its key. + func (db *PostgresClient) RedirectExists(ctx context.Context, key string) bool { + query := `SELECT EXISTS(SELECT 1 FROM url_redirects WHERE key = $1)` var exists bool err := db.Pool.QueryRow(ctx, query, key).Scan(&exists) + if err != nil { + return false + } return exists + } // NewRedirect inserts a new URL redirect into the database. + func (db *PostgresClient) NewRedirect(ctx context.Context, key, url string) error { + query := `INSERT INTO url_redirects (key, url) VALUES ($1, $2)` _, err := db.Pool.Exec(ctx, query, key, url) return err + } // GetHaste retrieves a hastebin text document from the database by its key. + func (db *PostgresClient) GetHaste(ctx context.Context, key string) (string, error) { + query := ` + UPDATE haste + SET access_count = access_count + 1 + WHERE key = $1 + RETURNING convert_from(zstd_decompress(content::bytea), 'utf-8') AS text; + ` var text string err := db.Pool.QueryRow(ctx, query, encode(key)).Scan(&text) + if err != nil { + return "", err + } return text, nil + } // NewHaste inserts a new compressed hastebin text document into the database. + func (db *PostgresClient) NewHaste( + ctx context.Context, + key string, + text []byte, + source string, + ) error { + query := ` + INSERT INTO haste (key, content, source) + VALUES ($1, zstd_compress($2, null, 8), $3) + ON CONFLICT (key) DO NOTHING; + ` _, err := db.Pool.Exec(ctx, query, encode(key), text, source) return err + } func encode(data string) string { + hash := md5.New() //nolint:gosec + hash.Write([]byte(data)) return hex.EncodeToString(hash.Sum(nil)) + } // NewUpload inserts a new file into the database and returns the creation timestamp. + func (db *PostgresClient) NewUpload( + ctx context.Context, + key string, + file []byte, + name string, + mimeType string, + ) (bool, *time.Time) { + query := ` + INSERT INTO file_store (file, file_name, mime_type, key) + VALUES ($1, $2, $3, $4) + RETURNING created_at; + ` + var createdAt time.Time + err := db.Pool.QueryRow(ctx, query, file, name, mimeType, key).Scan(&createdAt) + if err != nil { + logger.Error.Println("Error scanning upload", err) return false, nil + } return true, &createdAt + } // GetFileByKey retrieves a file from the database by its key. + func (db *PostgresClient) GetFileByKey( + ctx context.Context, + key string, + ) ([]byte, string, *string, *time.Time, error) { + query := ` + SELECT file, mime_type, file_name, created_at + FROM file_store + WHERE key = $1 + ` var content []byte + var mimeType string + var fileName *string + var createdAt time.Time err := db.Pool.QueryRow(ctx, query, key).Scan( + &content, + &mimeType, + &fileName, + &createdAt, ) + if err != nil { + return nil, "", nil, nil, err + } return content, mimeType, fileName, &createdAt, nil + } // DeleteFileByKey deletes a file from the database by its key. + func (db *PostgresClient) DeleteFileByKey( + ctx context.Context, + key string, + ) bool { + query := ` + DELETE FROM file_store + WHERE key = $1 + ` _, err := db.Pool.Exec(ctx, query, key) return err == nil + } // GetUploadCreatedAt retrieves the creation timestamp of an upload by its key. + func (db *PostgresClient) GetUploadCreatedAt( + ctx context.Context, + key string, + ) (*time.Time, error) { + query := ` + SELECT created_at + FROM file_store + WHERE key = $1 + ` var createdAt time.Time + err := db.Pool.QueryRow(ctx, query, key).Scan(&createdAt) + if err != nil { + return nil, err + } return &createdAt, nil + } // GetAllChannels retrieves all channels with state = 'JOINED', ordered by username. + func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelListItem, error) { + query := ` + SELECT channel_id, username, platform, state + FROM channels + WHERE state = 'JOINED' + ORDER BY username + ` rows, err := db.Pool.Query(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() var channels []common.ChannelListItem + for rows.Next() { + var ch common.ChannelListItem + if err := rows.Scan(&ch.ChannelID, &ch.Username, &ch.Platform, &ch.State); err != nil { + return nil, err + } + channels = append(channels, ch) + } return channels, rows.Err() + } // UpdateUserSettings replaces the settings JSONB column for a user. + func (db *PostgresClient) UpdateUserSettings(ctx context.Context, userID int, settings common.UserSettings) error { + query := `UPDATE users SET settings = $1 WHERE user_id = $2` + _, err := db.Pool.Exec(ctx, query, settings, userID) return err + } // UpdateChannelSettings replaces the settings JSONB column for a channel. + func (db *PostgresClient) UpdateChannelSettings( + ctx context.Context, + channelID string, + platform string, + settings common.ChannelSettings, + ) error { + query := `UPDATE channels SET settings = $1 WHERE channel_id = $2 AND platform = $3` + _, err := db.Pool.Exec(ctx, query, settings, channelID, platform) return err + } // GetChannelSettingsByID retrieves only the settings column for a channel by its ID. + func (db *PostgresClient) GetChannelSettingsByID( + ctx context.Context, + channelID string, + platform string, + ) (common.ChannelSettings, error) { + var settings common.ChannelSettings - err := db.Pool.QueryRow(ctx, `SELECT settings FROM channels WHERE channel_id = $1 AND platform = $2`, channelID, platform).Scan(&settings) + + err := db.Pool.QueryRow( + + ctx, + + `SELECT settings FROM channels WHERE channel_id = $1 AND platform = $2`, + + channelID, + + platform, + ).Scan(&settings) return settings, err + } // GetCommandSettings retrieves all command settings rows for a given channel. + func (db *PostgresClient) GetCommandSettings( + ctx context.Context, + channelID string, + ) ([]common.CommandSettings, error) { + query := ` + SELECT + channel_id, + command, + permission, + users_blacklisted, + users_whitelisted, + custom_cooldown, + channel_usage, + is_enabled, + offline_only, + silent_errors, + allow_bots, + platform, + ambassador_granted + FROM command_settings + WHERE channel_id = $1 + ORDER BY command + ` rows, err := db.Pool.Query(ctx, query, channelID) + if err != nil { + return nil, err + } + defer rows.Close() var results []common.CommandSettings + for rows.Next() { + var cs common.CommandSettings + if err := rows.Scan( + &cs.ChannelID, + &cs.Command, + &cs.Permission, + &cs.UsersBlacklisted, + &cs.UsersWhitelisted, + &cs.CustomCooldown, + &cs.ChannelUsage, + &cs.IsEnabled, + &cs.OfflineOnly, + &cs.SilentErrors, + &cs.AllowBots, + &cs.Platform, + &cs.AmbassadorGranted, ); err != nil { + return nil, err + } + results = append(results, cs) + } return results, rows.Err() + } // UpsertCommandSettings inserts or updates a single command's settings row. + func (db *PostgresClient) UpsertCommandSettings(ctx context.Context, cs common.CommandSettings) error { + query := ` + INSERT INTO command_settings ( + channel_id, command, permission, users_blacklisted, users_whitelisted, + custom_cooldown, is_enabled, offline_only, silent_errors, allow_bots, + platform, ambassador_granted + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ON CONFLICT (channel_id, command, platform) DO UPDATE SET + permission = EXCLUDED.permission, + users_blacklisted = EXCLUDED.users_blacklisted, + users_whitelisted = EXCLUDED.users_whitelisted, + custom_cooldown = EXCLUDED.custom_cooldown, + is_enabled = EXCLUDED.is_enabled, + offline_only = EXCLUDED.offline_only, + silent_errors = EXCLUDED.silent_errors, + allow_bots = EXCLUDED.allow_bots, + platform = EXCLUDED.platform, + ambassador_granted = EXCLUDED.ambassador_granted + ` platform := cs.Platform + if platform == "" { + platform = "TWITCH" + } _, err := db.Pool.Exec( + ctx, query, + cs.ChannelID, cs.Command, cs.Permission, + cs.UsersBlacklisted, cs.UsersWhitelisted, + cs.CustomCooldown, cs.IsEnabled, cs.OfflineOnly, + cs.SilentErrors, cs.AllowBots, platform, cs.AmbassadorGranted, ) return err + } // ResetCommandSettings resets a single command's overrides back to their defaults. + func (db *PostgresClient) ResetCommandSettings(ctx context.Context, channelID, command string) error { + query := ` + UPDATE command_settings SET + is_enabled = TRUE, + offline_only = NULL, + custom_cooldown = NULL, + silent_errors = FALSE, + users_whitelisted = NULL, + users_blacklisted = NULL, + allow_bots = NULL, + permission = NULL, + ambassador_granted = FALSE + WHERE channel_id = $1 AND command = $2 + ` + _, err := db.Pool.Exec(ctx, query, channelID, command) return err + } // GetChannelAmbassadors returns the ambassadors slice for a channel, used for auth checks. + func (db *PostgresClient) GetChannelAmbassadors( + ctx context.Context, + channelID string, + platform common.Platforms, + ) ([]string, error) { + query := `SELECT ambassadors FROM channels WHERE channel_id = $1 AND platform = $2` var ambassadors []string + err := db.Pool.QueryRow(ctx, query, channelID, platform).Scan(&ambassadors) + if err != nil { + return nil, err + } return ambassadors, nil + } // GetUserReminders retrieves all pending reminders for a user on a given platform. + func (db *PostgresClient) GetUserReminders( + ctx context.Context, + userID string, + platform common.Platforms, + ) ([]common.Reminder, error) { + query := ` + SELECT + reminder_id, user_id, recipient_id, channel_id, + message, ready_at, set_at, afk_withheld, status, platform, sent_at, type + FROM reminders + WHERE recipient_id = $1 AND platform = $2 + ORDER BY set_at DESC + ` rows, err := db.Pool.Query(ctx, query, userID, platform) + if err != nil { + return nil, err + } + defer rows.Close() var reminders []common.Reminder + for rows.Next() { + var r common.Reminder + if err := rows.Scan( + &r.ReminderID, &r.UserID, &r.RecipientID, &r.ChannelID, + &r.Message, &r.ReadyAt, &r.SetAt, &r.AfkWithheld, &r.Status, &r.Platform, &r.SentAt, &r.Type, ); err != nil { + return nil, err + } + reminders = append(reminders, r) + } return reminders, rows.Err() + } // DeleteReminder hard-deletes a reminder by ID, verifying the owner platform ID. + func (db *PostgresClient) DeleteReminder(ctx context.Context, reminderID int, recipientID string) error { + query := `DELETE FROM reminders WHERE reminder_id = $1 AND recipient_id = $2` + _, err := db.Pool.Exec(ctx, query, reminderID, recipientID) return err + } // UpsertOAuthToken stores or refreshes a platform OAuth token for a given user. + func (db *PostgresClient) UpsertOAuthToken( + ctx context.Context, + platformID string, + platform common.Platforms, + accessToken string, + refreshToken string, + scope []string, + expiresIn int, + ) error { + query := ` + INSERT INTO connection_oauth ( + platform_id, access_token, refresh_token, scope, expires_in, added_at, platform + ) VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (platform_id, platform) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + scope = EXCLUDED.scope, + expires_in = EXCLUDED.expires_in, + added_at = EXCLUDED.added_at + ` + _, err := db.Pool.Exec( + ctx, query, + platformID, accessToken, refreshToken, scope, expiresIn, time.Now(), platform, ) return err + } diff --git a/common/types.go b/common/types.go index 13a0462..362a5f9 100644 --- a/common/types.go +++ b/common/types.go @@ -1,4 +1,5 @@ // Package common provides common types and structures used across the application. + package common import ( @@ -7,472 +8,766 @@ import ( ) // Platforms represents the different platforms supported by the bot. + type Platforms string //nolint:revive + const ( - TWITCH Platforms = "TWITCH" + TWITCH Platforms = "TWITCH" + DISCORD Platforms = "DISCORD" - KICK Platforms = "KICK" - STV Platforms = "STV" + + KICK Platforms = "KICK" + + STV Platforms = "STV" + SPOTIFY Platforms = "SPOTIFY" + ANILIST Platforms = "ANILIST" - TRAKT Platforms = "TRAKT" - FFZ Platforms = "FFZ" - BTTV Platforms = "BTTV" - STEAM Platforms = "STEAM" + + TRAKT Platforms = "TRAKT" + + FFZ Platforms = "FFZ" + + BTTV Platforms = "BTTV" + + STEAM Platforms = "STEAM" ) // PermissionLevel represents the permission level of a user interally with the api and bot. + type PermissionLevel uint8 //nolint:revive + const ( - DEVELOPER PermissionLevel = 4 - ADMIN PermissionLevel = 3 - MOD PermissionLevel = 2 - USER PermissionLevel = 1 + DEVELOPER PermissionLevel = 4 + + ADMIN PermissionLevel = 3 + + MOD PermissionLevel = 2 + + USER PermissionLevel = 1 + BLACKLISTED PermissionLevel = 0 ) // User represents a user on a platform, including their first seen date, username, display name,. + type User struct { - FirstSeen time.Time `json:"first_seen"` - Username string `json:"username"` - Display string `json:"display"` - Settings UserSettings `json:"settings"` + FirstSeen time.Time `json:"first_seen"` + + Username string `json:"username"` + + Display string `json:"display"` + + Settings UserSettings `json:"settings"` + Connections []UserConnection `json:"connections,omitempty"` - ID int `json:"user_id"` - Level int `json:"level"` + + ID int `json:"user_id"` + + Level int `json:"level"` } // UserSettings represents the settings for a user on a platform, including language preferences,. + type UserSettings struct { - Language string `json:"language"` - IgnoreDropped bool `json:"ignore_dropped"` - ColorResponses bool `json:"color_responses"` - NoReply bool `json:"no_reply"` - IsBot bool `json:"is_bot"` - IsSelfBot bool `json:"is_selfbot"` + Language string `json:"language"` + + IgnoreDropped bool `json:"ignore_dropped"` + + ColorResponses bool `json:"color_responses"` + + NoReply bool `json:"no_reply"` + + IsBot bool `json:"is_bot"` + + IsSelfBot bool `json:"is_selfbot"` } // UserMeta represents the metadata associated with a user on a platform. + type UserMeta = json.RawMessage // UserConnection represents a connection between a user and a platform. + type UserConnection struct { Platform Platforms `json:"platform"` - Username string `json:"platform_username"` - Display string `json:"platform_display"` - UserID string `json:"platform_id"` - PFP string `json:"platform_pfp"` - Meta UserMeta `json:"platform_metadata"` - ID int `json:"user_id"` + + Username string `json:"platform_username"` + + Display string `json:"platform_display"` + + UserID string `json:"platform_id"` + + PFP string `json:"platform_pfp"` + + Meta UserMeta `json:"platform_metadata"` + + ID int `json:"user_id"` } // KickChannelMeta represents the metadata for a channel on Kick, including the channel ID, chatroom ID, and user ID. + type KickChannelMeta struct { - ChannelID string `json:"channel_id"` + ChannelID string `json:"channel_id"` + ChatroomID string `json:"chatroom_id"` - UserID string `json:"user_id"` + + UserID string `json:"user_id"` } // TwitchChannelMeta represents the metadata for a channel on Twitch, including whether the bot and channel are banned. + type TwitchChannelMeta struct { - BotBanned bool `json:"bot_banned"` + BotBanned bool `json:"bot_banned"` + TwitchBanned bool `json:"twitch_banned"` } // TwitchUserMeta represents the metadata for a user on Twitch, including their color and roles. + type TwitchUserMeta struct { - Color string `json:"color,omitempty"` + Color string `json:"color,omitempty"` + Roles TwitchRoles `json:"roles,omitzero"` } // StvUserMeta represents the metadata for a user on 7TV, including their paint ID and roles. + type StvUserMeta struct { - PaintID string `json:"paint_id,omitempty"` - Roles []string `json:"roles,omitempty"` + PaintID string `json:"paint_id,omitempty"` + + Roles []string `json:"roles,omitempty"` } // TwitchRoles represents the roles a user can have on Twitch, such as staff, partner, or affiliate. + type TwitchRoles struct { - IsStaff bool `json:"isStaff,omitempty"` - IsPartner bool `json:"isPartner,omitempty"` + IsStaff bool `json:"isStaff,omitempty"` + + IsPartner bool `json:"isPartner,omitempty"` + IsAffiliate bool `json:"isAffiliate,omitempty"` } // Channel represents a channel on a platform, including its blocks, settings, commands, and other metadata. + type Channel struct { - Blocks FilteredBlocks `json:"blocks,omitzero"` - JoinedAt *time.Time `json:"joined_at,omitempty"` - Meta map[string]any `json:"meta"` - Commands *[]ChannelCommand `json:"commands,omitempty"` - ChannelID string `json:"channel_id"` - Username string `json:"username"` - Platform Platforms `json:"platform"` - State string `json:"state"` - AddedBy []AddedByData `json:"added_by,omitempty"` - Editors []string `json:"editors"` - Ambassadors []string `json:"ambassadors"` - Settings ChannelSettings `json:"settings"` + Blocks FilteredBlocks `json:"blocks,omitzero"` + + JoinedAt *time.Time `json:"joined_at,omitempty"` + + Meta map[string]any `json:"meta"` + + Commands *[]ChannelCommand `json:"commands,omitempty"` + + ChannelID string `json:"channel_id"` + + Username string `json:"username"` + + Platform Platforms `json:"platform"` + + State string `json:"state"` + + AddedBy []AddedByData `json:"added_by,omitempty"` + + Editors []string `json:"editors"` + + Ambassadors []string `json:"ambassadors"` + + Settings ChannelSettings `json:"settings"` } // FilteredBlocks represents the blocks of users and commands in a channel. + type FilteredBlocks struct { - Users *[]Block `json:"users"` + Users *[]Block `json:"users"` + Commands *[]Block `json:"commands"` } // Block represents a block connection between a user and a command or another user. + type Block struct { - ChannelID string `json:"channel_id"` - BlockType BlockType `json:"block_type"` - CommandName string `json:"command_name,omitempty"` - ID int `json:"user_id"` - BlockedUserID int `json:"blocked_user_id"` + ChannelID string `json:"channel_id"` + + BlockType BlockType `json:"block_type"` + + CommandName string `json:"command_name,omitempty"` + + ID int `json:"user_id"` + + BlockedUserID int `json:"blocked_user_id"` } // BlockType represents the type of block connection that can be made between a user or a command. + type BlockType string //nolint:revive + const ( - UserBlock BlockType = "USER" + UserBlock BlockType = "USER" + CommandBlock BlockType = "COMMAND" - GlobalBlock BlockType = "GLOBAL" + + GlobalBlock BlockType = "GLOBAL" ) // AddedByData represents the data of a user who added a channel. + type AddedByData struct { - AddedAt time.Time `json:"addedAt"` - Username string `json:"username"` - ID string `json:"id"` + AddedAt time.Time `json:"addedAt"` + + Username string `json:"username"` + + ID string `json:"id"` } // ChannelSettings represents the settings for a channel, including bot settings, cooldowns, language, + // permission levels, and other configurations. + type ChannelSettings struct { - PajBot *string `json:"paj_bot,omitempty"` - UserCooldown *int `json:"user_cooldown,omitempty"` - ChannelCooldown *int `json:"channel_cooldown,omitempty"` - OnlinePermission *string `json:"online_permission,omitempty"` - EmoteStreakResponse *string `json:"emote_streak_response,omitempty"` - PyramidResponse *string `json:"pyramid_response,omitempty"` - OnlineSilentErrors *bool `json:"online_silent_errors,omitempty"` - OnlineWhisperOnly *bool `json:"online_whisper_only,omitempty"` - AllowBotEmoteTracking *bool `json:"allow_bot_emote_tracking,omitempty"` - IgnoreDropped *bool `json:"ignore_dropped,omitempty"` - NoLinks *bool `json:"no_links,omitempty"` - ForcePyramidNotVerbose *bool `json:"force_pyramid_not_verbose,omitempty"` - Language string `json:"language"` - Permission string `json:"permission"` - Prefix string `json:"prefix"` - UsersBlacklisted []string `json:"users_blacklisted"` - NoReply bool `json:"no_reply"` - FirstMsgResponses bool `json:"first_msg_responses"` - WhisperOnly bool `json:"whisper_only"` - OfflineOnly bool `json:"offline_only"` - ForceLanguage bool `json:"force_language"` - SilentErrors bool `json:"silent_errors"` - ColorResponses bool `json:"color_responses"` + PajBot *string `json:"paj_bot,omitempty"` + + UserCooldown *int `json:"user_cooldown,omitempty"` + + ChannelCooldown *int `json:"channel_cooldown,omitempty"` + + OnlinePermission *string `json:"online_permission,omitempty"` + + EmoteStreakResponse *string `json:"emote_streak_response,omitempty"` + + PyramidResponse *string `json:"pyramid_response,omitempty"` + + OnlineSilentErrors *bool `json:"online_silent_errors,omitempty"` + + OnlineWhisperOnly *bool `json:"online_whisper_only,omitempty"` + + AllowBotEmoteTracking *bool `json:"allow_bot_emote_tracking,omitempty"` + + IgnoreDropped *bool `json:"ignore_dropped,omitempty"` + + NoLinks *bool `json:"no_links,omitempty"` + + ForcePyramidNotVerbose *bool `json:"force_pyramid_not_verbose,omitempty"` + + Language string `json:"language"` + + Permission string `json:"permission"` + + Prefix string `json:"prefix"` + + UsersBlacklisted []string `json:"users_blacklisted"` + + NoReply bool `json:"no_reply"` + + FirstMsgResponses bool `json:"first_msg_responses"` + + WhisperOnly bool `json:"whisper_only"` + + OfflineOnly bool `json:"offline_only"` + + ForceLanguage bool `json:"force_language"` + + SilentErrors bool `json:"silent_errors"` + + ColorResponses bool `json:"color_responses"` } // CommandSettings represents the settings for a command in a channel, including its permissions, cooldowns, + // and usage limits. + type CommandSettings struct { //nolint:govet - ChannelID string `json:"channel_id"` - Command string `json:"command"` - Permission *string `json:"permission,omitempty"` - UsersBlacklisted []string `json:"users_blacklisted,omitempty"` - UsersWhitelisted []string `json:"users_whitelisted,omitempty"` - CustomCooldown *int `json:"custom_cooldown,omitempty"` - ChannelUsage int `json:"channel_usage"` - IsEnabled bool `json:"is_enabled"` - OfflineOnly *bool `json:"offline_only,omitempty"` - SilentErrors bool `json:"silent_errors"` - AllowBots *bool `json:"allow_bots,omitempty"` - Platform string `json:"platform"` - AmbassadorGranted bool `json:"ambassador_granted"` + + ChannelID string `json:"channel_id"` + + Command string `json:"command"` + + Permission *string `json:"permission,omitempty"` + + UsersBlacklisted []string `json:"users_blacklisted,omitempty"` + + UsersWhitelisted []string `json:"users_whitelisted,omitempty"` + + CustomCooldown *int `json:"custom_cooldown,omitempty"` + + ChannelUsage int `json:"channel_usage"` + + IsEnabled bool `json:"is_enabled"` + + OfflineOnly *bool `json:"offline_only,omitempty"` + + SilentErrors bool `json:"silent_errors"` + + AllowBots *bool `json:"allow_bots,omitempty"` + + Platform string `json:"platform"` + + AmbassadorGranted bool `json:"ambassador_granted"` } // PlatformOauth represents the OAuth token and metadata for a user on a specific platform, including. + type PlatformOauth struct { - AddedAt time.Time `json:"added_at"` - PlatformID string `json:"platform_id"` - AccessToken string `json:"access_token"` //nolint:gosec - RefreshToken string `json:"refresh_token"` //nolint:gosec - Platform Platforms `json:"platform"` - Scope []string `json:"scope"` - ExpiresIn int `json:"expires_in"` + AddedAt time.Time `json:"added_at"` + + PlatformID string `json:"platform_id"` + + AccessToken string `json:"access_token"` //nolint:gosec + + RefreshToken string `json:"refresh_token"` //nolint:gosec + + Platform Platforms `json:"platform"` + + Scope []string `json:"scope"` + + ExpiresIn int `json:"expires_in"` } // ChannelCommand represents a single custom command, including its properties and settings. + type ChannelCommand struct { - Created time.Time `json:"created"` - Modified time.Time `json:"modified"` - Help *string `json:"help"` - Name *string `json:"name"` - RunCommand *string `json:"run_command"` - ChannelID string `json:"channel_id"` - Trigger string `json:"trigger"` - Response string `json:"response"` - Platform string `json:"platform"` - UserTriggerIDs []string `json:"user_trigger_ids"` - UserIgnoreIDs []string `json:"user_ignore_ids"` - Cooldown int `json:"cooldown"` - UserID int `json:"user_id"` - Delay int `json:"delay"` - UseCount int `json:"use_count"` - CommandID int `json:"command_id"` - Reply bool `json:"reply"` - Whisper bool `json:"whisper"` - Announce bool `json:"announce"` - ActiveOffline bool `json:"active_offline"` - ActiveOnline bool `json:"active_online"` - Active bool `json:"active"` + Created time.Time `json:"created"` + + Modified time.Time `json:"modified"` + + Help *string `json:"help"` + + Name *string `json:"name"` + + RunCommand *string `json:"run_command"` + + ChannelID string `json:"channel_id"` + + Trigger string `json:"trigger"` + + Response string `json:"response"` + + Platform string `json:"platform"` + + UserTriggerIDs []string `json:"user_trigger_ids"` + + UserIgnoreIDs []string `json:"user_ignore_ids"` + + Cooldown int `json:"cooldown"` + + UserID int `json:"user_id"` + + Delay int `json:"delay"` + + UseCount int `json:"use_count"` + + CommandID int `json:"command_id"` + + Reply bool `json:"reply"` + + Whisper bool `json:"whisper"` + + Announce bool `json:"announce"` + + ActiveOffline bool `json:"active_offline"` + + ActiveOnline bool `json:"active_online"` + + Active bool `json:"active"` } // Potatoes represents the structure of potato-related data for a user. + type Potatoes struct { - StoleFrom *string `json:"stole_from"` - StoleAmount *int `json:"stole_amount"` - TrampledBy *string `json:"trampled_by"` - FirstSeen string `json:"first_seen"` - ID int `json:"user_id"` - PotatoCount int `json:"potato_count"` - PotatoPrestige int `json:"potato_prestige"` - PotatoRank int `json:"potato_rank"` - TaxMultiplier int `json:"tax_multiplier"` + StoleFrom *string `json:"stole_from"` + + StoleAmount *int `json:"stole_amount"` + + TrampledBy *string `json:"trampled_by"` + + FirstSeen string `json:"first_seen"` + + ID int `json:"user_id"` + + PotatoCount int `json:"potato_count"` + + PotatoPrestige int `json:"potato_prestige"` + + PotatoRank int `json:"potato_rank"` + + TaxMultiplier int `json:"tax_multiplier"` } // PotatoAnalytics contains various statistics related to potato interactions. + type PotatoAnalytics struct { - AverageResponseTime string `json:"average_response_time"` - EatCount int `json:"eat_count"` - HarvestCount int `json:"harvest_count"` - StolenCount int `json:"stolen_count"` - TheftCount int `json:"theft_count"` - TrampledCount int `json:"trampled_count"` - TrampleCount int `json:"trample_count"` - CDRCount int `json:"cdr_count"` - QuizCount int `json:"quiz_count"` - QuizCompleteCount int `json:"quiz_complete_count"` - GuardBuyCount int `json:"guard_buy_count"` - FertilizerBuyCount int `json:"fertilizer_buy_count"` - CDRBuyCount int `json:"cdr_buy_count"` - NewQuizBuyCount int `json:"new_quiz_buy_count"` - GambleWinCount int `json:"gamble_win_count"` - GambleLossCount int `json:"gamble_loss_count"` - GambleWinsTotal int `json:"gamble_wins_total"` - GambleLossesTotal int `json:"gamble_losses_total"` - DuelWinCount int `json:"duel_win_count"` - DuelLossCount int `json:"duel_loss_count"` - DuelWinsAmount int `json:"duel_wins_amount"` - DuelLossesAmount int `json:"duel_losses_amount"` - DuelCaughtLosses int `json:"duel_caught_losses"` - AverageResponseCount int `json:"average_response_count"` + AverageResponseTime string `json:"average_response_time"` + + EatCount int `json:"eat_count"` + + HarvestCount int `json:"harvest_count"` + + StolenCount int `json:"stolen_count"` + + TheftCount int `json:"theft_count"` + + TrampledCount int `json:"trampled_count"` + + TrampleCount int `json:"trample_count"` + + CDRCount int `json:"cdr_count"` + + QuizCount int `json:"quiz_count"` + + QuizCompleteCount int `json:"quiz_complete_count"` + + GuardBuyCount int `json:"guard_buy_count"` + + FertilizerBuyCount int `json:"fertilizer_buy_count"` + + CDRBuyCount int `json:"cdr_buy_count"` + + NewQuizBuyCount int `json:"new_quiz_buy_count"` + + GambleWinCount int `json:"gamble_win_count"` + + GambleLossCount int `json:"gamble_loss_count"` + + GambleWinsTotal int `json:"gamble_wins_total"` + + GambleLossesTotal int `json:"gamble_losses_total"` + + DuelWinCount int `json:"duel_win_count"` + + DuelLossCount int `json:"duel_loss_count"` + + DuelWinsAmount int `json:"duel_wins_amount"` + + DuelLossesAmount int `json:"duel_losses_amount"` + + DuelCaughtLosses int `json:"duel_caught_losses"` + + AverageResponseCount int `json:"average_response_count"` } // PotatoSettings represents the settings related to potato interactions, such as verbosity. + type PotatoSettings struct { NotVerbose bool `json:"not_verbose"` } // PotatoData combines the Potatoes, PotatoAnalytics, and PotatoSettings structures into a single structure. + type PotatoData struct { Potatoes + PotatoAnalytics + PotatoSettings } // Redirect represents a URL redirect structure, typically used for OAuth flows. + type Redirect struct { Key string `json:"key"` + URL string `json:"url"` } // ErrorMessage represents a structure for error messages returned in API responses. + type ErrorMessage struct { Message string `json:"message"` } // GenericResponse represents a generic API response structure, which can include data and errors. + type GenericResponse[T any] struct { - Data *[]T `json:"data"` + Data *[]T `json:"data"` + Errors *[]ErrorMessage `json:"errors,omitempty"` } // TwitchValidation represents the structure of a Twitch OAuth validation response. + type TwitchValidation struct { - ClientID string `json:"client_id"` - Login string `json:"login"` - UserID string `json:"user_id"` - Scopes []string `json:"scopes"` - ExpiresIn int `json:"expires_in"` - StatusCode int `json:"status_code"` + ClientID string `json:"client_id"` + + Login string `json:"login"` + + UserID string `json:"user_id"` + + Scopes []string `json:"scopes"` + + ExpiresIn int `json:"expires_in"` + + StatusCode int `json:"status_code"` } // GenericOAUTHResponse represents a generic OAuth response structure. + type GenericOAUTHResponse struct { - AccessToken string `json:"access_token"` //nolint:gosec - RefreshToken string `json:"refresh_token"` //nolint:gosec - TokenType string `json:"token_type"` - Scope []string `json:"scope"` - ExpiresIn int `json:"expires_in"` + AccessToken string `json:"access_token"` //nolint:gosec + + RefreshToken string `json:"refresh_token"` //nolint:gosec + + TokenType string `json:"token_type"` + + Scope []string `json:"scope"` + + ExpiresIn int `json:"expires_in"` } // Command represents chat command, with its permissions, description, and other details. + type Command struct { - Conditions CommandConditions `json:"conditions"` - BotRequires BotCommandRequirements `json:"botRequires"` - Description string `json:"description"` - DetailedDescription string `json:"detailedDescription,omitempty"` - Title string `json:"title"` - Usage string `json:"usage"` - Category CommandCategories `json:"category"` - Name string `json:"name"` - UserRequires UserRequires `json:"userRequires"` - Aliases []string `json:"aliases"` - Flags []FlagDetails `json:"flags"` - Cooldown int `json:"cooldown"` - Level PermissionLevel `json:"level"` + Conditions CommandConditions `json:"conditions"` + + BotRequires BotCommandRequirements `json:"botRequires"` + + Description string `json:"description"` + + DetailedDescription string `json:"detailedDescription,omitempty"` + + Title string `json:"title"` + + Usage string `json:"usage"` + + Category CommandCategories `json:"category"` + + Name string `json:"name"` + + UserRequires UserRequires `json:"userRequires"` + + Aliases []string `json:"aliases"` + + Flags []FlagDetails `json:"flags"` + + Cooldown int `json:"cooldown"` + + Level PermissionLevel `json:"level"` } // CommandCategories represents the category of a command. + type CommandCategories string //nolint:revive + const ( Development CommandCategories = "development" - Deprecated CommandCategories = "deprecated" - Moderation CommandCategories = "moderation" - Utilities CommandCategories = "utilities" - Unlisted CommandCategories = "unlisted" - Settings CommandCategories = "settings" - Stream CommandCategories = "stream" - Potato CommandCategories = "potato" - Emotes CommandCategories = "emotes" - Anime CommandCategories = "anime" - Music CommandCategories = "music" - Spam CommandCategories = "spam" - Misc CommandCategories = "misc" - Fun CommandCategories = "fun" + + Deprecated CommandCategories = "deprecated" + + Moderation CommandCategories = "moderation" + + Utilities CommandCategories = "utilities" + + Unlisted CommandCategories = "unlisted" + + Settings CommandCategories = "settings" + + Stream CommandCategories = "stream" + + Potato CommandCategories = "potato" + + Emotes CommandCategories = "emotes" + + Anime CommandCategories = "anime" + + Music CommandCategories = "music" + + Spam CommandCategories = "spam" + + Misc CommandCategories = "misc" + + Fun CommandCategories = "fun" ) // FlagDetails represents the details of a command flag, including its requirements, usage, and validation function. + type FlagDetails struct { - UserRequires *UserRequires `json:"user_requires,omitempty"` - Usage *string `json:"usage,omitempty"` - Multi *bool `json:"multi,omitempty"` - Check func(params Flags, flag FlagDetails) (FlagCheckResult, error) `json:"-"` - Name string `json:"name"` - Type string `json:"type"` - Description string `json:"description"` - Aliases []string `json:"aliases,omitempty"` - Level PermissionLevel `json:"level"` - Required bool `json:"required"` + UserRequires *UserRequires `json:"user_requires,omitempty"` + + Usage *string `json:"usage,omitempty"` + + Multi *bool `json:"multi,omitempty"` + + Check func(params Flags, flag FlagDetails) (FlagCheckResult, error) `json:"-"` + + Name string `json:"name"` + + Type string `json:"type"` + + Description string `json:"description"` + + Aliases []string `json:"aliases,omitempty"` + + Level PermissionLevel `json:"level"` + + Required bool `json:"required"` } // FlagCheckResult represents the result of a flag check, indicating whether the flag is valid, + // and any associated error or requirement. + type FlagCheckResult struct { MustBe *string `json:"must_be,omitempty"` - Error *string `json:"error,omitempty"` - Valid bool `json:"valid"` + + Error *string `json:"error,omitempty"` + + Valid bool `json:"valid"` } // UserRequires represents the required user permission level for a command or flag. + type UserRequires string //nolint:revive + const ( - None UserRequires = "NONE" - Subscriber UserRequires = "SUBSCRIBER" - VIP UserRequires = "VIP" - Mod UserRequires = "MOD" - Ambassador UserRequires = "AMBASSADOR" + None UserRequires = "NONE" + + Subscriber UserRequires = "SUBSCRIBER" + + VIP UserRequires = "VIP" + + Mod UserRequires = "MOD" + + Ambassador UserRequires = "AMBASSADOR" + Broadcaster UserRequires = "BROADCASTER" ) // Flags represents a map of flags, where each flag is identified by a string key and can hold any type of value. + type Flags map[string]any // CommandConditions represents the conditions under which a command can be executed. + type CommandConditions struct { - Ryan *bool `json:"ryan,omitempty"` - OfflineOnly *bool `json:"offlineOnly,omitempty"` - Whisperable *bool `json:"whisperable,omitempty"` - IgnoreBots *bool `json:"ignoreBots,omitempty"` - IsBlockable *bool `json:"isBlockable,omitempty"` + Ryan *bool `json:"ryan,omitempty"` + + OfflineOnly *bool `json:"offlineOnly,omitempty"` + + Whisperable *bool `json:"whisperable,omitempty"` + + IgnoreBots *bool `json:"ignoreBots,omitempty"` + + IsBlockable *bool `json:"isBlockable,omitempty"` + IsNotPipable *bool `json:"isNotPipable,omitempty"` } // BotCommandRequirements represents the requirements for a bot to execute a command,. + type BotCommandRequirements string //nolint:revive + const ( BotNone BotCommandRequirements = "NONE" - BotVIP BotCommandRequirements = "VIP" - BotMod BotCommandRequirements = "MOD" + + BotVIP BotCommandRequirements = "VIP" + + BotMod BotCommandRequirements = "MOD" ) // ChannelListItem is a lightweight channel representation used for list endpoints (e.g. nav search). + type ChannelListItem struct { - ChannelID string `json:"channel_id"` - Username string `json:"username"` - Platform Platforms `json:"platform"` - State string `json:"state"` + ChannelID string `json:"channel_id"` + + Username string `json:"username"` + + Platform Platforms `json:"platform"` + + State string `json:"state"` } // EmoteStat represents an aggregated emote usage entry from Clickhouse. + type EmoteStat struct { - EmoteID string `json:"emote_id"` - EmoteName string `json:"emote_name"` + EmoteID string `json:"emote_id"` + + EmoteName string `json:"emote_name"` + EmoteAlias string `json:"emote_alias"` - Provider string `json:"provider"` - Count int64 `json:"count"` + + Provider string `json:"provider"` + + Count int64 `json:"count"` } // EmoteHistoryEntry represents a single per-user emote usage record from Clickhouse. + type EmoteHistoryEntry struct { //nolint:govet - EmoteID string `json:"emote_id"` - EmoteName string `json:"emote_name"` - EmoteAlias string `json:"emote_alias"` - Provider string `json:"provider"` - ChannelID string `json:"channel_id"` - UserID string `json:"user_id"` - Count int64 `json:"count"` - UsedAt time.Time `json:"used_at"` + + EmoteID string `json:"emote_id"` + + EmoteName string `json:"emote_name"` + + EmoteAlias string `json:"emote_alias"` + + Provider string `json:"provider"` + + ChannelID string `json:"channel_id"` + + UserID string `json:"user_id"` + + Count int64 `json:"count"` + + UsedAt time.Time `json:"used_at"` } // PageInfo contains cursor-based pagination metadata. + type PageInfo struct { //nolint:govet - HasNextPage bool `json:"hasNextPage"` - Cursor string `json:"cursor"` + + HasNextPage bool `json:"hasNextPage"` + + Cursor string `json:"cursor"` } // EmoteStatsResponse is the paginated response shape for GET /emotes/stats. + type EmoteStatsResponse struct { - Data *[]EmoteStat `json:"data"` - Pagination PageInfo `json:"pagination"` - StatusCode int `json:"statusCode"` - Duration float64 `json:"duration"` + Data *[]EmoteStat `json:"data"` + + Pagination PageInfo `json:"pagination"` + + StatusCode int `json:"statusCode"` + + Duration float64 `json:"duration"` } // Reminder represents a scheduled or on-next-seen reminder message. + type Reminder struct { - SetAt time.Time `json:"set_at"` - SentAt *time.Time `json:"sent_at,omitempty"` - ReadyAt *time.Time `json:"ready_at,omitempty"` - UserID string `json:"user_id"` - RecipientID string `json:"recipient_id"` - ChannelID string `json:"channel_id"` - Message string `json:"message"` - Status string `json:"status"` - Platform string `json:"platform"` - Type string `json:"type"` - ReminderID int `json:"reminder_id"` - AfkWithheld bool `json:"afk_withheld"` + SetAt time.Time `json:"set_at"` + + SentAt *time.Time `json:"sent_at,omitempty"` + + ReadyAt *time.Time `json:"ready_at,omitempty"` + + UserID string `json:"user_id"` + + RecipientID string `json:"recipient_id"` + + ChannelID string `json:"channel_id"` + + Message string `json:"message"` + + Status string `json:"status"` + + Platform string `json:"platform"` + + Type string `json:"type"` + + ReminderID int `json:"reminder_id"` + + AfkWithheld bool `json:"afk_withheld"` } diff --git a/redirects/redirects_test.go b/redirects/redirects_test.go index c2c40f8..47f2d34 100644 --- a/redirects/redirects_test.go +++ b/redirects/redirects_test.go @@ -8,25 +8,39 @@ import ( ) func TestRedirects__CheckProtocolFormatAfterProtocolReformat(t *testing.T) { + redirector := redirects{} tests := []struct { - input string + input string + expected string }{ + {"https://google.com", "https://google.com"}, + {"http://google.com", "https://google.com"}, + {"//google.com", "https://google.com"}, + {"google.com", "https://google.com"}, } for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + cleanedURL := redirector.cleanRedirectProtocolSoLinksActuallyWork(tc.input) + assert.Truef( + t, strings.HasPrefix(cleanedURL, "https://"), "Expected cleaned URL to start with 'https://', got %q", cleanedURL, ) + assert.Equal(t, tc.expected, cleanedURL) + }) + } + } From 70547af210a566401b7028c108671ef0a1f81ce0 Mon Sep 17 00:00:00 2001 From: tlos Date: Mon, 18 May 2026 23:02:18 -0400 Subject: [PATCH 18/20] fix: lint please --- common/db/loops.go | 153 --------------------------- common/db/postgres.go | 202 ------------------------------------ redirects/redirects_test.go | 7 -- 3 files changed, 362 deletions(-) diff --git a/common/db/loops.go b/common/db/loops.go index 4fa0fed..45f0a66 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -32,7 +32,6 @@ const dumpPath = "./dump" // StartLoops initializes schedules and loops for various tasks. func StartLoops( - ctx context.Context, config common.Config, @@ -44,13 +43,9 @@ func StartLoops( clickhouse *ClickhouseClient, redis *RedisClient, - ) { - if !config.Loops.Enabled { - return - } cronManager := cron.New() @@ -58,13 +53,10 @@ func StartLoops( var err error _, err = cronManager.AddFunc("@hourly", func() { - go updateHourlyUsage(ctx, postgres) go validateTokens(ctx, config, postgres) - }) - if err != nil { logger.Error.Println("Failed initializing cron updateHourlyUsage", err) @@ -74,11 +66,8 @@ func StartLoops( } _, err = cronManager.AddFunc("@daily", func() { - updateDailyUsage(ctx, postgres) - }) - if err != nil { logger.Error.Println("Failed initializing cron updateDailyUsage", err) @@ -88,11 +77,8 @@ func StartLoops( } _, err = cronManager.AddFunc("@weekly", func() { - updateWeeklyUsage(ctx, postgres) - }) - if err != nil { logger.Error.Println("Failed initializing cron updateWeeklyUsage", err) @@ -102,11 +88,8 @@ func StartLoops( } _, err = cronManager.AddFunc("0 */2 * * *", func() { - refreshAllHelixTokens(ctx, config, postgres) - }) - if err != nil { logger.Error.Println("Failed initializing cron refreshAllHelixTokens", err) @@ -116,7 +99,6 @@ func StartLoops( } _, err = cronManager.AddFunc("*/30 * * * *", func() { - updateColorView(ctx, clickhouse) updateActiveBadgeView(ctx, clickhouse) @@ -124,9 +106,7 @@ func StartLoops( updateOwnedBadgeView(ctx, clickhouse) updateUserOwnedBadgeView(ctx, clickhouse) - }) - if err != nil { logger.Error.Println("Failed initializing cron clickhouse views", err) @@ -136,13 +116,10 @@ func StartLoops( } _, err = cronManager.AddFunc("0 */12 * * *", func() { - go backupPostgres(ctx, postgres, natsClient, config) // go optimizeClickhouse(ctx, config, clickhouse) - }) - if err != nil { logger.Error.Println("Failed initializing cron backupPostgres", err) @@ -158,11 +135,9 @@ func StartLoops( go deleteOldUploads(ctx, postgres) go updateAggregateTable(ctx, postgres) - } func decrementDuels(ctx context.Context, redis *RedisClient) { - for { time.Sleep(30 * time.Minute) @@ -170,7 +145,6 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { logger.Info.Println("Decrementing duels") keys, err := redis.Scan(ctx, "duelUse:*", 100, 0) - if err != nil { logger.Error.Println("Failed scanning keys for duels", err) @@ -180,9 +154,7 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { } if len(keys) == 0 { - return - } luaScript := ` @@ -210,21 +182,16 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { ` value, err := redis.Eval(ctx, luaScript, keys).Result() - if err != nil { - logger.Error.Println("Failed decrementing duels", err) - } logger.Info.Printf("Decremented %d duel keys", value) } - } func deleteOldUploads(ctx context.Context, postgres *PostgresClient) { - for { time.Sleep(24 * time.Hour) @@ -242,21 +209,16 @@ func deleteOldUploads(ctx context.Context, postgres *PostgresClient) { ` _, err := postgres.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error deleting old uploads ", err) - } logger.Debug.Println("Deleted old uploads") } - } func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { - for { time.Sleep(5 * time.Minute) @@ -278,75 +240,55 @@ func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { ` _, err := postgres.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating aggregate table", err) - } logger.Info.Println("Updated aggregate table") } - } func updateHourlyUsage(ctx context.Context, postgres *PostgresClient) { - logger.Info.Println("Updating hourly usage") query := `UPDATE gpt_usage SET hourly_usage = 0;` _, err := postgres.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating hourly usage", err) - } logger.Info.Println("Updated hourly usage") - } func updateDailyUsage(ctx context.Context, postgres *PostgresClient) { - logger.Info.Println("Updating daily usage") query := `UPDATE gpt_usage SET daily_usage = 0` _, err := postgres.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating daily usage", err) - } logger.Info.Println("Updated daily usage") - } func updateWeeklyUsage(ctx context.Context, postgres *PostgresClient) { - logger.Info.Println("Updating weekly usage") query := `UPDATE gpt_usage SET weekly_usage = 0` _, err := postgres.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating weekly usage", err) - } logger.Info.Println("Updated weekly usage") - } func updateColorView(ctx context.Context, clickhouse *ClickhouseClient) { - logger.Info.Println("Updating color view") query := ` @@ -376,21 +318,15 @@ func updateColorView(ctx context.Context, clickhouse *ClickhouseClient) { ` err := clickhouse.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating color view ", err) - } - } func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { - logger.Info.Println("Updating active badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_active_badge_stats;`) - if err != nil { logger.Error.Println("Error truncating badge stats table ", err) @@ -420,17 +356,12 @@ func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { ` err = clickhouse.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating badge view ", err) - } - } func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { - logger.Info.Println("Updating owned badge view") // Insert owned badges from active table first @@ -462,7 +393,6 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { // } err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_stats;`) - if err != nil { logger.Error.Println("Error truncating badge stats table ", err) @@ -492,21 +422,15 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { ` err = clickhouse.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating badge view ", err) - } - } func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { - logger.Info.Println("Updating user owned badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_user_stats;`) - if err != nil { logger.Error.Println("Error truncating badge stats table ", err) @@ -542,17 +466,12 @@ func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) ` err = clickhouse.Exec(ctx, query) - if err != nil { - logger.Error.Println("Error updating user owned badge view ", err) - } - } func upsertOAuthToken( - ctx context.Context, postgres *PostgresClient, @@ -560,9 +479,7 @@ func upsertOAuthToken( oauth *common.GenericOAUTHResponse, con common.PlatformOauth, - ) error { - query := ` INSERT INTO connection_oauth ( @@ -625,11 +542,9 @@ func upsertOAuthToken( ) return err - } func refreshOrDelete( - ctx context.Context, config common.Config, @@ -637,27 +552,20 @@ func refreshOrDelete( postgres *PostgresClient, con common.PlatformOauth, - ) (bool, error) { - var err error if con.RefreshToken == "" { - return false, errMissingRefreshToken - } refreshResult, err := utils.RefreshHelixToken(ctx, config, con.RefreshToken) if err != nil || refreshResult == nil { - return false, err - } err = upsertOAuthToken(ctx, postgres, refreshResult, con) - if err != nil { logger.Error.Println( @@ -670,11 +578,9 @@ func refreshOrDelete( } return true, nil - } func validateTokens(ctx context.Context, config common.Config, postgres *PostgresClient) { - logger.Info.Println("Validating Twitch tokens ") query := ` @@ -694,7 +600,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre ` rows, err := postgres.Query(ctx, query) - if err != nil { logger.Error.Println("Error getting tokens ", err) @@ -712,7 +617,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre var con common.PlatformOauth err := rows.Scan(&con.AccessToken, &con.PlatformID, &con.RefreshToken) - if err != nil { logger.Error.Println("Error scanning token: ", err) @@ -722,7 +626,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre } valid, _, err := utils.ValidateHelixToken(ctx, con.AccessToken, false) - if err != nil { logger.Error.Println("Error validating token ", err) @@ -734,7 +637,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre if !valid { ok, err := refreshOrDelete(ctx, config, postgres, con) - if err != nil { logger.Error.Println("Error refreshing token ", err) @@ -746,13 +648,9 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre } if ok { - validated++ - } else { - deleted++ - } continue @@ -773,11 +671,9 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre deleted, ) - } func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres *PostgresClient) { - logger.Info.Println("Refreshing all Twitch tokens") query := ` @@ -805,7 +701,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * ` rows, err := postgres.Query(ctx, query) - if err != nil { logger.Error.Println("Error getting tokens ", err) @@ -838,7 +733,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * &con.Scope, ) - if err != nil { logger.Error.Println("Error scanning token: ", err) @@ -848,7 +742,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * } ok, err := refreshOrDelete(ctx, config, postgres, con) - if err != nil { logger.Error.Println("Error refreshing token ", err) @@ -860,13 +753,9 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * } if ok { - refreshed++ - } else { - failed++ - } time.Sleep(200 * time.Millisecond) @@ -883,43 +772,31 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * failed, ) - } func sortFiles(files []string) func(i, j int) bool { - return func(i, j int) bool { - fileI, errI := os.Stat(filepath.Join(dumpPath, files[i])) if errI != nil { - return false - } fileJ, errJ := os.Stat(filepath.Join(dumpPath, files[j])) if errJ != nil { - return true - } return fileI.ModTime().Before(fileJ.ModTime()) - } - } func deleteOldDumps(files []string, maxSize int) { - logger.Debug.Printf("Checking for old dump files, current count: %d, max: %d", len(files), maxSize) if len(files) <= maxSize { - return - } sort.Slice(files, sortFiles(files)) @@ -929,21 +806,16 @@ func deleteOldDumps(files []string, maxSize int) { for _, file := range filesToDelete { err := os.Remove(file) - if err != nil { - logger.Error.Println("Failed deleting dump file ", err) - } } logger.Info.Printf("Deleted %d old dump files", len(filesToDelete)) - } func backupPostgres( - ctx context.Context, postgres *PostgresClient, @@ -951,9 +823,7 @@ func backupPostgres( natsClient *utils.NatsClient, config common.Config, - ) { - logger.Debug.Println("Backing up Postgres") if err := os.MkdirAll(dumpPath, 0o750); err != nil { @@ -965,7 +835,6 @@ func backupPostgres( } files, err := filepath.Glob(filepath.Join(dumpPath, "*.sql.zst")) - if err != nil { logger.Error.Println("Failed to list dump files:", err) @@ -1003,19 +872,15 @@ func backupPostgres( )) defer func() { - if err = cmd.Process.Release(); err != nil { logger.Error.Println("Failed to release pg_dump process:", err) if err = cmd.Process.Kill(); err != nil { - logger.Error.Fatalln("Failed to kill pg_dump process:", err) - } } - }() var stderr bytes.Buffer @@ -1035,7 +900,6 @@ func backupPostgres( duration := time.Since(start) stat, err := os.Stat(filePath) - if err != nil { logger.Error.Println("Failed to get backup file size:", err) @@ -1047,7 +911,6 @@ func backupPostgres( backupSize := float64(stat.Size()) / (1024 * 1024 * 1024) dbSize, err := getDatabaseSize(ctx, postgres, config.Postgres.Database) - if err != nil { logger.Error.Println("Failed to get database size:", err) @@ -1068,7 +931,6 @@ func backupPostgres( ) jsonMessage, err := json.Marshal(message) - if err != nil { logger.Error.Println("Failed to JSON stringify message:", err) @@ -1080,37 +942,25 @@ func backupPostgres( } if natsClient != nil { - publishBackupNotification(natsClient, jsonMessage) - } logger.Info.Println(message) - } func publishBackupNotification(natsClient *utils.NatsClient, message []byte) { - err := natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", message) - if err != nil { - logger.Error.Println("Failed to publish to queue:", err) - } - } func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName string) (string, error) { - query := `SELECT pg_size_pretty(pg_database_size($1)) AS size` rows, err := postgres.Query(ctx, query, dbName) - if err != nil { - return "", err - } defer rows.Close() @@ -1120,9 +970,7 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin var size string if err := rows.Scan(&size); err != nil { - return "", err - } return size, nil @@ -1130,7 +978,6 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin } return "", errNoRows - } // func optimizeClickhouse(ctx context.Context, config common.Config, clickhouse *ClickhouseClient) { diff --git a/common/db/postgres.go b/common/db/postgres.go index dbd7258..55c1570 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -45,25 +45,17 @@ var ( // InitPostgres initializes a new Postgres client with the provided configuration. func InitPostgres(ctx context.Context, config common.Config) (*PostgresClient, error) { - dbConfig, err := loadConfig(config) - if err != nil { - return nil, err - } pool, err := pgxpool.NewWithConfig(ctx, dbConfig) - if err != nil { - return nil, err - } return &PostgresClient{pool}, nil - } func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unparam @@ -71,33 +63,25 @@ func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unpara user := config.Postgres.User if user == "" { - user = "postgres" - } host := config.Postgres.Host if host == "" { - host = "localhost" //nolint:goconst - } port := config.Postgres.Port if port == "" { - port = "5432" - } database := config.Postgres.Database if database == "" { - database = "postgres" - } constring := fmt.Sprintf( @@ -116,7 +100,6 @@ func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unpara ) dbConfig, err := pgxpool.ParseConfig(constring) - if err != nil { logger.Error.Panicln("Error parsing database config", err) @@ -138,35 +121,26 @@ func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unpara dbConfig.ConnConfig.ConnectTimeout = 10 * time.Second return dbConfig, nil - } // CheckTableExists checks if a table exists in the database and creates it if it doesn't. func (db *PostgresClient) CheckTableExists(ctx context.Context, createTable string) { - _, err := db.Pool.Exec(ctx, createTable) - if err != nil { - logger.Error.Fatalf("Failed to create table: %v", err) - } - } // Ping checks the connection to the database. func (db *PostgresClient) Ping(ctx context.Context) error { - return db.Pool.Ping(ctx) - } // GetUserByName retrieves a user by their username from the database. func (db *PostgresClient) GetUserByName(ctx context.Context, username string) (*common.User, error) { - query := ` SELECT @@ -213,21 +187,16 @@ func (db *PostgresClient) GetUserByName(ctx context.Context, username string) (* &user.Connections, ) - if err != nil { - return nil, err - } return &user, nil - } // GetUserByInternalID retrieves a user by their internal ID from the database. func (db *PostgresClient) GetUserByInternalID(ctx context.Context, id int) (*common.User, error) { - query := ` SELECT @@ -274,21 +243,16 @@ func (db *PostgresClient) GetUserByInternalID(ctx context.Context, id int) (*com &user.Connections, ) - if err != nil { - return nil, err - } return &user, nil - } // GetChannelBlocks retrieves all blocks for a given channel from the database. func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string) *[]common.Block { - query := ` SELECT @@ -310,11 +274,8 @@ func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string ` rows, err := db.Pool.Query(ctx, query, channelID) - if err != nil { - return nil - } defer rows.Close() @@ -337,11 +298,8 @@ func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string &block.CommandName, ) - if err != nil { - return nil - } blocks = append(blocks, block) @@ -349,13 +307,11 @@ func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string } return &blocks - } // GetChannelCommands retrieves all custom channel commands for a given channel from the database. func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID string) *[]common.ChannelCommand { - query := ` SELECT @@ -411,11 +367,8 @@ func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID stri ` rows, err := db.Pool.Query(ctx, query, channelID) - if err != nil { - return nil - } defer rows.Close() @@ -472,11 +425,8 @@ func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID stri &command.Help, ) - if err != nil { - return nil - } commands = append(commands, command) @@ -484,39 +434,30 @@ func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID stri } return &commands - } // GetChannelByName retrieves a channel by its username and platform from the database. func (db *PostgresClient) GetChannelByName( - ctx context.Context, username string, platform common.Platforms, - ) (*common.Channel, error) { - return db.getChannelByType(ctx, username, platform, "NAME") - } // GetChannelByID retrieves a channel by its ID and platform from the database. func (db *PostgresClient) GetChannelByID( - ctx context.Context, channelID string, platform common.Platforms, - ) (*common.Channel, error) { - return db.getChannelByType(ctx, channelID, platform, "ID") - } func (db *PostgresClient) getChannelByType( //nolint:cyclop @@ -528,9 +469,7 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop platform common.Platforms, chanType string, - ) (*common.Channel, error) { - query := ` SELECT @@ -601,11 +540,8 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop &channel.State, ) - if err != nil { - return nil, err - } var wg sync.WaitGroup @@ -615,58 +551,44 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop var commands *[]common.ChannelCommand go func() { - defer wg.Done() cmds := db.GetChannelCommands(ctx, channel.ChannelID) if cmds != nil { - commands = cmds - } - }() var blocks []common.Block go func() { - defer wg.Done() bs := db.GetChannelBlocks(ctx, channel.ChannelID) if bs != nil { - blocks = *bs - } - }() wg.Wait() if commands != nil { - channel.Commands = commands - } else { - channel.Commands = &[]common.ChannelCommand{} - } if len(blocks) > 0 { channel.Blocks = common.FilteredBlocks{ - Users: &[]common.Block{}, Commands: &[]common.Block{}, } for _, block := range blocks { - switch block.BlockType { case common.UserBlock: @@ -682,23 +604,18 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop continue } - } } else { - channel.Blocks = common.FilteredBlocks{} - } return &channel, nil - } // GetPotatoData retrieves potato data for a user from the database. func (db *PostgresClient) GetPotatoData(ctx context.Context, username string) (*common.PotatoData, error) { - query := ` SELECT @@ -853,27 +770,20 @@ func (db *PostgresClient) GetPotatoData(ctx context.Context, username string) (* &data.NotVerbose, ) - if err != nil { - return nil, err - } return &data, nil - } // BatchUserConections retrieves user connections for a batch of user IDs from the database. func (db *PostgresClient) BatchUserConections( - ctx context.Context, ids []int, - ) *map[int][]common.UserConnection { - query := ` SELECT @@ -899,11 +809,8 @@ func (db *PostgresClient) BatchUserConections( ` rows, err := db.Pool.Query(ctx, query, ids) - if err != nil { - return nil - } defer rows.Close() @@ -930,11 +837,8 @@ func (db *PostgresClient) BatchUserConections( &connection.Meta, ) - if err != nil { - return nil - } users[connection.ID] = append(users[connection.ID], connection) @@ -942,85 +846,66 @@ func (db *PostgresClient) BatchUserConections( } return &users - } // GetRedirectByKey retrieves a URL redirect from the database by its key. func (db *PostgresClient) GetRedirectByKey(ctx context.Context, key string) (string, error) { - query := `SELECT url FROM url_redirects WHERE key = $1` var url string err := db.Pool.QueryRow(ctx, query, key).Scan(&url) - if err != nil { - return "", err - } return url, nil - } // GetKeyByRedirect retrieves the key associated with a given URL redirect from the database. func (db *PostgresClient) GetKeyByRedirect(ctx context.Context, url string) (string, error) { - query := `SELECT key FROM url_redirects WHERE url = $1` var key string err := db.Pool.QueryRow(ctx, query, url).Scan(&key) - if err != nil { - return "", err - } return key, nil - } // RedirectExists checks if a URL redirect exists in the database by its key. func (db *PostgresClient) RedirectExists(ctx context.Context, key string) bool { - query := `SELECT EXISTS(SELECT 1 FROM url_redirects WHERE key = $1)` var exists bool err := db.Pool.QueryRow(ctx, query, key).Scan(&exists) - if err != nil { - return false - } return exists - } // NewRedirect inserts a new URL redirect into the database. func (db *PostgresClient) NewRedirect(ctx context.Context, key, url string) error { - query := `INSERT INTO url_redirects (key, url) VALUES ($1, $2)` _, err := db.Pool.Exec(ctx, query, key, url) return err - } // GetHaste retrieves a hastebin text document from the database by its key. func (db *PostgresClient) GetHaste(ctx context.Context, key string) (string, error) { - query := ` UPDATE haste @@ -1036,21 +921,16 @@ func (db *PostgresClient) GetHaste(ctx context.Context, key string) (string, err var text string err := db.Pool.QueryRow(ctx, query, encode(key)).Scan(&text) - if err != nil { - return "", err - } return text, nil - } // NewHaste inserts a new compressed hastebin text document into the database. func (db *PostgresClient) NewHaste( - ctx context.Context, key string, @@ -1058,9 +938,7 @@ func (db *PostgresClient) NewHaste( text []byte, source string, - ) error { - query := ` INSERT INTO haste (key, content, source) @@ -1074,23 +952,19 @@ func (db *PostgresClient) NewHaste( _, err := db.Pool.Exec(ctx, query, encode(key), text, source) return err - } func encode(data string) string { - hash := md5.New() //nolint:gosec hash.Write([]byte(data)) return hex.EncodeToString(hash.Sum(nil)) - } // NewUpload inserts a new file into the database and returns the creation timestamp. func (db *PostgresClient) NewUpload( - ctx context.Context, key string, @@ -1100,9 +974,7 @@ func (db *PostgresClient) NewUpload( name string, mimeType string, - ) (bool, *time.Time) { - query := ` INSERT INTO file_store (file, file_name, mime_type, key) @@ -1116,7 +988,6 @@ func (db *PostgresClient) NewUpload( var createdAt time.Time err := db.Pool.QueryRow(ctx, query, file, name, mimeType, key).Scan(&createdAt) - if err != nil { logger.Error.Println("Error scanning upload", err) @@ -1126,19 +997,15 @@ func (db *PostgresClient) NewUpload( } return true, &createdAt - } // GetFileByKey retrieves a file from the database by its key. func (db *PostgresClient) GetFileByKey( - ctx context.Context, key string, - ) ([]byte, string, *string, *time.Time, error) { - query := ` SELECT file, mime_type, file_name, created_at @@ -1167,27 +1034,20 @@ func (db *PostgresClient) GetFileByKey( &createdAt, ) - if err != nil { - return nil, "", nil, nil, err - } return content, mimeType, fileName, &createdAt, nil - } // DeleteFileByKey deletes a file from the database by its key. func (db *PostgresClient) DeleteFileByKey( - ctx context.Context, key string, - ) bool { - query := ` DELETE FROM file_store @@ -1199,19 +1059,15 @@ func (db *PostgresClient) DeleteFileByKey( _, err := db.Pool.Exec(ctx, query, key) return err == nil - } // GetUploadCreatedAt retrieves the creation timestamp of an upload by its key. func (db *PostgresClient) GetUploadCreatedAt( - ctx context.Context, key string, - ) (*time.Time, error) { - query := ` SELECT created_at @@ -1225,21 +1081,16 @@ func (db *PostgresClient) GetUploadCreatedAt( var createdAt time.Time err := db.Pool.QueryRow(ctx, query, key).Scan(&createdAt) - if err != nil { - return nil, err - } return &createdAt, nil - } // GetAllChannels retrieves all channels with state = 'JOINED', ordered by username. func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelListItem, error) { - query := ` SELECT channel_id, username, platform, state @@ -1253,11 +1104,8 @@ func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelL ` rows, err := db.Pool.Query(ctx, query) - if err != nil { - return nil, err - } defer rows.Close() @@ -1269,9 +1117,7 @@ func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelL var ch common.ChannelListItem if err := rows.Scan(&ch.ChannelID, &ch.Username, &ch.Platform, &ch.State); err != nil { - return nil, err - } channels = append(channels, ch) @@ -1279,25 +1125,21 @@ func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelL } return channels, rows.Err() - } // UpdateUserSettings replaces the settings JSONB column for a user. func (db *PostgresClient) UpdateUserSettings(ctx context.Context, userID int, settings common.UserSettings) error { - query := `UPDATE users SET settings = $1 WHERE user_id = $2` _, err := db.Pool.Exec(ctx, query, settings, userID) return err - } // UpdateChannelSettings replaces the settings JSONB column for a channel. func (db *PostgresClient) UpdateChannelSettings( - ctx context.Context, channelID string, @@ -1305,29 +1147,23 @@ func (db *PostgresClient) UpdateChannelSettings( platform string, settings common.ChannelSettings, - ) error { - query := `UPDATE channels SET settings = $1 WHERE channel_id = $2 AND platform = $3` _, err := db.Pool.Exec(ctx, query, settings, channelID, platform) return err - } // GetChannelSettingsByID retrieves only the settings column for a channel by its ID. func (db *PostgresClient) GetChannelSettingsByID( - ctx context.Context, channelID string, platform string, - ) (common.ChannelSettings, error) { - var settings common.ChannelSettings err := db.Pool.QueryRow( @@ -1342,19 +1178,15 @@ func (db *PostgresClient) GetChannelSettingsByID( ).Scan(&settings) return settings, err - } // GetCommandSettings retrieves all command settings rows for a given channel. func (db *PostgresClient) GetCommandSettings( - ctx context.Context, channelID string, - ) ([]common.CommandSettings, error) { - query := ` SELECT @@ -1394,11 +1226,8 @@ func (db *PostgresClient) GetCommandSettings( ` rows, err := db.Pool.Query(ctx, query, channelID) - if err != nil { - return nil, err - } defer rows.Close() @@ -1437,9 +1266,7 @@ func (db *PostgresClient) GetCommandSettings( &cs.AmbassadorGranted, ); err != nil { - return nil, err - } results = append(results, cs) @@ -1447,13 +1274,11 @@ func (db *PostgresClient) GetCommandSettings( } return results, rows.Err() - } // UpsertCommandSettings inserts or updates a single command's settings row. func (db *PostgresClient) UpsertCommandSettings(ctx context.Context, cs common.CommandSettings) error { - query := ` INSERT INTO command_settings ( @@ -1493,9 +1318,7 @@ func (db *PostgresClient) UpsertCommandSettings(ctx context.Context, cs common.C platform := cs.Platform if platform == "" { - platform = "TWITCH" - } _, err := db.Pool.Exec( @@ -1512,13 +1335,11 @@ func (db *PostgresClient) UpsertCommandSettings(ctx context.Context, cs common.C ) return err - } // ResetCommandSettings resets a single command's overrides back to their defaults. func (db *PostgresClient) ResetCommandSettings(ctx context.Context, channelID, command string) error { - query := ` UPDATE command_settings SET @@ -1548,49 +1369,38 @@ func (db *PostgresClient) ResetCommandSettings(ctx context.Context, channelID, c _, err := db.Pool.Exec(ctx, query, channelID, command) return err - } // GetChannelAmbassadors returns the ambassadors slice for a channel, used for auth checks. func (db *PostgresClient) GetChannelAmbassadors( - ctx context.Context, channelID string, platform common.Platforms, - ) ([]string, error) { - query := `SELECT ambassadors FROM channels WHERE channel_id = $1 AND platform = $2` var ambassadors []string err := db.Pool.QueryRow(ctx, query, channelID, platform).Scan(&ambassadors) - if err != nil { - return nil, err - } return ambassadors, nil - } // GetUserReminders retrieves all pending reminders for a user on a given platform. func (db *PostgresClient) GetUserReminders( - ctx context.Context, userID string, platform common.Platforms, - ) ([]common.Reminder, error) { - query := ` SELECT @@ -1608,11 +1418,8 @@ func (db *PostgresClient) GetUserReminders( ` rows, err := db.Pool.Query(ctx, query, userID, platform) - if err != nil { - return nil, err - } defer rows.Close() @@ -1629,9 +1436,7 @@ func (db *PostgresClient) GetUserReminders( &r.Message, &r.ReadyAt, &r.SetAt, &r.AfkWithheld, &r.Status, &r.Platform, &r.SentAt, &r.Type, ); err != nil { - return nil, err - } reminders = append(reminders, r) @@ -1639,25 +1444,21 @@ func (db *PostgresClient) GetUserReminders( } return reminders, rows.Err() - } // DeleteReminder hard-deletes a reminder by ID, verifying the owner platform ID. func (db *PostgresClient) DeleteReminder(ctx context.Context, reminderID int, recipientID string) error { - query := `DELETE FROM reminders WHERE reminder_id = $1 AND recipient_id = $2` _, err := db.Pool.Exec(ctx, query, reminderID, recipientID) return err - } // UpsertOAuthToken stores or refreshes a platform OAuth token for a given user. func (db *PostgresClient) UpsertOAuthToken( - ctx context.Context, platformID string, @@ -1671,9 +1472,7 @@ func (db *PostgresClient) UpsertOAuthToken( scope []string, expiresIn int, - ) error { - query := ` INSERT INTO connection_oauth ( @@ -1704,5 +1503,4 @@ func (db *PostgresClient) UpsertOAuthToken( ) return err - } diff --git a/redirects/redirects_test.go b/redirects/redirects_test.go index 47f2d34..4003c01 100644 --- a/redirects/redirects_test.go +++ b/redirects/redirects_test.go @@ -8,7 +8,6 @@ import ( ) func TestRedirects__CheckProtocolFormatAfterProtocolReformat(t *testing.T) { - redirector := redirects{} tests := []struct { @@ -16,7 +15,6 @@ func TestRedirects__CheckProtocolFormatAfterProtocolReformat(t *testing.T) { expected string }{ - {"https://google.com", "https://google.com"}, {"http://google.com", "https://google.com"}, @@ -27,9 +25,7 @@ func TestRedirects__CheckProtocolFormatAfterProtocolReformat(t *testing.T) { } for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - cleanedURL := redirector.cleanRedirectProtocolSoLinksActuallyWork(tc.input) assert.Truef( @@ -38,9 +34,6 @@ func TestRedirects__CheckProtocolFormatAfterProtocolReformat(t *testing.T) { ) assert.Equal(t, tc.expected, cleanedURL) - }) - } - } From e63f530961a97ef3edf7cf20fa211d1d07e6a7c2 Mon Sep 17 00:00:00 2001 From: tlos Date: Mon, 18 May 2026 23:30:19 -0400 Subject: [PATCH 19/20] fix: remove block-start/end blank lines to satisfy whitespace linter --- common/db/loops.go | 117 ++++++++++++++++-------------------------- common/db/postgres.go | 23 --------- 2 files changed, 43 insertions(+), 97 deletions(-) diff --git a/common/db/loops.go b/common/db/loops.go index 45f0a66..fca9099 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -32,6 +32,7 @@ const dumpPath = "./dump" // StartLoops initializes schedules and loops for various tasks. func StartLoops( + ctx context.Context, config common.Config, @@ -43,6 +44,7 @@ func StartLoops( clickhouse *ClickhouseClient, redis *RedisClient, + ) { if !config.Loops.Enabled { return @@ -57,45 +59,41 @@ func StartLoops( go validateTokens(ctx, config, postgres) }) - if err != nil { + if err != nil { logger.Error.Println("Failed initializing cron updateHourlyUsage", err) return - } _, err = cronManager.AddFunc("@daily", func() { updateDailyUsage(ctx, postgres) }) - if err != nil { + if err != nil { logger.Error.Println("Failed initializing cron updateDailyUsage", err) return - } _, err = cronManager.AddFunc("@weekly", func() { updateWeeklyUsage(ctx, postgres) }) - if err != nil { + if err != nil { logger.Error.Println("Failed initializing cron updateWeeklyUsage", err) return - } _, err = cronManager.AddFunc("0 */2 * * *", func() { refreshAllHelixTokens(ctx, config, postgres) }) - if err != nil { + if err != nil { logger.Error.Println("Failed initializing cron refreshAllHelixTokens", err) return - } _, err = cronManager.AddFunc("*/30 * * * *", func() { @@ -107,12 +105,11 @@ func StartLoops( updateUserOwnedBadgeView(ctx, clickhouse) }) - if err != nil { + if err != nil { logger.Error.Println("Failed initializing cron clickhouse views", err) return - } _, err = cronManager.AddFunc("0 */12 * * *", func() { @@ -120,12 +117,11 @@ func StartLoops( // go optimizeClickhouse(ctx, config, clickhouse) }) - if err != nil { + if err != nil { logger.Error.Println("Failed initializing cron backupPostgres", err) return - } cronManager.Start() @@ -139,18 +135,16 @@ func StartLoops( func decrementDuels(ctx context.Context, redis *RedisClient) { for { - time.Sleep(30 * time.Minute) logger.Info.Println("Decrementing duels") keys, err := redis.Scan(ctx, "duelUse:*", 100, 0) - if err != nil { + if err != nil { logger.Error.Println("Failed scanning keys for duels", err) return - } if len(keys) == 0 { @@ -182,18 +176,17 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { ` value, err := redis.Eval(ctx, luaScript, keys).Result() + if err != nil { logger.Error.Println("Failed decrementing duels", err) } logger.Info.Printf("Decremented %d duel keys", value) - } } func deleteOldUploads(ctx context.Context, postgres *PostgresClient) { for { - time.Sleep(24 * time.Hour) logger.Info.Println("Deleting old uploads") @@ -209,18 +202,17 @@ func deleteOldUploads(ctx context.Context, postgres *PostgresClient) { ` _, err := postgres.Exec(ctx, query) + if err != nil { logger.Error.Println("Error deleting old uploads ", err) } logger.Debug.Println("Deleted old uploads") - } } func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { for { - time.Sleep(5 * time.Minute) query := ` @@ -240,12 +232,12 @@ func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { ` _, err := postgres.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating aggregate table", err) } logger.Info.Println("Updated aggregate table") - } } @@ -255,6 +247,7 @@ func updateHourlyUsage(ctx context.Context, postgres *PostgresClient) { query := `UPDATE gpt_usage SET hourly_usage = 0;` _, err := postgres.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating hourly usage", err) } @@ -268,6 +261,7 @@ func updateDailyUsage(ctx context.Context, postgres *PostgresClient) { query := `UPDATE gpt_usage SET daily_usage = 0` _, err := postgres.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating daily usage", err) } @@ -281,6 +275,7 @@ func updateWeeklyUsage(ctx context.Context, postgres *PostgresClient) { query := `UPDATE gpt_usage SET weekly_usage = 0` _, err := postgres.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating weekly usage", err) } @@ -318,6 +313,7 @@ func updateColorView(ctx context.Context, clickhouse *ClickhouseClient) { ` err := clickhouse.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating color view ", err) } @@ -327,12 +323,11 @@ func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { logger.Info.Println("Updating active badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_active_badge_stats;`) - if err != nil { + if err != nil { logger.Error.Println("Error truncating badge stats table ", err) return - } query := ` @@ -356,6 +351,7 @@ func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { ` err = clickhouse.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating badge view ", err) } @@ -387,18 +383,16 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { // err := clickhouse.Exec(ctx, prepare) // if err != nil { - // logger.Error.Println("Error preparing badge view ", err) // } err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_stats;`) - if err != nil { + if err != nil { logger.Error.Println("Error truncating badge stats table ", err) return - } query := ` @@ -422,6 +416,7 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { ` err = clickhouse.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating badge view ", err) } @@ -431,12 +426,11 @@ func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) logger.Info.Println("Updating user owned badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_user_stats;`) - if err != nil { + if err != nil { logger.Error.Println("Error truncating badge stats table ", err) return - } query := ` @@ -466,12 +460,14 @@ func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) ` err = clickhouse.Exec(ctx, query) + if err != nil { logger.Error.Println("Error updating user owned badge view ", err) } } func upsertOAuthToken( + ctx context.Context, postgres *PostgresClient, @@ -479,6 +475,7 @@ func upsertOAuthToken( oauth *common.GenericOAUTHResponse, con common.PlatformOauth, + ) error { query := ` @@ -545,6 +542,7 @@ func upsertOAuthToken( } func refreshOrDelete( + ctx context.Context, config common.Config, @@ -552,6 +550,7 @@ func refreshOrDelete( postgres *PostgresClient, con common.PlatformOauth, + ) (bool, error) { var err error @@ -566,15 +565,14 @@ func refreshOrDelete( } err = upsertOAuthToken(ctx, postgres, refreshResult, con) - if err != nil { + if err != nil { logger.Error.Println( "Error updating token for user_id", con.PlatformID, ":", err, ) return false, err - } return true, nil @@ -600,12 +598,11 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre ` rows, err := postgres.Query(ctx, query) - if err != nil { + if err != nil { logger.Error.Println("Error getting tokens ", err) return - } defer rows.Close() @@ -613,38 +610,33 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre validated, deleted := 0, 0 for rows.Next() { - var con common.PlatformOauth err := rows.Scan(&con.AccessToken, &con.PlatformID, &con.RefreshToken) - if err != nil { + if err != nil { logger.Error.Println("Error scanning token: ", err) continue - } valid, _, err := utils.ValidateHelixToken(ctx, con.AccessToken, false) - if err != nil { + if err != nil { logger.Error.Println("Error validating token ", err) continue - } if !valid { - ok, err := refreshOrDelete(ctx, config, postgres, con) - if err != nil { + if err != nil { logger.Error.Println("Error refreshing token ", err) deleted++ continue - } if ok { @@ -654,13 +646,11 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre } continue - } validated++ time.Sleep(200 * time.Millisecond) - } logger.Info.Printf( @@ -701,12 +691,11 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * ` rows, err := postgres.Query(ctx, query) - if err != nil { + if err != nil { logger.Error.Println("Error getting tokens ", err) return - } defer rows.Close() @@ -714,7 +703,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * refreshed, failed := 0, 0 for rows.Next() { - var con common.PlatformOauth err := rows.Scan( @@ -733,23 +721,21 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * &con.Scope, ) - if err != nil { + if err != nil { logger.Error.Println("Error scanning token: ", err) continue - } ok, err := refreshOrDelete(ctx, config, postgres, con) - if err != nil { + if err != nil { logger.Error.Println("Error refreshing token ", err) failed++ continue - } if ok { @@ -761,7 +747,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * time.Sleep(200 * time.Millisecond) continue - } logger.Info.Printf( @@ -804,18 +789,18 @@ func deleteOldDumps(files []string, maxSize int) { filesToDelete := files[:len(files)-maxSize+1] for _, file := range filesToDelete { - err := os.Remove(file) + if err != nil { logger.Error.Println("Failed deleting dump file ", err) } - } logger.Info.Printf("Deleted %d old dump files", len(filesToDelete)) } func backupPostgres( + ctx context.Context, postgres *PostgresClient, @@ -823,24 +808,22 @@ func backupPostgres( natsClient *utils.NatsClient, config common.Config, + ) { logger.Debug.Println("Backing up Postgres") if err := os.MkdirAll(dumpPath, 0o750); err != nil { - logger.Error.Println("Failed to create backup folder:", err) return - } files, err := filepath.Glob(filepath.Join(dumpPath, "*.sql.zst")) - if err != nil { + if err != nil { logger.Error.Println("Failed to list dump files:", err) return - } deleteOldDumps(files, 10) @@ -853,7 +836,6 @@ func backupPostgres( ) //nolint:gosec - cmd := exec.CommandContext(ctx, "sh", "-c", fmt.Sprintf( "PGPASSWORD=%s pg_dump -d %s -U %s -h %s | zstd -3 --threads=%d > %s", @@ -873,13 +855,11 @@ func backupPostgres( defer func() { if err = cmd.Process.Release(); err != nil { - logger.Error.Println("Failed to release pg_dump process:", err) if err = cmd.Process.Kill(); err != nil { logger.Error.Fatalln("Failed to kill pg_dump process:", err) } - } }() @@ -890,33 +870,29 @@ func backupPostgres( start := time.Now() if err = cmd.Run(); err != nil { - logger.Error.Println("Failed to execute pg_dump:", err, stderr.String()) return - } duration := time.Since(start) stat, err := os.Stat(filePath) - if err != nil { + if err != nil { logger.Error.Println("Failed to get backup file size:", err) return - } backupSize := float64(stat.Size()) / (1024 * 1024 * 1024) dbSize, err := getDatabaseSize(ctx, postgres, config.Postgres.Database) - if err != nil { + if err != nil { logger.Error.Println("Failed to get database size:", err) return - } message := fmt.Sprintf( @@ -931,14 +907,13 @@ func backupPostgres( ) jsonMessage, err := json.Marshal(message) - if err != nil { + if err != nil { logger.Error.Println("Failed to JSON stringify message:", err) logger.Info.Println(message) return - } if natsClient != nil { @@ -950,6 +925,7 @@ func backupPostgres( func publishBackupNotification(natsClient *utils.NatsClient, message []byte) { err := natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", message) + if err != nil { logger.Error.Println("Failed to publish to queue:", err) } @@ -959,6 +935,7 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin query := `SELECT pg_size_pretty(pg_database_size($1)) AS size` rows, err := postgres.Query(ctx, query, dbName) + if err != nil { return "", err } @@ -966,7 +943,6 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin defer rows.Close() if rows.Next() { - var size string if err := rows.Scan(&size); err != nil { @@ -974,14 +950,12 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin } return size, nil - } return "", errNoRows } // func optimizeClickhouse(ctx context.Context, config common.Config, clickhouse *ClickhouseClient) { - // // offset any concurrent crons // time.Sleep(5 * time.Minute) @@ -989,7 +963,6 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin // logger.Info.Println("Optimizing Clickhouse tables") // if config.Clickhouse.Database == "" { - // logger.Error.Println("Clickhouse database is not configured") // return @@ -1001,7 +974,6 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin // rows, err := clickhouse.Query(ctx, query, config.Clickhouse.Database) // if err != nil { - // logger.Error.Println("Failed to query Clickhouse tables:", err) // return @@ -1009,11 +981,9 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin // } // for rows.Next() { - // var table string // if err := rows.Scan(&table); err != nil { - // logger.Error.Println("Failed to scan Clickhouse table:", err) // continue @@ -1023,7 +993,6 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin // query := fmt.Sprintf("OPTIMIZE TABLE %s.%s FINAL", config.Clickhouse.Database, table) // if err := clickhouse.Exec(ctx, query); err != nil { - // logger.Error.Println("Failed to optimize Clickhouse table:", err) // } diff --git a/common/db/postgres.go b/common/db/postgres.go index 55c1570..a266fc0 100644 --- a/common/db/postgres.go +++ b/common/db/postgres.go @@ -59,7 +59,6 @@ func InitPostgres(ctx context.Context, config common.Config) (*PostgresClient, e } func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unparam - user := config.Postgres.User if user == "" { @@ -101,11 +100,9 @@ func loadConfig(config common.Config) (*pgxpool.Config, error) { //nolint:unpara dbConfig, err := pgxpool.ParseConfig(constring) if err != nil { - logger.Error.Panicln("Error parsing database config", err) return nil, err - } dbConfig.MaxConns = 32 @@ -283,7 +280,6 @@ func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string var blocks []common.Block for rows.Next() { - var block common.Block err := rows.Scan( @@ -303,7 +299,6 @@ func (db *PostgresClient) GetChannelBlocks(ctx context.Context, channelID string } blocks = append(blocks, block) - } return &blocks @@ -376,7 +371,6 @@ func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID stri var commands []common.ChannelCommand for rows.Next() { - var command common.ChannelCommand err := rows.Scan( @@ -430,7 +424,6 @@ func (db *PostgresClient) GetChannelCommands(ctx context.Context, channelID stri } commands = append(commands, command) - } return &commands @@ -499,7 +492,6 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop ` switch chanType { - case "ID": query += `WHERE c.channel_id = $1 ` @@ -511,7 +503,6 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop default: return nil, errInvalidType - } query += `AND platform = $2;` @@ -581,7 +572,6 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop } if len(blocks) > 0 { - channel.Blocks = common.FilteredBlocks{ Users: &[]common.Block{}, @@ -590,7 +580,6 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop for _, block := range blocks { switch block.BlockType { - case common.UserBlock: *channel.Blocks.Users = append(*channel.Blocks.Users, block) @@ -602,10 +591,8 @@ func (db *PostgresClient) getChannelByType( //nolint:cyclop case common.GlobalBlock: continue - } } - } else { channel.Blocks = common.FilteredBlocks{} } @@ -818,7 +805,6 @@ func (db *PostgresClient) BatchUserConections( users := make(map[int][]common.UserConnection) for rows.Next() { - var connection common.UserConnection err := rows.Scan( @@ -842,7 +828,6 @@ func (db *PostgresClient) BatchUserConections( } users[connection.ID] = append(users[connection.ID], connection) - } return &users @@ -989,11 +974,9 @@ func (db *PostgresClient) NewUpload( err := db.Pool.QueryRow(ctx, query, file, name, mimeType, key).Scan(&createdAt) if err != nil { - logger.Error.Println("Error scanning upload", err) return false, nil - } return true, &createdAt @@ -1113,7 +1096,6 @@ func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelL var channels []common.ChannelListItem for rows.Next() { - var ch common.ChannelListItem if err := rows.Scan(&ch.ChannelID, &ch.Username, &ch.Platform, &ch.State); err != nil { @@ -1121,7 +1103,6 @@ func (db *PostgresClient) GetAllChannels(ctx context.Context) ([]common.ChannelL } channels = append(channels, ch) - } return channels, rows.Err() @@ -1235,7 +1216,6 @@ func (db *PostgresClient) GetCommandSettings( var results []common.CommandSettings for rows.Next() { - var cs common.CommandSettings if err := rows.Scan( @@ -1270,7 +1250,6 @@ func (db *PostgresClient) GetCommandSettings( } results = append(results, cs) - } return results, rows.Err() @@ -1427,7 +1406,6 @@ func (db *PostgresClient) GetUserReminders( var reminders []common.Reminder for rows.Next() { - var r common.Reminder if err := rows.Scan( @@ -1440,7 +1418,6 @@ func (db *PostgresClient) GetUserReminders( } reminders = append(reminders, r) - } return reminders, rows.Err() From b12f6250ba1fc9bea934b1c96deb276993bdc60a Mon Sep 17 00:00:00 2001 From: tlos Date: Mon, 18 May 2026 23:46:02 -0400 Subject: [PATCH 20/20] fix: format loops.go with gofumpt v0.9.2 --- common/db/loops.go | 43 ------------------------------------------- 1 file changed, 43 deletions(-) diff --git a/common/db/loops.go b/common/db/loops.go index fca9099..90e03d4 100644 --- a/common/db/loops.go +++ b/common/db/loops.go @@ -32,7 +32,6 @@ const dumpPath = "./dump" // StartLoops initializes schedules and loops for various tasks. func StartLoops( - ctx context.Context, config common.Config, @@ -44,7 +43,6 @@ func StartLoops( clickhouse *ClickhouseClient, redis *RedisClient, - ) { if !config.Loops.Enabled { return @@ -59,7 +57,6 @@ func StartLoops( go validateTokens(ctx, config, postgres) }) - if err != nil { logger.Error.Println("Failed initializing cron updateHourlyUsage", err) @@ -69,7 +66,6 @@ func StartLoops( _, err = cronManager.AddFunc("@daily", func() { updateDailyUsage(ctx, postgres) }) - if err != nil { logger.Error.Println("Failed initializing cron updateDailyUsage", err) @@ -79,7 +75,6 @@ func StartLoops( _, err = cronManager.AddFunc("@weekly", func() { updateWeeklyUsage(ctx, postgres) }) - if err != nil { logger.Error.Println("Failed initializing cron updateWeeklyUsage", err) @@ -89,7 +84,6 @@ func StartLoops( _, err = cronManager.AddFunc("0 */2 * * *", func() { refreshAllHelixTokens(ctx, config, postgres) }) - if err != nil { logger.Error.Println("Failed initializing cron refreshAllHelixTokens", err) @@ -105,7 +99,6 @@ func StartLoops( updateUserOwnedBadgeView(ctx, clickhouse) }) - if err != nil { logger.Error.Println("Failed initializing cron clickhouse views", err) @@ -117,7 +110,6 @@ func StartLoops( // go optimizeClickhouse(ctx, config, clickhouse) }) - if err != nil { logger.Error.Println("Failed initializing cron backupPostgres", err) @@ -140,7 +132,6 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { logger.Info.Println("Decrementing duels") keys, err := redis.Scan(ctx, "duelUse:*", 100, 0) - if err != nil { logger.Error.Println("Failed scanning keys for duels", err) @@ -176,7 +167,6 @@ func decrementDuels(ctx context.Context, redis *RedisClient) { ` value, err := redis.Eval(ctx, luaScript, keys).Result() - if err != nil { logger.Error.Println("Failed decrementing duels", err) } @@ -202,7 +192,6 @@ func deleteOldUploads(ctx context.Context, postgres *PostgresClient) { ` _, err := postgres.Exec(ctx, query) - if err != nil { logger.Error.Println("Error deleting old uploads ", err) } @@ -232,7 +221,6 @@ func updateAggregateTable(ctx context.Context, postgres *PostgresClient) { ` _, err := postgres.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating aggregate table", err) } @@ -247,7 +235,6 @@ func updateHourlyUsage(ctx context.Context, postgres *PostgresClient) { query := `UPDATE gpt_usage SET hourly_usage = 0;` _, err := postgres.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating hourly usage", err) } @@ -261,7 +248,6 @@ func updateDailyUsage(ctx context.Context, postgres *PostgresClient) { query := `UPDATE gpt_usage SET daily_usage = 0` _, err := postgres.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating daily usage", err) } @@ -275,7 +261,6 @@ func updateWeeklyUsage(ctx context.Context, postgres *PostgresClient) { query := `UPDATE gpt_usage SET weekly_usage = 0` _, err := postgres.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating weekly usage", err) } @@ -313,7 +298,6 @@ func updateColorView(ctx context.Context, clickhouse *ClickhouseClient) { ` err := clickhouse.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating color view ", err) } @@ -323,7 +307,6 @@ func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { logger.Info.Println("Updating active badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_active_badge_stats;`) - if err != nil { logger.Error.Println("Error truncating badge stats table ", err) @@ -351,7 +334,6 @@ func updateActiveBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { ` err = clickhouse.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating badge view ", err) } @@ -388,7 +370,6 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { // } err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_stats;`) - if err != nil { logger.Error.Println("Error truncating badge stats table ", err) @@ -416,7 +397,6 @@ func updateOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) { ` err = clickhouse.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating badge view ", err) } @@ -426,7 +406,6 @@ func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) logger.Info.Println("Updating user owned badge view") err := clickhouse.Exec(ctx, `TRUNCATE TABLE potatbotat.twitch_owned_badge_user_stats;`) - if err != nil { logger.Error.Println("Error truncating badge stats table ", err) @@ -460,14 +439,12 @@ func updateUserOwnedBadgeView(ctx context.Context, clickhouse *ClickhouseClient) ` err = clickhouse.Exec(ctx, query) - if err != nil { logger.Error.Println("Error updating user owned badge view ", err) } } func upsertOAuthToken( - ctx context.Context, postgres *PostgresClient, @@ -475,7 +452,6 @@ func upsertOAuthToken( oauth *common.GenericOAUTHResponse, con common.PlatformOauth, - ) error { query := ` @@ -542,7 +518,6 @@ func upsertOAuthToken( } func refreshOrDelete( - ctx context.Context, config common.Config, @@ -550,7 +525,6 @@ func refreshOrDelete( postgres *PostgresClient, con common.PlatformOauth, - ) (bool, error) { var err error @@ -565,7 +539,6 @@ func refreshOrDelete( } err = upsertOAuthToken(ctx, postgres, refreshResult, con) - if err != nil { logger.Error.Println( @@ -598,7 +571,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre ` rows, err := postgres.Query(ctx, query) - if err != nil { logger.Error.Println("Error getting tokens ", err) @@ -613,7 +585,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre var con common.PlatformOauth err := rows.Scan(&con.AccessToken, &con.PlatformID, &con.RefreshToken) - if err != nil { logger.Error.Println("Error scanning token: ", err) @@ -621,7 +592,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre } valid, _, err := utils.ValidateHelixToken(ctx, con.AccessToken, false) - if err != nil { logger.Error.Println("Error validating token ", err) @@ -630,7 +600,6 @@ func validateTokens(ctx context.Context, config common.Config, postgres *Postgre if !valid { ok, err := refreshOrDelete(ctx, config, postgres, con) - if err != nil { logger.Error.Println("Error refreshing token ", err) @@ -691,7 +660,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * ` rows, err := postgres.Query(ctx, query) - if err != nil { logger.Error.Println("Error getting tokens ", err) @@ -721,7 +689,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * &con.Scope, ) - if err != nil { logger.Error.Println("Error scanning token: ", err) @@ -729,7 +696,6 @@ func refreshAllHelixTokens(ctx context.Context, config common.Config, postgres * } ok, err := refreshOrDelete(ctx, config, postgres, con) - if err != nil { logger.Error.Println("Error refreshing token ", err) @@ -790,7 +756,6 @@ func deleteOldDumps(files []string, maxSize int) { for _, file := range filesToDelete { err := os.Remove(file) - if err != nil { logger.Error.Println("Failed deleting dump file ", err) } @@ -800,7 +765,6 @@ func deleteOldDumps(files []string, maxSize int) { } func backupPostgres( - ctx context.Context, postgres *PostgresClient, @@ -808,7 +772,6 @@ func backupPostgres( natsClient *utils.NatsClient, config common.Config, - ) { logger.Debug.Println("Backing up Postgres") @@ -819,7 +782,6 @@ func backupPostgres( } files, err := filepath.Glob(filepath.Join(dumpPath, "*.sql.zst")) - if err != nil { logger.Error.Println("Failed to list dump files:", err) @@ -878,7 +840,6 @@ func backupPostgres( duration := time.Since(start) stat, err := os.Stat(filePath) - if err != nil { logger.Error.Println("Failed to get backup file size:", err) @@ -888,7 +849,6 @@ func backupPostgres( backupSize := float64(stat.Size()) / (1024 * 1024 * 1024) dbSize, err := getDatabaseSize(ctx, postgres, config.Postgres.Database) - if err != nil { logger.Error.Println("Failed to get database size:", err) @@ -907,7 +867,6 @@ func backupPostgres( ) jsonMessage, err := json.Marshal(message) - if err != nil { logger.Error.Println("Failed to JSON stringify message:", err) @@ -925,7 +884,6 @@ func backupPostgres( func publishBackupNotification(natsClient *utils.NatsClient, message []byte) { err := natsClient.Publish("github.com/Potat-Industries/potat-api.postgres-backup", message) - if err != nil { logger.Error.Println("Failed to publish to queue:", err) } @@ -935,7 +893,6 @@ func getDatabaseSize(ctx context.Context, postgres *PostgresClient, dbName strin query := `SELECT pg_size_pretty(pg_database_size($1)) AS size` rows, err := postgres.Query(ctx, query, dbName) - if err != nil { return "", err }