Skip to content

Commit 66e1e2b

Browse files
authored
add [BatchSize] and pass thru to multi-row execute (#76)
1 parent b8d5dd0 commit 66e1e2b

10 files changed

+365
-7
lines changed

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,7 @@ enum ParameterMode
747747
}
748748
}
749749

750+
int? batchSize = null;
750751
foreach (var attrib in methodAttribs)
751752
{
752753
if (IsDapperAttribute(attrib))
@@ -778,6 +779,12 @@ enum ParameterMode
778779
case Types.CommandPropertyAttribute:
779780
cmdPropsCount++;
780781
break;
782+
case Types.BatchSizeAttribute:
783+
if (attrib.ConstructorArguments.Length == 1 && attrib.ConstructorArguments[0].Value is int batchTmp)
784+
{
785+
batchSize = batchTmp;
786+
}
787+
break;
781788
}
782789
}
783790
}
@@ -806,8 +813,8 @@ enum ParameterMode
806813
}
807814

808815

809-
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null
810-
? null : new(rowCountHint, rowCountHintMember?.Member.Name, cmdProps);
816+
return cmdProps.IsDefaultOrEmpty && rowCountHint <= 0 && rowCountHintMember is null && batchSize is null
817+
? null : new(rowCountHint, rowCountHintMember?.Member.Name, batchSize, cmdProps);
811818
}
812819

813820
internal static ImmutableArray<ElementMember>? SharedGetParametersToInclude(MemberMap? map, ref OperationFlags flags, string? sql, Action<Diagnostic>? reportDiagnostic, out SqlParseOutputFlags parseFlags)

src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Multi.cs

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ void WriteMultiExecExpression(ITypeSymbol elementType, string castType)
4444
bool isAsync = flags.HasAny(OperationFlags.Async);
4545
sb.Append("Execute").Append(isAsync ? "Async" : "").Append("(");
4646
sb.Append("(").Append(castType).Append(")param!");
47+
if (additionalCommandState?.BatchSize is { } batchSize)
48+
{
49+
sb.Append(", batchSize: ").Append(batchSize);
50+
}
4751
if (isAsync && HasParam(methodParameters, "cancellationToken"))
4852
{
4953
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));

src/Dapper.AOT.Analyzers/Internal/AdditionalCommandState.cs

+12-4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public bool Equals(in CommandProperty other)
3535
internal sealed class AdditionalCommandState : IEquatable<AdditionalCommandState>
3636
{
3737
public readonly int RowCountHint;
38+
public readonly int? BatchSize;
3839
public readonly string? RowCountHintMemberName;
3940
public readonly ImmutableArray<CommandProperty> CommandProperties;
4041

@@ -72,7 +73,8 @@ private static AdditionalCommandState Combine(AdditionalCommandState inherited,
7273
countMember = null;
7374
}
7475

75-
return new(count, countMember, Concat(inherited.CommandProperties, overrides.CommandProperties));
76+
return new(count, countMember, inherited.BatchSize ?? overrides.BatchSize,
77+
Concat(inherited.CommandProperties, overrides.CommandProperties));
7678
}
7779

7880
static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x, ImmutableArray<CommandProperty> y)
@@ -85,10 +87,13 @@ static ImmutableArray<CommandProperty> Concat(ImmutableArray<CommandProperty> x,
8587
return builder.ToImmutable();
8688
}
8789

88-
internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName, ImmutableArray<CommandProperty> commandProperties)
90+
internal AdditionalCommandState(
91+
int rowCountHint, string? rowCountHintMemberName, int? batchSize,
92+
ImmutableArray<CommandProperty> commandProperties)
8993
{
9094
RowCountHint = rowCountHint;
9195
RowCountHintMemberName = rowCountHintMemberName;
96+
BatchSize = batchSize;
9297
CommandProperties = commandProperties;
9398
}
9499

