Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/kagent-dev/kagent/go/adk/pkg/sts"
"github.com/kagent-dev/kagent/go/adk/pkg/tools"
"github.com/kagent-dev/kagent/go/api/adk"
"github.com/kagent-dev/kagent/go/core/pkg/env"
"google.golang.org/adk/agent"
"google.golang.org/adk/agent/llmagent"
adkmodel "google.golang.org/adk/model"
Expand Down Expand Up @@ -50,12 +51,16 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig
return nil, nil, fmt.Errorf("agent config is required")
}

propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true"
propagateToken := env.KagentPropagateToken.Get()
tokenPrecedence := mcp.StaticTokenWins
if env.KagentPropagateTokenOverridesStatic.Get() {
tokenPrecedence = mcp.ForwardedTokenWins
}
var dynamicHeaderProvider mcp.DynamicHeaderProvider
if stsPlugin != nil {
dynamicHeaderProvider = stsPlugin.HeaderProvider
}
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken, dynamicHeaderProvider)
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken, tokenPrecedence, dynamicHeaderProvider)
subagentSessionIDs := make(map[string]string)

var remoteAgentTools []tool.Tool
Expand Down
6 changes: 6 additions & 0 deletions go/adk/pkg/constants/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,10 @@ const (
// A2A call context's NewRequestMeta normalizes header names to lowercase.
// This is why we use "authorization" instead of "Authorization".
AuthorizationHeader = "authorization"

// ActorTokenHeader carries the agent's own workload token alongside a
// forwarded end-user Authorization, so a downstream gateway can run an
// RFC 8693 delegation (subject=user, actor=agent). It is set on the
// outgoing request, so it uses the canonical header form.
ActorTokenHeader = "X-Actor-Token"
)
101 changes: 85 additions & 16 deletions go/adk/pkg/mcp/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/a2aproject/a2a-go/a2asrv"
Expand All @@ -23,6 +24,22 @@ import (
// This is used for dynamic token injection (e.g., STS tokens) per session.
type DynamicHeaderProvider func(ctx context.Context) map[string]string

// TokenPrecedence selects how a static Authorization configured on an MCP
// server relates to a forwarded or STS-exchanged Authorization.
type TokenPrecedence int

const (
// StaticTokenWins keeps a static Authorization at the highest precedence: it
// overrides any forwarded or STS-exchanged Authorization. This is the default.
StaticTokenWins TokenPrecedence = iota

// ForwardedTokenWins lets a forwarded or STS-exchanged Authorization win over
// a static Authorization. The displaced static token is sent as the actor
// token (X-Actor-Token) so a downstream gateway can run an RFC 8693
// delegation with subject=end user and actor=agent.
ForwardedTokenWins
)

const (
// Default timeout matching Python KAGENT_REMOTE_AGENT_TIMEOUT
defaultTimeout = 30 * time.Minute
Expand Down Expand Up @@ -69,6 +86,7 @@ type mcpServerParams struct {
Headers map[string]string
AllowedHeaders []string // header names to forward from incoming request
PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
TokenPrecedence TokenPrecedence // how a static Authorization relates to a forwarded/STS Authorization
HeaderProvider DynamicHeaderProvider // optional per-request headers derived from invocation context (e.g., STS exchanged access tokens)
ServerType string // "http" or "sse"
Timeout *float64
Expand All @@ -86,13 +104,17 @@ type mcpServerParams struct {
// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin
// behaviour triggered by KAGENT_PROPAGATE_TOKEN.
//
// tokenPrecedence is a runtime-global policy (KAGENT_PROPAGATE_TOKEN_OVERRIDES_STATIC)
// applied uniformly to every server here; see TokenPrecedence and applyStaticHeaders.
//
// Optional headerProvider can be used to inject per-request headers
// derived from invocation context (e.g., STS exchanged access tokens).
func CreateToolsets(
ctx context.Context,
httpTools []adk.HttpMcpServerConfig,
sseTools []adk.SseMcpServerConfig,
propagateToken bool,
tokenPrecedence TokenPrecedence,
headerProvider DynamicHeaderProvider,
) []tool.Toolset {
log := logr.FromContextOrDiscard(ctx)
Expand All @@ -105,6 +127,7 @@ func CreateToolsets(
Headers: httpTool.Params.Headers,
AllowedHeaders: httpTool.AllowedHeaders,
PropagateToken: propagateToken,
TokenPrecedence: tokenPrecedence,
HeaderProvider: headerProvider,
ServerType: "http",
Timeout: httpTool.Params.Timeout,
Expand All @@ -127,6 +150,7 @@ func CreateToolsets(
Headers: sseTool.Params.Headers,
AllowedHeaders: sseTool.AllowedHeaders,
PropagateToken: propagateToken,
TokenPrecedence: tokenPrecedence,
HeaderProvider: headerProvider,
ServerType: "sse",
Timeout: sseTool.Params.Timeout,
Expand Down Expand Up @@ -224,14 +248,20 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
baseTransport.TLSClientConfig = tlsConfig
}

if params.TokenPrecedence == ForwardedTokenWins &&
params.TLSInsecureSkipVerify != nil && *params.TLSInsecureSkipVerify {
log.Info("WARNING: ForwardedTokenWins sends the static M2M credential as X-Actor-Token, but TLS verification is disabled for this MCP server - the actor token can leak to an unverified endpoint", "url", params.URL)
}

var httpTransport http.RoundTripper = baseTransport
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken || params.HeaderProvider != nil {
httpTransport = &headerRoundTripper{
base: baseTransport,
headers: params.Headers,
allowedHeaders: params.AllowedHeaders,
propagateToken: params.PropagateToken,
headerProvider: params.HeaderProvider,
base: baseTransport,
headers: params.Headers,
allowedHeaders: params.AllowedHeaders,
propagateToken: params.PropagateToken,
tokenPrecedence: params.TokenPrecedence,
headerProvider: params.HeaderProvider,
}
}

Expand All @@ -257,20 +287,22 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
}

// headerRoundTripper wraps an http.RoundTripper to add custom headers to all
// requests. It supports four sources of headers, applied in this order so that
// higher-priority sources win on collision:
// requests. Header sources are applied lowest to highest precedence:
// 1. propagateToken: when true, Authorization is read from the incoming A2A
// CallContext and forwarded unconditionally (independent of allowedHeaders).
// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext.
// 3. headerProvider: runtime headers derived from ADK context, such as STS tokens.
// 4. headers: static key/value pairs configured on the MCP server spec (highest
// priority — always wins).
// 4. headers: static key/value pairs configured on the MCP server spec.
//
// Static headers (4) have the highest precedence; the one exception is the
// Authorization header under ForwardedTokenWins, resolved in applyStaticHeaders.
type headerRoundTripper struct {
base http.RoundTripper
headers map[string]string
allowedHeaders []string // header names (case-insensitive) to forward from A2A context
propagateToken bool // when true, Authorization is forwarded independently
headerProvider DynamicHeaderProvider
base http.RoundTripper
headers map[string]string
allowedHeaders []string // header names (case-insensitive) to forward from A2A context
propagateToken bool // when true, Authorization is forwarded independently
tokenPrecedence TokenPrecedence // resolves static vs forwarded Authorization
headerProvider DynamicHeaderProvider
}

func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
Expand Down Expand Up @@ -300,12 +332,49 @@ func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
}
}

// Apply static headers last — they take precedence over all dynamic sources.
rt.applyStaticHeaders(req)

return rt.base.RoundTrip(req)
}

// applyStaticHeaders writes the static headers configured on the MCP server spec
// onto req. Non-Authorization headers always overwrite forwarded values. The
// Authorization header honours tokenPrecedence: StaticTokenWins overwrites any
// forwarded token, while ForwardedTokenWins keeps a forwarded/STS token and, when
// it differs from the static one, carries the displaced static token as the actor
// (X-Actor-Token) for a downstream RFC 8693 delegation. With no forwarded token
// the static Authorization is applied and no actor is added; a forwarded token
// equal to the static one is treated as M2M (no actor); an actor token already
// forwarded via allowedHeaders is left untouched.
func (rt *headerRoundTripper) applyStaticHeaders(req *http.Request) {
// headers is assumed to hold at most one Authorization key; with case-variant
// duplicates map iteration order decides which wins.
staticAuthorization := ""
for key, value := range rt.headers {
if strings.EqualFold(key, constants.AuthorizationHeader) {
staticAuthorization = value
continue
}
req.Header.Set(key, value)
}
Comment on lines +350 to 359

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

headers are sourced from YAML/JSON config; duplicate case-variant Authorization keys cannot arise.


return rt.base.RoundTrip(req)
if staticAuthorization == "" {
return
}

if rt.tokenPrecedence == StaticTokenWins {
req.Header.Set(constants.AuthorizationHeader, staticAuthorization)
return
}

forwardedAuthorization := req.Header.Get(constants.AuthorizationHeader)
if forwardedAuthorization == "" {
req.Header.Set(constants.AuthorizationHeader, staticAuthorization)
return
}
if forwardedAuthorization != staticAuthorization && req.Header.Get(constants.ActorTokenHeader) == "" {
req.Header.Set(constants.ActorTokenHeader, staticAuthorization)
}
}

// initializeToolSet fetches tools from an MCP server using Google ADK's mcptoolset.
Expand Down
Loading
Loading