Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func StartServing(
postgres *db.PostgresClient,
redis *db.RedisClient,
clickhouse *db.ClickhouseClient,
nats *utils.NatsClient,
metrics *utils.Metrics,
) error {
if config.API.Host == "" || config.API.Port == "" {
Expand All @@ -54,7 +55,7 @@ func StartServing(
}

api.router.Use(middleware.LogRequest(metrics))
api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse))
api.router.Use(middleware.InjectDatabases(postgres, redis, clickhouse, nats))
api.router.Use(middleware.NewRateLimiter(100, 1*time.Minute, redis))

authenticator := middleware.NewAuthenticator(config.Twitch.ClientSecret, GenericResponse)
Expand Down
46 changes: 46 additions & 0 deletions api/middleware/authorize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package middleware

import (
"net/http"
"slices"

"github.com/Potat-Industries/potat-api/common"
"github.com/Potat-Industries/potat-api/common/db"
)

// GetTwitchPlatformID returns the Twitch platform user ID from a user's connections, or empty string if not found.
func GetTwitchPlatformID(user *common.User) string {
for _, conn := range user.Connections {
if conn.Platform == common.TWITCH {
return conn.UserID
}
}

return ""
}

// IsChannelAuthorized returns true if the user is an admin, the broadcaster of the given channel,
// or a channel ambassador. It is used as a shared auth check across channel-scoped write routes.
func IsChannelAuthorized(
request *http.Request,
user *common.User,
channelID string,
postgres *db.PostgresClient,
) bool {
if user.Level >= int(common.ADMIN) {
return true
}

twitchID := GetTwitchPlatformID(user)

if twitchID == channelID {
return true
}

ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH)
if err != nil {
return false
}

return slices.Contains(ambassadors, twitchID)
}
4 changes: 4 additions & 0 deletions api/middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"

"github.com/Potat-Industries/potat-api/common/db"
"github.com/Potat-Industries/potat-api/common/utils"
)

// ErrMissingContext is returned when a database client is not found in the request context.
Expand All @@ -18,19 +19,22 @@ const (
PostgresKey contextKey = "postgres"
RedisKey contextKey = "redis"
ClickhouseKey contextKey = "clickhouse"
NatsKey contextKey = "nats"
)

// InjectDatabases returns a middleware that injects DB clients into the request context.
func InjectDatabases(
postgres *db.PostgresClient,
redis *db.RedisClient,
clickhouse *db.ClickhouseClient,
nats *utils.NatsClient,
) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), PostgresKey, postgres)
ctx = context.WithValue(ctx, RedisKey, redis)
ctx = context.WithValue(ctx, ClickhouseKey, clickhouse)
ctx = context.WithValue(ctx, NatsKey, nats)

next.ServeHTTP(w, r.WithContext(ctx))
})
Expand Down
98 changes: 98 additions & 0 deletions api/routes/del/command_settings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Package del contains routes for http.MethodDelete requests.
package del

import (
"net/http"
"time"

"github.com/Potat-Industries/potat-api/api"
"github.com/Potat-Industries/potat-api/api/middleware"
"github.com/Potat-Industries/potat-api/common"
"github.com/Potat-Industries/potat-api/common/db"
"github.com/Potat-Industries/potat-api/common/logger"
)

func init() {
api.SetRoute(api.Route{
Path: "/channel/command-settings",
Method: http.MethodDelete,
Handler: deleteCommandSettings,
UseAuth: true,
})
}

func deleteCommandSettings(writer http.ResponseWriter, request *http.Request) { //nolint:cyclop
start := time.Now()

user, ok := request.Context().Value(middleware.AuthedUser).(*common.User)
if !ok || user == nil {
api.GenericResponse(writer, http.StatusUnauthorized, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}},
}, start)

return
}

postgres, pgOK := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient)
if !pgOK {
logger.Error.Println("Postgres client not found in context")
api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}},
}, start)

return
}

channelID := request.URL.Query().Get("id")
if channelID == "" {
for _, conn := range user.Connections {
if conn.Platform == common.TWITCH {
channelID = conn.UserID

break
}
}
}

command := request.URL.Query().Get("command")

if channelID == "" {
api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "channel id is required"}},
}, start)

return
}

if !middleware.IsChannelAuthorized(request, user, channelID, postgres) {
api.GenericResponse(writer, http.StatusForbidden, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Forbidden"}},
}, start)

return
}

// command is required — this endpoint resets a single command override back to defaults.
if command == "" {
api.GenericResponse(writer, http.StatusBadRequest, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Missing required field: command"}},
}, start)

return
}

if err := postgres.ResetCommandSettings(request.Context(), channelID, command); err != nil {
logger.Error.Printf("Error resetting command settings: %v", err)
api.GenericResponse(writer, http.StatusInternalServerError, common.GenericResponse[any]{
Errors: &[]common.ErrorMessage{{Message: "Failed to reset command settings"}},
}, start)

return
}

api.GenericResponse(writer, http.StatusOK, common.GenericResponse[any]{
Data: &[]any{},
}, start)
}


82 changes: 82 additions & 0 deletions api/routes/get/ambassadors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Package get contains routes for http.MethodGet requests.
package get

import (
"net/http"
"time"

"github.com/Potat-Industries/potat-api/api"
"github.com/Potat-Industries/potat-api/api/middleware"
"github.com/Potat-Industries/potat-api/common"
"github.com/Potat-Industries/potat-api/common/db"
"github.com/Potat-Industries/potat-api/common/logger"
)

// AmbassadorsResponse is the response type for GET /channel/ambassadors.
type AmbassadorsResponse = common.GenericResponse[string]

func init() {
api.SetRoute(api.Route{
Path: "/channel/ambassadors",
Method: http.MethodGet,
Handler: getAmbassadorsHandler,
UseAuth: true,
})
}

func getAmbassadorsHandler(writer http.ResponseWriter, request *http.Request) {
start := time.Now()

user, ok := request.Context().Value(middleware.AuthedUser).(*common.User)
if !ok || user == nil {
api.GenericResponse(writer, http.StatusUnauthorized, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Unauthorized"}},
}, start)

return
}

postgres, ok := request.Context().Value(middleware.PostgresKey).(*db.PostgresClient)
if !ok {
api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Internal Server Error"}},
}, start)

return
}

channelID := resolveChannelID(request, user)
if channelID == "" {
api.GenericResponse(writer, http.StatusBadRequest, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Could not resolve channel ID"}},
}, start)

return
}

if !middleware.IsChannelAuthorized(request, user, channelID, postgres) {
api.GenericResponse(writer, http.StatusForbidden, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Forbidden"}},
}, start)

return
}

ambassadors, err := postgres.GetChannelAmbassadors(request.Context(), channelID, common.TWITCH)
if err != nil {
logger.Error.Printf("Error fetching ambassadors: %v", err)
api.GenericResponse(writer, http.StatusInternalServerError, AmbassadorsResponse{
Errors: &[]common.ErrorMessage{{Message: "Failed to fetch ambassadors"}},
}, start)

return
}

if ambassadors == nil {
ambassadors = []string{}
}

api.GenericResponse(writer, http.StatusOK, AmbassadorsResponse{
Data: &ambassadors,
}, start)
}
Loading
Loading