@@ -98,7 +103,9 @@ internal AdditionalCommandState(int rowCountHint, string? rowCountHintMemberName
98103
bool IEquatable<AdditionalCommandState>.Equals(AdditionalCommandState other) => Equals(in other);
99104

100105
public bool Equals(in AdditionalCommandState other)
101-
=> RowCountHint == other.RowCountHint && RowCountHintMemberName == other.RowCountHintMemberName
106+
=> RowCountHint == other.RowCountHint
107+
&& BatchSize == other.BatchSize
108+
&& RowCountHintMemberName == other.RowCountHintMemberName
102109
&& ((CommandProperties.IsDefaultOrEmpty && other.CommandProperties.IsDefaultOrEmpty) || Equals(CommandProperties, other.CommandProperties));
103110

104111
private static bool Equals(in ImmutableArray<CommandProperty> x, in ImmutableArray<CommandProperty> y)
@@ -136,6 +143,7 @@ static int GetHashCode(in ImmutableArray<CommandProperty> x)
136143
}
137144

138145
public override int GetHashCode()
139-
=> (RowCountHint + (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
146+
=> (RowCountHint + BatchSize.GetValueOrDefault()
147+
+ (RowCountHintMemberName is null ? 0 : RowCountHintMemberName.GetHashCode()))
140148
^ (CommandProperties.IsDefaultOrEmpty ? 0 : GetHashCode(in CommandProperties));
141149
}

src/Dapper.AOT.Analyzers/Internal/Types.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ public const string
1616
IDynamicParameters = nameof(IDynamicParameters),
1717
SqlMapper = nameof(SqlMapper),
1818
SqlAttribute = nameof(SqlAttribute),
19-
ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute);
19+
ExplicitConstructorAttribute = nameof(ExplicitConstructorAttribute),
20+
BatchSizeAttribute = nameof(BatchSizeAttribute);
2021
}

