Skip to content

Commit 985d289

Browse files
1. refactoring around conditional codegen logic (#217)
1 parent b8c47e3 commit 985d289

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+2862
-1806
lines changed

CodeGenerator/Generators/CsprojGen.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public File GenerateFile()
2525

2626
private string GetFileContents()
2727
{
28-
var optionalNullableProperty = options.DotnetFramework.LatestDotnetSupported() ? Environment.NewLine + " <Nullable>enable</Nullable>" : "";
28+
var optionalNullableProperty = options.DotnetFramework.IsDotnetCore() ? Environment.NewLine + " <Nullable>enable</Nullable>" : "";
2929
return $"""
3030
<!--{Consts.AutoGeneratedComment}-->
3131
<!--Run the following to add the project to the solution:

CodeGenerator/Generators/DataClassesGen.cs

+12-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using Microsoft.CodeAnalysis.CSharp.Syntax;
33
using Plugin;
44
using SqlcGenCsharp.Drivers;
5-
using System;
65
using System.Collections.Generic;
76
using System.Linq;
87
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
@@ -14,7 +13,7 @@ internal class DataClassesGen(DbDriver dbDriver)
1413
public MemberDeclarationSyntax Generate(string name, ClassMember classMember, IList<Column> columns, Options options)
1514
{
1615
var className = classMember.Name(name);
17-
if (options.DotnetFramework.LatestDotnetSupported() && !options.UseDapper)
16+
if (options.DotnetFramework.IsDotnetCore() && !options.UseDapper)
1817
return GenerateAsRecord(className, columns);
1918
return GenerateAsCLass(className, columns);
2019
}
@@ -30,7 +29,7 @@ private MemberDeclarationSyntax GenerateAsRecord(string className, IList<Column>
3029

3130
private ClassDeclarationSyntax GenerateAsCLass(string className, IList<Column> columns)
3231
{
33-
var modernDotnetSupported = dbDriver.Options.DotnetFramework.LatestDotnetSupported();
32+
var modernDotnetSupported = dbDriver.Options.DotnetFramework.IsDotnetCore();
3433
return ClassDeclaration(className)
3534
.AddModifiers(Token(SyntaxKind.PublicKeyword))
3635
.AddMembers(ColumnsToProperties())
@@ -42,10 +41,7 @@ MemberDeclarationSyntax[] ColumnsToProperties()
4241
return columns.Select(column =>
4342
{
4443
var csharpType = dbDriver.GetCsharpType(column);
45-
var requiredModifierNeeded = modernDotnetSupported && // required modifier supported by .Net framework
46-
column.NotNull && // the field is not null, hence required
47-
!dbDriver.IsTypeNullableForAllRuntimes(csharpType); // TODO document
48-
var optionalRequiredModifier = requiredModifierNeeded ? "required" : string.Empty;
44+
var optionalRequiredModifier = RequiredModifierNeeded(column) ? "required" : string.Empty;
4945
var setterMethod = modernDotnetSupported ? "init" : "set";
5046
return ParseMemberDeclaration(
5147
$$"""
@@ -55,6 +51,15 @@ MemberDeclarationSyntax[] ColumnsToProperties()
5551
.Cast<MemberDeclarationSyntax>()
5652
.ToArray();
5753
}
54+
55+
bool RequiredModifierNeeded(Column column)
56+
{
57+
if (!dbDriver.Options.DotnetFramework.IsDotnetCore())
58+
return false;
59+
if (column.EmbedTable != null)
60+
return true;
61+
return column.NotNull;
62+
}
5863
}
5964

6065
private static string GetFieldName(Column column, Dictionary<string, int> seenEmbed)

CodeGenerator/Generators/ModelsGen.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ namespace SqlcGenCsharp.Generators;
1010

1111
internal class ModelsGen(DbDriver dbDriver, Options options, string namespaceName)
1212
{
13+
private const string ClassName = "Models";
14+
1315
private RootGen RootGen { get; } = new(options);
1416

1517
private DataClassesGen DataClassesGen { get; } = new(dbDriver);
@@ -23,7 +25,7 @@ public File GenerateFile(Dictionary<string, Table> tables)
2325

2426
return new File
2527
{
26-
Name = "Models.cs",
28+
Name = $"{ClassName}.cs",
2729
Contents = root.ToByteString()
2830
};
2931
}

