Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added rbs type support in generated code #28

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 51 additions & 28 deletions CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public Task<GenerateResponse> Generate(GenerateRequest generateRequest)
DbDriver = InstantiateDriver();
var fileQueries = GetFileQueries();
var files = fileQueries
.Select(fq => GenerateFile(fq.Value, fq.Key))
.SelectMany(fq => GenerateFiles(fq.Value, fq.Key))
.AppendIfNotNull(GenerateGemfile());
return Task.FromResult(new GenerateResponse { Files = { files } });

Expand All @@ -67,20 +67,38 @@ string QueryFilenameToClassName(string filenameWithExtension)
}
}

private File GenerateFile(IList<Query> queries, string className)
private IEnumerable<File> GenerateFiles(IList<Query> queries, string className)
{
var (requiredGems, moduleDeclaration) = GenerateModule(queries, className);
var contents = $"""
{AutoGeneratedComment}
{requiredGems.Select(r => r.Build()).JoinByNewLine()}

{moduleDeclaration.Build()}
""";
return new File
IEnumerable<File> files = new List<File>
{
Name = $"{className.SnakeCase()}.rb",
Contents = ByteString.CopyFromUtf8(contents)
new()
{
Name = $"{className.SnakeCase()}.rb",
Contents = ByteString.CopyFromUtf8(
$"""
{AutoGeneratedComment}
{requiredGems.Select(r => r.Build()).JoinByNewLine()}

{moduleDeclaration.BuildCode()}
"""
)
}
};
if (!Options.GenerateTypes)
return files;

files = files.Append(new File
{
Name = $"{className.SnakeCase()}.rbs",
Contents = ByteString.CopyFromUtf8(
$"""
{AutoGeneratedComment}
{moduleDeclaration.BuildType()}
"""
)
});
return files;
}

private File? GenerateGemfile()
Expand All @@ -91,15 +109,18 @@ private File GenerateFile(IList<Query> queries, string className)
return new File
{
Name = "Gemfile",
Contents = ByteString.CopyFromUtf8($"""
source 'https://rubygems.org'
Contents = ByteString.CopyFromUtf8(
$"""
source 'https://rubygems.org'

{requireGems}
""")
{requireGems}
"""
)
};
}

private (IEnumerable<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries, string className)
private (IEnumerable<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries,
string className)
{
var requiredGems = DbDriver.GetRequiredGems();
var initMethod = DbDriver.GetInitMethod();
Expand Down Expand Up @@ -130,37 +151,39 @@ ClassDeclaration GetClassDeclaration()
}
}

