Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions internal/proxy/cee.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ import (
// CeeSessionKey builds a consistent session identity for cross-request
// exfiltration detection. Exported for use by the session reset admin API.
func CeeSessionKey(agent, clientIP string) string {
if agent != "" && agent != agentAnonymous {
return agent + "|" + clientIP
}
return clientIP
return sessionKeyFor(agent, clientIP)
}

// maxCaptureSessionKeyLen aliases the writer-side ceiling so the
Expand Down
41 changes: 8 additions & 33 deletions internal/proxy/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,7 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
baseAction := config.ActionWarn
effectiveAction := decide.UpgradeAction(baseAction, sr.Level, &cfg.AdaptiveEnforcement)
if effectiveAction == config.ActionBlock {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
p.logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sr.Level), baseAction, effectiveAction, result.Scanner, clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade(baseAction, effectiveAction, session.EscalationLabel(sr.Level))
p.logger.LogBlockedDetail(targetCtx, result.Scanner, result.Reason+" (escalated)", auditDetailFromResult(result))
Expand All @@ -343,10 +340,7 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
// block_all enforcement: deny ALL traffic (including clean) when the
// session is at an escalation level with block_all=true.
if sr.Level > 0 && decide.UpgradeAction("", sr.Level, &cfg.AdaptiveEnforcement) == config.ActionBlock {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
p.logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sr.Level), "", config.ActionBlock, "session_deny", clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade("", config.ActionBlock, session.EscalationLabel(sr.Level))
p.metrics.RecordTunnelBlocked(agentLabel)
Expand Down Expand Up @@ -619,11 +613,7 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
// escalation level lookups instead of a stale snapshot from sr.Level.
var interceptRec session.Recorder
if sm := p.sessionMgrPtr.Load(); sm != nil {
interceptSessionKey := clientIP
if agent != "" && agent != agentAnonymous {
interceptSessionKey = agent + "|" + clientIP
}
interceptRec = sm.GetOrCreate(interceptSessionKey)
interceptRec = sm.GetOrCreate(sessionKeyFor(agent, clientIP))
}
if err := interceptTunnel(interceptCtx, interceptConn, &InterceptContext{
TargetHost: host,
Expand Down Expand Up @@ -931,10 +921,7 @@ func (p *Proxy) handleForwardHTTP(w http.ResponseWriter, r *http.Request) {
baseAction := config.ActionWarn
effectiveAction := decide.UpgradeAction(baseAction, sr.Level, &cfg.AdaptiveEnforcement)
if effectiveAction == config.ActionBlock {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
p.logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sr.Level), baseAction, effectiveAction, result.Scanner, clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade(baseAction, effectiveAction, session.EscalationLabel(sr.Level))
p.logger.LogBlockedDetail(actx, result.Scanner, result.Reason+" (escalated)", auditDetailFromResult(result))
Expand All @@ -958,10 +945,7 @@ func (p *Proxy) handleForwardHTTP(w http.ResponseWriter, r *http.Request) {
// block_all enforcement: deny ALL traffic (including clean) when the
// session is at an escalation level with block_all=true.
if sr.Level > 0 && decide.UpgradeAction("", sr.Level, &cfg.AdaptiveEnforcement) == config.ActionBlock {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
p.logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sr.Level), "", config.ActionBlock, "session_deny", clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade("", config.ActionBlock, session.EscalationLabel(sr.Level))
p.metrics.RecordBlocked(r.URL.Hostname(), "session_deny", time.Since(start), agentLabel)
Expand Down Expand Up @@ -1228,10 +1212,7 @@ func (p *Proxy) handleForwardHTTP(w http.ResponseWriter, r *http.Request) {
action = decide.UpgradeAction(action, sr.Level, &cfg.AdaptiveEnforcement)
}
if action != originalBodyAction {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
p.logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sr.Level), originalBodyAction, action, scannerLabel, clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade(originalBodyAction, action, session.EscalationLabel(sr.Level))
}
Expand Down Expand Up @@ -2105,10 +2086,7 @@ func (p *Proxy) handleForwardHTTP(w http.ResponseWriter, r *http.Request) {
if forwardRec != nil && !fwdRespExempt {
action = decide.UpgradeAction(action, forwardRec.EscalationLevel(), &cfg.AdaptiveEnforcement)
if action != originalAction {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
p.logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(forwardRec.EscalationLevel()), originalAction, action, "response_scan", clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade(originalAction, action, session.EscalationLabel(forwardRec.EscalationLevel()))
}
Expand All @@ -2127,10 +2105,7 @@ func (p *Proxy) handleForwardHTTP(w http.ResponseWriter, r *http.Request) {
// Exempt domains skip scoring — findings are logged but don't escalate.
if !fwdRespExempt {
if sm := p.sessionMgrPtr.Load(); sm != nil && cfg.AdaptiveEnforcement.Enabled {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
sess := sm.GetOrCreate(sessionKey)
decide.RecordSignal(sess, session.SignalStrip, decide.EscalationParams{
Threshold: cfg.AdaptiveEnforcement.EscalationThreshold,
Expand Down
35 changes: 7 additions & 28 deletions internal/proxy/intercept.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ func interceptRecordSignal(ic *InterceptContext, sig session.SignalType) {
if !ic.Config.AdaptiveEnforcement.Enabled {
return
}
sessionKey := ic.ClientIP
if ic.Agent != "" && ic.Agent != agentAnonymous {
sessionKey = ic.Agent + "|" + ic.ClientIP
}
sessionKey := sessionKeyFor(ic.Agent, ic.ClientIP)
var m *metrics.Metrics
if ic.Proxy != nil {
m = ic.Proxy.metrics
Expand Down Expand Up @@ -606,10 +603,7 @@ func newInterceptHandler(
baseAction := config.ActionWarn
effectiveAction := decide.UpgradeAction(baseAction, recEscalationLevel(ic.Recorder), &ic.Config.AdaptiveEnforcement)
if effectiveAction == config.ActionBlock {
sessionKey := ic.ClientIP
if ic.Agent != "" && ic.Agent != agentAnonymous {
sessionKey = ic.Agent + "|" + ic.ClientIP
}
sessionKey := sessionKeyFor(ic.Agent, ic.ClientIP)
ic.Logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(recEscalationLevel(ic.Recorder)), baseAction, effectiveAction, urlResult.Scanner, ic.ClientIP, ic.RequestID)
if ic.Proxy != nil {
ic.Proxy.metrics.RecordAdaptiveUpgrade(baseAction, effectiveAction, session.EscalationLabel(recEscalationLevel(ic.Recorder)))
Expand Down Expand Up @@ -826,10 +820,7 @@ func newInterceptHandler(
action = decide.UpgradeAction(action, recEscalationLevel(ic.Recorder), &ic.Config.AdaptiveEnforcement)
}
if action != originalBodyAction {
sessionKey := ic.ClientIP
if ic.Agent != "" && ic.Agent != agentAnonymous {
sessionKey = ic.Agent + "|" + ic.ClientIP
}
sessionKey := sessionKeyFor(ic.Agent, ic.ClientIP)
ic.Logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(recEscalationLevel(ic.Recorder)), originalBodyAction, action, scannerLabel, ic.ClientIP, ic.RequestID)
if ic.Proxy != nil {
ic.Proxy.metrics.RecordAdaptiveUpgrade(originalBodyAction, action, session.EscalationLabel(recEscalationLevel(ic.Recorder)))
Expand Down Expand Up @@ -1102,10 +1093,7 @@ func newInterceptHandler(
interceptMetrics = ic.Proxy.metrics
}
if changed, fromLabel, toLabel := trySessionRecovery(ic.Recorder, &ic.Config.AdaptiveEnforcement, interceptMetrics); changed {
sessionKey := ic.ClientIP
if ic.Agent != "" && ic.Agent != agentAnonymous {
sessionKey = ic.Agent + "|" + ic.ClientIP
}
sessionKey := sessionKeyFor(ic.Agent, ic.ClientIP)
if ic.Logger != nil {
ic.Logger.LogAdaptiveEscalation(sessionKey, fromLabel, toLabel, ic.ClientIP, ic.RequestID, ic.Recorder.ThreatScore())
}
Expand All @@ -1114,10 +1102,7 @@ func newInterceptHandler(
// block_all enforcement: deny ALL traffic (including clean) when the
// session is at an escalation level with block_all=true.
if ic.Recorder != nil && decide.UpgradeAction("", recEscalationLevel(ic.Recorder), &ic.Config.AdaptiveEnforcement) == config.ActionBlock {
sessionKey := ic.ClientIP
if ic.Agent != "" && ic.Agent != agentAnonymous {
sessionKey = ic.Agent + "|" + ic.ClientIP
}
sessionKey := sessionKeyFor(ic.Agent, ic.ClientIP)
ic.Logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(recEscalationLevel(ic.Recorder)), "", config.ActionBlock, "session_deny", ic.ClientIP, ic.RequestID)
if ic.Proxy != nil {
ic.Proxy.metrics.RecordAdaptiveUpgrade("", config.ActionBlock, session.EscalationLabel(recEscalationLevel(ic.Recorder)))
Expand Down Expand Up @@ -1723,10 +1708,7 @@ func newInterceptHandler(
action = decide.UpgradeAction(action, recEscalationLevel(ic.Recorder), &ic.Config.AdaptiveEnforcement)
}
if action != originalAction {
sessionKey := ic.ClientIP
if ic.Agent != "" && ic.Agent != agentAnonymous {
sessionKey = ic.Agent + "|" + ic.ClientIP
}
sessionKey := sessionKeyFor(ic.Agent, ic.ClientIP)
ic.Logger.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(recEscalationLevel(ic.Recorder)), originalAction, action, "response_scan", ic.ClientIP, ic.RequestID)
if ic.Proxy != nil {
ic.Proxy.metrics.RecordAdaptiveUpgrade(originalAction, action, session.EscalationLabel(recEscalationLevel(ic.Recorder)))
Expand Down Expand Up @@ -1772,10 +1754,7 @@ func newInterceptHandler(
ceeSM = ic.Proxy.sessionMgrPtr.Load()
}
if ceeSM != nil {
sessionKey := ic.ClientIP
if ic.Agent != "" && ic.Agent != agentAnonymous {
sessionKey = ic.Agent + "|" + ic.ClientIP
}
sessionKey := sessionKeyFor(ic.Agent, ic.ClientIP)
sess := ceeSM.GetOrCreate(sessionKey)
var stripMetrics *metrics.Metrics
if ic.Proxy != nil {
Expand Down
38 changes: 7 additions & 31 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2096,12 +2096,7 @@ func (p *Proxy) recordSessionActivityWithUserAgent(opts sessionActivityOptions)
return SessionResult{}
}

// Build session key: agent|clientIP when agent is known, else just clientIP.
key := clientIP
if agent != "" && agent != agentAnonymous {
key = agent + "|" + clientIP
}

key := sessionKeyFor(agent, clientIP)
sess := sm.GetOrCreate(key)

// On-entry de-escalation: recover sessions stuck at block_all.
Expand Down Expand Up @@ -2487,10 +2482,7 @@ func (p *Proxy) recordShieldIntervention(summary *receipt.ShieldSummary, cfg *co
if sm == nil {
return
}
sessionKey := clientIP
if agent := actx.Agent(); agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(actx.Agent(), clientIP)
sess := sm.GetOrCreate(sessionKey)
for i := 0; i < signals; i++ {
if decide.RecordSignal(sess, session.SignalShieldRewrite, decide.EscalationParams{
Expand Down Expand Up @@ -3134,11 +3126,7 @@ func (p *Proxy) handleFetch(w http.ResponseWriter, r *http.Request) {
// RecordClean at the end when no finding was detected.
var fetchRec session.Recorder
if sm := p.sessionMgrPtr.Load(); sm != nil {
fetchSessionKey := clientIP
if agent != "" && agent != agentAnonymous {
fetchSessionKey = agent + "|" + clientIP
}
fetchRec = sm.GetOrCreate(fetchSessionKey)
fetchRec = sm.GetOrCreate(sessionKeyFor(agent, clientIP))
}
fetchTaint := evaluateHTTPTaint(cfg, fetchRec, http.MethodGet, parsed)

Expand Down Expand Up @@ -3224,10 +3212,7 @@ func (p *Proxy) handleFetch(w http.ResponseWriter, r *http.Request) {
baseAction := config.ActionWarn
effectiveAction := decide.UpgradeAction(baseAction, sr.Level, &cfg.AdaptiveEnforcement)
if effectiveAction == config.ActionBlock {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
log.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sr.Level), baseAction, effectiveAction, result.Scanner, clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade(baseAction, effectiveAction, session.EscalationLabel(sr.Level))
log.LogBlockedDetail(actx, result.Scanner, result.Reason+" (escalated)", auditDetailFromResult(result))
Expand Down Expand Up @@ -3309,10 +3294,7 @@ func (p *Proxy) handleFetch(w http.ResponseWriter, r *http.Request) {
// session is at an escalation level with block_all=true. UpgradeAction
// with an empty base action returns "block" only when block_all is set.
if sr.Level > 0 && decide.UpgradeAction("", sr.Level, &cfg.AdaptiveEnforcement) == config.ActionBlock {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
log.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sr.Level), "", config.ActionBlock, "session_deny", clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade("", config.ActionBlock, session.EscalationLabel(sr.Level))
p.emitReceipt(receipt.EmitOpts{
Expand Down Expand Up @@ -4252,10 +4234,7 @@ func (p *Proxy) filterAndActOnResponseScan(
action = decide.UpgradeAction(action, sessionLevel, &cfg.AdaptiveEnforcement)
}
if action != originalAction {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
log.LogAdaptiveUpgrade(sessionKey, session.EscalationLabel(sessionLevel), originalAction, action, "response_scan", clientIP, requestID)
p.metrics.RecordAdaptiveUpgrade(originalAction, action, session.EscalationLabel(sessionLevel))
}
Expand All @@ -4268,10 +4247,7 @@ func (p *Proxy) filterAndActOnResponseScan(
return
}
if sm := p.sessionMgrPtr.Load(); sm != nil && cfg.AdaptiveEnforcement.Enabled {
sessionKey := clientIP
if agent != "" && agent != agentAnonymous {
sessionKey = agent + "|" + clientIP
}
sessionKey := sessionKeyFor(agent, clientIP)
sess := sm.GetOrCreate(sessionKey)
decide.RecordSignal(sess, sig, decide.EscalationParams{
Threshold: cfg.AdaptiveEnforcement.EscalationThreshold,
Expand Down
5 changes: 1 addition & 4 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5104,10 +5104,7 @@ func TestProxy_RegisterAndShutdownAgentServers(t *testing.T) {
// recording enough block signals to cross the threshold repeatedly.
// Returns the session key used.
func escalateSession(sm *SessionManager, clientIP, agent string, threshold float64, targetLevel int) string {
key := clientIP
if agent != "" && agent != agentAnonymous {
key = agent + "|" + clientIP
}
key := sessionKeyFor(agent, clientIP)
sess := sm.GetOrCreate(key)
// Each escalation doubles the threshold. We need to accumulate enough
// points to cross the threshold 'targetLevel' times.
Expand Down
5 changes: 1 addition & 4 deletions internal/proxy/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1303,10 +1303,7 @@ func (sm *SessionManager) AdaptiveStatus() AdaptiveStatus {
}

func (sm *SessionManager) AdaptiveWhoami(clientIP, agent string) AdaptiveWhoami {
key := clientIP
if agent != "" && agent != agentAnonymous {
key = agent + "|" + clientIP
}
key := sessionKeyFor(agent, clientIP)
out := AdaptiveWhoami{
ClientIP: clientIP,
Agent: agent,
Expand Down
17 changes: 17 additions & 0 deletions internal/proxy/sessionkey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package proxy

// sessionKeyFor builds the per-session key used for adaptive-enforcement
// tracking and audit correlation. A named agent is namespaced ahead of its
// client IP so that two agents sharing one client IP are tracked as distinct
// sessions. An unnamed or anonymous agent keys on the client IP alone.
//
// This is the single source of truth for session-key construction. Every
// transport (fetch, forward, CONNECT, WebSocket, TLS intercept) must build
// the key the same way, otherwise adaptive escalation and de-escalation would
// track different keys for the same logical session.
func sessionKeyFor(agent, clientIP string) string {
if agent == "" || agent == agentAnonymous {
return clientIP
}
return agent + "|" + clientIP
}
56 changes: 56 additions & 0 deletions internal/proxy/sessionkey_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package proxy

import "testing"

func TestSessionKeyFor(t *testing.T) {
tests := []struct {
name string
agent string
clientIP string
want string
}{
{
name: "named agent namespaces ahead of ip",
agent: "agent-a",
clientIP: "10.0.0.1",
want: "agent-a|10.0.0.1",
},
{
name: "empty agent keys on ip alone",
agent: "",
clientIP: "10.0.0.1",
want: "10.0.0.1",
},
{
name: "anonymous agent keys on ip alone",
agent: agentAnonymous,
clientIP: "10.0.0.1",
want: "10.0.0.1",
},
{
name: "two named agents on same ip stay distinct",
agent: "agent-b",
clientIP: "10.0.0.1",
want: "agent-b|10.0.0.1",
},
{
name: "named agent with empty ip",
agent: "agent-a",
clientIP: "",
want: "agent-a|",
},
{
name: "empty agent and empty ip",
agent: "",
clientIP: "",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := sessionKeyFor(tt.agent, tt.clientIP); got != tt.want {
t.Errorf("sessionKeyFor(%q, %q) = %q, want %q", tt.agent, tt.clientIP, got, tt.want)
}
})
}
}
Loading
Loading