Skip to content

Commit

Permalink
Add support for System.Text.Json converter
Browse files Browse the repository at this point in the history
This adds native support for asp.net core minimal API parameter binding and output rendering.

We simplify the template-based approach since it's highly reusable as of now: both parsable and STJ generation is identical except for the text templates themselves.
  • Loading branch information
kzu committed Nov 24, 2024
1 parent 3f16349 commit 981d2fd
Show file tree
Hide file tree
Showing 14 changed files with 215 additions and 109 deletions.
7 changes: 5 additions & 2 deletions src/Sample/Program.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using Microsoft.AspNetCore.Mvc;
using Sample;

var builder = WebApplication.CreateBuilder(args);
var app = builder.Build();

app.MapGet("/{id}", (UserId id) => id);
app.MapGet("/{id}", (UserId id) => new User(id, "kzu"));

app.Run();

readonly partial record struct UserId : IStructId;
readonly partial record struct UserId : IStructId<int>;

record User(UserId id, string Alias);
2 changes: 1 addition & 1 deletion src/Sample/Sample.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="StructId" Version="42.253.1273" />
<PackageReference Include="StructId" Version="42.254.378" />
</ItemGroup>

</Project>
9 changes: 9 additions & 0 deletions src/StructId.Analyzer/JsonConverterGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Microsoft.CodeAnalysis;

namespace StructId;

[Generator(LanguageNames.CSharp)]
public class JsonConverterGenerator() : TemplateGenerator(
"System.IParsable`1",
ThisAssembly.Resources.Templates.SJsonConverter.Text,
ThisAssembly.Resources.Templates.TJsonConverter.Text);
94 changes: 4 additions & 90 deletions src/StructId.Analyzer/ParsableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,93 +9,7 @@
namespace StructId;

[Generator(LanguageNames.CSharp)]
public class ParsableGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Locate the IParseable<T> type
var parseable = context.CompilationProvider
.Select((x, _) => x.GetTypeByMetadataName("System.IParsable`1"));

var ids = context.CompilationProvider
.SelectMany((x, _) => x.Assembly.GetAllTypes().OfType<INamedTypeSymbol>())
.Where(x => x.IsStructId())
.Where(x => x.IsPartial());

var combined = ids.Combine(parseable)
.Where(x =>
{
var (id, parseable) = x;

// NOTE: we never generate for compilations that don't have the IParsable<T> type (i.e. .NET6)
if (parseable == null)
return false;

var type = id.AllInterfaces
.First(x => x.Name == "IStructId")
.TypeArguments.FirstOrDefault();

// If we don't have a generic type of IStructId, then it's the string-based one
// which we can always parse
if (type == null)
return true;

return type.Is(parseable);
})
.Select((x, _) => x.Left);

context.RegisterImplementationSourceOutput(combined, GenerateCode);
}

void GenerateCode(SourceProductionContext context, INamedTypeSymbol symbol)
{
var ns = symbol.ContainingNamespace.Equals(symbol.ContainingModule.GlobalNamespace, SymbolEqualityComparer.Default)
? null
: symbol.ContainingNamespace.ToDisplayString();

// Generic IStructId<T> -> T, otherwise string
var type = symbol.AllInterfaces.First(x => x.Name == "IStructId").TypeArguments.
Select(x => x.GetTypeName(ns)).FirstOrDefault() ?? "string";

var template = type == "string"
? ThisAssembly.Resources.Templates.SParseable.Text
: ThisAssembly.Resources.Templates.TParseable.Text;

// parse template into a C# compilation unit
var parseable = CSharpSyntaxTree.ParseText(template).GetCompilationUnitRoot();

// if we got a ns, move all members after a file-scoped namespace declaration
if (ns != null)
{
var members = parseable.Members;
var fsns = FileScopedNamespaceDeclaration(ParseName(ns).WithLeadingTrivia(Whitespace(" ")))
.WithLeadingTrivia(LineFeed)
.WithTrailingTrivia(LineFeed)
.WithMembers(members);
parseable = parseable.WithMembers(SingletonList<MemberDeclarationSyntax>(fsns));
}

// replace all nodes with the identifier TStruct/SStruct with symbol.Name
var structIds = parseable.DescendantNodes()
.OfType<IdentifierNameSyntax>()
.Where(x => x.Identifier.Text == "TStruct" || x.Identifier.Text == "SStruct");
parseable = parseable.ReplaceNodes(structIds, (node, _) => IdentifierName(symbol.Name)
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(node.GetTrailingTrivia()));

var structTokens = parseable.DescendantTokens()
.OfType<SyntaxToken>()
.Where(x => x.IsKind(SyntaxKind.IdentifierToken))
.Where(x => x.Text == "TStruct" || x.Text == "SStruct");
// replace with a new identifier with symbol.name
parseable = parseable.ReplaceTokens(structTokens, (token, _) => Identifier(symbol.Name)
.WithLeadingTrivia(token.LeadingTrivia)
.WithTrailingTrivia(token.TrailingTrivia));

// replace all nodes with the identifier TValue with actual type
var placeholder = parseable.DescendantNodes().OfType<IdentifierNameSyntax>().Where(x => x.Identifier.Text == "TValue");
parseable = parseable.ReplaceNodes(placeholder, (_, _) => IdentifierName(type));

context.AddSource($"{symbol.ToFileName()}.parsable.cs", SourceText.From(parseable.ToFullString(), Encoding.UTF8));
}
}
public class ParsableGenerator() : TemplateGenerator(
"System.IParsable`1",
ThisAssembly.Resources.Templates.SParsable.Text,
ThisAssembly.Resources.Templates.TParsable.Text);
87 changes: 87 additions & 0 deletions src/StructId.Analyzer/TemplateGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using System;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace StructId;

