Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c18f23f
add saml support (tested with okta saml)
willesq Dec 12, 2025
d6e813c
update example
willesq Dec 12, 2025
3c17a1d
split out callback funcs
willesq Dec 12, 2025
c0ca99f
add saml tests
willesq Dec 12, 2025
442853b
update doc
willesq Dec 12, 2025
0d9e65a
add diff from hugh
willesq Dec 12, 2025
15f0614
update tests
willesq Dec 12, 2025
fb8b27d
Fix: Add missing go.sum entries for test module
willesq Dec 12, 2025
f444814
Address Copilot review comments for SAML PR
willesq Dec 12, 2025
215e737
fix import
willesq Dec 12, 2025
e75fc0f
update swagger docs
willesq Dec 12, 2025
eefb998
Address additional Copilot review comments for SAML PR
willesq Dec 12, 2025
2a6c182
update saml example
willesq Dec 12, 2025
b8300b2
Merge branch 'refs/heads/main' into saml-support
willesq Dec 15, 2025
32e042b
address cursor comment
willesq Dec 15, 2025
53d4142
Add core security infrastructure for SAML authentication
willesq Dec 15, 2025
6e0502f
Fix critical SAML authentication vulnerabilities
willesq Dec 15, 2025
067371a
Prevent session fixation and add CSRF protection to auth callbacks
willesq Dec 15, 2025
5780dcc
Add security components to Server struct
willesq Dec 15, 2025
f31c206
Fix circular import dependency in SAML provider
willesq Dec 15, 2025
ac21745
Initialize SAML security components and middleware
willesq Dec 15, 2025
659aa43
Fix SAML config parsing - add missing allow_idp_initiated and session…
willesq Dec 15, 2025
0bd13f1
Replace custom security implementations with battle-tested packages
willesq Dec 15, 2025
cb9db5e
csrf middleware integration tests
willesq Dec 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/configuration/providers/saml/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ openssl req -new -x509 -key saml.key -out saml.cert -days 365 \

1. **Register Service Provider**: Add your agent as a Service Provider in your IdP
2. **Configure Entity ID**: Use your chosen entity ID (e.g., `https://your-app.example.com/saml/metadata`)
3. **Set Assertion Consumer Service**: Configure ACS URL (e.g., `https://your-app.example.com/saml/acs`)
3. **Set Assertion Consumer Service**: Configure ACS URL (e.g., `https://your-app.example.com/api/v1/auth/callback/{provider-name}`), replacing `{provider-name}` with the key you use for this provider (e.g., `company-saml`)
4. **Upload Certificate**: Upload your public certificate to the IdP

## Example Configurations
Expand Down
2 changes: 1 addition & 1 deletion examples/providers/saml.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ providers:
idp_metadata_url: "https://your-idp.example.com/saml/metadata"

# Required: Entity ID for this service provider
entity_id: "https://your-app.example.com/saml/metadata"
entity_id: "https://your-app.example.com/saml/sp"

# Required: Root URL of your application
root_url: "https://your-app.example.com"
Expand Down
1 change: 1 addition & 0 deletions internal/config/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
_ "github.com/thand-io/agent/internal/providers/oauth2.google"
_ "github.com/thand-io/agent/internal/providers/okta"
_ "github.com/thand-io/agent/internal/providers/salesforce"
_ "github.com/thand-io/agent/internal/providers/saml"
_ "github.com/thand-io/agent/internal/providers/slack"
_ "github.com/thand-io/agent/internal/providers/terraform"
_ "github.com/thand-io/agent/internal/providers/thand"
Expand Down
200 changes: 171 additions & 29 deletions internal/daemon/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ import (
"github.com/thand-io/agent/internal/models"
)

// Authentication Callback Handlers
//
// This file implements two separate callback handlers:
//
// 1. getAuthCallback() - OAuth2 GET callbacks
// - Expects state and code in query parameters
// - Used by OAuth2 providers (GitHub, Google, etc.)
//
// 2. postAuthCallback() - SAML POST callbacks
// - Expects RelayState and SAMLResponse in form parameters
// - Supports SP-initiated (with RelayState) and IdP-initiated (no RelayState)
// - Used by SAML providers (Okta, Azure AD, etc.)
//
// Both handlers delegate to getAuthCallbackPage() for session creation.

