Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c18f23f
add saml support (tested with okta saml)
willesq Dec 12, 2025
d6e813c
update example
willesq Dec 12, 2025
3c17a1d
split out callback funcs
willesq Dec 12, 2025
c0ca99f
add saml tests
willesq Dec 12, 2025
442853b
update doc
willesq Dec 12, 2025
0d9e65a
add diff from hugh
willesq Dec 12, 2025
15f0614
update tests
willesq Dec 12, 2025
fb8b27d
Fix: Add missing go.sum entries for test module
willesq Dec 12, 2025
f444814
Address Copilot review comments for SAML PR
willesq Dec 12, 2025
215e737
fix import
willesq Dec 12, 2025
e75fc0f
update swagger docs
willesq Dec 12, 2025
eefb998
Address additional Copilot review comments for SAML PR
willesq Dec 12, 2025
2a6c182
update saml example
willesq Dec 12, 2025
b8300b2
Merge branch 'refs/heads/main' into saml-support
willesq Dec 15, 2025
32e042b
address cursor comment
willesq Dec 15, 2025
53d4142
Add core security infrastructure for SAML authentication
willesq Dec 15, 2025
6e0502f
Fix critical SAML authentication vulnerabilities
willesq Dec 15, 2025
067371a
Prevent session fixation and add CSRF protection to auth callbacks
willesq Dec 15, 2025
5780dcc
Add security components to Server struct
willesq Dec 15, 2025
f31c206
Fix circular import dependency in SAML provider
willesq Dec 15, 2025
ac21745
Initialize SAML security components and middleware
willesq Dec 15, 2025
659aa43
Fix SAML config parsing - add missing allow_idp_initiated and session…
willesq Dec 15, 2025
0bd13f1
Replace custom security implementations with battle-tested packages
willesq Dec 15, 2025
cb9db5e
csrf middleware integration tests
willesq Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/configuration/providers/saml/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ openssl req -new -x509 -key saml.key -out saml.cert -days 365 \

1. **Register Service Provider**: Add your agent as a Service Provider in your IdP
2. **Configure Entity ID**: Use your chosen entity ID (e.g., `https://your-app.example.com/saml/metadata`)
3. **Set Assertion Consumer Service**: Configure ACS URL (e.g., `https://your-app.example.com/saml/acs`)
3. **Set Assertion Consumer Service**: Configure ACS URL (e.g., `https://your-app.example.com/saml/sp`)
4. **Upload Certificate**: Upload your public certificate to the IdP

## Example Configurations
Expand Down
2 changes: 1 addition & 1 deletion examples/providers/saml.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ providers:
idp_metadata_url: "https://your-idp.example.com/saml/metadata"

# Required: Entity ID for this service provider
entity_id: "https://your-app.example.com/saml/metadata"
entity_id: "https://your-app.example.com/saml/sp"

# Required: Root URL of your application
root_url: "https://your-app.example.com"
Expand Down
1 change: 1 addition & 0 deletions internal/config/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
_ "github.com/thand-io/agent/internal/providers/oauth2.google"
_ "github.com/thand-io/agent/internal/providers/okta"
_ "github.com/thand-io/agent/internal/providers/salesforce"
_ "github.com/thand-io/agent/internal/providers/saml"
_ "github.com/thand-io/agent/internal/providers/slack"
_ "github.com/thand-io/agent/internal/providers/terraform"
_ "github.com/thand-io/agent/internal/providers/thand"
Expand Down
184 changes: 155 additions & 29 deletions internal/daemon/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ import (
"github.com/thand-io/agent/internal/models"
)

// Authentication Callback Handlers
//
// This file implements two separate callback handlers:
//
// 1. getAuthCallback() - OAuth2 GET callbacks
// - Expects state and code in query parameters
// - Used by OAuth2 providers (GitHub, Google, etc.)
//
// 2. postAuthCallback() - SAML POST callbacks
// - Expects RelayState and SAMLResponse in form parameters
// - Supports SP-initiated (with RelayState) and IdP-initiated (no RelayState)
// - Used by SAML providers (Okta, Azure AD, etc.)
//
// Both handlers delegate to getAuthCallbackPage() for session creation.

