Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func StartServing(
postgres *db.PostgresClient,
redis *db.RedisClient,
clickhouse *db.ClickhouseClient,
nats *utils.NatsClient,
metrics *utils.Metrics,
) error {
if config.API.Host == "" || config.API.Port == "" {
Expand All @@ -54,7 +55,7 @@ func StartServing(
}

api.router.Use(middleware.LogRequest(metrics))
api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse))
api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse, nats))
api.router.Use(middleware.NewRateLimiter(100, 1*time.Minute, redis))

authenticator := middleware.NewAuthenticator(config.Twitch.ClientSecret, GenericResponse)
Expand Down
46 changes: 46 additions & 0 deletions api/middleware/authorize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package middleware

import (
"net/http"
"slices"

"github.com/Potat-Industries/potat-api/common"
"github.com/Potat-Industries/potat-api/common/db"
)

// GetTwitchPlatformID returns the Twitch platform user ID from a user's connections, or empty string if not found.
func GetTwitchPlatformID(user *common.User) string {
for _, conn := range user.Connections {
if conn.Platform == common.TWITCH {
return conn.UserID
}
}

return ""
}

// IsChannelAuthorized returns true if the user is an admin, the broadcaster of the given channel,
// or a channel ambassador. It is used as a shared auth check across channel-scoped write routes.
func IsChannelAuthorized(
request *http.Request,
user *common.User,
channelID string,
postgres *db.PostgresClient,
) bool {
if user.Level >= int(common.ADMIN) {
return true
}

twitchID := GetTwitchPlatformID(user)

if twitchID == channelID {
return true
}

ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH)
if err != nil {
return false
}

return slices.Contains(ambassadors, twitchID)
}
4 changes: 4 additions & 0 deletions api/middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"

"github.com/Potat-Industries/potat-api/common/db"
"github.com/Potat-Industries/potat-api/common/utils"
)

// ErrMissingContext is returned when a database client is not found in the request context.
Expand All @@ -18,19 +19,22 @@ const (
PostgresKey contextKey = "postgres"
RedisKey contextKey = "redis"
ClickhouseKey contextKey = "clickhouse"
NatsKey contextKey = "nats"
)

// InjectDatabases returns a middleware that injects DB clients into the request context.
func InjectDatabases(
postgres *db.PostgresClient,
redis *db.RedisClient,
clickhouse *db.ClickhouseClient,
nats *utils.NatsClient,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), PostgresKey, postgres)
ctx = context.WithValue(ctx, RedisKey, redis)
ctx = context.WithValue(ctx, ClickhouseKey, clickhouse)
ctx = context.WithValue(ctx, NatsKey, nats)

next.ServeHTTP(w, r.WithContext(ctx))
})
Expand Down
96 changes: 96 additions & 0 deletions api/routes/del/command_settings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Package del contains routes for http.MethodDelete requests.
package del

import (
"net/http"
"time"

"github.com/Potat-Industries/potat-api/api"
"github.com/Potat-Industries/potat-api/api/middleware"
"github.com/Potat-Industries/potat-api/common"
"github.com/Potat-Industries/potat-api/common/db"
"github.com/Potat-Industries/potat-api/common/logger"
)

func init() {
api.SetRoute(api.Route{
Path: "/channel/command-settings",
Method: http.MethodDelete,
Handler: deleteCommandSettings,
UseAuth: true,
})
}

func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop
start := time.Now()

user, ok := request.Context().Value(middleware.AuthedUser).(*common.User)
if !ok || user == nil {
api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}},
}, start)

return
}

postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient)
if !pgOK {
logger.Error.Println("Postgres client not found in context")
api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}},
}, start)

return
}

channelID := request.URL.Query().Get("id")
if channelID == "" {
for _, conn := range user.Connections {
if conn.Platform == common.TWITCH {
channelID = conn.UserID

break
}
}
}

command := request.URL.Query().Get("command")

if channelID == "" {
api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "channel id is required"}},
}, start)

return
}

if !middleware.IsChannelAuthorized(request, user, channelID, postgres) {
api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Forbidden"}},
}, start)

return
}

// command is required — this endpoint resets a single command override back to defaults.
if command == "" {
api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Missing required field: command"}},
}, start)

return
}

if err := postgres.ResetCommandSettings(request.Context(), channelID, command); err != nil {
logger.Error.Printf("Error resetting command settings: %v", err)
api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Failed to reset command settings"}},
}, start)

return
}

api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{
Data: &[]any{},
}, start)
}
82 changes: 82 additions & 0 deletions api/routes/get/ambassadors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Package get contains routes for http.MethodGet requests.
package get

import (
"net/http"
"time"

"github.com/Potat-Industries/potat-api/api"
"github.com/Potat-Industries/potat-api/api/middleware"
"github.com/Potat-Industries/potat-api/common"
"github.com/Potat-Industries/potat-api/common/db"
"github.com/Potat-Industries/potat-api/common/logger"
)

// AmbassadorsResponse is the response type for GET /channel/ambassadors.
type AmbassadorsResponse = common.GenericResponse[string]

func init() {
api.SetRoute(api.Route{
Path: "/channel/ambassadors",
Method: http.MethodGet,
Handler: getAmbassadorsHandler,
UseAuth: true,
})
}

func getAmbassadorsHandler(writer http.ResponseWriter, request *http.Request) {
start := time.Now()

user, ok := request.Context().Value(middleware.AuthedUser).(*common.User)
if !ok || user == nil {
api.GenericResponse(writer, http.StatusUnauthorized, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}},
}, start)

return
}

postgres, ok := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient)
if !ok {
api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}},
}, start)

return
}

channelID := resolveChannelID(request, user)
if channelID == "" {
api.GenericResponse(writer, http.StatusBadRequest, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Could not resolve channel ID"}},
}, start)

return
}

if !middleware.IsChannelAuthorized(request, user, channelID, postgres) {
api.GenericResponse(writer, http.StatusForbidden, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Forbidden"}},
}, start)

return
}

ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH)
if err != nil {
logger.Error.Printf("Error fetching ambassadors: %v", err)
api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Failed to fetch ambassadors"}},
}, start)

return
}

if ambassadors == nil {
ambassadors = []string{}
}

api.GenericResponse(writer, http.StatusOK, AmbassadorsResponse{
Data: &ambassadors,
}, start)
}
Loading
Loading