CodeGenerator/Generators/QueriesGen.cs

+10-7
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,16 @@ class {{className}}
8080
private IEnumerable<MemberDeclarationSyntax> GetMembersForSingleQuery(Query query)
8181
{
8282
return new List<MemberDeclarationSyntax>()
83-
.Append(GetQueryTextConstant(query))
83+
.AppendIfNotNull(GetQueryTextConstant(query))
8484
.AppendIfNotNull(GetQueryColumnsDataclass(query))
8585
.AppendIfNotNull(GetQueryParamsDataclass(query))
8686
.Append(AddMethodDeclaration(query));
8787
}
8888

8989
private MemberDeclarationSyntax? GetQueryColumnsDataclass(Query query)
9090
{
91-
return query.Columns.Count <= 0
92-
? null
93-
: DataClassesGen.Generate(query.Name, ClassMember.Row, query.Columns, options);
91+
if (query.Columns.Count <= 0) return null;
92+
return DataClassesGen.Generate(query.Name, ClassMember.Row, query.Columns, options);
9493
}
9594

9695
private MemberDeclarationSyntax? GetQueryParamsDataclass(Query query)
@@ -100,11 +99,15 @@ private IEnumerable<MemberDeclarationSyntax> GetMembersForSingleQuery(Query quer
10099
return DataClassesGen.Generate(query.Name, ClassMember.Args, columns, options);
101100
}
102101

103-
private MemberDeclarationSyntax GetQueryTextConstant(Query query)
102+
private MemberDeclarationSyntax? GetQueryTextConstant(Query query)
104103
{
104+
var transformQueryText = dbDriver.TransformQueryText(query);
105+
if (transformQueryText == string.Empty)
106+
return null;
105107
return ParseMemberDeclaration(
106-
$"private const string {ClassMember.Sql.Name(query.Name)} = \"{dbDriver.TransformQueryText(query)}\";")
107-
!
108+
$"""
109+
private const string {ClassMember.Sql.Name(query.Name)} = "{transformQueryText}";
110+
""")!
108111
.AppendNewLine();
109112
}
110113

