Skip to content

Commit 5a977fd

Browse files
authored
Merge branch 'master' into scim
2 parents 3149241 + a89a0b0 commit 5a977fd

27 files changed

+1268
-230
lines changed

Makefile

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
.PHONY: all build deps dev-deps image migrate test vet sec format unused
1+
.PHONY: all build deps image migrate test vet sec format unused
2+
.PHONY: check-exhaustive check-gosec check-oapi-codegen check-staticcheck
23
CHECK_FILES?=./...
34

45
ifdef RELEASE_VERSION
@@ -33,13 +34,6 @@ build-strip: deps ## Build a stripped binary, for which the version file needs t
3334
CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build \
3435
$(FLAGS) -ldflags "-s -w" -o auth-arm64-strip
3536

36-
dev-deps: ## Install developer dependencies
37-
@go install github.com/gobuffalo/pop/soda@latest
38-
@go install github.com/securego/gosec/v2/cmd/gosec@latest
39-
@go install honnef.co/go/tools/cmd/staticcheck@latest
40-
@go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@latest
41-
@go install github.com/nishanths/exhaustive/cmd/exhaustive@latest
42-
4337
deps: ## Install dependencies.
4438
@go mod download
4539
@go mod verify
@@ -57,26 +51,40 @@ test: build ## Run tests.
5751
vet: # Vet the code
5852
go vet $(CHECK_FILES)
5953

60-
sec: dev-deps # Check for security vulnerabilities
54+
sec: check-gosec # Check for security vulnerabilities
6155
gosec -quiet -exclude-generated $(CHECK_FILES)
6256
gosec -quiet -tests -exclude-generated -exclude=G104 $(CHECK_FILES)
6357

64-
unused: dev-deps # Look for unused code
58+
check-gosec:
59+
@command -v gosec >/dev/null 2>&1 \
60+
|| @go install github.com/securego/gosec/v2/cmd/gosec@latest
61+
62+
unused: | check-staticcheck # Look for unused code
6563
@echo "Unused code:"
6664
staticcheck -checks U1000 $(CHECK_FILES)
67-
6865
@echo
69-
7066
@echo "Code used only in _test.go (do move it in those files):"
7167
staticcheck -checks U1000 -tests=false $(CHECK_FILES)
7268

73-
static: dev-deps
69+
static: | check-staticcheck check-exhaustive
7470
staticcheck ./...
7571
exhaustive ./...
7672

77-
generate: dev-deps
73+
check-staticcheck:
74+
@command -v staticcheck >/dev/null 2>&1 \
75+
|| @go install honnef.co/go/tools/cmd/staticcheck@latest
76+
77+
check-exhaustive:
78+
@command -v exhaustive >/dev/null 2>&1 \
79+
|| @go install github.com/nishanths/exhaustive/cmd/exhaustive@latest
80+
81+
generate: | check-oapi-codegen
7882
go generate ./...
7983

84+
check-oapi-codegen:
85+
@command -v oapi-codegen >/dev/null 2>&1 \
86+
|| go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@latest
87+
8088
dev: ## Run the development containers
8189
${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) up
8290

internal/api/api.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,6 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
9393
version: version,
9494
}
9595

96-
// Only initialize OAuth server if enabled
97-
if globalConfig.OAuthServer.Enabled {
98-
api.oauthServer = oauthserver.NewServer(globalConfig, db)
99-
}
100-
10196
for _, o := range opt {
10297
o.apply(api)
10398
}
@@ -123,6 +118,11 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
123118
// Connect token service to API's time function (supports test overrides)
124119
api.tokenService.SetTimeFunc(api.Now)
125120

