Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package ru.tinkoff.kora.database.annotation.processor;

import jakarta.annotation.Nullable;
import ru.tinkoff.kora.annotation.processor.common.ProcessingErrorException;

import jakarta.annotation.Nullable;
import javax.annotation.processing.Filer;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.type.DeclaredType;
Expand All @@ -14,10 +14,14 @@
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.regex.Pattern;

public record QueryWithParameters(String rawQuery, List<QueryParameter> parameters) {

public record QueryParameter(String sqlParameterName, int methodIndex, List<Integer> sqlIndexes) {}
public record QueryParameter(String sqlParameterName, int methodIndex, List<QueryIndex> queryIndexes, List<Integer> sqlIndexes) {

public record QueryIndex(int start, int end) { }
}

@Nullable
public QueryParameter find(String name) {
Expand Down Expand Up @@ -104,7 +108,7 @@ public static QueryWithParameters parse(Filer filer,
.toList();

params = params.stream()
.map(p -> new QueryParameter(p.sqlParameterName(), p.methodIndex(), p.sqlIndexes()
.map(p -> new QueryParameter(p.sqlParameterName(), p.methodIndex(), p.queryIndexes(), p.sqlIndexes()
.stream()
.map(paramsNumbers::indexOf)
.toList()
Expand All @@ -116,40 +120,38 @@ public static QueryWithParameters parse(Filer filer,


private static Optional<QueryParameter> parseSimpleParameter(String rawSql, int methodParameterNumber, String sqlParameterName) {
int index = -1;
var result = new ArrayList<Integer>();
while ((index = rawSql.indexOf(":" + sqlParameterName, index + 1)) >= 0) {
var indexAfter = index + sqlParameterName.length() + 1;
if (rawSql.length() >= indexAfter + 1) {
var charAfter = rawSql.charAt(indexAfter);
if (Character.isAlphabetic(charAfter) || charAfter == '_' || charAfter == '$' || Character.isDigit(charAfter)) {
continue;
}
}
result.add(index);
var result = new ArrayList<QueryParameter.QueryIndex>();
var pattern = Pattern.compile("[\\s\\n,(](?<param>:" + sqlParameterName + ")(?=[\\s\\n,:)]|$)");
var matcher = pattern.matcher(rawSql);
while (matcher.find()) {
var mr = matcher.toMatchResult();
var start = mr.start(1);
var end = mr.end();
result.add(new QueryParameter.QueryIndex(start, end));
}

return (result.isEmpty())
? Optional.empty()
: Optional.of(new QueryParameter(sqlParameterName, methodParameterNumber, result));
: Optional.of(new QueryParameter(sqlParameterName, methodParameterNumber, result, result.stream()
.map(QueryParameter.QueryIndex::start)
.toList()));
}

private static Optional<QueryParameter> parseEntityDirectParameter(String rawSql, int methodParameterNumber, String sqlParameterName) {
int index = -1;
var result = new ArrayList<Integer>();
while ((index = rawSql.indexOf(":" + sqlParameterName, index + 1)) >= 0) {
var indexAfter = index + sqlParameterName.length() + 1;
if (rawSql.length() >= indexAfter + 1) {
var charAfter = rawSql.charAt(indexAfter);
if ('.' == charAfter || Character.isAlphabetic(charAfter) || charAfter == '_' || charAfter == '$' || Character.isDigit(charAfter)) {
continue;
}
}
result.add(index);
var result = new ArrayList<QueryParameter.QueryIndex>();
var pattern = Pattern.compile("[\\s\\n,(](?<param>:" + sqlParameterName + ")(?=[\\s\\n,:)]|$)");
var matcher = pattern.matcher(rawSql);
while (matcher.find()) {
var mr = matcher.toMatchResult();
var start = mr.start(1);
var end = mr.end();
result.add(new QueryParameter.QueryIndex(start, end));
}

return (result.isEmpty())
? Optional.empty()
: Optional.of(new QueryParameter(sqlParameterName, methodParameterNumber, result));
: Optional.of(new QueryParameter(sqlParameterName, methodParameterNumber, result, result.stream()
.map(QueryParameter.QueryIndex::start)
.toList()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ru.tinkoff.kora.database.annotation.processor.RepositoryGenerator;
import ru.tinkoff.kora.database.annotation.processor.model.QueryParameter;
import ru.tinkoff.kora.database.annotation.processor.model.QueryParameterParser;
import ru.tinkoff.kora.database.annotation.processor.vertx.VertxRepositoryGenerator;

import javax.annotation.processing.Filer;
import javax.annotation.processing.ProcessingEnvironment;
Expand Down Expand Up @@ -69,11 +70,21 @@ public TypeSpec generate(TypeElement repositoryElement, TypeSpec.Builder type, M
return type.addMethod(constructor.build()).build();
}

private record QueryReplace(int start, int end, String name) {}

private MethodSpec generate(TypeSpec.Builder type, int methodNumber, ExecutableElement method, ExecutableType methodType, QueryWithParameters query, List<QueryParameter> parameters, @Nullable String resultMapperName, FieldFactory parameterMappers) {
var sql = query.rawQuery();
for (var parameter : query.parameters().stream().sorted(Comparator.<QueryWithParameters.QueryParameter, Integer>comparing(p -> p.sqlParameterName().length()).reversed()).toList()) {
sql = sql.replace(":" + parameter.sqlParameterName(), "?");
List<QueryReplace> replaceParams = query.parameters().stream()
.flatMap(p -> p.queryIndexes().stream().map(i -> new QueryReplace(i.start(), i.end(), p.sqlParameterName())))
.sorted(Comparator.comparingInt(QueryReplace::start))
.toList();
int sqlIndexDiff = 0;
for (var parameter : replaceParams) {
int queryIndexAdjusted = parameter.start() - sqlIndexDiff;
sql = sql.substring(0, queryIndexAdjusted) + "?" + sql.substring(queryIndexAdjusted + parameter.name().length() + 1);
sqlIndexDiff += parameter.name().length();
}

var b = DbUtils.queryMethodBuilder(method, methodType);

var queryContextFieldName = "QUERY_CONTEXT_" + methodNumber;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,20 @@ public ClassName repositoryInterface() {
return JdbcTypes.JDBC_REPOSITORY;
}

private record QueryReplace(int start, int end, String name) {}

public MethodSpec generate(TypeElement repositoryElement, TypeSpec.Builder type, int methodNumber, ExecutableElement method, ExecutableType methodType, QueryWithParameters query, List<QueryParameter> parameters, @Nullable String resultMapperName, FieldFactory parameterMappers) {
var batchParam = parameters.stream().filter(QueryParameter.BatchParameter.class::isInstance).findFirst().orElse(null);
var sql = query.rawQuery();
for (var parameter : query.parameters().stream().sorted(Comparator.<QueryWithParameters.QueryParameter>comparingInt(s -> s.sqlParameterName().length()).reversed()).toList()) {
sql = sql.replace(":" + parameter.sqlParameterName(), "?");
List<QueryReplace> replaceParams = query.parameters().stream()
.flatMap(p -> p.queryIndexes().stream().map(i -> new QueryReplace(i.start(), i.end(), p.sqlParameterName())))
.sorted(Comparator.comparingInt(QueryReplace::start))
.toList();
int sqlIndexDiff = 0;
for (var parameter : replaceParams) {
int queryIndexAdjusted = parameter.start() - sqlIndexDiff;
sql = sql.substring(0, queryIndexAdjusted) + "?" + sql.substring(queryIndexAdjusted + parameter.name().length() + 1);
sqlIndexDiff += parameter.name().length();
}

var b = DbUtils.queryMethodBuilder(method, methodType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ru.tinkoff.kora.database.annotation.processor.DbUtils;
import ru.tinkoff.kora.database.annotation.processor.QueryWithParameters;
import ru.tinkoff.kora.database.annotation.processor.RepositoryGenerator;
import ru.tinkoff.kora.database.annotation.processor.jdbc.JdbcRepositoryGenerator;
import ru.tinkoff.kora.database.annotation.processor.model.QueryParameter;
import ru.tinkoff.kora.database.annotation.processor.model.QueryParameterParser;

Expand All @@ -20,10 +21,7 @@
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Elements;
import javax.lang.model.util.Types;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.*;

public final class R2dbcRepositoryGenerator implements RepositoryGenerator {
private final Types types;
Expand Down Expand Up @@ -72,14 +70,32 @@ public TypeSpec generate(TypeElement repositoryElement, TypeSpec.Builder type, M
return type.addMethod(constructor.build()).build();
}

private record QueryReplace(int start, int end, int sqlIndex, String name) {}

private MethodSpec generate(TypeSpec.Builder type, int methodNumber, ExecutableElement method, ExecutableType methodType, QueryWithParameters query, List<QueryParameter> parameters, @Nullable String resultMapperName, FieldFactory parameterMappers) {
final boolean generatedKeys = AnnotationUtils.isAnnotationPresent(method, DbUtils.ID_ANNOTATION);

var sql = query.rawQuery();
for (var parameter : query.parameters().stream().sorted(Comparator.<QueryWithParameters.QueryParameter>comparingInt(s -> s.sqlParameterName().length()).reversed()).toList()) {
for (var sqlIndex : parameter.sqlIndexes()) {
sql = sql.replace(":" + parameter.sqlParameterName(), "$" + (sqlIndex + 1));
}
List<QueryReplace> replaceParams = query.parameters().stream()
.flatMap(p -> {
List<QueryReplace> replaces = new ArrayList<>();
for (int i = 0; i < p.queryIndexes().size(); i++) {
var queryIndex = p.queryIndexes().get(i);
var sqlIndex = p.sqlIndexes().get(i);
replaces.add(new QueryReplace(queryIndex.start(), queryIndex.end(), sqlIndex, p.sqlParameterName()));
}
return replaces.stream();
})
.sorted(Comparator.comparingInt(QueryReplace::start))
.toList();
int sqlIndexDiff = 0;
for (var parameter : replaceParams) {
int queryIndexAdjusted = parameter.start() - sqlIndexDiff;
int index = parameter.sqlIndex() + 1;
sql = sql.substring(0, queryIndexAdjusted) + "$" + index + sql.substring(queryIndexAdjusted + parameter.name().length() + 1);
sqlIndexDiff += (parameter.name().length() - String.valueOf(index).length());
}

var connectionParameter = parameters.stream().filter(QueryParameter.ConnectionParameter.class::isInstance).findFirst().orElse(null);

var b = DbUtils.queryMethodBuilder(method, methodType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ru.tinkoff.kora.database.annotation.processor.RepositoryGenerator;
import ru.tinkoff.kora.database.annotation.processor.model.QueryParameter;
import ru.tinkoff.kora.database.annotation.processor.model.QueryParameterParser;
import ru.tinkoff.kora.database.annotation.processor.r2dbc.R2dbcRepositoryGenerator;

import javax.annotation.processing.Filer;
import javax.annotation.processing.ProcessingEnvironment;
Expand Down Expand Up @@ -69,18 +70,26 @@ public TypeSpec generate(TypeElement repositoryElement, TypeSpec.Builder type, M
return type.addMethod(constructor.build()).build();
}

private record QueryReplace(int start, int end, int paramIndex, String name) {}

private MethodSpec generate(TypeSpec.Builder type, int methodNumber, ExecutableElement method, ExecutableType methodType, QueryWithParameters query, List<QueryParameter> parameters, @Nullable String resultMapperName, FieldFactory parameterMappers) {
var sql = query.rawQuery();
{
var params = new ArrayList<Map.Entry<QueryWithParameters.QueryParameter, Integer>>(query.parameters().size());
for (int i = 0; i < query.parameters().size(); i++) {
var parameter = query.parameters().get(i);
params.add(Map.entry(parameter, i));
}
for (var parameter : params.stream().sorted(Comparator.<Map.Entry<QueryWithParameters.QueryParameter, Integer>, Integer>comparing(p -> p.getKey().sqlParameterName().length()).reversed()).toList()) {
sql = sql.replace(":" + parameter.getKey().sqlParameterName(), "$" + (parameter.getValue() + 1));

List<QueryReplace> replaceParams = new ArrayList<>();
for (int i = 0; i < query.parameters().size(); i++) {
var parameter = query.parameters().get(i);
for (var queryIndex : parameter.queryIndexes()) {
replaceParams.add(new QueryReplace(queryIndex.start(), queryIndex.end(), i + 1, parameter.sqlParameterName()));
}
}
replaceParams.sort(Comparator.comparingInt(QueryReplace::start));
int sqlIndexDiff = 0;
for (var parameter : replaceParams) {
int queryIndexAdjusted = parameter.start() - sqlIndexDiff;
sql = sql.substring(0, queryIndexAdjusted) + "$" + parameter.paramIndex() + sql.substring(queryIndexAdjusted + parameter.name().length() + 1);
sqlIndexDiff += (parameter.name().length() - String.valueOf(parameter.paramIndex()).length());
}

var b = DbUtils.queryMethodBuilder(method, methodType);

var queryContextFieldName = "QUERY_CONTEXT_" + methodNumber;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,4 +470,34 @@ public record TestRecord(String value){}
assertThat(tag).isNotNull();
assertThat(tag.value()).isEqualTo(new Class<?>[]{compileResult.loadClass("TestRepository")});
}

@Test
public void testSamePrefixParameterNameMapping() {
var repository = compileCassandra(List.of(), """
@Repository
public interface TestRepository extends CassandraRepository {
@Query("SELECT * FROM test WHERE user_status = 'CREATED'::status_type AND status = :status")
void test(String status);
}
""");

repository.invoke("test", "someStatus");

verify(executor.mockSession).prepare("SELECT * FROM test WHERE user_status = 'CREATED'::status_type AND status = ?");
}

@Test
public void testSamePrefixMultiParameterNameMapping() {
var repository = compileCassandra(List.of(), """
@Repository
public interface TestRepository extends CassandraRepository {
@Query("SELECT * FROM test WHERE some_status = :status AND user_status = 'CREATED'::status_type AND diff_status = :statusDiff AND other_status = :status AND status = :status")
void test(String status, String statusDiff);
}
""");

repository.invoke("test", "someStatus", "otherStatus");

verify(executor.mockSession).prepare("SELECT * FROM test WHERE some_status = ? AND user_status = 'CREATED'::status_type AND diff_status = ? AND other_status = ? AND status = ?");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import ru.tinkoff.kora.database.jdbc.JdbcConnectionFactory;
import ru.tinkoff.kora.database.jdbc.JdbcDatabaseConfig;
import ru.tinkoff.kora.database.jdbc.mapper.parameter.JdbcParameterColumnMapper;
import ru.tinkoff.kora.database.jdbc.mapper.result.JdbcResultSetMapper;
import ru.tinkoff.kora.json.common.annotation.Json;

import java.math.BigDecimal;
Expand Down Expand Up @@ -162,7 +163,7 @@ public TestRepository(JdbcDatabaseConfig config) {
}

@Test
public void testNativeParameter() throws SQLException {
public void testNativeParameter() {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {
Expand Down Expand Up @@ -230,7 +231,7 @@ public interface TestRepository extends JdbcRepository {
}

@Test
public void testEntityFieldMappingByTag() throws SQLException, ClassNotFoundException {
public void testEntityFieldMappingByTag() throws SQLException {
var mapper = Mockito.mock(JdbcParameterColumnMapper.class);
var repository = compileJdbc(List.of(mapper), """
public record SomeEntity(long id, @ru.tinkoff.kora.json.common.annotation.Json String value) {}
Expand Down Expand Up @@ -341,7 +342,7 @@ public interface TestRepository extends JdbcRepository {
}

@Test
public void testMultipleParametersWithSameMapper() throws SQLException {
public void testMultipleParametersWithSameMapper() {
var repository = compileJdbc(List.of(newGeneratedObject("TestMapper")), """
public class TestMapper implements JdbcParameterColumnMapper<String> {
@Override
Expand All @@ -361,7 +362,7 @@ public interface TestRepository extends JdbcRepository {
}

@Test
public void testMultipleParameterFieldsWithSameMapper() throws SQLException {
public void testMultipleParameterFieldsWithSameMapper() {
var repository = compileJdbc(List.of(newGeneratedObject("TestMapper")), """
public class TestMapper implements JdbcParameterColumnMapper<TestRecord> {
@Override
Expand All @@ -385,7 +386,7 @@ public record TestRecord(@Mapping(TestMapper.class) TestRecord f1, @Mapping(Test
}

@Test
public void testParameterMappingByTag() throws ClassNotFoundException, SQLException {
public void testParameterMappingByTag() throws SQLException {
var mapper = Mockito.mock(JdbcParameterColumnMapper.class);
var repository = compileJdbc(List.of(mapper), """
@Repository
Expand All @@ -407,7 +408,7 @@ public interface TestRepository extends JdbcRepository {
}

@Test
public void testRecordParameterMapping() throws ClassNotFoundException, SQLException {
public void testRecordParameterMapping() {
var mapper = Mockito.mock(JdbcParameterColumnMapper.class);
var repository = compileJdbc(List.of(mapper), """
@Repository
Expand All @@ -426,4 +427,35 @@ public record TestRecord(String value){}
assertThat(tag.value()).isEqualTo(new Class<?>[]{compileResult.loadClass("TestRepository")});
}

@Test
public void testSamePrefixParameterNameMapping() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {
@Query("SELECT * FROM test WHERE user_status = 'CREATED'::status_type AND status = :status")
void test(String status);
}
""");

repository.invoke("test", "someStatus");

verify(executor.mockConnection).prepareStatement("SELECT * FROM test WHERE user_status = 'CREATED'::status_type AND status = ?");
verify(executor.preparedStatement).execute();
}

@Test
public void testSamePrefixMultiParameterNameMapping() throws SQLException {
var repository = compileJdbc(List.of(), """
@Repository
public interface TestRepository extends JdbcRepository {
@Query("SELECT * FROM test WHERE some_status = :status AND user_status = 'CREATED'::status_type AND diff_status = :statusDiff AND other_status = :status AND status = :status")
void test(String status, String statusDiff);
}
""");

repository.invoke("test", "someStatus", "otherStatus");

verify(executor.mockConnection).prepareStatement("SELECT * FROM test WHERE some_status = ? AND user_status = 'CREATED'::status_type AND diff_status = ? AND other_status = ? AND status = ?");
verify(executor.preparedStatement).execute();
}
}
Loading