public abstract class TemplateGenerator(string valueInterface, string stringTemplate, string typeTemplate) : IIncrementalGenerator
{
record struct TemplateArgs(string TargetNamespace, INamedTypeSymbol StructId, INamedTypeSymbol ValueType, INamedTypeSymbol InterfaceType, INamedTypeSymbol StringType);

public void Initialize(IncrementalGeneratorInitializationContext context)
{
var targetNamespace = context.AnalyzerConfigOptionsProvider
.Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId");

// Locate the required types
var types = context.CompilationProvider
.Select((x, _) => (InterfaceType: x.GetTypeByMetadataName(valueInterface), StringType: x.GetTypeByMetadataName("System.String")));

var ids = context.CompilationProvider
.SelectMany((x, _) => x.Assembly.GetAllTypes().OfType<INamedTypeSymbol>())
.Where(x => x.IsStructId())
.Where(x => x.IsPartial());

var combined = ids.Combine(types)
// NOTE: we never generate for compilations that don't have the specified value interface type
.Where(x => x.Right.InterfaceType != null || x.Right.StringType == null)
.Combine(targetNamespace)
.Select((x, _) =>
{
var ((structId, (interfaceType, stringType)), targetNamespace) = x;

// The value type is either a generic type argument for IStructId<T>, or the string type
// for the non-generic IStructId
var valueType = structId.AllInterfaces
.First(x => x.Name == "IStructId")
.TypeArguments.OfType<INamedTypeSymbol>().FirstOrDefault() ??
stringType!;

return new TemplateArgs(targetNamespace, structId, valueType, interfaceType!, stringType!);
})
.Where(x => x.ValueType.Is(x.InterfaceType));

context.RegisterImplementationSourceOutput(combined, GenerateCode);
}

void GenerateCode(SourceProductionContext context, TemplateArgs args)
{
var ns = args.StructId.ContainingNamespace.Equals(args.StructId.ContainingModule.GlobalNamespace, SymbolEqualityComparer.Default)
? null
: args.StructId.ContainingNamespace.ToDisplayString();

var template = args.ValueType.Equals(args.StringType, SymbolEqualityComparer.Default)
? stringTemplate : typeTemplate;

// replace tokens in the template
template = template
// Adjust to current target namespace
.Replace("namespace StructId;", $"namespace {args.TargetNamespace};")
.Replace("using StructId;", $"using {args.TargetNamespace};")
// Simple names suffices since we emit a partial in the same namespace
.Replace("TStruct", args.StructId.Name)
.Replace("SStruct", args.StructId.Name)
.Replace("TValue", args.ValueType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));

// parse template into a C# compilation unit
var parseable = CSharpSyntaxTree.ParseText(template).GetCompilationUnitRoot();

// if we got a ns, move all members after a file-scoped namespace declaration
if (ns != null)
{
var members = parseable.Members;
var fsns = FileScopedNamespaceDeclaration(ParseName(ns).WithLeadingTrivia(Whitespace(" ")))
.WithLeadingTrivia(LineFeed)
.WithTrailingTrivia(LineFeed)
.WithMembers(members);
parseable = parseable.WithMembers(SingletonList<MemberDeclarationSyntax>(fsns));
}

context.AddSource($"{args.StructId.ToFileName()}.cs", SourceText.From(parseable.ToFullString(), Encoding.UTF8));
}
}
1 change: 0 additions & 1 deletion src/StructId.Package/StructId.props
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

<ItemGroup>
<CompilerVisibleProperty Include="StructIdNamespace" />
<CompilerVisibleProperty Include="RootNamespace" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace StructId;

public class ParseableGeneratorTests
public class ParsableGeneratorTests
{
[Fact]
public async Task GenerateParseable()
Expand All @@ -26,8 +26,8 @@ public async Task GenerateParseable()
},
GeneratedSources =
{
(typeof(ParsableGenerator), @"UserId.parsable.cs",
ThisAssembly.Resources.StructId.Templates.TParseable.Text.Replace("TStruct", "UserId").Replace("TValue", "System.Int32"),
(typeof(ParsableGenerator), "UserId.cs",
ThisAssembly.Resources.StructId.Templates.TParsable.Text.Replace("TStruct", "UserId").Replace("TValue", "int"),
Encoding.UTF8)
},
},
Expand All @@ -54,8 +54,8 @@ public async Task GenerateStringParseable()
},
GeneratedSources =
{
(typeof(ParsableGenerator), @"UserId.parsable.cs",
ThisAssembly.Resources.StructId.Templates.SParseable.Text.Replace("SStruct", "UserId"),
(typeof(ParsableGenerator), "UserId.cs",
ThisAssembly.Resources.StructId.Templates.SParsable.Text.Replace("SStruct", "UserId"),
Encoding.UTF8)
},
},
Expand Down
64 changes: 64 additions & 0 deletions src/StructId/StructIdConverters.JsonConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// <auto-generated/>