private static SimpleStatement GenerateDataclass(string name, ClassMember classMember, IEnumerable<Column> columns,
private IComposableRbsType GenerateDataclass(string funcName, ClassMember classMember, IList<Column> columns,
Options options)
{
var dataclassName = $"{name.FirstCharToUpper()}{classMember.Name()}";
var dataColumns = columns.Select(c => $":{c.Name.ToLower()}").ToList();
var dataColumnsStr = dataColumns.JoinByCommaAndFormat();
return new SimpleStatement(dataclassName,
new SimpleExpression(options.RubyVersion.ImmutableDataSupported()
? $"Data.define({dataColumnsStr})"
: $"Struct.new({dataColumnsStr})"));
var dataclassName = $"{funcName.FirstCharToUpper()}{classMember.Name()}";
var nameToType = columns.ToDictionary(
kv => kv.Name,
kv => DbDriver.GetColumnType(kv)
);
return options.RubyVersion.ImmutableDataSupported()
? new DataDefine(dataclassName, nameToType)
: new NewStruct(dataclassName, nameToType);
}

private SimpleStatement? GetQueryColumnsDataclass(Query query)
private IComposableRbsType? GetQueryColumnsDataclass(Query query)
{
return query.Columns.Count <= 0
? null
: GenerateDataclass(query.Name, ClassMember.Row, query.Columns, Options);
}

private SimpleStatement? GetQueryParamsDataclass(Query query)

private IComposableRbsType? GetQueryParamsDataclass(Query query)
{
if (query.Params.Count <= 0)
return null;
var columns = query.Params.Select(p => p.Column);
var columns = query.Params.Select(p => p.Column).ToList();
return GenerateDataclass(query.Name, ClassMember.Args, columns, Options);
}

private MethodDeclaration GetMethodDeclaration(Query query)
{
var queryTextConstant = GetInterfaceName(ClassMember.Sql);
var argInterface = GetInterfaceName(ClassMember.Args).SnakeCase();
var argInterface = GetInterfaceName(ClassMember.Args);
var returnInterface = GetInterfaceName(ClassMember.Row);
var funcName = query.Name.SnakeCase();

Expand Down
16 changes: 15 additions & 1 deletion Drivers/DbDriver.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Plugin;
using RubyCodegen;
using System;
using System.Collections.Generic;

namespace SqlcGenRuby.Drivers;
Expand All @@ -15,7 +16,20 @@ protected static IEnumerable<RequireGem> GetCommonGems()

public abstract MethodDeclaration GetInitMethod();

public abstract SimpleStatement QueryTextConstantDeclare(Query query);
protected abstract List<(string, HashSet<string>)> GetColumnMapping();

public string GetColumnType(Column column)
{
var columnType = column.Type.Name.ToLower();
foreach (var (csharpType, dbTypes) in GetColumnMapping())
{
if (dbTypes.Contains(columnType))
return csharpType;
}
throw new NotSupportedException($"Unsupported column type: {column.Type.Name}");
}

public abstract PropertyDeclaration QueryTextConstantDeclare(Query query);

public abstract IComposable PrepareStmt(string funcName, string queryTextConstant);

Expand Down
24 changes: 19 additions & 5 deletions Drivers/MethodGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ public MethodDeclaration OneDeclare(string funcName, string queryTextConstant, s
]
).ToList();

return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
return new MethodDeclaration(
funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
returnInterface,
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList())
Expand Down Expand Up @@ -58,7 +62,11 @@ public MethodDeclaration ManyDeclare(string funcName, string queryTextConstant,
]
);

