Skip to content

Commit a0e1600

Browse files
[release] 1. simplified drivers generators usage (#97)
2. allow to override data type for all drivers, currently meaningful only in postgres 3. fix null warnings in CI 4. refactor copy flow & tests
1 parent d154e40 commit a0e1600

File tree

16 files changed

+453
-315
lines changed

16 files changed

+453
-315
lines changed

Drivers/ColumnMapping.cs

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
namespace SqlcGenCsharp.Drivers;
5+
6+
public class ColumnMapping(string csharpType, Func<int, string> readerFn, Dictionary<string, string?> dbTypes)
7+
{
8+
public string CsharpType { get; } = csharpType;
9+
10+
public Func<int, string> ReaderFn { get; } = readerFn;
11+
12+
public Dictionary<string, string?> DbTypes { get; } = dbTypes;
13+
}

Drivers/DbDriver.cs

+21-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Plugin;
33
using System;
44
using System.Collections.Generic;
5+
using System.Linq;
56
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
67
using static System.String;
78

@@ -15,6 +16,8 @@ public abstract class DbDriver(DotnetFramework dotnetFramework)
1516

1617
private HashSet<string> CsharpPrimitives { get; } = ["long", "double", "int", "float", "bool", "DateTime"];
1718

19+
protected abstract List<ColumnMapping> ColumnMappings { get; }
20+
1821
public virtual UsingDirectiveSyntax[] GetUsingDirectives()
1922
{
2023
return
@@ -25,8 +28,6 @@ public virtual UsingDirectiveSyntax[] GetUsingDirectives()
2528
];
2629
}
2730

28-
protected abstract List<(string, Func<int, string>, HashSet<string>)> GetColumnMapping();
29-
3031
public string AddNullableSuffix(string csharpType, bool notNull)
3132
{
3233
if (notNull) return csharpType;
@@ -42,10 +43,10 @@ public string GetColumnType(Column column)
4243
string GetTypeWithoutNullableSuffix()
4344
{
4445
var columnType = column.Type.Name.ToLower();
45-
foreach (var (csharpType, _, dbTypes) in GetColumnMapping())
46+
foreach (var columnMapping in ColumnMappings
47+
.Where(columnMapping => columnMapping.DbTypes.ContainsKey(columnType)))
4648
{
47-
if (dbTypes.Contains(columnType))
48-
return csharpType;
49+
return columnMapping.CsharpType;
4950
}
5051
throw new NotSupportedException($"Unsupported column type: {column.Type.Name}");
5152
}
@@ -54,11 +55,23 @@ string GetTypeWithoutNullableSuffix()
5455
public string GetColumnReader(Column column, int ordinal)
5556
{
5657
var columnType = column.Type.Name.ToLower();
57-
foreach (var (_, getDataReader, dbTypes) in GetColumnMapping())
58+
foreach (var columnMapping in ColumnMappings
59+
.Where(columnMapping => columnMapping.DbTypes.ContainsKey(columnType)))
60+
{
61+
return columnMapping.ReaderFn(ordinal);
62+
}
63+
throw new NotSupportedException($"Unsupported column type: {column.Type.Name}");
64+
}
65+
66+
public string? GetColumnDbTypeOverride(Column column)
67+
{
68+
var columnType = column.Type.Name.ToLower();
69+
foreach (var columnMapping in ColumnMappings)
5870
{
59-
if (dbTypes.Contains(columnType))
60-
return getDataReader(ordinal);
71+
if (columnMapping.DbTypes.TryGetValue(columnType, out var dbTypeOverride))
72+
return dbTypeOverride;
6173
}
74+
6275
throw new NotSupportedException($"Unsupported column type: {column.Type.Name}");
6376
}
6477

Drivers/Generators/CommonGen.cs

+9-4
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,15 @@ string GetNullExpression(Column column)
6161
}
6262
}
6363

