Skip to content

Add header propagation functionality #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using Microsoft.Agents.Builder;
using Microsoft.Agents.Client.Errors;
using Microsoft.Agents.Core;
using Microsoft.Agents.Core.HeaderPropagation;
using Microsoft.Agents.Core.Models;
using Microsoft.Agents.Core.Serialization;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -255,6 +256,8 @@ private async Task<HttpResponseMessage> SendRequest(IActivity activity, bool use

using var httpClient = _httpClientFactory.CreateClient(nameof(HttpAgentClient));

httpClient.AddHeaderPropagation();

// Add the auth header to the HTTP request.
if (!useAnonymous)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Net.Http;
using System.Threading.Tasks;
using System;
using Microsoft.Agents.Core.HeaderPropagation;

namespace Microsoft.Agents.Connector.RestClients
{
Expand All @@ -28,6 +29,7 @@ public async Task<HttpClient> GetHttpClientAsync()
}

httpClient.AddDefaultUserAgent();
httpClient.AddHeaderPropagation();

return httpClient;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Reflection;

namespace Microsoft.Agents.Core.HeaderPropagation;

/// <summary>
/// Attribute to load headers for header propagation.
/// This attribute should be applied to classes that implement the <see cref="IHeaderPropagationAttribute"/> interface.
/// </summary>
[AttributeUsage(AttributeTargets.Class)]
public class HeaderPropagationAttribute : Attribute
{
internal static void LoadHeaders()
{
// Init newly loaded assemblies
AppDomain.CurrentDomain.AssemblyLoad += (s, o) => LoadHeadersAssembly(o.LoadedAssembly);

// And all the ones we currently have loaded
foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
{
LoadHeadersAssembly(assembly);
}
}

private static void LoadHeadersAssembly(Assembly assembly)
{
foreach (var type in GetLoadHeadersTypes(assembly))
{
#if !NETSTANDARD
if (!typeof(IHeaderPropagationAttribute).IsAssignableFrom(type))
{
throw new InvalidOperationException(
$"Type '{type.FullName}' is marked with [HeaderPropagation] but does not implement IHeaderPropagationAttribute.");
}
#endif

var loadHeaders = type.GetMethod(nameof(LoadHeaders), BindingFlags.Static | BindingFlags.Public);

if (loadHeaders == null)
{
continue;
}

loadHeaders.Invoke(assembly, [HeaderPropagationContext.HeadersToPropagate]);
}
Comment on lines +48 to +49
Copy link
Preview

Copilot AI Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When invoking a static method via reflection, the instance parameter should be 'null' rather than 'assembly'. Use 'loadHeaders.Invoke(null, new object[]{ HeaderPropagationContext.HeadersToPropagate })'.

Suggested change
loadHeaders.Invoke(assembly, [HeaderPropagationContext.HeadersToPropagate]);
}
loadHeaders.Invoke(null, new object[] { HeaderPropagationContext.HeadersToPropagate });

Copilot uses AI. Check for mistakes.

}

private static IEnumerable<Type> GetLoadHeadersTypes(Assembly assembly)
{
foreach (Type type in assembly.GetTypes())
{
if (type.GetCustomAttributes(typeof(HeaderPropagationAttribute), true).Length > 0)
{
yield return type;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.Extensions.Primitives;

namespace Microsoft.Agents.Core.HeaderPropagation;

/// <summary>
/// Shared context to manage request headers that will be used to propagate them in the <see cref="HeaderPropagationExtensions.AddHeaderPropagation"/>.
/// </summary>
public class HeaderPropagationContext()
{
private static readonly AsyncLocal<IDictionary<string, StringValues>> _headersFromRequest = new();
private static HeaderPropagationEntryCollection _headersToPropagate = new();

static HeaderPropagationContext()
{
HeaderPropagationAttribute.LoadHeaders();
}

/// <summary>
/// Gets or sets the request headers that will be propagated based on what's inside the <see cref="HeadersToPropagate"/> property.
/// </summary>
public static IDictionary<string, StringValues> HeadersFromRequest
{
get
{
return _headersFromRequest.Value;
}
set
{
// Create a copy to ensure headers are not modified by the original request.
#if !NETSTANDARD
var headers = value?.ToDictionary(StringComparer.InvariantCultureIgnoreCase);
#else
var headers = value?.ToDictionary(x => x.Key, x => x.Value, StringComparer.InvariantCultureIgnoreCase);
#endif
_headersFromRequest.Value = FilterHeaders(headers);
}
}

/// <summary>
/// Gets or sets the headers to allow during the propagation.
/// </summary>
public static HeaderPropagationEntryCollection HeadersToPropagate
{
get
{
return _headersToPropagate;
}
set
{
_headersToPropagate = value ?? new();
}
}

/// <summary>
/// Filters the request headers based on the keys provided in <see cref="HeadersToPropagate"/>.
/// </summary>
/// <param name="requestHeaders">Headers collection from an Http request.</param>
/// <returns>Filtered headers.</returns>
private static Dictionary<string, StringValues> FilterHeaders(Dictionary<string, StringValues> requestHeaders)
{
var result = new Dictionary<string, StringValues>();

if (requestHeaders == null || requestHeaders.Count == 0)
{
return result;
}

// Ensure the default headers are always set by overriding the LoadHeaders configuration.
_headersToPropagate.Propagate("x-ms-correlation-id");

foreach (var header in HeadersToPropagate.Entries)
{
var headerExists = requestHeaders.TryGetValue(header.Key, out var requestHeader);

switch (header.Action)
{
case HeaderPropagationEntryAction.Add:
#if !NETSTANDARD
result.TryAdd(header.Key, header.Value);
#else
result.Add(header.Key, header.Value);
#endif
break;
case HeaderPropagationEntryAction.Append when headerExists:
StringValues newValue = requestHeader.Concat(header.Value).ToArray();
result[header.Key] = newValue;
break;
case HeaderPropagationEntryAction.Propagate when headerExists:
result[header.Key] = requestHeader;
break;
case HeaderPropagationEntryAction.Override:
result[header.Key] = header.Value;
break;
}
}

return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using Microsoft.Extensions.Primitives;

namespace Microsoft.Agents.Core.HeaderPropagation;

/// <summary>
/// Represents a single header entry used for header propagation.
/// </summary>
public class HeaderPropagationEntry
{
/// <summary>
/// Key of the header entry.
/// </summary>
public string Key { get; set; } = string.Empty;

/// <summary>
/// Value of the header entry.
/// </summary>
public StringValues Value { get; set; } = new StringValues(string.Empty);

/// <summary>
/// Action of the header entry (Add, Append, etc.).
/// </summary>
public HeaderPropagationEntryAction Action;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Collections.Generic;
using System.Linq;
using Microsoft.Extensions.Primitives;

namespace Microsoft.Agents.Core.HeaderPropagation;

/// <summary>
/// Represents a collection of all the header entries that are to be propagated to the outgoing request.
/// </summary>
public class HeaderPropagationEntryCollection
{
private readonly Dictionary<string, HeaderPropagationEntry> _entries = [];

/// <summary>
/// Gets the collection of header entries to be propagated to the outgoing request.
/// </summary>
public List<HeaderPropagationEntry> Entries
{
get => [.. _entries.Select(x => x.Value)];
}

/// <summary>
/// Attempts to add a new header entry to the collection.
/// </summary>
/// <remarks>
/// If the key already exists in the incoming request headers collection, it will be ignored.
/// </remarks>
/// <param name="key">The key of the element to add.</param>
/// <param name="value">The value to add for the specified key.</param>
public void Add(string key, StringValues value)
{
_entries[key] = new HeaderPropagationEntry
{
Key = key,
Value = value,
Action = HeaderPropagationEntryAction.Add
};
}

/// <summary>
/// Appends a new header value to an existing key.
/// </summary>
/// <remarks>
/// If the key does not exist in the incoming request headers collection, it will be ignored.
/// </remarks>
/// <param name="key">The key of the element to add.</param>
/// <param name="value">The value to add for the specified key.</param>
public void Append(string key, StringValues value)
{
_entries[key] = new HeaderPropagationEntry
{
Key = key,
Value = value,
Action = HeaderPropagationEntryAction.Append
};
}

/// <summary>
/// Propagates the incoming request header value to the outgoing request.
/// </summary>
/// <remarks>
/// If the key does not exist in the incoming request headers collection, it will be ignored.
/// </remarks>
/// <param name="key">The key of the element to add.</param>
public void Propagate(string key)
{
_entries[key] = new HeaderPropagationEntry
{
Key = key,
Action = HeaderPropagationEntryAction.Propagate
};
}

/// <summary>
/// Overrides the header value of an existing key.
/// </summary>
/// <remarks>
/// If the key does not exist in the incoming request headers collection, it will add it.
/// </remarks>
/// <param name="key">The key of the element to add.</param>
/// <param name="value">The value to add for the specified key.</param>
public void Override(string key, StringValues value)
{
_entries[key] = new HeaderPropagationEntry
{
Key = key,
Value = value,
Action = HeaderPropagationEntryAction.Override
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Runtime.Serialization;

namespace Microsoft.Agents.Core.HeaderPropagation;

/// <summary>
/// Represents the action of the header entry.
/// </summary>
public enum HeaderPropagationEntryAction
{
/// <summary>
/// Adds a new header entry to the outgoing request.
/// </summary>
[EnumMember(Value = "add")]
Add,

/// <summary>
/// Appends a new header value to an existing key in the outgoing request.
/// </summary>
[EnumMember(Value = "append")]
Append,

/// <summary>
/// Propagates the header entry from the incoming request to the outgoing request.
/// </summary>
[EnumMember(Value = "propagate")]
Propagate,

/// <summary>
/// Overrides an existing header entry in the outgoing request.
/// </summary>
[EnumMember(Value = "override")]
Override
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Net.Http;

namespace Microsoft.Agents.Core.HeaderPropagation;

public static class HeaderPropagationExtensions
{
/// <summary>
/// Loads incoming request headers based on a list of headers to propagate into the HttpClient.
/// </summary>
/// <param name="httpClient">The <see cref="HttpClient"/>.</param>
public static void AddHeaderPropagation(this HttpClient httpClient)
{
if (HeaderPropagationContext.HeadersFromRequest == null)
{
return;
}

foreach (var header in HeaderPropagationContext.HeadersFromRequest)
{
httpClient.DefaultRequestHeaders.TryAddWithoutValidation(header.Key, [header.Value]);
Copy link
Preview

Copilot AI Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing '[header.Value]' wraps header.Value in an array, which may not match the expected type. It is likely more appropriate to pass 'header.Value' directly.

Suggested change
httpClient.DefaultRequestHeaders.TryAddWithoutValidation(header.Key, [header.Value]);
httpClient.DefaultRequestHeaders.TryAddWithoutValidation(header.Key, header.Value);

Copilot uses AI. Check for mistakes.

}
}
}
Loading
Loading