src/Dapper.AOT/BatchSizeAttribute.cs

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.ComponentModel;
3+
using System.Diagnostics;
4+
5+
namespace Dapper;
6+
7+
/// <summary>
8+
/// Indicates the batch size to use when executing commands with a sequence of argument rows.
9+
/// </summary>
10+
[Conditional("DEBUG")] // not needed post-build, so: evaporate
11+
[ImmutableObject(true)]
12+
[AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Module | AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Method, AllowMultiple = false)]
13+
public sealed class BatchSizeAttribute : Attribute
14+
{
15+
/// <summary>
16+
/// Indicates the batch size to use when executing commands with a sequence of argument row; a value of zero disables batch usage; a negative value uses a single batch for all rows.
17+
/// </summary>
18+
public BatchSizeAttribute(int batchSize) => _ = batchSize;
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using Dapper;
2+
using System.Data.Common;
3+
4+
[module: DapperAot]
5+
6+
public static class Foo
7+
{
8+
[BatchSize(10)] // should be passed explicitly
9+
static void SomeCode(DbConnection connection, string sql, string bar)
10+
{
11+
var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } };
12+
13+
connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs);
14+
}
15+
16+
// no batch size, should be passed implicitly
17+
static void SomeOtherCode(DbConnection connection, string sql, string bar)
18+
{
19+
var objs = new[] { new { id = 12, bar }, new { id = 34, bar = "def" } };
20+
21+
connection.Execute("insert Foo (Id, Value) values (@id, @bar)", objs);
22+
}
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#nullable enable
2+
namespace Dapper.AOT // interceptors must be in a known namespace
3+
{
4+
file static class DapperGeneratedInterceptors
5+
{
6+
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 13, 20)]
7+
internal static int Execute0(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
8+
{
9+
// Execute, HasParameters, Text, KnownParameters
10+
// takes parameter: global::<anonymous type: int id, string bar>[]
11+
// parameter map: bar id
12+
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
13+
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
14+
global::System.Diagnostics.Debug.Assert(param is not null);
15+
16+
return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory0.Instance).Execute((object?[])param!, batchSize: 10);
17+
18+
}
19+
20+
[global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\BatchSize.input.cs", 21, 20)]
21+
internal static int Execute1(this global::System.Data.IDbConnection cnn, string sql, object? param, global::System.Data.IDbTransaction? transaction, int? commandTimeout, global::System.Data.CommandType? commandType)
22+
{
23+
// Execute, HasParameters, Text, KnownParameters
24+
// takes parameter: global::<anonymous type: int id, string bar>[]
25+
// parameter map: bar id
26+
global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));
27+
global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.Text);
28+
global::System.Diagnostics.Debug.Assert(param is not null);
29+
30+
return global::Dapper.DapperAotExtensions.Command(cnn, transaction, sql, global::System.Data.CommandType.Text, commandTimeout.GetValueOrDefault(), CommandFactory1.Instance).Execute((object?[])param!);
31+
32+
}
33+
34+
private class CommonCommandFactory<T> : global::Dapper.CommandFactory<T>
35+
{
36+
public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args)
37+
{
38+
var cmd = base.GetCommand(connection, sql, commandType, args);
39+
// apply special per-provider command initialization logic for OracleCommand
40+
if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0)
41+
{
42+
cmd0.BindByName = true;
43+
cmd0.InitialLONGFetchSize = -1;
44+
45+
}
46+
return cmd;
47+
}
48+
49+
}
50+
51+
private static readonly CommonCommandFactory<object?> DefaultCommandFactory = new();
52+
53+
private sealed class CommandFactory0 : CommonCommandFactory<object?> // <anonymous type: int id, string bar>
54+
{
55+
internal static readonly CommandFactory0 Instance = new();
56+
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
57+
{
58+
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
59+
var ps = cmd.Parameters;
60+
global::System.Data.Common.DbParameter p;
61+
p = cmd.CreateParameter();
62+
p.ParameterName = "id";
63+
p.DbType = global::System.Data.DbType.Int32;
64+
p.Direction = global::System.Data.ParameterDirection.Input;
65+
p.Value = AsValue(typed.id);
66+
ps.Add(p);
67+
68+
p = cmd.CreateParameter();
69+
p.ParameterName = "bar";
70+
p.DbType = global::System.Data.DbType.String;
71+
p.Size = -1;
72+
p.Direction = global::System.Data.ParameterDirection.Input;
73+
p.Value = AsValue(typed.bar);
74+
ps.Add(p);
75+
76+
}
77+
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
78+
{
79+
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
80+
var ps = cmd.Parameters;
81+
ps[0].Value = AsValue(typed.id);
82+
ps[1].Value = AsValue(typed.bar);
83+
84+
}
85+
public override bool CanPrepare => true;
86+
87+
}
88+
89+
private sealed class CommandFactory1 : CommonCommandFactory<object?> // <anonymous type: int id, string bar>
90+
{
91+
internal static readonly CommandFactory1 Instance = new();
92+
public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args)
93+
{
94+
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
95+
var ps = cmd.Parameters;
96+
global::System.Data.Common.DbParameter p;
97+
p = cmd.CreateParameter();
98+
p.ParameterName = "id";
99+
p.DbType = global::System.Data.DbType.Int32;
100+
p.Direction = global::System.Data.ParameterDirection.Input;
101+
p.Value = AsValue(typed.id);
102+
ps.Add(p);
103+
104+
p = cmd.CreateParameter();
105+
p.ParameterName = "bar";
106+
p.DbType = global::System.Data.DbType.String;
107+
p.Size = -1;
108+
p.Direction = global::System.Data.ParameterDirection.Input;
109+
p.Value = AsValue(typed.bar);
110+
ps.Add(p);
111+
112+
}
113+
public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args)
114+
{
115+
var typed = Cast(args, static () => new { id = default(int), bar = default(string)! }); // expected shape
116+
var ps = cmd.Parameters;
117+
ps[0].Value = AsValue(typed.id);
118+
ps[1].Value = AsValue(typed.bar);
119+
120+
}
121+
public override bool CanPrepare => true;
122+
123+
}
124+
125+
126+
}
127+
}
128+
namespace System.Runtime.CompilerServices
129+
{
130+
// this type is needed by the compiler to implement interceptors - it doesn't need to
131+
// come from the runtime itself, though
132+
133+
[global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate
134+
[global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)]
135+
sealed file class InterceptsLocationAttribute : global::System.Attribute
136+
{
137+
public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber)
138+
{
139+
_ = path;
140+
_ = lineNumber;
141+
_ = columnNumber;
142+
}
143+
}
144+
}

0 commit comments

Comments
 (0)