CodeGenerator/Generators/RootGen.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ internal class RootGen(Options options)
99
public CompilationUnitSyntax CompilationRootGen(IdentifierNameSyntax namespaceName,
1010
UsingDirectiveSyntax[] usingDirectives, MemberDeclarationSyntax[] classDeclarations)
1111
{
12-
return options.DotnetFramework.LatestDotnetSupported() ? GetFileScoped() : GetBLockScoped();
12+
return options.DotnetFramework.IsDotnetCore() ? GetFileScoped() : GetBLockScoped();
1313

1414
CompilationUnitSyntax GetFileScoped()
1515
{

Drivers/DbDriver.cs

+38-23
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,30 @@ namespace SqlcGenCsharp.Drivers;
99

1010
public record ConnectionGenCommands(string EstablishConnection, string ConnectionOpen);
1111

12-
public abstract class DbDriver(Options options, Dictionary<string, Table> tables)
12+
public abstract class DbDriver
1313
{
14-
public Options Options { get; } = options;
14+
public Options Options { get; }
1515

16-
public Dictionary<string, Table> Tables { get; } = tables;
16+
public Dictionary<string, Table> Tables { get; }
17+
18+
private HashSet<string> NullableTypesInDotnetCore { get; } = ["string"];
19+
20+
private HashSet<string> NullableTypes { get; } = ["long", "double", "int", "float", "bool", "DateTime"];
1721

1822
protected abstract List<ColumnMapping> ColumnMappings { get; }
1923

24+
protected DbDriver(Options options, Dictionary<string, Table> tables)
25+
{
26+
Options = options;
27+
Tables = tables;
28+
29+
if (!Options.DotnetFramework.IsDotnetCore()) return; // `string?` is possible only in .Net Core
30+
foreach (var t in NullableTypesInDotnetCore)
31+
{
32+
NullableTypes.Add(t);
33+
}
34+
}
35+
2036
public virtual UsingDirectiveSyntax[] GetUsingDirectives()
2137
{
2238
var usingDirectives = new List<UsingDirectiveSyntax>
@@ -31,11 +47,11 @@ public virtual UsingDirectiveSyntax[] GetUsingDirectives()
3147
return usingDirectives.ToArray();
3248
}
3349

34-
public string AddNullableSuffix(string csharpType, bool notNull)
50+
public string AddNullableSuffixIfNeeded(string csharpType, bool notNull)
3551
{
3652
if (notNull) return csharpType;
37-
if (IsTypeNullableForAllRuntimes(csharpType)) return $"{csharpType}?";
38-
return Options.DotnetFramework.LatestDotnetSupported() ? $"{csharpType}?" : csharpType;
53+
if (IsTypeNullable(csharpType)) return $"{csharpType}?";
54+
return Options.DotnetFramework.IsDotnetCore() ? $"{csharpType}?" : csharpType;
3955
}
4056

4157
public string GetCsharpType(Column column)
@@ -44,7 +60,7 @@ public string GetCsharpType(Column column)
4460
return column.EmbedTable.Name.ToModelName();
4561

4662
var columnCsharpType = string.IsNullOrEmpty(column.Type.Name) ? "object" : GetTypeWithoutNullableSuffix();
47-
return AddNullableSuffix(columnCsharpType, column.NotNull);
63+
return AddNullableSuffixIfNeeded(columnCsharpType, column.NotNull);
4864

4965
string GetTypeWithoutNullableSuffix()
5066
{
@@ -89,31 +105,30 @@ public string GetColumnReader(Column column, int ordinal)
89105

90106
public abstract string CreateSqlCommand(string sqlTextConstant);
91107

92-
private HashSet<string> NullableTypesInAllRuntimes { get; } = ["long", "double", "int", "float", "bool", "DateTime"];
93-
94108
// TODO move out from driver + rename
95-
public bool IsTypeNullableForAllRuntimes(string csharpType)
109+
public bool IsTypeNullable(string csharpType)
96110
{
97-
return NullableTypesInAllRuntimes.Contains(csharpType.Replace("?", ""));
111+
return NullableTypes.Contains(csharpType.Replace("?", ""));
98112
}
99113

100-
protected static string GetConnectionStringField()
114+
/*
115+
Since there is no indication of the primary key column in SQLC protobuf (assuming it is a single column even),
116+
this method uses a few heuristics to assess the type of the id column
117+
*/
118+
public string GetIdColumnType(Query query)
101119
{
102-
return Variable.ConnectionString.AsPropertyName();
103-
}
120+
var tableColumns = Tables[query.InsertIntoTable.Name].Columns;
121+
var idColumn = tableColumns.First(c => c.Name.Equals("id", StringComparison.OrdinalIgnoreCase));
122+
if (idColumn is not null)
123+
return GetCsharpType(idColumn);
104124

105-
public string GetIdColumnType()
106-
{
107-
return Options.DriverName switch
108-
{
109-
DriverName.Sqlite => "int",
110-
_ => "long"
111-
};
125+
idColumn = tableColumns.First(c => c.Name.Contains("id", StringComparison.CurrentCultureIgnoreCase));
126+
return GetCsharpType(idColumn ?? tableColumns[0]);
112127
}
113128

114-
public virtual string[] GetLastIdStatement()
129+
public virtual string[] GetLastIdStatement(Query query)
115130
{
116-
var convertFunc = GetIdColumnType() == "int" ? "ToInt32" : "ToInt64";
131+
var convertFunc = GetIdColumnType(query) == "int" ? "ToInt32" : "ToInt64"; // TODO refactor
117132
return
118133
[
119134
$"var {Variable.Result.AsVarName()} = await {Variable.Command.AsVarName()}.ExecuteScalarAsync();",

Drivers/Generators/CommonGen.cs

+34-22
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,30 @@ public static string GetMethodParameterList(string argInterface, IEnumerable<Par
1313
: $"{argInterface} {Variable.Args.AsVarName()}")}";
1414
}
1515

16-
public static string AddParametersToCommand(IEnumerable<Parameter> parameters)
16+
public string AddParametersToCommand(IEnumerable<Parameter> parameters)
1717
{
1818
return parameters.Select(p =>
1919
{
2020
var commandVar = Variable.Command.AsVarName();
2121
var param = p.Column.Name.ToPascalCase();
2222
var argsVar = Variable.Args.AsVarName();
23-
return p.Column.IsSqlcSlice
24-
? $$"""
25-
for (int i = 0; i < {{argsVar}}.{{param}}.Length; i++)
26-
{{commandVar}}.Parameters.AddWithValue($"@{{param}}Arg{i}", {{argsVar}}.{{param}}[i]);
27-
"""
28-
: $$"""
29-
if ({{argsVar}}.{{param}} != null)
30-
{{commandVar}}.Parameters.AddWithValue("@{{p.Column.Name}}", {{argsVar}}.{{param}});
31-
""";
23+
if (p.Column.IsSqlcSlice)
24+
return $$"""
25+
for (int i = 0; i < {{argsVar}}.{{param}}.Length; i++)
26+
{{commandVar}}.Parameters.AddWithValue($"@{{param}}Arg{i}", {{argsVar}}.{{param}}[i]);
27+
""";
28+
29+
var addParamToCommand = $"""{commandVar}.Parameters.AddWithValue("@{p.Column.Name}", {argsVar}.{param});""";
30+
return ShouldCheckParameterForNull(p)
31+
? $"""
32+
if ({argsVar}.{param} != null)
33+
{addParamToCommand}
34+
"""
35+
: addParamToCommand;
3236
}).JoinByNewLine();
3337
}
3438

35-
public static string ConstructDapperParamsDict(IList<Parameter> parameters)
39+
public string ConstructDapperParamsDict(IList<Parameter> parameters)
3640
{
3741
if (!parameters.Any()) return string.Empty;
3842
var initParamsDict = $"var {Variable.QueryParams.AsVarName()} = new Dictionary<string, object>();";
@@ -49,20 +53,28 @@ public static string ConstructDapperParamsDict(IList<Parameter> parameters)
4953
""";
5054

5155
var addParamToDict = $"{queryParamsVar}.Add(\"{p.Column.Name}\", {argsVar}.{param});";
52-
return p.Column.NotNull
53-
? addParamToDict
54-
: $"""
55-
if ({argsVar}.{param} != null)
56-
{addParamToDict}
57-
""";
56+
return ShouldCheckParameterForNull(p)
57+
? $"""
58+
if ({argsVar}.{param} != null)
59+
{addParamToDict}
60+
"""
61+
: addParamToDict;
5862
});
5963

60-
return $$"""
61-
{{initParamsDict}}
62-
{{dapperParamsCommands.JoinByNewLine()}}
64+
return $"""
65+
{initParamsDict}
66+
{dapperParamsCommands.JoinByNewLine()}
6367
""";
6468
}
6569

70+
private bool ShouldCheckParameterForNull(Parameter parameter)
71+
{
72+
if (parameter.Column.IsArray || parameter.Column.NotNull)
73+
return false;
74+
var csharpType = dbDriver.GetCsharpType(parameter.Column);
75+
return dbDriver.IsTypeNullable(csharpType);
76+
}
77+
6678
public static string AwaitReaderRow()
6779
{
6880
return $"await {Variable.Reader.AsVarName()}.ReadAsync()";
@@ -150,7 +162,7 @@ string GetNullExpression(Column column)
150162
var csharpType = dbDriver.GetCsharpType(column);
151163
if (csharpType == "string")
152164
return "string.Empty";
153-
return !dbDriver.Options.DotnetFramework.LatestDotnetSupported() && dbDriver.IsTypeNullableForAllRuntimes(csharpType)
165+
return !dbDriver.Options.DotnetFramework.IsDotnetCore() && dbDriver.IsTypeNullable(csharpType)
154166
? $"({csharpType}) null"
155167
: "null";
156168
}
@@ -165,7 +177,7 @@ string InstantiateDataclassInternal(string name, IEnumerable<string> fieldsInit)
165177
return $$"""
166178
new {{name}}
167179
{
168-
{{string.Join(",\n", fieldsInit)}}
180+
{{fieldsInit.JoinByComma()}}
169181
}
170182
""";
171183
}

Drivers/Generators/ExecLastIdDeclareGen.cs

+8-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public MemberDeclarationSyntax Generate(string queryTextConstant, string argInte
1212
{
1313
var parametersStr = CommonGen.GetMethodParameterList(argInterface, query.Params);
1414
return ParseMemberDeclaration($$"""
15-
public async Task<{{dbDriver.GetIdColumnType()}}> {{query.Name}}({{parametersStr}})
15+
public async Task<{{dbDriver.GetIdColumnType(query)}}> {{query.Name}}({{parametersStr}})
1616
{
1717
{{GetMethodBody(queryTextConstant, query)}}
1818
}
@@ -29,22 +29,23 @@ private string GetMethodBody(string queryTextConstant, Query query)
2929
string GetAsDapper()
3030
{
3131
var dapperParamsSection = CommonGen.ConstructDapperParamsDict(query.Params);
32-
var dapperArgs = dapperParamsSection != string.Empty
33-
? $", {Variable.QueryParams.AsVarName()}"
34-
: string.Empty;
32+
var dapperArgs = dapperParamsSection == string.Empty
33+
? string.Empty
34+
: $", {Variable.QueryParams.AsVarName()}";
3535
return $$"""
3636
using ({{establishConnection}})
3737
{{{sqlTextTransform}}{{dapperParamsSection}}
38-
return await {{Variable.Connection.AsVarName()}}.QuerySingleAsync<{{dbDriver.GetIdColumnType()}}>({{queryTextConstant}}{{dapperArgs}});
38+
return await {{Variable.Connection.AsVarName()}}.QuerySingleAsync<{{dbDriver.GetIdColumnType(query)}}>({{queryTextConstant}}{{dapperArgs}});
3939
}
4040
""";
4141
}
4242

4343
string GetAsDriver()
4444
{
45-
var createSqlCommand = dbDriver.CreateSqlCommand(sqlTextTransform != string.Empty ? Variable.SqlText.AsVarName() : queryTextConstant);
45+
var sqlTextVar = sqlTextTransform == string.Empty ? queryTextConstant : Variable.SqlText.AsVarName();
46+
var createSqlCommand = dbDriver.CreateSqlCommand(sqlTextVar);
4647
var commandParameters = CommonGen.AddParametersToCommand(query.Params);
47-
var returnLastId = ((IExecLastId)dbDriver).GetLastIdStatement().JoinByNewLine();
48+
var returnLastId = ((IExecLastId)dbDriver).GetLastIdStatement(query).JoinByNewLine();
4849
return $$"""
4950
using ({{establishConnection}})
5051
{

Drivers/Generators/ManyDeclareGen.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ string GetAsDapper()
3636
{
3737
var dapperParamsSection = CommonGen.ConstructDapperParamsDict(query.Params);
3838
var dapperArgs = dapperParamsSection != string.Empty ? $", {Variable.QueryParams.AsVarName()}" : string.Empty;
39-
var returnType = dbDriver.AddNullableSuffix(returnInterface, true);
39+
var returnType = dbDriver.AddNullableSuffixIfNeeded(returnInterface, true);
4040
var sqlQuery = sqlTextTransform != string.Empty ? Variable.SqlText.AsVarName() : queryTextConstant;
4141

4242
return $$"""

Drivers/Generators/OneDeclareGen.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class OneDeclareGen(DbDriver dbDriver)
1212

1313
public MemberDeclarationSyntax Generate(string queryTextConstant, string argInterface, string returnInterface, Query query)
1414
{
15-
var returnType = $"Task<{dbDriver.AddNullableSuffix(returnInterface, false)}>";
15+
var returnType = $"Task<{dbDriver.AddNullableSuffixIfNeeded(returnInterface, false)}>";
1616
var parametersStr = CommonGen.GetMethodParameterList(argInterface, query.Params);
1717
return ParseMemberDeclaration($$"""
1818
public async {{returnType}} {{query.Name}}({{parametersStr}})
@@ -37,7 +37,7 @@ string GetAsDapper()
3737
{
3838
var dapperParamsSection = CommonGen.ConstructDapperParamsDict(query.Params);
3939
var dapperArgs = dapperParamsSection != string.Empty ? $", {Variable.QueryParams.AsVarName()}" : string.Empty;
40-
var returnType = dbDriver.AddNullableSuffix(returnInterface, false);
40+
var returnType = dbDriver.AddNullableSuffixIfNeeded(returnInterface, false);
4141

4242
return $$"""
4343
using ({{establishConnection}})

0 commit comments

Comments
 (0)