// getAuthRequest initiates the authentication flow
//
// @Summary Initiate authentication
Expand Down Expand Up @@ -42,9 +57,27 @@ func (s *Server) getAuthRequest(c *gin.Context) {

config := s.GetConfig()

if len(callback) > 0 && strings.Compare(callback, config.GetLoginServerUrl()) == 0 {
s.getErrorPage(c, http.StatusBadRequest, "Callback cannot be the login server")
return
// Validate callback URL to prevent infinite loops
// Only block callbacks that would loop back to the auth request endpoint
if len(callback) > 0 {
callbackURL, callbackErr := url.Parse(callback)
loginServerURL, loginServerErr := url.Parse(config.GetLoginServerUrl())

if callbackErr == nil && loginServerErr == nil {
// Block only if it's the same host and the callback would loop back to /api/v1/auth/request
if callbackURL.Host == loginServerURL.Host &&
strings.HasPrefix(callbackURL.Path, "/api/v1/auth/request") {
s.getErrorPage(c, http.StatusBadRequest, "Callback cannot be the auth request endpoint - this would create an infinite loop")
return
}
} else {
// If we can't parse the URLs, log the error but allow the request to proceed
logrus.WithFields(logrus.Fields{
"callback": callback,
"callbackErr": callbackErr,
"loginServerErr": loginServerErr,
}).Warnln("Failed to parse callback or login server URL for validation")
}
}

logrus.WithFields(logrus.Fields{
Expand All @@ -63,22 +96,29 @@ func (s *Server) getAuthRequest(c *gin.Context) {

client := common.GetClientIdentifier()

encodedState := models.EncodingWrapper{
Type: models.ENCODED_AUTH,
Data: models.NewAuthWrapper(
callback, // where are we returning to
client.String(), // server identifier
provider, // provider name
code, // the code sent by the client
),
}.EncodeAndEncrypt(
s.Config.GetServices().GetEncryption(),
)

logrus.WithFields(logrus.Fields{
"encodedState": encodedState,
"stateLength": len(encodedState),
}).Debugln("Encoded state for auth request")

authResponse, err := providerConfig.GetClient().AuthorizeSession(
context.Background(),
// This creates the state payload for the auth request
&models.AuthorizeUser{
Scopes: []string{"email", "profile"},
State: models.EncodingWrapper{
Type: models.ENCODED_AUTH,
Data: models.NewAuthWrapper(
callback, // where are we returning to
client.String(), // server identifier
provider, // provider name
code, // the code sent by the client
),
}.EncodeAndEncrypt(
s.Config.GetServices().GetEncryption(),
),
Scopes: []string{"email", "profile"},
State: encodedState,
RedirectUri: s.GetConfig().GetAuthCallbackUrl(provider),
},
)
Expand All @@ -94,9 +134,9 @@ func (s *Server) getAuthRequest(c *gin.Context) {
)
}

// getAuthCallback handles the OAuth2 callback
// getAuthCallback handles OAuth2 GET callback requests
//
// @Summary Authentication callback
// @Summary OAuth2 authentication callback
// @Description Handle the OAuth2 callback from the provider
// @Tags auth
// @Accept json
Expand All @@ -108,34 +148,111 @@ func (s *Server) getAuthRequest(c *gin.Context) {
// @Failure 400 {object} map[string]any "Bad request"
// @Router /auth/callback/{provider} [get]
func (s *Server) getAuthCallback(c *gin.Context) {
// OAuth2 flow: state and code come in query parameters (GET)
state := c.Query("state")

// Handle the callback to the CLI to store the users session state
// Debug logging
logrus.WithFields(logrus.Fields{
"method": c.Request.Method,
"state": state,
}).Debugln("OAuth2 callback parameters")

// Check if the callback is a workflow resumption or
// a local callback response
// Validate state parameter is required for OAuth2
if len(state) == 0 {
s.getErrorPage(c, http.StatusBadRequest, "State is required for OAuth2 flow")
return
}

state := c.Query("state")
// Decode and decrypt state
decoded, err := s.decodeState(state)
if err != nil {
s.getErrorPage(c, http.StatusBadRequest, "Invalid state", err)
return
}

if len(state) == 0 {
s.getErrorPage(c, http.StatusBadRequest, "State is required")
// Process decoded state
s.processDecodedState(c, decoded)
}

// postAuthCallback handles SAML POST callback requests
//
// @Summary SAML authentication callback
// @Description Handle the SAML POST callback from the provider
// @Tags auth
// @Accept x-www-form-urlencoded
// @Produce json
// @Param provider path string true "Provider name"
// @Param RelayState formData string false "SAML RelayState (SP-initiated)"
// @Param SAMLResponse formData string true "SAML Response"
// @Success 200 "Authentication successful"
// @Failure 400 {object} map[string]any "Bad request"
// @Router /auth/callback/{provider} [post]
func (s *Server) postAuthCallback(c *gin.Context) {
// SAML flow: RelayState and SAMLResponse come in form parameters (POST)
relayState := c.PostForm("RelayState")
samlResponse := c.PostForm("SAMLResponse")

// Debug logging
logrus.WithFields(logrus.Fields{
"method": c.Request.Method,
"relay_state": relayState,
"has_response": len(samlResponse) > 0,
}).Debugln("SAML callback parameters")

// Handle IdP-initiated SAML flow (no RelayState parameter)
if len(relayState) == 0 {
// Check if this is a SAML callback with SAMLResponse
if len(samlResponse) > 0 {
// IdP-initiated flow: create a default auth wrapper
providerName := c.Param("provider")
logrus.WithFields(logrus.Fields{
"provider": providerName,
}).Info("Handling IdP-initiated SAML flow")

authWrapper := models.AuthWrapper{
Callback: "", // No callback for IdP-initiated
Provider: providerName,
Code: "", // No client code
Client: "", // No client identifier
}
s.getAuthCallbackPage(c, authWrapper)
return
}

// Not a SAML IdP-initiated flow, RelayState is required
s.getErrorPage(c, http.StatusBadRequest, "RelayState is required for SP-initiated SAML flow")
return
}

// SP-initiated flow: decode and decrypt RelayState
decoded, err := s.decodeState(relayState)
if err != nil {
s.getErrorPage(c, http.StatusBadRequest, "Invalid RelayState", err)
return
}

// Process decoded state
s.processDecodedState(c, decoded)
}

// decodeState decodes and decrypts the state parameter
func (s *Server) decodeState(state string) (models.EncodingWrapper, error) {
decoded, err := models.EncodingWrapper{}.DecodeAndDecrypt(
state,
s.Config.GetServices().GetEncryption(),
)

if err != nil {
s.getErrorPage(c, http.StatusBadRequest, "Invalid state", err)
return
return models.EncodingWrapper{}, fmt.Errorf("failed to decode state: %w", err)
}
return *decoded, nil
}

// processDecodedState routes based on decoded state type
func (s *Server) processDecodedState(c *gin.Context, decoded models.EncodingWrapper) {
switch decoded.Type {
case models.ENCODED_WORKFLOW_TASK:
s.getElevateAuthOAuth2(c)
case models.ENCODED_AUTH:

authWrapper := models.AuthWrapper{}
err := common.ConvertMapToInterface(
decoded.Data.(map[string]any), &authWrapper)
Expand All @@ -146,7 +263,6 @@ func (s *Server) getAuthCallback(c *gin.Context) {
}

s.getAuthCallbackPage(c, authWrapper)

default:
s.getErrorPage(c, http.StatusBadRequest, "Invalid state type")
}
Expand Down Expand Up @@ -251,8 +367,17 @@ func (s *Server) getAuthCallbackPage(c *gin.Context, auth models.AuthWrapper) {
return
}

// For OAuth2: state and code come in query parameters (GET)
// For SAML: RelayState and SAMLResponse come in form parameters (POST)
state := c.Query("state")
if len(state) == 0 {
state = c.PostForm("RelayState")
}

code := c.Query("code") // This is the code from the provider - not the client
if len(code) == 0 {
code = c.PostForm("SAMLResponse")
}

session, err := provider.GetClient().CreateSession(c, &models.AuthorizeUser{
State: state,
Expand Down Expand Up @@ -292,7 +417,8 @@ func (s *Server) getAuthCallbackPage(c *gin.Context, auth models.AuthWrapper) {
}

if len(auth.Callback) == 0 {
c.Redirect(http.StatusTemporaryRedirect, "/")
// Use 303 See Other to force a GET redirect from the POST callback
c.Redirect(http.StatusSeeOther, "/")
} else {
s.renderHtml(c, "auth_callback.html", data)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/daemon/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,8 @@ func (s *Server) setupRoutes(router *gin.Engine) {
api.GET("/sync", s.getSync)

api.GET("/auth/request/:provider", s.getAuthRequest)
api.GET("/auth/callback/:provider", s.getAuthCallback)
api.GET("/auth/callback/:provider", s.getAuthCallback) // OAuth2 callbacks
api.POST("/auth/callback/:provider", s.postAuthCallback) // SAML callbacks
api.GET("/auth/logout/:provider", s.getLogoutPage)
api.GET("/auth/logout", s.getLogoutPage)

Expand Down
Loading
Loading