64-
public static IEnumerable<string> GetCommandParameters(IEnumerable<Parameter> parameters)
64+
public IEnumerable<string> GetCommandParameters(IEnumerable<Parameter> parameters)
6565
{
66-
return parameters.Select(param =>
67-
$"{Variable.Command.Name()}.Parameters.AddWithValue(\"@{param.Column.Name}\", " +
68-
$"args.{param.Column.Name.FirstCharToUpper()});");
66+
return parameters.Select(p =>
67+
{
68+
var varName = Variable.Command.Name();
69+
var columnName = p.Column.Name;
70+
var param = p.Column.Name.FirstCharToUpper();
71+
var nullCheck = dbDriver.DotnetFramework.LatestDotnetSupported() && !p.Column.NotNull ? "!" : "";
72+
return $"{varName}.Parameters.AddWithValue(\"@{columnName}\", args.{param}{nullCheck});";
73+
});
6974
}
7075
}
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using Microsoft.CodeAnalysis.CSharp.Syntax;
2+
using Plugin;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
6+
7+
namespace SqlcGenCsharp.Drivers.Generators;
8+
9+
public class CopyFromDeclareGen(DbDriver dbDriver)
10+
{
11+
public MemberDeclarationSyntax Generate(string queryTextConstant, string argInterface, Query query)
12+
{
13+
var (establishConnection, connectionOpen) = dbDriver.EstablishConnection(query);
14+
var beginBinaryImport = $"{Variable.Connection.Name()}.BeginBinaryImportAsync({queryTextConstant}";
15+
var addRowsToCopyCommand = AddRowsToCopyCommand(query);
16+
var methodBody = dbDriver.DotnetFramework.LatestDotnetSupported() ?
17+
$$"""
18+
{
19+
await using {{establishConnection}};
20+
{{connectionOpen.AppendSemicolonUnlessEmpty()}}
21+
await {{Variable.Connection.Name()}}.OpenAsync();
22+
await using var {{Variable.Writer.Name()}} = await {{beginBinaryImport}});
23+
{{addRowsToCopyCommand}}
24+
await {{Variable.Writer.Name()}}.CompleteAsync();
25+
await {{Variable.Connection.Name()}}.CloseAsync();
26+
}
27+
""" :
28+
$$"""
29+
{
30+
using ({{establishConnection}})
31+
{
32+
{{connectionOpen.AppendSemicolonUnlessEmpty()}}
33+
await {{Variable.Connection.Name()}}.OpenAsync();
34+
using (var {{Variable.Writer.Name()}} = await {{beginBinaryImport}}))
35+
{
36+
{{addRowsToCopyCommand}}
37+
await {{Variable.Writer.Name()}}.CompleteAsync();
38+
}
39+
await {{Variable.Connection.Name()}}.CloseAsync();
40+
}
41+
}
42+
""";
43+
44+
return ParseMemberDeclaration($$"""
45+
public async Task {{query.Name}}(List<{{argInterface}}> args)
46+
{
47+
{{methodBody}}
48+
}
49+
""")!;
50+
}
51+
52+
private string AddRowsToCopyCommand(Query query)
53+
{
54+
var constructRow = new List<string>()
55+
.Append($"await {Variable.Writer.Name()}.StartRowAsync();")
56+
.Concat(query.Params
57+
.Select(p =>
58+
{
59+
var typeOverride = dbDriver.GetColumnDbTypeOverride(p.Column);
60+
var partialStmt =
61+
$"await {Variable.Writer.Name()}.WriteAsync({Variable.Row.Name()}.{p.Column.Name.FirstCharToUpper()}";
62+
return typeOverride is null
63+
? $"{partialStmt});"
64+
: $"{partialStmt}, {typeOverride});";
65+
}))
66+
.JoinByNewLine();
67+
return $$"""
68+
foreach (var {{Variable.Row.Name()}} in args)
69+
{
70+
{{constructRow}}
71+
}
72+
""";
73+
}
74+
}

Drivers/Generators/ExecDeclareGen.cs

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ namespace SqlcGenCsharp.Drivers.Generators;
66

