diff --git a/docs/readme.md b/docs/readme.md index 6ee2f2fb..9182b72b 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -9,6 +9,7 @@ Index: - [SQL Syntax](/sqlsyntax) - [Generated Code](/generatedcode) - [Bulk Copy](/bulkcopy) +- [Type Handlers](/typehandlers) - [Frequently Asked Questions](/faq) Packages: diff --git a/docs/rules/DAP048.md b/docs/rules/DAP048.md new file mode 100644 index 00000000..6e29c162 --- /dev/null +++ b/docs/rules/DAP048.md @@ -0,0 +1,20 @@ +# DAP048 + +Duplicate classes have been registered as type handlers for the same type, +meaning it's not possible to determine which to use when handling the type. +Note type handlers can be registered at the assembly and module level, so +ensure the type used for the `TValue` parameter in the attribute is only +specified once. + +Error: + +``` c# +[module: TypeHandler] +[module: TypeHandler] +``` + +Good: + +``` c# +[module: TypeHandler] +``` diff --git a/docs/typehandlers.md b/docs/typehandlers.md new file mode 100644 index 00000000..0489f2f0 --- /dev/null +++ b/docs/typehandlers.md @@ -0,0 +1,21 @@ +# Type Handlers + +At times you might want to customise how a type is read from a query or how it +is saved in a parameter. In Dapper you might use a `SqlMapper.TypeHandler` for +this, which has a slightly altered interface in the AOT version and a different +way of registering them. + +To register your own type handler, use either an assembly or module level +attribute to specify the mapping (you can replace `module` with `assembly` +below, it has the same effect): + +``` csharp +using Dapper; + +[module: TypeHandler] +``` + +Your type handler must inherit from `Dapper.TypeHandler` and be default +constructable. The methods are virtual, so you can override only which ones you +need (e.g. if you're just interested in reading your values and not using them +as parameters, you only need to override the `Read` method). \ No newline at end of file diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs index cc07fa11..78055ced 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Diagnostics.cs @@ -12,6 +12,8 @@ internal static readonly DiagnosticDescriptor LanguageVersionTooLow = LibraryWarning("DAP004", "Language version too low", "Interceptors require at least C# version 11"), CommandPropertyNotFound = LibraryWarning("DAP033", "Command property not found", "Command property {0}.{1} was not found or was not valid; attribute will be ignored"), - CommandPropertyReserved = LibraryWarning("DAP034", "Command property reserved", "Command property {1} is reserved for internal usage; attribute will be ignored"); + CommandPropertyReserved = LibraryWarning("DAP034", "Command property reserved", "Command property {1} is reserved for internal usage; attribute will be ignored"), + + DuplicateTypeHandlers = LibraryError("DAP048", "Duplicate type handlers", "Type {0} has multiple type handlers registered"); } } diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs index 83fd1be8..62888b11 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs @@ -218,7 +218,8 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm { try { - Generate(new(ctx, state)); + var typeHandlers = IdentifyTypeHandlers(ctx, state.Compilation); + Generate(new(ctx, state.Compilation, state.Nodes, typeHandlers)); } catch (Exception ex) { @@ -490,11 +491,11 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory else { sb.Append("public override void AddParameters(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args)").Indent().NewLine(); - WriteArgs(type, sb, WriteArgsMode.Add, map, ref flags); + WriteArgs(ctx, type, sb, WriteArgsMode.Add, map, ref flags); sb.Outdent().NewLine(); sb.Append("public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args)").Indent().NewLine(); - WriteArgs(type, sb, WriteArgsMode.Update, map, ref flags); + WriteArgs(ctx, type, sb, WriteArgsMode.Update, map, ref flags); sb.Outdent().NewLine(); if ((flags & (WriteArgsFlags.NeedsRowCount | WriteArgsFlags.NeedsPostProcess)) != 0) @@ -507,11 +508,11 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory sb.Append("public override void PostProcess(in global::Dapper.UnifiedCommand cmd, ").Append(declaredType).Append(" args, int rowCount)").Indent().NewLine(); if ((flags & WriteArgsFlags.NeedsPostProcess) != 0) { - WriteArgs(type, sb, WriteArgsMode.PostProcess, map, ref flags); + WriteArgs(ctx, type, sb, WriteArgsMode.PostProcess, map, ref flags); } if ((flags & WriteArgsFlags.NeedsRowCount) != 0) { - WriteArgs(type, sb, WriteArgsMode.SetRowCount, map, ref flags); + WriteArgs(ctx, type, sb, WriteArgsMode.SetRowCount, map, ref flags); } if (baseFactory != DapperBaseCommandFactory) { @@ -524,7 +525,7 @@ private static void WriteCommandFactory(in GenerateState ctx, string baseFactory { sb.Append("public override global::System.Threading.CancellationToken GetCancellationToken(").Append(declaredType).Append(" args)") .Indent().NewLine(); - WriteArgs(type, sb, WriteArgsMode.GetCancellationToken, map, ref flags); + WriteArgs(ctx, type, sb, WriteArgsMode.GetCancellationToken, map, ref flags); sb.Outdent().NewLine(); } } @@ -702,7 +703,7 @@ static bool IsReserved(string name) } } - private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITypeSymbol type, int index) + private static void WriteRowFactory(in GenerateState ctx, CodeWriter sb, ITypeSymbol type, int index) { var map = MemberMap.CreateForResults(type); if (map is null) return; @@ -723,6 +724,7 @@ private static void WriteRowFactory(in GenerateState context, CodeWriter sb, ITy var hasGetOnlyMembers = members.Any(member => member is { IsGettable: true, IsSettable: false, IsInitOnly: false }); var useConstructorDeferred = map.Constructor is not null; var useFactoryMethodDeferred = map.FactoryMethod is not null; + var typeHandlers = ctx.TypeHandlers; // Prevent ctx getting captured // Implementation detail: // constructor takes advantage over factory method. @@ -756,18 +758,29 @@ void WriteTokenizeMethod() .Append("var type = reader.GetFieldType(columnOffset);").NewLine() .Append("switch (NormalizedHash(name))").Indent().NewLine(); - int token = 0; + int firstToken = 0; + int secondToken = map.Members.Length; foreach (var member in members) { var dbName = member.DbName; sb.Append("case ").Append(StringHashing.NormalizedHash(dbName)) .Append(" when NormalizedEquals(name, ") - .AppendVerbatimLiteral(StringHashing.Normalize(dbName)).Append("):").Indent(false).NewLine() - .Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(token) - .Append(" : ").Append(token + map.Members.Length).Append(";") - .Append(token == 0 ? " // two tokens for right-typed and type-flexible" : "").NewLine() + .AppendVerbatimLiteral(StringHashing.Normalize(dbName)).Append("):").Indent(false).NewLine(); + + if (typeHandlers.TryGetValue(member.CodeType, out var typeHandler)) + { + sb.Append("token = ").Append(firstToken).Append(";"); + } + else + { + sb.Append("token = type == typeof(").Append(Inspection.MakeNonNullable(member.CodeType)).Append(") ? ").Append(firstToken) + .Append(" : ").Append(secondToken).Append(";"); + secondToken++; + } + + sb.Append(firstToken == 0 ? " // two tokens for right-typed and type-flexible" : "").NewLine() .Append("break;").Outdent(false).NewLine(); - token++; + firstToken++; } sb.Outdent().NewLine() .Append("tokens[i] = token;").NewLine() @@ -825,45 +838,55 @@ void WriteReadMethod() sb.Append("foreach (var token in tokens)").Indent().NewLine() .Append("switch (token)").Indent().NewLine(); - token = 0; + int firstToken = 0; + int secondToken = members.Length; foreach (var member in members) { var memberType = member.CodeType; member.GetDbType(out var readerMethod); var nullCheck = Inspection.CouldBeNullable(memberType) ? $"reader.IsDBNull(columnOffset) ? ({CodeWriter.GetTypeName(memberType.WithNullableAnnotation(NullableAnnotation.Annotated))})null : " : ""; - sb.Append("case ").Append(token).Append(":").NewLine().Indent(false); + sb.Append("case ").Append(firstToken).Append(":").NewLine().Indent(false); // write `result.X = ` or `member0 = ` - if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token); + if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(firstToken); else sb.Append("result.").Append(member.CodeName); sb.Append(" = "); sb.Append(nullCheck); - if (readerMethod is null) + if (typeHandlers.TryGetValue(memberType, out var handler)) { - sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);"); + sb.Append("new ").Append(handler).Append("().Read(reader, columnOffset);").NewLine() + .Append("break;").NewLine().Outdent(false); } else { - sb.Append("reader.").Append(readerMethod).Append("(columnOffset);"); - } + if (readerMethod is null) + { + sb.Append("reader.GetFieldValue<").Append(memberType).Append(">(columnOffset);"); + } + else + { + sb.Append("reader.").Append(readerMethod).Append("(columnOffset);"); + } + sb.NewLine().Append("break;").NewLine().Outdent(false) + .Append("case ").Append(secondToken).Append(":").NewLine().Indent(false); - sb.NewLine().Append("break;").NewLine().Outdent(false) - .Append("case ").Append(token + map.Members.Length).Append(":").NewLine().Indent(false); + // write `result.X = ` or `member0 = ` + if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(firstToken); + else sb.Append("result.").Append(member.CodeName); - // write `result.X = ` or `member0 = ` - if (useDeferredConstruction) sb.Append(DeferredConstructionVariableName).Append(token); - else sb.Append("result.").Append(member.CodeName); + sb.Append(" = ") + .Append(nullCheck) + .Append("GetValue<") + .Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine() + .Append("break;").NewLine().Outdent(false); - sb.Append(" = ") - .Append(nullCheck) - .Append("GetValue<") - .Append(Inspection.MakeNonNullable(memberType)).Append(">(reader, columnOffset);").NewLine() - .Append("break;").NewLine().Outdent(false); + secondToken++; + } - token++; + firstToken++; } sb.Outdent().NewLine().Append("columnOffset++;").NewLine().Outdent().NewLine(); @@ -966,7 +989,7 @@ enum WriteArgsMode GetCancellationToken } - private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteArgsMode mode, string map, ref WriteArgsFlags flags) + private static void WriteArgs(in GenerateState ctx, ITypeSymbol? parameterType, CodeWriter sb, WriteArgsMode mode, string map, ref WriteArgsFlags flags) { if (parameterType is null) { @@ -1130,7 +1153,7 @@ private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteAr } else { - sb.Append("p.Value = ").Append("AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine(); + AppendSetValue(ctx, sb, "p", source, member); } break; default: @@ -1149,30 +1172,32 @@ private static void WriteArgs(ITypeSymbol? parameterType, CodeWriter sb, WriteAr } break; case WriteArgsMode.Update: - sb.Append("ps["); - if ((flags & WriteArgsFlags.NeedsTest) != 0) sb.AppendVerbatimLiteral(member.DbName); - else sb.Append(parameterIndex); - sb.Append("].Value = "); + var parameter = GetParameterIndex(flags, member.DbName, parameterIndex); switch (direction) { case ParameterDirection.Input: case ParameterDirection.InputOutput: - sb.Append("AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine(); + AppendSetValue(ctx, sb, parameter, source, member); break; default: - sb.Append("global::System.DBNull.Value;").NewLine(); + sb.Append(parameter).Append(".Value = global::System.DBNull.Value;").NewLine(); break; } break; case WriteArgsMode.PostProcess: // we already eliminated args that we don't need to look at - sb.Append(source).Append(".").Append(member.CodeName).Append(" = Parse<") - .Append(member.CodeType).Append(">(ps["); - if ((flags & WriteArgsFlags.NeedsTest) != 0) sb.AppendVerbatimLiteral(member.DbName); - else sb.Append(parameterIndex); - sb.Append("].Value);").NewLine(); - + parameter = GetParameterIndex(flags, member.DbName, parameterIndex); + sb.Append(source).Append(".").Append(member.CodeName).Append(" = "); + if (ctx.TypeHandlers.TryGetValue(member.CodeType, out var handler)) + { + sb.Append("new ").Append(handler).Append("().Parse(").Append(parameter).Append(");").NewLine(); + } + else + { + sb.Append(source).Append(".").Append(member.CodeName).Append("Parse<") + .Append(member.CodeType).Append(">(").Append(parameter).Append(".Value);").NewLine(); + } break; } if (test) @@ -1198,6 +1223,20 @@ static void AppendDbParameterSetting(CodeWriter sb, string memberName, byte? val } } + private static void AppendSetValue(in GenerateState ctx, CodeWriter sb, string parameter, string? source, in Inspection.ElementMember member) + { + if (ctx.TypeHandlers.TryGetValue(member.CodeType, out var handler)) + { + sb.Append("new ").Append(handler).Append("().SetValue(") + .Append(parameter).Append(", ").Append(source).Append(".").Append(member.CodeName) + .Append(");").NewLine(); + } + else + { + sb.Append(parameter).Append(".Value = AsValue(").Append(source).Append(".").Append(member.CodeName).Append(");").NewLine(); + } + } + private static void AppendShapeLambda(CodeWriter sb, ITypeSymbol parameterType) { var members = parameterType.GetMembers(); @@ -1227,6 +1266,15 @@ private static void AppendShapeLambda(CodeWriter sb, ITypeSymbol parameterType) } } + private static string GetParameterIndex(WriteArgsFlags flags, string dbName, int parameterIndex) + { + string index = ((flags & WriteArgsFlags.NeedsTest) != 0) + ? CodeWriter.CreateVerbatimLiteral(dbName) + : parameterIndex.ToString(CultureInfo.InvariantCulture); + + return "ps[" + index + "]"; + } + private static SpecialCommandFlags GetSpecialCommandFlags(ITypeSymbol type) { // check whether these command-types need special handling @@ -1336,6 +1384,31 @@ static bool IsDerived(ITypeSymbol? type, ITypeSymbol baseType) } } + private static IImmutableDictionary IdentifyTypeHandlers(in SourceProductionContext ctx, Compilation compilation) + { + var assembly = compilation.Assembly; + var attributes = assembly.GetAttributes() + .Concat(assembly.Modules.SelectMany(x => x.GetAttributes())) + .Where(x => Inspection.IsDapperAttribute(x) && x.AttributeClass!.Name == "TypeHandlerAttribute"); + + var dictionary = ImmutableDictionary.CreateBuilder(SymbolEqualityComparer.Default); + foreach (var attribute in attributes) + { + var valueType = attribute.AttributeClass!.TypeArguments[0]; + var typeHandler = attribute.AttributeClass!.TypeArguments[1]; + if (dictionary.ContainsKey(valueType)) + { + ctx.ReportDiagnostic(Diagnostic.Create(Diagnostics.DuplicateTypeHandlers, null, valueType.Name)); + } + else + { + dictionary.Add(valueType, typeHandler); + } + } + + return dictionary.ToImmutable(); + } + internal abstract class SourceState { public Location? Location { get; } diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs index a1b2ea09..0c8939ad 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/ParseState.cs @@ -70,11 +70,13 @@ public GenerateState(GenerateContextProxy proxy) Nodes = proxy.Nodes; ctx = default; this.proxy = proxy; + TypeHandlers = ImmutableDictionary.Empty; } - public GenerateState(SourceProductionContext ctx, in (Compilation Compilation, ImmutableArray Nodes) state) + public GenerateState(SourceProductionContext ctx, Compilation compilation, ImmutableArray nodes, IImmutableDictionary typeHandlers) { - Compilation = state.Compilation; - Nodes = state.Nodes; + Compilation = compilation; + Nodes = nodes; + TypeHandlers = typeHandlers; this.ctx = ctx; proxy = null; } @@ -82,6 +84,7 @@ public GenerateState(SourceProductionContext ctx, in (Compilation Compilation, I private readonly GenerateContextProxy? proxy; public readonly ImmutableArray Nodes; public readonly Compilation Compilation; + public readonly IImmutableDictionary TypeHandlers; internal void ReportDiagnostic(Diagnostic diagnostic) { diff --git a/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs b/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs index edb74d7e..82f8fd89 100644 --- a/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs +++ b/src/Dapper.AOT.Analyzers/Internal/CodeWriter.cs @@ -237,8 +237,13 @@ public CodeWriter AppendEnumLiteral(ITypeSymbol enumType, int value) return Append("(").Append(enumType).Append(")").Append(value).Append("); "); } + public CodeWriter AppendVerbatimLiteral(string? value) => Append( - value is null ? "null" : SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(value)).ToFullString()); + CreateVerbatimLiteral(value)); + + public static string CreateVerbatimLiteral(string? value) => + value is null ? "null" : SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(value)).ToFullString(); + public CodeWriter Append(char value) { Core.Append(value); diff --git a/src/Dapper.AOT/TypeHandlerT.cs b/src/Dapper.AOT/TypeHandlerT.cs index 51b15bd7..0f2775a3 100644 --- a/src/Dapper.AOT/TypeHandlerT.cs +++ b/src/Dapper.AOT/TypeHandlerT.cs @@ -10,7 +10,7 @@ namespace Dapper; /// when processing values of type /// [ImmutableObject(true)] -[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, AllowMultiple = true)] +[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module, AllowMultiple = true)] public sealed class TypeHandlerAttribute : Attribute where TTypeHandler : TypeHandler, new() {} @@ -31,4 +31,10 @@ public virtual void SetValue(DbParameter parameter, T value) /// public virtual T Parse(DbParameter parameter) => CommandUtils.As(parameter.Value); + + /// + /// Reads the value from the results + /// + public virtual T Read(DbDataReader reader, int columnOffset) + => CommandUtils.As(reader.GetValue(columnOffset)); } \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.input.cs b/test/Dapper.AOT.Test/Interceptors/TypeHandler.input.cs new file mode 100644 index 00000000..c98a7ac6 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.input.cs @@ -0,0 +1,35 @@ +using Dapper; +using System.Data; +using System.Data.Common; + +[module: DapperAot] +[module: TypeHandler] + +public class CustomClassTypeHandler : TypeHandler +{ +} + +public class CustomClass +{ +} + +public static class Foo +{ + static void SomeCode(DbConnection connection, string bar, bool isBuffered) + { + _ = connection.Query("def"); + _ = connection.Query("def", new { Param = new CustomClass() }); + _ = connection.Query("@OutputValue = def", new CommandParameters()); + } + + public class CommandParameters + { + [DbValue(Direction = ParameterDirection.Output)] + public CustomClass OutputValue { get; set; } + } + + public class MyType + { + public CustomClass C { get; set; } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.cs b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.cs new file mode 100644 index 00000000..b81b00be --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.cs @@ -0,0 +1,192 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 20, 24)] + internal static global::System.Collections.Generic.IEnumerable Query0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, Buffered, StoredProcedure, BindResultsByName + // returns data: global::Foo.MyType + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), DefaultCommandFactory).QueryBuffered(param, RowFactory0.Instance); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 21, 24)] + internal static global::System.Collections.Generic.IEnumerable Query1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, StoredProcedure, KnownParameters + // takes parameter: + // parameter map: Param + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QueryBuffered(param, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 22, 24)] + internal static global::System.Collections.Generic.IEnumerable Query2(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, Text, KnownParameters + // takes parameter: global::Foo.CommandParameters + // parameter map: OutputValue + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).QueryBuffered((global::Foo.CommandParameters)param!, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 3859557458U when NormalizedEquals(name, "c"): + token = 0; // two tokens for right-typed and type-flexible + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::Foo.MyType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + global::Foo.MyType result = new(); + foreach (var token in tokens) + { + switch (token) + { + case 0: + result.C = reader.IsDBNull(columnOffset) ? (global::CustomClass?)null : new global::CustomClassTypeHandler().Read(reader, columnOffset); + break; + + } + columnOffset++; + + } + return result; + + } + + } + + private sealed class CommandFactory0 : CommonCommandFactory // + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Param"; + p.Direction = global::System.Data.ParameterDirection.Input; + new global::CustomClassTypeHandler().SetValue(p, typed.Param); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + new global::CustomClassTypeHandler().SetValue(ps[0], typed.Param); + + } + + } + + private sealed class CommandFactory1 : CommonCommandFactory + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "OutputValue"; + p.Direction = global::System.Data.ParameterDirection.Output; + p.Value = global::System.DBNull.Value; + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + ps[0].Value = global::System.DBNull.Value; + + } + public override bool RequirePostProcess => true; + + public override void PostProcess(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args, int rowCount) + { + var ps = cmd.Parameters; + args.OutputValue = new global::CustomClassTypeHandler().Parse(ps[0]); + base.PostProcess(in cmd, args, rowCount); + + } + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.cs b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.cs new file mode 100644 index 00000000..b81b00be --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.cs @@ -0,0 +1,192 @@ +#nullable enable +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 20, 24)] + internal static global::System.Collections.Generic.IEnumerable Query0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, Buffered, StoredProcedure, BindResultsByName + // returns data: global::Foo.MyType + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), DefaultCommandFactory).QueryBuffered(param, RowFactory0.Instance); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 21, 24)] + internal static global::System.Collections.Generic.IEnumerable Query1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, StoredProcedure, KnownParameters + // takes parameter: + // parameter map: Param + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.StoredProcedure, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).QueryBuffered(param, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\TypeHandler.input.cs", 22, 24)] + internal static global::System.Collections.Generic.IEnumerable Query2(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, bool buffered, int? commandTimeout, global::System.Data.CommandType? commandType) + { + // Query, TypedResult, HasParameters, Buffered, Text, KnownParameters + // takes parameter: global::Foo.CommandParameters + // parameter map: OutputValue + // returns data: int + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql)); + global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(buffered is true); + global::System.Diagnostics.Debug.Assert(param is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).QueryBuffered((global::Foo.CommandParameters)param!, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class RowFactory0 : global::Dapper.RowFactory + { + internal static readonly RowFactory0 Instance = new(); + private RowFactory0() {} + public override object? Tokenize(global::System.Data.Common.DbDataReader reader, global::System.Span tokens, int columnOffset) + { + for (int i = 0; i < tokens.Length; i++) + { + int token = -1; + var name = reader.GetName(columnOffset); + var type = reader.GetFieldType(columnOffset); + switch (NormalizedHash(name)) + { + case 3859557458U when NormalizedEquals(name, "c"): + token = 0; // two tokens for right-typed and type-flexible + break; + + } + tokens[i] = token; + columnOffset++; + + } + return null; + } + public override global::Foo.MyType Read(global::System.Data.Common.DbDataReader reader, global::System.ReadOnlySpan tokens, int columnOffset, object? state) + { + global::Foo.MyType result = new(); + foreach (var token in tokens) + { + switch (token) + { + case 0: + result.C = reader.IsDBNull(columnOffset) ? (global::CustomClass?)null : new global::CustomClassTypeHandler().Read(reader, columnOffset); + break; + + } + columnOffset++; + + } + return result; + + } + + } + + private sealed class CommandFactory0 : CommonCommandFactory // + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "Param"; + p.Direction = global::System.Data.ParameterDirection.Input; + new global::CustomClassTypeHandler().SetValue(p, typed.Param); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { Param = default(global::CustomClass)! }); // expected shape + var ps = cmd.Parameters; + new global::CustomClassTypeHandler().SetValue(ps[0], typed.Param); + + } + + } + + private sealed class CommandFactory1 : CommonCommandFactory + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "OutputValue"; + p.Direction = global::System.Data.ParameterDirection.Output; + p.Value = global::System.DBNull.Value; + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args) + { + var ps = cmd.Parameters; + ps[0].Value = global::System.DBNull.Value; + + } + public override bool RequirePostProcess => true; + + public override void PostProcess(in global::Dapper.UnifiedCommand cmd, global::Foo.CommandParameters args, int rowCount) + { + var ps = cmd.Parameters; + args.OutputValue = new global::CustomClassTypeHandler().Parse(ps[0]); + base.PostProcess(in cmd, args, rowCount); + + } + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.txt b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.txt new file mode 100644 index 00000000..d4a5c195 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.netfx.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 3 of 3 possible call-sites using 3 interceptors, 2 commands and 1 readers diff --git a/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.txt b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.txt new file mode 100644 index 00000000..d4a5c195 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/TypeHandler.output.txt @@ -0,0 +1,4 @@ +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 3 of 3 possible call-sites using 3 interceptors, 2 commands and 1 readers