Skip to content

Commit 538318f

Browse files
added rbs type support in generated code (#28)
1 parent bfb32fd commit 538318f

16 files changed

+654
-112
lines changed

CodeGenerator/CodeGenerator.cs

+51-28
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public Task<GenerateResponse> Generate(GenerateRequest generateRequest)
4646
DbDriver = InstantiateDriver();
4747
var fileQueries = GetFileQueries();
4848
var files = fileQueries
49-
.Select(fq => GenerateFile(fq.Value, fq.Key))
49+
.SelectMany(fq => GenerateFiles(fq.Value, fq.Key))
5050
.AppendIfNotNull(GenerateGemfile());
5151
return Task.FromResult(new GenerateResponse { Files = { files } });
5252

@@ -67,20 +67,38 @@ string QueryFilenameToClassName(string filenameWithExtension)
6767
}
6868
}
6969

70-
private File GenerateFile(IList<Query> queries, string className)
70+
private IEnumerable<File> GenerateFiles(IList<Query> queries, string className)
7171
{
7272
var (requiredGems, moduleDeclaration) = GenerateModule(queries, className);
73-
var contents = $"""
74-
{AutoGeneratedComment}
75-
{requiredGems.Select(r => r.Build()).JoinByNewLine()}
76-
77-
{moduleDeclaration.Build()}
78-
""";
79-
return new File
73+
IEnumerable<File> files = new List<File>
8074
{
81-
Name = $"{className.SnakeCase()}.rb",
82-
Contents = ByteString.CopyFromUtf8(contents)
75+
new()
76+
{
77+
Name = $"{className.SnakeCase()}.rb",
78+
Contents = ByteString.CopyFromUtf8(
79+
$"""
80+
{AutoGeneratedComment}
81+
{requiredGems.Select(r => r.Build()).JoinByNewLine()}
82+
83+
{moduleDeclaration.BuildCode()}
84+
"""
85+
)
86+
}
8387
};
88+
if (!Options.GenerateTypes)
89+
return files;
90+
91+
files = files.Append(new File
92+
{
93+
Name = $"{className.SnakeCase()}.rbs",
94+
Contents = ByteString.CopyFromUtf8(
95+
$"""
96+
{AutoGeneratedComment}
97+
{moduleDeclaration.BuildType()}
98+
"""
99+
)
100+
});
101+
return files;
84102
}
85103

86104
private File? GenerateGemfile()
@@ -91,15 +109,18 @@ private File GenerateFile(IList<Query> queries, string className)
91109
return new File
92110
{
93111
Name = "Gemfile",
94-
Contents = ByteString.CopyFromUtf8($"""
95-
source 'https://rubygems.org'
112+
Contents = ByteString.CopyFromUtf8(
113+
$"""
114+
source 'https://rubygems.org'
96115
97-
{requireGems}
98-
""")
116+
{requireGems}
117+
"""
118+
)
99119
};
100120
}
101121

102-
private (IEnumerable<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries, string className)
122+
private (IEnumerable<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries,
123+
string className)
103124
{
104125
var requiredGems = DbDriver.GetRequiredGems();
105126
var initMethod = DbDriver.GetInitMethod();
@@ -130,37 +151,39 @@ ClassDeclaration GetClassDeclaration()
130151
}
131152
}
132153