121+
// Initialize OAuth server (only if enabled)
122+
if globalConfig.OAuthServer.Enabled {
123+
api.oauthServer = oauthserver.NewServer(globalConfig, db, api.tokenService)
124+
}
125+
126126
if api.config.Password.HIBP.Enabled {
127127
httpClient := &http.Client{
128128
// all HIBP API requests should finish quickly to avoid
@@ -238,7 +238,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
238238
With(api.verifyCaptcha).Post("/otp", api.Otp)
239239

240240
// rate limiting applied in handler
241-
r.With(api.verifyCaptcha).With(api.oauthClientAuth).Post("/token", api.Token)
241+
r.With(api.verifyCaptcha).Post("/token", api.Token)
242242

243243
r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) {
244244
r.Get("/", api.Verify)
@@ -379,6 +379,9 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
379379
r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)).
380380
Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister)
381381

382+
// OAuth Token endpoint (public, with client authentication)
383+
r.With(api.requireOAuthClientAuth).Post("/token", api.oauthServer.OAuthToken)
384+
382385
// OAuth 2.1 Authorization endpoints
383386
// `/authorize` to initiate OAuth2 authorization code flow where Supabase Auth is the OAuth2 provider
384387
r.Get("/authorize", api.oauthServer.OAuthServerAuthorize)

internal/api/middleware.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ import (
1414
"time"
1515

1616
chimiddleware "github.com/go-chi/chi/v5/middleware"
17+
"github.com/gofrs/uuid"
1718
"github.com/sirupsen/logrus"
1819
"github.com/supabase/auth/internal/api/apierrors"
1920
"github.com/supabase/auth/internal/api/oauthserver"
21+
"github.com/supabase/auth/internal/api/shared"
2022
"github.com/supabase/auth/internal/models"
2123
"github.com/supabase/auth/internal/observability"
2224
"github.com/supabase/auth/internal/security"
@@ -84,9 +86,9 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
8486
}
8587
}
8688

87-
// oauthClientAuth optionally authenticates an OAuth client as middleware
88-
// This doesn't fail if no client credentials are provided, but validates them if present
89-
func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.Context, error) {
89+
// requireOAuthClientAuth authenticates an OAuth client as middleware
90+
// Requires client_id to be present and validates client credentials
91+
func (a *API) requireOAuthClientAuth(w http.ResponseWriter, r *http.Request) (context.Context, error) {
9092
ctx := r.Context()
9193

9294
clientID, clientSecret, err := oauthserver.ExtractClientCredentials(r)
@@ -99,23 +101,29 @@ func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.C
99101
return ctx, nil
100102
}
101103

104+
// Parse client_id as UUID
105+
clientUUID, err := uuid.FromString(clientID)
106+
if err != nil {
107+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client_id format")
108+
}
109+
102110
// Validate client credentials
103111
db := a.db.WithContext(ctx)
104-
client, err := models.FindOAuthServerClientByClientID(db, clientID)
112+
client, err := models.FindOAuthServerClientByID(db, clientUUID)
105113
if err != nil {
106114
if models.IsNotFoundError(err) {
107115
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials")
108116
}
109117
return nil, apierrors.NewInternalServerError("Error validating client credentials").WithInternalError(err)
110118
}
111119

112-
// Validate client secret
113-
if !oauthserver.ValidateClientSecret(clientSecret, client.ClientSecretHash) {
114-
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials")
120+
// Validate authentication using centralized logic
121+
if err := oauthserver.ValidateClientAuthentication(client, clientSecret); err != nil {
122+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, err.Error())
115123
}
116124

117125
// Add authenticated client to context
118-
ctx = oauthserver.WithOAuthServerClient(ctx, client)
126+
ctx = shared.WithOAuthServerClient(ctx, client)
119127
return ctx, nil
120128
}
121129

internal/api/oauthserver/auth.go

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
package oauthserver
22

33
import (
4+
"bytes"
45
"encoding/base64"
6+
"encoding/json"
57
"errors"
8+
"io"
69
"net/http"
710
"strings"
811
)
912

1013
// ExtractClientCredentials extracts OAuth client credentials from the request
11-
// Supports both Basic auth header and form body parameters
14+
// Supports Basic auth header, form body parameters, and JSON body parameters
1215
func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, err error) {
1316
// First, try Basic auth header: Authorization: Basic base64(client_id:client_secret)
1417
authHeader := r.Header.Get("Authorization")
@@ -28,22 +31,40 @@ func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, e
2831
return parts[0], parts[1], nil
2932
}
3033

