Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c18f23f
add saml support (tested with okta saml)
willesq Dec 12, 2025
d6e813c
update example
willesq Dec 12, 2025
3c17a1d
split out callback funcs
willesq Dec 12, 2025
c0ca99f
add saml tests
willesq Dec 12, 2025
442853b
update doc
willesq Dec 12, 2025
0d9e65a
add diff from hugh
willesq Dec 12, 2025
15f0614
update tests
willesq Dec 12, 2025
fb8b27d
Fix: Add missing go.sum entries for test module
willesq Dec 12, 2025
f444814
Address Copilot review comments for SAML PR
willesq Dec 12, 2025
215e737
fix import
willesq Dec 12, 2025
e75fc0f
update swagger docs
willesq Dec 12, 2025
eefb998
Address additional Copilot review comments for SAML PR
willesq Dec 12, 2025
2a6c182
update saml example
willesq Dec 12, 2025
b8300b2
Merge branch 'refs/heads/main' into saml-support
willesq Dec 15, 2025
32e042b
address cursor comment
willesq Dec 15, 2025
53d4142
Add core security infrastructure for SAML authentication
willesq Dec 15, 2025
6e0502f
Fix critical SAML authentication vulnerabilities
willesq Dec 15, 2025
067371a
Prevent session fixation and add CSRF protection to auth callbacks
willesq Dec 15, 2025
5780dcc
Add security components to Server struct
willesq Dec 15, 2025
f31c206
Fix circular import dependency in SAML provider
willesq Dec 15, 2025
ac21745
Initialize SAML security components and middleware
willesq Dec 15, 2025
659aa43
Fix SAML config parsing - add missing allow_idp_initiated and session…
willesq Dec 15, 2025
0bd13f1
Replace custom security implementations with battle-tested packages
willesq Dec 15, 2025
cb9db5e
csrf middleware integration tests
willesq Dec 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/configuration/providers/saml/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ openssl req -new -x509 -key saml.key -out saml.cert -days 365 \

1. **Register Service Provider**: Add your agent as a Service Provider in your IdP
2. **Configure Entity ID**: Use your chosen entity ID (e.g., `https://your-app.example.com/saml/metadata`)
3. **Set Assertion Consumer Service**: Configure ACS URL (e.g., `https://your-app.example.com/saml/acs`)
3. **Set Assertion Consumer Service**: Configure ACS URL (e.g., `https://your-app.example.com/api/v1/auth/callback/{provider-name}`), replacing `{provider-name}` with the key you use for this provider (e.g., `company-saml`)
4. **Upload Certificate**: Upload your public certificate to the IdP

