Skip to content

Fix handling of record types in validations source generator #61402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterIn
file static class GeneratedServiceCollectionExtensions
{
{{addValidation.GetInterceptsLocationAttributeSyntax()}}
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<ValidationOptions>? configureOptions = null)
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<global::Microsoft.AspNetCore.Http.Validation.ValidationOptions>? configureOptions = null)
{
// Use non-extension method to avoid infinite recursion.
return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options =>
Expand All @@ -133,13 +133,39 @@ private sealed record CacheKey(global::System.Type ContainingType, string Proper
var key = new CacheKey(containingType, propertyName);
return _cache.GetOrAdd(key, static k =>
{
var results = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationAttribute>();

// Get attributes from the property
var property = k.ContainingType.GetProperty(k.PropertyName);
if (property == null)
if (property != null)
{
var propertyAttributes = global::System.Reflection.CustomAttributeExtensions
.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true);

results.AddRange(propertyAttributes);
}

// Check constructors for parameters that match the property name
// to handle record scenarios
foreach (var constructor in k.ContainingType.GetConstructors())
{
return [];
// Look for parameter with matching name (case insensitive)
var parameter = global::System.Linq.Enumerable.FirstOrDefault(
constructor.GetParameters(),
p => string.Equals(p.Name, k.PropertyName, global::System.StringComparison.OrdinalIgnoreCase));

if (parameter != null)
{
var paramAttributes = global::System.Reflection.CustomAttributeExtensions
.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(parameter, inherit: true);

results.AddRange(paramAttributes);

break;
}
}

return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true)];
return results.ToArray();
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace Microsoft.AspNetCore.Http.ValidationsGenerator;
Expand Down Expand Up @@ -101,4 +102,24 @@ internal static bool IsExemptType(this ITypeSymbol type, RequiredSymbols require
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.Stream)
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.PipeReader);
}

internal static IPropertySymbol? FindPropertyIncludingBaseTypes(this INamedTypeSymbol typeSymbol, string propertyName)
{
var property = typeSymbol.GetMembers()
.OfType<IPropertySymbol>()
.FirstOrDefault(p => string.Equals(p.Name, propertyName, System.StringComparison.OrdinalIgnoreCase));

if (property != null)
{
return property;
}

// If not found, recursively search base types
if (typeSymbol.BaseType is INamedTypeSymbol baseType)
{
return FindPropertyIncludingBaseTypes(baseType, propertyName);
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,74 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
{
var members = new List<ValidatableProperty>();
var resolvedRecordProperty = new List<IPropertySymbol>();

// Special handling for record types to extract properties from
// the primary constructor.
if (typeSymbol is INamedTypeSymbol { IsRecord: true } namedType)
{
// Find the primary constructor for the record, account
// for members that are in base types to account for
// record inheritance scenarios
var primaryConstructor = namedType.Constructors
.FirstOrDefault(c => c.Parameters.Length > 0 && c.Parameters.All(p =>
namedType.FindPropertyIncludingBaseTypes(p.Name) != null));

if (primaryConstructor != null)
{
// Process all parameters in constructor order to maintain parameter ordering
foreach (var parameter in primaryConstructor.Parameters)
{
// Find the corresponding property in this type, we ignore
// base types here since that will be handled by the inheritance
// checks in the default ValidatableTypeInfo implementation.
var correspondingProperty = typeSymbol.GetMembers()
.OfType<IPropertySymbol>()
.FirstOrDefault(p => string.Equals(p.Name, parameter.Name, System.StringComparison.OrdinalIgnoreCase));

if (correspondingProperty != null)
{
resolvedRecordProperty.Add(correspondingProperty);

// Check if the property's type is validatable, this resolves
// validatable types in the inheritance hierarchy
var hasValidatableType = TryExtractValidatableType(
correspondingProperty.Type.UnwrapType(requiredSymbols.IEnumerable),
requiredSymbols,
ref validatableTypes,
ref visitedTypes);

members.Add(new ValidatableProperty(
ContainingType: correspondingProperty.ContainingType,
Type: correspondingProperty.Type,
Name: correspondingProperty.Name,
DisplayName: parameter.GetDisplayName(requiredSymbols.DisplayAttribute) ??
correspondingProperty.GetDisplayName(requiredSymbols.DisplayAttribute),
Attributes: []));
}
}
}
}

// Handle properties for classes and any properties not handled by the constructor
foreach (var member in typeSymbol.GetMembers().OfType<IPropertySymbol>())
{
// Skip compiler generated properties and properties already processed via
// the record processing logic above.
if (member.IsImplicitlyDeclared || resolvedRecordProperty.Contains(member, SymbolEqualityComparer.Default))
{
continue;
}

var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes);
var attributes = ExtractValidationAttributes(member, requiredSymbols, out var isRequired);

// If the member has no validation attributes or validatable types and is not required, skip it.
if (attributes.IsDefaultOrEmpty && !hasValidatableType && !isRequired)
{
continue;
}

members.Add(new ValidatableProperty(
ContainingType: member.ContainingType,
Type: member.Type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ public class ComplexType
public int IntegerWithRangeAndDisplayName { get; set; } = 50;
[Required]
public SubType PropertyWithMemberAttributes { get; set; } = new SubType();
public SubType PropertyWithMemberAttributes { get; set; } = new SubType("some-value", default);
public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType();
public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType("some-value", default);
public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance();
public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance("some-value", default);
public List<SubType> ListOfSubTypes { get; set; } = [];
Expand All @@ -62,16 +62,16 @@ public class DerivedValidationAttribute : ValidationAttribute
public override bool IsValid(object? value) => value is int number && number % 2 == 0;
}
public class SubType
public class SubType(string? requiredProperty, string? stringWithLength)
{
[Required]
public string RequiredProperty { get; set; } = "some-value";
public string RequiredProperty { get; } = requiredProperty;
[StringLength(10)]
public string? StringWithLength { get; set; }
public string? StringWithLength { get; } = stringWithLength;
}
public class SubTypeWithInheritance : SubType
public class SubTypeWithInheritance(string? requiredProperty, string? stringWithLength) : SubType(requiredProperty, stringWithLength)
{
[EmailAddress]
public string? EmailString { get; set; }
Expand Down
Loading
Loading