@@ -9,14 +9,30 @@ namespace SqlcGenCsharp.Drivers;
9
9
10
10
public record ConnectionGenCommands ( string EstablishConnection , string ConnectionOpen ) ;
11
11
12
- public abstract class DbDriver ( Options options , Dictionary < string , Table > tables )
12
+ public abstract class DbDriver
13
13
{
14
- public Options Options { get ; } = options ;
14
+ public Options Options { get ; }
15
15
16
- public Dictionary < string , Table > Tables { get ; } = tables ;
16
+ public Dictionary < string , Table > Tables { get ; }
17
+
18
+ private HashSet < string > NullableTypesInDotnetCore { get ; } = [ "string" ] ;
19
+
20
+ private HashSet < string > NullableTypes { get ; } = [ "long" , "double" , "int" , "float" , "bool" , "DateTime" ] ;
17
21
18
22
protected abstract List < ColumnMapping > ColumnMappings { get ; }
19
23
24
+ protected DbDriver ( Options options , Dictionary < string , Table > tables )
25
+ {
26
+ Options = options ;
27
+ Tables = tables ;
28
+
29
+ if ( ! Options . DotnetFramework . IsDotnetCore ( ) ) return ; // `string?` is possible only in .Net Core
30
+ foreach ( var t in NullableTypesInDotnetCore )
31
+ {
32
+ NullableTypes . Add ( t ) ;
33
+ }
34
+ }
35
+
20
36
public virtual UsingDirectiveSyntax [ ] GetUsingDirectives ( )
21
37
{
22
38
var usingDirectives = new List < UsingDirectiveSyntax >
@@ -31,11 +47,11 @@ public virtual UsingDirectiveSyntax[] GetUsingDirectives()
31
47
return usingDirectives . ToArray ( ) ;
32
48
}
33
49
34
- public string AddNullableSuffix ( string csharpType , bool notNull )
50
+ public string AddNullableSuffixIfNeeded ( string csharpType , bool notNull )
35
51
{
36
52
if ( notNull ) return csharpType ;
37
- if ( IsTypeNullableForAllRuntimes ( csharpType ) ) return $ "{ csharpType } ?";
38
- return Options . DotnetFramework . LatestDotnetSupported ( ) ? $ "{ csharpType } ?" : csharpType ;
53
+ if ( IsTypeNullable ( csharpType ) ) return $ "{ csharpType } ?";
54
+ return Options . DotnetFramework . IsDotnetCore ( ) ? $ "{ csharpType } ?" : csharpType ;
39
55
}
40
56
41
57
public string GetCsharpType ( Column column )
@@ -44,7 +60,7 @@ public string GetCsharpType(Column column)
44
60
return column . EmbedTable . Name . ToModelName ( ) ;
45
61
46
62
var columnCsharpType = string . IsNullOrEmpty ( column . Type . Name ) ? "object" : GetTypeWithoutNullableSuffix ( ) ;
47
- return AddNullableSuffix ( columnCsharpType , column . NotNull ) ;
63
+ return AddNullableSuffixIfNeeded ( columnCsharpType , column . NotNull ) ;
48
64
49
65
string GetTypeWithoutNullableSuffix ( )
50
66
{
@@ -89,31 +105,30 @@ public string GetColumnReader(Column column, int ordinal)
89
105
90
106
public abstract string CreateSqlCommand ( string sqlTextConstant ) ;
91
107
92
- private HashSet < string > NullableTypesInAllRuntimes { get ; } = [ "long" , "double" , "int" , "float" , "bool" , "DateTime" ] ;
93
-
94
108
// TODO move out from driver + rename
95
- public bool IsTypeNullableForAllRuntimes ( string csharpType )
109
+ public bool IsTypeNullable ( string csharpType )
96
110
{
97
- return NullableTypesInAllRuntimes . Contains ( csharpType . Replace ( "?" , "" ) ) ;
111
+ return NullableTypes . Contains ( csharpType . Replace ( "?" , "" ) ) ;
98
112
}
99
113
100
- protected static string GetConnectionStringField ( )
114
+ /*
115
+ Since there is no indication of the primary key column in SQLC protobuf (assuming it is a single column even),
116
+ this method uses a few heuristics to assess the type of the id column
117
+ */
118
+ public string GetIdColumnType ( Query query )
101
119
{
102
- return Variable . ConnectionString . AsPropertyName ( ) ;
103
- }
120
+ var tableColumns = Tables [ query . InsertIntoTable . Name ] . Columns ;
121
+ var idColumn = tableColumns . First ( c => c . Name . Equals ( "id" , StringComparison . OrdinalIgnoreCase ) ) ;
122
+ if ( idColumn is not null )
123
+ return GetCsharpType ( idColumn ) ;
104
124
105
- public string GetIdColumnType ( )
106
- {
107
- return Options . DriverName switch
108
- {
109
- DriverName . Sqlite => "int" ,
110
- _ => "long"
111
- } ;
125
+ idColumn = tableColumns . First ( c => c . Name . Contains ( "id" , StringComparison . CurrentCultureIgnoreCase ) ) ;
126
+ return GetCsharpType ( idColumn ?? tableColumns [ 0 ] ) ;
112
127
}
113
128
114
- public virtual string [ ] GetLastIdStatement ( )
129
+ public virtual string [ ] GetLastIdStatement ( Query query )
115
130
{
116
- var convertFunc = GetIdColumnType ( ) == "int" ? "ToInt32" : "ToInt64" ;
131
+ var convertFunc = GetIdColumnType ( query ) == "int" ? "ToInt32" : "ToInt64" ; // TODO refactor
117
132
return
118
133
[
119
134
$ "var { Variable . Result . AsVarName ( ) } = await { Variable . Command . AsVarName ( ) } .ExecuteScalarAsync();",
0 commit comments