## Example Configurations
Expand Down
51 changes: 49 additions & 2 deletions docs/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ const docTemplate = `{
"tags": [
"auth"
],
"summary": "Authentication callback",
"summary": "OAuth2 authentication callback",
"parameters": [
{
"type": "string",
Expand Down Expand Up @@ -137,6 +137,53 @@ const docTemplate = `{
}
}
}
},
"post": {
"description": "Handle the SAML POST callback from the provider",
"consumes": [
"application/x-www-form-urlencoded"
],
"produces": [
"application/json"
],
"tags": [
"auth"
],
"summary": "SAML authentication callback",
"parameters": [
{
"type": "string",
"description": "Provider name",
"name": "provider",
"in": "path",
"required": true
},
{
"type": "string",
"description": "SAML RelayState (SP-initiated)",
"name": "RelayState",
"in": "formData"
},
{
"type": "string",
"description": "SAML Response",
"name": "SAMLResponse",
"in": "formData",
"required": true
}
],
"responses": {
"200": {
"description": "Authentication successful"
},
"400": {
"description": "Bad request",
"schema": {
"type": "object",
"additionalProperties": true
}
}
}
}
},
"/auth/logout": {
Expand Down Expand Up @@ -1149,7 +1196,7 @@ const docTemplate = `{
],
"responses": {
"200": {
"description": "Provider roles",
"description": "Provider identities",
"schema": {
"$ref": "#/definitions/models.ProviderIdentitiesResponse"
}
Expand Down
51 changes: 49 additions & 2 deletions docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
"tags": [
"auth"
],
"summary": "Authentication callback",
"summary": "OAuth2 authentication callback",
"parameters": [
{
"type": "string",
Expand Down Expand Up @@ -135,6 +135,53 @@
}
}
}
},
"post": {
"description": "Handle the SAML POST callback from the provider",
"consumes": [
"application/x-www-form-urlencoded"
],
"produces": [
"application/json"
],
"tags": [
"auth"
],
"summary": "SAML authentication callback",
"parameters": [
{
"type": "string",
"description": "Provider name",
"name": "provider",
"in": "path",
"required": true
},
{
"type": "string",
"description": "SAML RelayState (SP-initiated)",
"name": "RelayState",
"in": "formData"
},
{
"type": "string",
"description": "SAML Response",
"name": "SAMLResponse",
"in": "formData",
"required": true
}
],
"responses": {
"200": {
"description": "Authentication successful"
},
"400": {
"description": "Bad request",
"schema": {
"type": "object",
"additionalProperties": true
}
}
}
}
},
"/auth/logout": {
Expand Down Expand Up @@ -1147,7 +1194,7 @@
],
"responses": {
"200": {
"description": "Provider roles",
"description": "Provider identities",
"schema": {
"$ref": "#/definitions/models.ProviderIdentitiesResponse"
}
Expand Down
36 changes: 34 additions & 2 deletions docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,39 @@ paths:
schema:
additionalProperties: true
type: object
summary: Authentication callback
summary: OAuth2 authentication callback
tags:
- auth
post:
consumes:
- application/x-www-form-urlencoded
description: Handle the SAML POST callback from the provider
parameters:
- description: Provider name
in: path
name: provider
required: true
type: string
- description: SAML RelayState (SP-initiated)
in: formData
name: RelayState
type: string
- description: SAML Response
in: formData
name: SAMLResponse
required: true
type: string
produces:
- application/json
responses:
"200":
description: Authentication successful
"400":
description: Bad request
schema:
additionalProperties: true
type: object
summary: SAML authentication callback
tags:
- auth
/auth/logout:
Expand Down Expand Up @@ -2275,7 +2307,7 @@ paths:
- application/json
responses:
"200":
description: Provider roles
description: Provider identities
schema:
$ref: '#/definitions/models.ProviderIdentitiesResponse'
"404":
Expand Down
5 changes: 4 additions & 1 deletion examples/providers/saml.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ providers:
# Required: URL to fetch IdP metadata
idp_metadata_url: "https://your-idp.example.com/saml/metadata"

# Required: Entity ID for this service provider
# Required: Entity ID for this service provider (typically the metadata URL)
Copy link

Copilot AI Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says 'typically the metadata URL' but the example value on line 14 shows 'https://your-app.example.com/saml/metadata' which contradicts the actual implementation. According to the code in main.go line 95, the metadata URL is constructed as '/saml/metadata', but the entity_id is a separate configuration field. The comment should clarify that entity_id is a unique identifier for the SP, not necessarily the metadata URL.

Suggested change
# Required: Entity ID for this service provider (typically the metadata URL)
# Required: Entity ID for this service provider (a unique identifier for your SP; often set to the metadata URL, but can be any URI under your control)

Copilot uses AI. Check for mistakes.
entity_id: "https://your-app.example.com/saml/metadata"

# Required: Root URL of your application
Expand All @@ -24,3 +24,6 @@ providers:

# Optional: Whether to sign SAML requests (default: false)
sign_requests: true

# Optional: Allow IdP-initiated SSO (default: false)
allow_idp_initiated: true
1 change: 1 addition & 0 deletions internal/config/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
_ "github.com/thand-io/agent/internal/providers/oauth2.google"
_ "github.com/thand-io/agent/internal/providers/okta"
_ "github.com/thand-io/agent/internal/providers/salesforce"
_ "github.com/thand-io/agent/internal/providers/saml"
_ "github.com/thand-io/agent/internal/providers/slack"
_ "github.com/thand-io/agent/internal/providers/terraform"
_ "github.com/thand-io/agent/internal/providers/thand"
Expand Down
142 changes: 142 additions & 0 deletions internal/daemon/assertion_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package daemon

import (
"sync"
"time"

"github.com/sirupsen/logrus"
)

// AssertionCache implements in-memory cache for SAML assertion ID replay protection.
// It provides thread-safe tracking of used assertion IDs to prevent replay attacks
// where an attacker captures a valid SAML assertion and attempts to reuse it.
type AssertionCache struct {
cache sync.Map // map[string]*assertionEntry
ttl time.Duration // Time-to-live for cached assertions
cleanupTicker *time.Ticker // Periodic cleanup ticker
stopCleanup chan struct{} // Channel to stop cleanup goroutine
}

// assertionEntry represents a cached assertion with its timing information
type assertionEntry struct {
addedAt time.Time // When the assertion was first cached
expiry time.Time // When the assertion entry expires
}

// NewAssertionCache creates a new assertion cache with the specified TTL and cleanup interval.
// The TTL should match the typical validity window of SAML assertions (usually 5 minutes).
// The cleanup interval determines how often expired entries are removed from memory.
func NewAssertionCache(ttl time.Duration, cleanupInterval time.Duration) *AssertionCache {
if ttl == 0 {
ttl = 5 * time.Minute // Default TTL matches typical SAML assertion validity
}
if cleanupInterval == 0 {
cleanupInterval = 1 * time.Minute // Default cleanup every minute
}

ac := &AssertionCache{
ttl: ttl,
stopCleanup: make(chan struct{}),
}

// Start cleanup goroutine
ac.cleanupTicker = time.NewTicker(cleanupInterval)
go ac.cleanup()

logrus.WithFields(logrus.Fields{
"ttl": ttl,
"cleanup_interval": cleanupInterval,
}).Info("Assertion cache initialized")

return ac
}

// CheckAndAdd atomically checks if an assertion ID exists in the cache and adds it if not.
// This method is the core of replay protection - it ensures that each assertion ID can
// only be used once within the TTL window.
//
// Returns true if the assertion was added (not a replay), false if it already exists (replay detected).
func (ac *AssertionCache) CheckAndAdd(assertionID string) bool {
if assertionID == "" {
logrus.Warn("Empty assertion ID provided to cache")
return false
}

now := time.Now()
entry := &assertionEntry{
addedAt: now,
expiry: now.Add(ac.ttl),
}

// LoadOrStore is atomic - it returns the existing value if present,
// or stores the new value and returns it. The 'loaded' bool indicates
// whether the value was loaded (true) or stored (false).
_, loaded := ac.cache.LoadOrStore(assertionID, entry)

if loaded {
// Assertion ID already exists - this is a replay attack
logrus.WithFields(logrus.Fields{
"assertion_id": assertionID,
"event": "replay_detected",
}).Warn("SAML assertion replay attack detected")
return false
}

// Successfully cached new assertion ID
logrus.WithFields(logrus.Fields{
"assertion_id": assertionID,
"expiry": entry.expiry,
}).Debug("SAML assertion ID cached successfully")

return true
}

// cleanup removes expired assertion entries from the cache.
// This goroutine runs periodically based on the cleanup interval and prevents
// unbounded memory growth by removing entries that have exceeded their TTL.
func (ac *AssertionCache) cleanup() {
for {
select {
case <-ac.cleanupTicker.C:
now := time.Now()
count := 0

// Iterate through all cache entries
ac.cache.Range(func(key, value interface{}) bool {
entry := value.(*assertionEntry)
if now.After(entry.expiry) {
ac.cache.Delete(key)
count++
}
return true // Continue iteration
})

if count > 0 {
logrus.WithField("count", count).Debug("Cleaned up expired assertion cache entries")
}

case <-ac.stopCleanup:
// Graceful shutdown requested
ac.cleanupTicker.Stop()
logrus.Info("Assertion cache cleanup goroutine stopped")
return
}
}
}

// Stop gracefully stops the cleanup goroutine.
// This should be called when the server is shutting down to prevent goroutine leaks.
func (ac *AssertionCache) Stop() {
close(ac.stopCleanup)
}

// Size returns the current number of cached assertions.
// This is useful for monitoring and observability to track cache utilization.
func (ac *AssertionCache) Size() int {
count := 0
ac.cache.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
}
Loading
Loading