using System;
using System.Globalization;
using System.Text.Json;
using System.Text.Json.Serialization;

namespace StructId;

public static partial class StructIdConverters
{
#if NET7_0_OR_GREATER
public class SystemTextJsonConverter<TStruct, TValue> : JsonConverter<TStruct>
where TStruct : IStructId<TValue>, IParsable<TStruct>
where TValue: struct
{
public override TStruct Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
return TStruct.Parse(reader.GetString(), CultureInfo.InvariantCulture);
}

public override void Write(Utf8JsonWriter writer, TStruct value, JsonSerializerOptions options)
{
switch (value.Value)
{
case Guid guid:
writer.WriteStringValue(guid);
break;
case TValue inner:
writer.WriteRawValue(inner.ToString());
break;
default:
throw new InvalidOperationException("Unsupported value type.");
}
}

public override TStruct ReadAsPropertyName(ref global::System.Text.Json.Utf8JsonReader reader, global::System.Type typeToConvert, global::System.Text.Json.JsonSerializerOptions options)
=> TStruct.Parse(reader.GetString() ?? throw new FormatException("Unsupported null value for struct id."), CultureInfo.InvariantCulture);

public override void WriteAsPropertyName(global::System.Text.Json.Utf8JsonWriter writer, TStruct value, global::System.Text.Json.JsonSerializerOptions options)
=> writer.WritePropertyName(value.Value.ToString());
}

public class SystemTextJsonConverter<TStruct> : JsonConverter<TStruct>
where TStruct : IStructId, IParsable<TStruct>
{
public override TStruct Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
return TStruct.Parse(reader.GetString() ?? throw new FormatException("Unsupported null value for struct id."), CultureInfo.InvariantCulture);
}

public override void Write(Utf8JsonWriter writer, TStruct value, JsonSerializerOptions options)
{
writer.WriteStringValue(value.Value);
}

public override TStruct ReadAsPropertyName(ref global::System.Text.Json.Utf8JsonReader reader, global::System.Type typeToConvert, global::System.Text.Json.JsonSerializerOptions options)
=> TStruct.Parse(reader.GetString() ?? throw new FormatException("Unsupported null value for struct id."), CultureInfo.InvariantCulture);

public override void WriteAsPropertyName(global::System.Text.Json.Utf8JsonWriter writer, TStruct value, global::System.Text.Json.JsonSerializerOptions options)
=> writer.WritePropertyName(value.Value);
}
#endif
}
2 changes: 1 addition & 1 deletion src/StructId/StructIdConverters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace StructId;

/// <summary>
/// Type converter for <see cref="IStructId{T}"/> and <see cref="IStructId"/>.
/// Type converters for <see cref="IStructId{T}"/> and <see cref="IStructId"/>.
/// </summary>
public static partial class StructIdConverters
{
Expand Down
12 changes: 12 additions & 0 deletions src/StructId/Templates/SJsonConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// <auto-generated/>
#nullable enable

using System.Text.Json.Serialization;
using StructId;

#if NET7_0_OR_GREATER
[JsonConverter(typeof(StructIdConverters.SystemTextJsonConverter<SStruct>))]
#endif
readonly partial record struct SStruct
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

readonly partial record struct SStruct : IParsable<SStruct>
{
public static SStruct Parse(string s, IFormatProvider? provider) => new(s);
public static SStruct Parse(string s, IFormatProvider? provider)
=> s is null ? throw new ArgumentNullException(nameof(s)) : new(s);

public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out SStruct result)
{
if (s is not null)
Expand Down
12 changes: 12 additions & 0 deletions src/StructId/Templates/TJsonConverter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// <auto-generated/>
#nullable enable

using System.Text.Json.Serialization;
using StructId;

#if NET7_0_OR_GREATER
[JsonConverter(typeof(StructIdConverters.SystemTextJsonConverter<TStruct, TValue>))]
#endif
readonly partial record struct TStruct
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

readonly partial record struct TStruct : IParsable<TStruct>
{
public static TStruct Parse(string s, IFormatProvider? provider) => new(TValue.Parse(s));
public static TStruct Parse(string s, IFormatProvider? provider) => new(TValue.Parse(s, provider));

public static bool TryParse([NotNullWhen(true)] string? s, IFormatProvider? provider, [MaybeNullWhen(false)] out TStruct result)
{
if (TValue.TryParse(s, out var value))
if (TValue.TryParse(s, provider, out var value))
{
result = new TStruct(value);
return true;
Expand Down
Loading

0 comments on commit 981d2fd

Please sign in to comment.