31-
// Fall back to form parameters
32-
if err := r.ParseForm(); err != nil {
33-
return "", "", errors.New("failed to parse form")
34-
}
34+
// Check Content-Type to determine how to parse body parameters
35+
contentType := r.Header.Get("Content-Type")
36+
if strings.Contains(contentType, "application/json") {
37+
// Parse JSON body
38+
body, err := io.ReadAll(r.Body)
39+
if err != nil {
40+
return "", "", errors.New("failed to read request body")
41+
}
42+
// Restore the body so other handlers can read it
43+
r.Body = io.NopCloser(bytes.NewBuffer(body))
3544

36-
clientID = r.FormValue("client_id")
37-
clientSecret = r.FormValue("client_secret")
45+
var jsonData struct {
46+
ClientID string `json:"client_id"`
47+
ClientSecret string `json:"client_secret"`
48+
}
49+
if err := json.Unmarshal(body, &jsonData); err != nil {
50+
return "", "", errors.New("failed to parse JSON body")
51+
}
52+
53+
clientID = jsonData.ClientID
54+
clientSecret = jsonData.ClientSecret
55+
} else {
56+
// Fall back to form parameters
57+
if err := r.ParseForm(); err != nil {
58+
return "", "", errors.New("failed to parse form")
59+
}
3860

39-
// Return empty credentials if both are empty (no client auth attempted)
40-
if clientID == "" && clientSecret == "" {
41-
return "", "", nil
61+
clientID = r.FormValue("client_id")
62+
clientSecret = r.FormValue("client_secret")
4263
}
4364

44-
// If only one is provided, it's an error
45-
if clientID == "" || clientSecret == "" {
46-
return "", "", errors.New("both client_id and client_secret must be provided")
65+
// return error if client_id is not provided
66+
if clientID == "" {
67+
return "", "", errors.New("client_id is required")
4768
}
4869

4970
return clientID, clientSecret, nil

internal/api/oauthserver/authorize.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010

1111
"github.com/go-chi/chi/v5"
12+
"github.com/gofrs/uuid"
1213
"github.com/supabase/auth/internal/api/apierrors"
1314
"github.com/supabase/auth/internal/api/shared"
1415
"github.com/supabase/auth/internal/models"
@@ -103,7 +104,13 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er
103104
return err
104105
}
105106

106-
client, err := s.getOAuthServerClient(ctx, params.ClientID)
107+
// Parse client_id as UUID
108+
clientID, err := uuid.FromString(params.ClientID)
109+
if err != nil {
110+
return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id format")
111+
}
112+
113+
client, err := s.getOAuthServerClient(ctx, clientID)
107114
if err != nil {
108115
if models.IsNotFoundError(err) {
109116
return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id")
@@ -144,7 +151,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er
144151
}
145152

146153
observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
147-
observability.LogEntrySetField(r, "client_id", client.ClientID)
154+
observability.LogEntrySetField(r, "client_id", client.ID.String())
148155

149156
// Redirect to authorization path with authorization_id
150157
if config.OAuthServer.AuthorizationPath == "" {
@@ -228,7 +235,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
228235
response := AuthorizationDetailsResponse{
229236
AuthorizationID: authorization.AuthorizationID,
230237
Client: ClientDetailsResponse{
231-
ClientID: authorization.Client.ClientID,
238+
ClientID: authorization.Client.ID.String(),
232239
ClientName: utilities.StringValue(authorization.Client.ClientName),
233240
ClientURI: utilities.StringValue(authorization.Client.ClientURI),
234241
LogoURI: utilities.StringValue(authorization.Client.LogoURI),
@@ -241,7 +248,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
241248
}
242249

243250
observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
244-
observability.LogEntrySetField(r, "client_id", authorization.Client.ClientID)
251+
observability.LogEntrySetField(r, "client_id", authorization.Client.ID.String())
245252

246253
return shared.SendJSON(w, http.StatusOK, response)
247254
}

0 commit comments

Comments
 (0)