return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
return new MethodDeclaration(
funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
returnInterface,
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList())
Expand All @@ -76,7 +84,10 @@ public MethodDeclaration ExecDeclare(string funcName, string queryTextConstant,
.Append(dbDriver.ExecuteStmt(funcName, queryParams))
.ToList();

return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
return new MethodDeclaration(funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
null,
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList()
Expand All @@ -100,7 +111,10 @@ public MethodDeclaration ExecLastIdDeclare(string funcName, string queryTextCons
);

return new MethodDeclaration(
funcName, GetMethodArgs(argInterface, parameters),
funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
"Integer",
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(),
Expand All @@ -111,7 +125,7 @@ public MethodDeclaration ExecLastIdDeclare(string funcName, string queryTextCons

private static SimpleStatement? GetQueryParams(string argInterface, IList<Parameter> parameters)
{
var queryParams = parameters.Select(p => $"{argInterface}.{p.Column.Name}").ToList();
var queryParams = parameters.Select(p => $"{argInterface.SnakeCase()}.{p.Column.Name}").ToList();
return queryParams.Count == 0
? null
: new SimpleStatement(Variable.QueryParams.AsVar(),
Expand Down
58 changes: 53 additions & 5 deletions Drivers/Mysql2Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,65 @@ public override IEnumerable<RequireGem> GetRequiredGems()

public override MethodDeclaration GetInitMethod()
{
return new MethodDeclaration("initialize", "connection_pool_params, mysql2_params",
var connectionPoolInit = new NewObject("ConnectionPool",
new[] { new SimpleExpression("**connection_pool_params") },
new List<IComposable> { new SimpleExpression("Mysql2::Client.new(**mysql2_params)") });
return new MethodDeclaration(
"initialize",
"Hash[String, String], Hash[String, String]",
"connection_pool_params, mysql2_params",
null,
[
new SimpleStatement(Variable.Pool.AsProperty(), new SimpleExpression(
"ConnectionPool::new(**connection_pool_params) { Mysql2::Client.new(**mysql2_params) }"))
new PropertyDeclaration(Variable.Pool.AsProperty(), "untyped", connectionPoolInit)
]
);
}

public override SimpleStatement QueryTextConstantDeclare(Query query)
protected override List<(string, HashSet<string>)> GetColumnMapping()
{
return new SimpleStatement($"{query.Name}{ClassMember.Sql}", new SimpleExpression($"%q({query.Text})"));
return
[
("Array[Integer]", [
"binary",
"bit",
"blob",
"longblob",
"mediumblob",
"tinyblob",
"varbinary"
]),
("String", [
"char",
"date",
"datetime",
"decimal",
"longtext",
"mediumtext",
"text",
"time",
"timestamp",
"tinytext",
"varchar",
"json"
]),
("Integer", [
"bigint",
"int",
"mediumint",
"smallint",
"tinyint",
"year"
]),
("Float", ["double", "float"]),
];
}

public override PropertyDeclaration QueryTextConstantDeclare(Query query)
{
return new PropertyDeclaration(
$"{query.Name}{ClassMember.Sql}",
"String",
new SimpleExpression($"%q({query.Text})"));
}

public override IComposable PrepareStmt(string _, string queryTextConstant)
Expand Down
73 changes: 63 additions & 10 deletions Drivers/PgDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ public override IEnumerable<RequireGem> GetRequiredGems()

public override MethodDeclaration GetInitMethod()
{
return new MethodDeclaration("initialize", "connection_pool_params, pg_params",
var connectionPoolInit = new NewObject("ConnectionPool",
new[] { new SimpleExpression("**connection_pool_params") },
PgClientCreate());
return new MethodDeclaration(
"initialize",
"Hash[String, String], Hash[String, String]",
"connection_pool_params, pg_params",
null,
[
new SimpleStatement(
Variable.Pool.AsProperty(),
new NewObject("ConnectionPool",
new[] { new SimpleExpression("**connection_pool_params") },
PgClientCreate())),
new SimpleStatement(Variable.PreparedStatements.AsProperty(), new SimpleExpression("Set[]"))
new PropertyDeclaration(Variable.Pool.AsProperty(), "untyped", connectionPoolInit),
new PropertyDeclaration(Variable.PreparedStatements.AsProperty(), "Set[String]", new SimpleExpression("Set[]"))
]
);

Expand All @@ -50,11 +53,61 @@ IList<IComposable> PgClientCreate()
}
}

public override SimpleStatement QueryTextConstantDeclare(Query query)
protected override List<(string, HashSet<string>)> GetColumnMapping()
{
return
[
("bool", [
"bool",
"boolean"
]),
("Array[Integer]", [
"binary",
"bit",
"bytea",
"blob",
"longblob",
"mediumblob",
"tinyblob",
"varbinary"
]),
("String", [
"char",
"date",
"datetime",
"longtext",
"mediumtext",
"text",
"bpchar",
"time",
"timestamp",
"tinytext",
"varchar",
"json"
]),
("Integer", [
"int2",
"int4",
"int8",
"serial",
"bigserial"
]),
("Float", [
"numeric",
"float4",
"float8",
"decimal"
])
];
}

public override PropertyDeclaration QueryTextConstantDeclare(Query query)
{
var counter = 1;
var transformedQueryText = BindRegexToReplace().Replace(query.Text, m => $"${counter++}");
return new SimpleStatement($"{query.Name}{ClassMember.Sql}",
var transformedQueryText = BindRegexToReplace().Replace(query.Text, _ => $"${counter++}");
return new PropertyDeclaration(
$"{query.Name}{ClassMember.Sql}",
"String",
new SimpleExpression($"%q({transformedQueryText})"));
}

Expand Down
2 changes: 1 addition & 1 deletion Extensions/ListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static string JoinByNewLine(this IEnumerable<string> me, int cnt = 1)
public static string JoinByCommaAndFormat(this IList<string> me)
{
return me.Count < MaxElementsPerLine
? string.Join(", ", me).Indent()
? string.Join(", ", me)
: $"\n{string.Join(",\n", me).Indent()}\n";
}
}
Loading
Loading