// getAuthRequest initiates the authentication flow
//
// @Summary Initiate authentication
Expand Down Expand Up @@ -42,9 +57,28 @@ func (s *Server) getAuthRequest(c *gin.Context) {

config := s.GetConfig()

if len(callback) > 0 && strings.Compare(callback, config.GetLoginServerUrl()) == 0 {
s.getErrorPage(c, http.StatusBadRequest, "Callback cannot be the login server")
return
// Validate callback URL to prevent infinite loops
// Only block callbacks that would loop back to the auth request endpoint
if len(callback) > 0 {
callbackURL, callbackErr := url.Parse(callback)
loginServerURL, loginServerErr := url.Parse(config.GetLoginServerUrl())

if callbackErr == nil && loginServerErr == nil {
// Block if it's the same host and the callback would loop back to auth endpoints
if callbackURL.Host == loginServerURL.Host &&
(strings.HasPrefix(callbackURL.Path, "/api/v1/auth/request") ||
strings.HasPrefix(callbackURL.Path, "/api/v1/auth/callback")) {
s.getErrorPage(c, http.StatusBadRequest, "Callback cannot be the auth request or callback endpoint - this would create an infinite loop")
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback URL validation logic was changed to check specific path prefixes instead of just comparing domains, but it doesn't account for all possible loop scenarios. For instance, if someone sets a callback to /api/v1/auth/request/provider with different query parameters, it would still be blocked. Consider documenting this behavior or making the validation more precise by checking both path AND ensuring it's not just adding query parameters to the exact same endpoint.

Suggested change
// Only block callbacks that would loop back to the auth request endpoint
if len(callback) > 0 {
callbackURL, callbackErr := url.Parse(callback)
loginServerURL, loginServerErr := url.Parse(config.GetLoginServerUrl())
if callbackErr == nil && loginServerErr == nil {
// Block if it's the same host and the callback would loop back to auth endpoints
if callbackURL.Host == loginServerURL.Host &&
(strings.HasPrefix(callbackURL.Path, "/api/v1/auth/request") ||
strings.HasPrefix(callbackURL.Path, "/api/v1/auth/callback")) {
s.getErrorPage(c, http.StatusBadRequest, "Callback cannot be the auth request or callback endpoint - this would create an infinite loop")
// Only block callbacks that would loop back to the exact same auth request endpoint (path and query)
if len(callback) > 0 {
callbackURL, callbackErr := url.Parse(callback)
loginServerURL, loginServerErr := url.Parse(config.GetLoginServerUrl())
if callbackErr == nil && loginServerErr == nil {
// Block if it's the same host, path, and query as the current request (i.e., would cause a loop)
currentPath := c.Request.URL.Path
currentRawQuery := c.Request.URL.RawQuery
// Remove the "callback" parameter from the current query for comparison
currentQueryVals := c.Request.URL.Query()
currentQueryVals.Del("callback")
currentQuery := currentQueryVals.Encode()
callbackPath := callbackURL.Path
callbackRawQuery := callbackURL.RawQuery
callbackQueryVals := callbackURL.Query()
callbackQueryVals.Del("callback")
callbackQuery := callbackQueryVals.Encode()
if callbackURL.Host == loginServerURL.Host &&
callbackPath == currentPath &&
callbackQuery == currentQuery {
s.getErrorPage(c, http.StatusBadRequest, "Callback cannot be the same as the current auth request endpoint - this would create an infinite loop")

Copilot uses AI. Check for mistakes.
return
}
} else {
// If we can't parse the URLs, log the error but allow the request to proceed
logrus.WithFields(logrus.Fields{
"callback": callback,
"callbackErr": callbackErr,
"loginServerErr": loginServerErr,
}).Warnln("Failed to parse callback or login server URL for validation")
}
}

logrus.WithFields(logrus.Fields{
Expand All @@ -63,22 +97,29 @@ func (s *Server) getAuthRequest(c *gin.Context) {

client := common.GetClientIdentifier()

encodedState := models.EncodingWrapper{
Type: models.ENCODED_AUTH,
Data: models.NewAuthWrapper(
callback, // where are we returning to
client.String(), // server identifier
provider, // provider name
code, // the code sent by the client
),
}.EncodeAndEncrypt(
s.Config.GetServices().GetEncryption(),
)

logrus.WithFields(logrus.Fields{
"encodedState": encodedState,
"stateLength": len(encodedState),
}).Debugln("Encoded state for auth request")

authResponse, err := providerConfig.GetClient().AuthorizeSession(
context.Background(),
// This creates the state payload for the auth request
&models.AuthorizeUser{
Scopes: []string{"email", "profile"},
State: models.EncodingWrapper{
Type: models.ENCODED_AUTH,
Data: models.NewAuthWrapper(
callback, // where are we returning to
client.String(), // server identifier
provider, // provider name
code, // the code sent by the client
),
}.EncodeAndEncrypt(
s.Config.GetServices().GetEncryption(),
),
Scopes: []string{"email", "profile"},
State: encodedState,
RedirectUri: s.GetConfig().GetAuthCallbackUrl(provider),
},
)
Expand All @@ -94,9 +135,9 @@ func (s *Server) getAuthRequest(c *gin.Context) {
)
}

// getAuthCallback handles the OAuth2 callback
// getAuthCallback handles OAuth2 GET callback requests
//
// @Summary Authentication callback
// @Summary OAuth2 authentication callback
// @Description Handle the OAuth2 callback from the provider
// @Tags auth
// @Accept json
Expand All @@ -108,34 +149,126 @@ func (s *Server) getAuthRequest(c *gin.Context) {
// @Failure 400 {object} map[string]any "Bad request"
// @Router /auth/callback/{provider} [get]
func (s *Server) getAuthCallback(c *gin.Context) {
// OAuth2 flow: state and code come in query parameters (GET)
state := c.Query("state")

// Handle the callback to the CLI to store the users session state
// Debug logging
logrus.WithFields(logrus.Fields{
"method": c.Request.Method,
"state": state,
}).Debugln("OAuth2 callback parameters")

// Check if the callback is a workflow resumption or
// a local callback response
// Validate state parameter is required for OAuth2
if len(state) == 0 {
s.getErrorPage(c, http.StatusBadRequest, "State is required for OAuth2 flow")
return
}

state := c.Query("state")
// Decode and decrypt state
decoded, err := s.decodeState(state)
if err != nil {
s.getErrorPage(c, http.StatusBadRequest, "Invalid state", err)
return
}

if len(state) == 0 {
s.getErrorPage(c, http.StatusBadRequest, "State is required")
// Process decoded state
s.processDecodedState(c, decoded)
}

// postAuthCallback handles SAML POST callback requests
//
// @Summary SAML authentication callback
// @Description Handle the SAML POST callback from the provider
// @Tags auth
// @Accept x-www-form-urlencoded
// @Produce json
// @Param provider path string true "Provider name"
// @Param RelayState formData string false "SAML RelayState (SP-initiated)"
// @Param SAMLResponse formData string true "SAML Response"
// @Success 200 "Authentication successful"
// @Failure 400 {object} map[string]any "Bad request"
// @Router /auth/callback/{provider} [post]
func (s *Server) postAuthCallback(c *gin.Context) {
// SAML flow: RelayState and SAMLResponse come in form parameters (POST)
relayState := c.PostForm("RelayState")
samlResponse := c.PostForm("SAMLResponse")

// Debug logging
logrus.WithFields(logrus.Fields{
"method": c.Request.Method,
"relay_state": relayState,
"has_response": len(samlResponse) > 0,
}).Debugln("SAML callback parameters")

// Handle IdP-initiated SAML flow (no RelayState parameter)
if len(relayState) == 0 {
// Check if this is a SAML callback with SAMLResponse
if len(samlResponse) > 0 {
// IdP-initiated flow: validate provider allows this
providerName := c.Param("provider")

// Get provider config to verify IdP-initiated is allowed
_, err := s.GetConfig().GetProviderByName(providerName)
if err != nil {
logrus.WithFields(logrus.Fields{
"provider": providerName,
"error": err,
}).Warn("Provider not found for IdP-initiated SAML flow")
s.getErrorPage(c, http.StatusBadRequest, "Provider not configured")
return
}

// Security logging for IdP-initiated flows (for audit/monitoring)
logrus.WithFields(logrus.Fields{
"provider": providerName,
"source_ip": c.ClientIP(),
"user_agent": c.Request.UserAgent(),
}).Info("Processing IdP-initiated SAML authentication - verify this is expected")

authWrapper := models.AuthWrapper{
Callback: "", // No callback for IdP-initiated
Provider: providerName,
Code: "", // No client code
Client: "", // No client identifier
}
s.getAuthCallbackPage(c, authWrapper)
return
}
Comment on lines 214 to 268
Copy link

Copilot AI Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IdP-initiated SAML flow accepts authentication without validating a RelayState, which could potentially be exploited for session fixation or CSRF attacks. While the code includes security logging, consider adding additional validation such as checking for expected SAML assertion conditions (e.g., validating the Destination field matches the ACS URL, checking for replay attacks by tracking assertion IDs). Additionally, consider making IdP-initiated flows opt-in via configuration rather than always allowing them.

Copilot uses AI. Check for mistakes.

// Not a SAML IdP-initiated flow, RelayState is required
s.getErrorPage(c, http.StatusBadRequest, "RelayState is required for SP-initiated SAML flow")
return
}

// SP-initiated flow: decode and decrypt RelayState
decoded, err := s.decodeState(relayState)
if err != nil {
s.getErrorPage(c, http.StatusBadRequest, "Invalid RelayState", err)
return
}

// Process decoded state
s.processDecodedState(c, decoded)
}

// decodeState decodes and decrypts the state parameter
func (s *Server) decodeState(state string) (models.EncodingWrapper, error) {
decoded, err := models.EncodingWrapper{}.DecodeAndDecrypt(
state,
s.Config.GetServices().GetEncryption(),
)

if err != nil {
s.getErrorPage(c, http.StatusBadRequest, "Invalid state", err)
return
return models.EncodingWrapper{}, fmt.Errorf("failed to decode state: %w", err)
}
return *decoded, nil
}

// processDecodedState routes based on decoded state type
func (s *Server) processDecodedState(c *gin.Context, decoded models.EncodingWrapper) {
switch decoded.Type {
case models.ENCODED_WORKFLOW_TASK:
s.getElevateAuthOAuth2(c)
case models.ENCODED_AUTH:

authWrapper := models.AuthWrapper{}
err := common.ConvertMapToInterface(
decoded.Data.(map[string]any), &authWrapper)
Expand All @@ -146,7 +279,6 @@ func (s *Server) getAuthCallback(c *gin.Context) {
}

s.getAuthCallbackPage(c, authWrapper)

default:
s.getErrorPage(c, http.StatusBadRequest, "Invalid state type")
}
Expand Down Expand Up @@ -251,8 +383,17 @@ func (s *Server) getAuthCallbackPage(c *gin.Context, auth models.AuthWrapper) {
return
}

// For OAuth2: state and code come in query parameters (GET)
// For SAML: RelayState and SAMLResponse come in form parameters (POST)
state := c.Query("state")
if len(state) == 0 {
state = c.PostForm("RelayState")
}

code := c.Query("code") // This is the code from the provider - not the client
if len(code) == 0 {
code = c.PostForm("SAMLResponse")
}

session, err := provider.GetClient().CreateSession(c, &models.AuthorizeUser{
State: state,
Expand Down Expand Up @@ -292,7 +433,8 @@ func (s *Server) getAuthCallbackPage(c *gin.Context, auth models.AuthWrapper) {
}

if len(auth.Callback) == 0 {
c.Redirect(http.StatusTemporaryRedirect, "/")
// Use 303 See Other to force a GET redirect from the POST callback
c.Redirect(http.StatusSeeOther, "/")
} else {
s.renderHtml(c, "auth_callback.html", data)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ func (s *Server) setupRoutes(router *gin.Engine) {
api.GET("/sync", s.getSync)

api.GET("/auth/request/:provider", s.getAuthRequest)
api.GET("/auth/callback/:provider", s.getAuthCallback)
api.GET("/auth/callback/:provider", s.getAuthCallback) // OAuth2 callbacks
api.POST("/auth/callback/:provider", s.postAuthCallback) // SAML callbacks
api.GET("/auth/logout/:provider", s.getLogoutPage)
api.GET("/auth/logout", s.getLogoutPage)

Expand Down
Loading