Skip to content

Allow to register non-essential wrapper types. #8134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions src/HotChocolate/Core/src/Types/Internal/ExtendedType.Helper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ namespace HotChocolate.Internal;

internal sealed partial class ExtendedType
{
internal static ImmutableArray<Type> NonEssentialWrapperTypes { get; set; } =
[typeof(ValueTask<>), typeof(Task<>), typeof(NativeType<>), typeof(Optional<>)];

private static class Helper
{
internal static bool IsSchemaType(Type type)
Expand All @@ -23,9 +26,8 @@ internal static bool IsSchemaType(Type type)
if (type.IsGenericType)
{
var definition = type.GetGenericTypeDefinition();
if (typeof(ListType<>) == definition
|| typeof(NonNullType<>) == definition
|| typeof(NativeType<>) == definition)
var nonEssentialWrapperTypes = NonEssentialWrapperTypes;
if (nonEssentialWrapperTypes.Contains(definition))
{
return IsSchemaType(type.GetGenericArguments()[0]);
}
Expand All @@ -36,25 +38,28 @@ internal static bool IsSchemaType(Type type)

internal static Type RemoveNonEssentialTypes(Type type)
{
if (type.IsGenericType
&& (type.GetGenericTypeDefinition() == typeof(NativeType<>)
|| type.GetGenericTypeDefinition() == typeof(ValueTask<>)
|| type.GetGenericTypeDefinition() == typeof(Task<>)))
if (type.IsGenericType)
{
return RemoveNonEssentialTypes(type.GetGenericArguments()[0]);
var definition = type.GetGenericTypeDefinition();
var nonEssentialWrapperTypes = NonEssentialWrapperTypes;
if (nonEssentialWrapperTypes.Contains(definition))
{
return RemoveNonEssentialTypes(type.GetGenericArguments()[0]);
}
}

return type;
}

internal static IExtendedType RemoveNonEssentialTypes(IExtendedType type)
{
if (type.IsGeneric
&& (type.Definition == typeof(NativeType<>)
|| type.Definition == typeof(ValueTask<>)
|| type.Definition == typeof(Task<>)))
if (type.IsGeneric)
{
return RemoveNonEssentialTypes(type.TypeArguments[0]);
var nonEssentialWrapperTypes = NonEssentialWrapperTypes;
if (nonEssentialWrapperTypes.Contains(type.Definition))
{
return RemoveNonEssentialTypes(type.TypeArguments[0]);
}
}

return type;
Expand Down Expand Up @@ -244,7 +249,7 @@ internal static ExtendedTypeId CreateIdentifier(IExtendedType type)
{
var position = 0;
Span<bool> nullability = stackalloc bool[32];
CollectNullability(type, nullability, ref position);
CollectNullability(type, ref nullability, ref position);

return CreateIdentifier(
type.Source,
Expand All @@ -258,7 +263,7 @@ internal static ExtendedTypeId CreateIdentifier(
{
var position = 0;
Span<bool> nullability = stackalloc bool[32];
CollectNullability(type, nullability, ref position);
CollectNullability(type, ref nullability, ref position);
nullability = nullability.Slice(0, position);

var length = nullability.Length < nullabilityChange.Length
Expand Down Expand Up @@ -293,7 +298,7 @@ private static ExtendedTypeId CreateIdentifier(

internal static void CollectNullability(
IExtendedType type,
Span<bool> nullability,
ref Span<bool> nullability,
ref int position)
{
if (position >= 32)
Expand All @@ -306,7 +311,7 @@ internal static void CollectNullability(

foreach (var typeArgument in type.TypeArguments)
{
CollectNullability(typeArgument, nullability, ref position);
CollectNullability(typeArgument, ref nullability, ref position);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ internal static IExtendedType ChangeNullability(
{
var length = 0;
Span<bool> buffer = stackalloc bool[32];
Helper.CollectNullability(type, buffer, ref length);
Helper.CollectNullability(type, ref buffer, ref length);
buffer = buffer.Slice(0, length);

var nullability = new bool?[buffer.Length];
Expand All @@ -128,7 +128,7 @@ internal static bool CollectNullability(
{
var length = 0;
Span<bool> buffer = stackalloc bool[32];
Helper.CollectNullability(type, buffer, ref length);
Helper.CollectNullability(type, ref buffer, ref length);
buffer = buffer.Slice(0, length);

if (nullability.Length < buffer.Length)
Expand Down
22 changes: 22 additions & 0 deletions src/HotChocolate/Core/src/Types/Internal/ExtendedType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,26 @@ public static ExtendedMethodInfo FromMethod(MethodInfo method, TypeCache cache)

return Members.FromMethod(method, cache);
}

public static void RegisterNonEssentialWrapperTypes(Type type)
{
if (type is null)
{
throw new ArgumentNullException(nameof(type));
}

if(!type.IsGenericTypeDefinition)
{
throw new ArgumentException(
"The type must be a generic type definition.",
nameof(type));
}

if(NonEssentialWrapperTypes.Contains(type))
{
return;
}

NonEssentialWrapperTypes = NonEssentialWrapperTypes.Add(type);
}
}
1 change: 1 addition & 0 deletions src/HotChocolate/Core/src/Types/Internal/IExtendedType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public interface IExtendedType : IEquatable<IExtendedType>
/// <summary>
/// Defines that this type is a generic type.
/// </summary>
[MemberNotNullWhen(true, nameof(Definition))]
bool IsGeneric { get; }

/// <summary>
Expand Down
33 changes: 16 additions & 17 deletions src/HotChocolate/Core/src/Types/Internal/TypeInfo.RuntimeType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ private static IExtendedType RemoveNonEssentialParts(IExtendedType type)
short i = 0;
var current = type;

while (IsWrapperType(current) ||
IsTaskType(current) ||
IsOptional(current) ||
IsFieldResult(current))
while (IsNonEssentialPart(current))
{
current = type.TypeArguments[0];

Expand All @@ -109,20 +106,22 @@ private static IExtendedType RemoveNonEssentialParts(IExtendedType type)
return current;
}

private static bool IsWrapperType(IExtendedType type) =>
type.IsGeneric &&
typeof(NativeType<>) == type.Definition;

private static bool IsTaskType(IExtendedType type) =>
type.IsGeneric &&
(typeof(Task<>) == type.Definition ||
typeof(ValueTask<>) == type.Definition);
public static bool IsNonEssentialPart(IExtendedType type)
{
if (type.IsGeneric)
{
if (ExtendedType.NonEssentialWrapperTypes.Contains(type.Definition))
{
return true;
}

private static bool IsOptional(IExtendedType type) =>
type.IsGeneric &&
typeof(Optional<>) == type.Definition;
if (typeof(IFieldResult).IsAssignableFrom(type))
{
return true;
}
}

private static bool IsFieldResult(IExtendedType type) =>
type.IsGeneric && typeof(IFieldResult).IsAssignableFrom(type);
return false;
}
}
}
Loading