133-
private static SimpleStatement GenerateDataclass(string name, ClassMember classMember, IEnumerable<Column> columns,
154+
private IComposableRbsType GenerateDataclass(string funcName, ClassMember classMember, IList<Column> columns,
134155
Options options)
135156
{
136-
var dataclassName = $"{name.FirstCharToUpper()}{classMember.Name()}";
137-
var dataColumns = columns.Select(c => $":{c.Name.ToLower()}").ToList();
138-
var dataColumnsStr = dataColumns.JoinByCommaAndFormat();
139-
return new SimpleStatement(dataclassName,
140-
new SimpleExpression(options.RubyVersion.ImmutableDataSupported()
141-
? $"Data.define({dataColumnsStr})"
142-
: $"Struct.new({dataColumnsStr})"));
157+
var dataclassName = $"{funcName.FirstCharToUpper()}{classMember.Name()}";
158+
var nameToType = columns.ToDictionary(
159+
kv => kv.Name,
160+
kv => DbDriver.GetColumnType(kv)
161+
);
162+
return options.RubyVersion.ImmutableDataSupported()
163+
? new DataDefine(dataclassName, nameToType)
164+
: new NewStruct(dataclassName, nameToType);
143165
}
144166

145-
private SimpleStatement? GetQueryColumnsDataclass(Query query)
167+
private IComposableRbsType? GetQueryColumnsDataclass(Query query)
146168
{
147169
return query.Columns.Count <= 0
148170
? null
149171
: GenerateDataclass(query.Name, ClassMember.Row, query.Columns, Options);
150172
}
151173

152-
private SimpleStatement? GetQueryParamsDataclass(Query query)
174+
175+
private IComposableRbsType? GetQueryParamsDataclass(Query query)
153176
{
154177
if (query.Params.Count <= 0)
155178
return null;
156-
var columns = query.Params.Select(p => p.Column);
179+
var columns = query.Params.Select(p => p.Column).ToList();
157180
return GenerateDataclass(query.Name, ClassMember.Args, columns, Options);
158181
}
159182

160183
private MethodDeclaration GetMethodDeclaration(Query query)
161184
{
162185
var queryTextConstant = GetInterfaceName(ClassMember.Sql);
163-
var argInterface = GetInterfaceName(ClassMember.Args).SnakeCase();
186+
var argInterface = GetInterfaceName(ClassMember.Args);
164187
var returnInterface = GetInterfaceName(ClassMember.Row);
165188
var funcName = query.Name.SnakeCase();
166189

Drivers/DbDriver.cs

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Plugin;
22
using RubyCodegen;
3+
using System;
34
using System.Collections.Generic;
45

56
namespace SqlcGenRuby.Drivers;
@@ -15,7 +16,20 @@ protected static IEnumerable<RequireGem> GetCommonGems()
1516

1617
public abstract MethodDeclaration GetInitMethod();
1718

18-
public abstract SimpleStatement QueryTextConstantDeclare(Query query);
19+
protected abstract List<(string, HashSet<string>)> GetColumnMapping();
20+
21+
public string GetColumnType(Column column)
22+
{
23+
var columnType = column.Type.Name.ToLower();
24+
foreach (var (csharpType, dbTypes) in GetColumnMapping())
25+
{
26+
if (dbTypes.Contains(columnType))
27+
return csharpType;
28+
}
29+
throw new NotSupportedException($"Unsupported column type: {column.Type.Name}");
30+
}
31+
32+
public abstract PropertyDeclaration QueryTextConstantDeclare(Query query);
1933

2034
public abstract IComposable PrepareStmt(string funcName, string queryTextConstant);
2135

Drivers/MethodGen.cs

+19-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ public MethodDeclaration OneDeclare(string funcName, string queryTextConstant, s
2626
]
2727
).ToList();
2828

29-
return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
29+
return new MethodDeclaration(
30+
funcName,
31+
argInterface,
32+
GetMethodArgs(argInterface, parameters),
33+
returnInterface,
3034
new List<IComposable>
3135
{
3236
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList())
@@ -58,7 +62,11 @@ public MethodDeclaration ManyDeclare(string funcName, string queryTextConstant,
5862
]
5963
);
6064

61-
return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
65+
return new MethodDeclaration(
66+
funcName,
67+
argInterface,
68+
GetMethodArgs(argInterface, parameters),
69+
returnInterface,
6270
new List<IComposable>
6371
{
6472
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList())
@@ -76,7 +84,10 @@ public MethodDeclaration ExecDeclare(string funcName, string queryTextConstant,
7684
.Append(dbDriver.ExecuteStmt(funcName, queryParams))
7785
.ToList();
7886

79-
return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
87+
return new MethodDeclaration(funcName,
88+
argInterface,
89+
GetMethodArgs(argInterface, parameters),
90+
null,
8091
new List<IComposable>
8192
{
8293
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList()
@@ -100,7 +111,10 @@ public MethodDeclaration ExecLastIdDeclare(string funcName, string queryTextCons
100111
);
101112

102113
return new MethodDeclaration(
103-
funcName, GetMethodArgs(argInterface, parameters),
114+
funcName,
115+
argInterface,
116+
GetMethodArgs(argInterface, parameters),
117+
"Integer",
104118
new List<IComposable>
105119
{
106120
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(),
@@ -111,7 +125,7 @@ public MethodDeclaration ExecLastIdDeclare(string funcName, string queryTextCons
111125

112126
private static SimpleStatement? GetQueryParams(string argInterface, IList<Parameter> parameters)
113127
{
114-
var queryParams = parameters.Select(p => $"{argInterface}.{p.Column.Name}").ToList();
128+
var queryParams = parameters.Select(p => $"{argInterface.SnakeCase()}.{p.Column.Name}").ToList();
115129
return queryParams.Count == 0
116130
? null
117131
: new SimpleStatement(Variable.QueryParams.AsVar(),

Drivers/Mysql2Driver.cs

+53-5
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,65 @@ public override IEnumerable<RequireGem> GetRequiredGems()
2121

2222
public override MethodDeclaration GetInitMethod()
2323
{
24-
return new MethodDeclaration("initialize", "connection_pool_params, mysql2_params",
24+
var connectionPoolInit = new NewObject("ConnectionPool",
25+
new[] { new SimpleExpression("**connection_pool_params") },
26+
new List<IComposable> { new SimpleExpression("Mysql2::Client.new(**mysql2_params)") });
27+
return new MethodDeclaration(
28+
"initialize",
29+
"Hash[String, String], Hash[String, String]",
30+
"connection_pool_params, mysql2_params",
31+
null,
2532
[
26-
new SimpleStatement(Variable.Pool.AsProperty(), new SimpleExpression(
27-
"ConnectionPool::new(**connection_pool_params) { Mysql2::Client.new(**mysql2_params) }"))
33+
new PropertyDeclaration(Variable.Pool.AsProperty(), "untyped", connectionPoolInit)
2834
]
2935
);
3036
}
3137

32-
public override SimpleStatement QueryTextConstantDeclare(Query query)
38+
protected override List<(string, HashSet<string>)> GetColumnMapping()
3339
{
34-
return new SimpleStatement($"{query.Name}{ClassMember.Sql}", new SimpleExpression($"%q({query.Text})"));
40+
return
41+
[
42+
("Array[Integer]", [
43+
"binary",
44+
"bit",
45+
"blob",
46+
"longblob",
47+
"mediumblob",
48+
"tinyblob",
49+
"varbinary"
50+
]),
51+
("String", [
52+
"char",
53+
"date",
54+
"datetime",
55+
"decimal",
56+
"longtext",
57+
"mediumtext",
58+
"text",
59+
"time",
60+
"timestamp",
61+
"tinytext",
62+
"varchar",
63+
"json"
64+
]),
65+
("Integer", [
66+
"bigint",
67+
"int",
68+
"mediumint",
69+
"smallint",
70+
"tinyint",
71+
"year"
72+
]),
73+
("Float", ["double", "float"]),
74+
];
75+
}
76+
77+
public override PropertyDeclaration QueryTextConstantDeclare(Query query)
78+
{
79+
return new PropertyDeclaration(
80+
$"{query.Name}{ClassMember.Sql}",
81+
"String",
82+
new SimpleExpression($"%q({query.Text})"));
3583
}
3684

3785
public override IComposable PrepareStmt(string _, string queryTextConstant)

Drivers/PgDriver.cs

+63-10
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,17 @@ public override IEnumerable<RequireGem> GetRequiredGems()
2424

2525
public override MethodDeclaration GetInitMethod()
2626
{
27-
return new MethodDeclaration("initialize", "connection_pool_params, pg_params",
27+
var connectionPoolInit = new NewObject("ConnectionPool",
28+
new[] { new SimpleExpression("**connection_pool_params") },
29+
PgClientCreate());
30+
return new MethodDeclaration(
31+
"initialize",
32+
"Hash[String, String], Hash[String, String]",
33+
"connection_pool_params, pg_params",
34+
null,
2835
[
29-
new SimpleStatement(
30-
Variable.Pool.AsProperty(),
31-
new NewObject("ConnectionPool",
32-
new[] { new SimpleExpression("**connection_pool_params") },
33-
PgClientCreate())),
34-
new SimpleStatement(Variable.PreparedStatements.AsProperty(), new SimpleExpression("Set[]"))
36+
new PropertyDeclaration(Variable.Pool.AsProperty(), "untyped", connectionPoolInit),
37+
new PropertyDeclaration(Variable.PreparedStatements.AsProperty(), "Set[String]", new SimpleExpression("Set[]"))
3538
]
3639
);
3740

@@ -50,11 +53,61 @@ IList<IComposable> PgClientCreate()
5053
}
5154
}
5255

53-
public override SimpleStatement QueryTextConstantDeclare(Query query)
56+
protected override List<(string, HashSet<string>)> GetColumnMapping()
57+
{
58+
return
59+
[
60+
("bool", [
61+
"bool",
62+
"boolean"
63+
]),
64+
("Array[Integer]", [
65+
"binary",
66+
"bit",
67+
"bytea",
68+
"blob",
69+
"longblob",
70+
"mediumblob",
71+
"tinyblob",
72+
"varbinary"
73+
]),
74+
("String", [
75+
"char",
76+
"date",
77+
"datetime",
78+
"longtext",
79+
"mediumtext",
80+
"text",
81+
"bpchar",
82+
"time",
83+
"timestamp",
84+
"tinytext",
85+
"varchar",
86+
"json"
87+
]),
88+
("Integer", [
89+
"int2",
90+
"int4",
91+
"int8",
92+
"serial",
93+
"bigserial"
94+
]),
95+
("Float", [
96+
"numeric",
97+
"float4",
98+
"float8",
99+
"decimal"
100+
])
101+
];
102+
}
103+
104+
public override PropertyDeclaration QueryTextConstantDeclare(Query query)
54105
{
55106
var counter = 1;
56-
var transformedQueryText = BindRegexToReplace().Replace(query.Text, m => $"${counter++}");
57-
return new SimpleStatement($"{query.Name}{ClassMember.Sql}",
107+
var transformedQueryText = BindRegexToReplace().Replace(query.Text, _ => $"${counter++}");
108+
return new PropertyDeclaration(
109+
$"{query.Name}{ClassMember.Sql}",
110+
"String",
58111
new SimpleExpression($"%q({transformedQueryText})"));
59112
}
60113

Extensions/ListExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public static string JoinByNewLine(this IEnumerable<string> me, int cnt = 1)
1717
public static string JoinByCommaAndFormat(this IList<string> me)
1818
{
1919
return me.Count < MaxElementsPerLine
20-
? string.Join(", ", me).Indent()
20+
? string.Join(", ", me)
2121
: $"\n{string.Join(",\n", me).Indent()}\n";
2222
}
2323
}

0 commit comments

Comments
 (0)