77
public class ExecDeclareGen(DbDriver dbDriver)
88
{
9+
private CommonGen CommonGen { get; } = new(dbDriver);
10+
911
public MemberDeclarationSyntax Generate(string queryTextConstant, string argInterface, Query query)
1012
{
1113
var parametersStr = CommonGen.GetParameterListAsString(argInterface, query.Params);
+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using Microsoft.CodeAnalysis.CSharp.Syntax;
2+
using Plugin;
3+
using System.Collections.Generic;
4+
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
5+
6+
namespace SqlcGenCsharp.Drivers.Generators;
7+
8+
public class ExecLastIdDeclareGen(DbDriver dbDriver)
9+
{
10+
private CommonGen CommonGen { get; } = new(dbDriver);
11+
12+
13+
public MemberDeclarationSyntax Generate(string queryTextConstant, string argInterface, Query query)
14+
{
15+
var parametersStr = CommonGen.GetParameterListAsString(argInterface, query.Params);
16+
var (establishConnection, connectionOpen) = dbDriver.EstablishConnection(query);
17+
var createSqlCommand = dbDriver.CreateSqlCommand(queryTextConstant);
18+
var commandParameters = CommonGen.GetCommandParameters(query.Params);
19+
var executeScalarAndReturnCreated = ExecuteScalarAndReturnCreated();
20+
var methodBody = dbDriver.DotnetFramework.LatestDotnetSupported()
21+
? GetWithUsingAsStatement()
22+
: GetWithUsingAsBlock();
23+
24+
return ParseMemberDeclaration($$"""
25+
public async Task<long> {{query.Name}}({{parametersStr}})
26+
{
27+
{{methodBody}}
28+
}
29+
""")!;
30+
31+
string GetWithUsingAsStatement()
32+
{
33+
return $$"""
34+
{
35+
await using {{establishConnection}};
36+
{{connectionOpen.AppendSemicolonUnlessEmpty()}}
37+
await using {{createSqlCommand}};
38+
{{commandParameters.JoinByNewLine()}}
39+
{{executeScalarAndReturnCreated.JoinByNewLine()}}
40+
}
41+
""";
42+
}
43+
44+
string GetWithUsingAsBlock()
45+
{
46+
return $$"""
47+
{
48+
using ({{establishConnection}})
49+
{
50+
{{connectionOpen.AppendSemicolonUnlessEmpty()}}
51+
using ({{createSqlCommand}})
52+
{
53+
{{commandParameters.JoinByNewLine()}}
54+
{{executeScalarAndReturnCreated.JoinByNewLine()}}
55+
}
56+
}
57+
}
58+
""";
59+
}
60+
61+
IEnumerable<string> ExecuteScalarAndReturnCreated()
62+
{
63+
return
64+
[
65+
$"await {Variable.Command.Name()}.ExecuteNonQueryAsync();",
66+
$"return {Variable.Command.Name()}.LastInsertedId;"
67+
];
68+
}
69+
}
70+
}
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using Microsoft.CodeAnalysis.CSharp.Syntax;
2+
using Plugin;
3+
using System.Collections.Generic;
4+
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
5+
6+
namespace SqlcGenCsharp.Drivers.Generators;
7+
8+
public class ExecRowsDeclareGen(DbDriver dbDriver)
9+
{
10+
private CommonGen CommonGen { get; } = new(dbDriver);
11+
12+
public MemberDeclarationSyntax Generate(string queryTextConstant, string argInterface, Query query)
13+
{
14+
var parametersStr = CommonGen.GetParameterListAsString(argInterface, query.Params);
15+
var (establishConnection, connectionOpen) = dbDriver.EstablishConnection(query);
16+
var createSqlCommand = dbDriver.CreateSqlCommand(queryTextConstant);
17+
var commandParameters = CommonGen.GetCommandParameters(query.Params);
18+
var executeScalarAndReturnCreated = ExecuteScalarAndReturnCreated();
19+
var methodBody = dbDriver.DotnetFramework.LatestDotnetSupported()
20+
? GetWithUsingAsStatement()
21+
: GetWithUsingAsBlock();
22+
23+
return ParseMemberDeclaration($$"""
24+
public async Task<long> {{query.Name}}({{parametersStr}})
25+
{
26+
{{methodBody}}
27+
}
28+
""")!;
29+
30+
string GetWithUsingAsStatement()
31+
{
32+
return $$"""
33+
{
34+
await using {{establishConnection}};
35+
{{connectionOpen.AppendSemicolonUnlessEmpty()}}
36+
await using {{createSqlCommand}};
37+
{{commandParameters.JoinByNewLine()}}
38+
{{executeScalarAndReturnCreated.JoinByNewLine()}}
39+
}
40+
""";
41+
}
42+
43+
string GetWithUsingAsBlock()
44+
{
45+
return $$"""
46+
{
47+
using ({{establishConnection}})
48+
{
49+
{{connectionOpen.AppendSemicolonUnlessEmpty()}}
50+
using ({{createSqlCommand}})
51+
{
52+
{{commandParameters.JoinByNewLine()}}
53+
{{executeScalarAndReturnCreated.JoinByNewLine()}}
54+
}
55+
}
56+
}
57+
""";
58+
}
59+
60+
IEnumerable<string> ExecuteScalarAndReturnCreated()
61+
{
62+
return [$"return await {Variable.Command.Name()}.ExecuteNonQueryAsync();"];
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)