27
27
import org .springframework .dao .OptimisticLockingFailureException ;
28
28
import org .springframework .data .domain .Pageable ;
29
29
import org .springframework .data .domain .Sort ;
30
+ import org .springframework .data .jdbc .repository .QueryMappingConfiguration ;
30
31
import org .springframework .data .mapping .PersistentPropertyPath ;
31
32
import org .springframework .data .relational .core .conversion .IdValueSource ;
32
33
import org .springframework .data .relational .core .mapping .AggregatePath ;
60
61
* @author Radim Tlusty
61
62
* @author Chirag Tailor
62
63
* @author Diego Krupitza
64
+ * @author Mikhail Polivakha
63
65
* @since 1.1
64
66
*/
65
67
public class DefaultDataAccessStrategy implements DataAccessStrategy {
@@ -71,6 +73,8 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
71
73
private final SqlParametersFactory sqlParametersFactory ;
72
74
private final InsertStrategyFactory insertStrategyFactory ;
73
75
76
+ private final QueryMappingConfiguration queryMappingConfiguration ;
77
+
74
78
/**
75
79
* Creates a {@link DefaultDataAccessStrategy}
76
80
*
@@ -82,21 +86,23 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
82
86
*/
83
87
public DefaultDataAccessStrategy (SqlGeneratorSource sqlGeneratorSource , RelationalMappingContext context ,
84
88
JdbcConverter converter , NamedParameterJdbcOperations operations , SqlParametersFactory sqlParametersFactory ,
85
- InsertStrategyFactory insertStrategyFactory ) {
89
+ InsertStrategyFactory insertStrategyFactory , QueryMappingConfiguration queryMappingConfiguration ) {
86
90
87
91
Assert .notNull (sqlGeneratorSource , "SqlGeneratorSource must not be null" );
88
92
Assert .notNull (context , "RelationalMappingContext must not be null" );
89
93
Assert .notNull (converter , "JdbcConverter must not be null" );
90
94
Assert .notNull (operations , "NamedParameterJdbcOperations must not be null" );
91
95
Assert .notNull (sqlParametersFactory , "SqlParametersFactory must not be null" );
92
96
Assert .notNull (insertStrategyFactory , "InsertStrategyFactory must not be null" );
97
+ Assert .notNull (queryMappingConfiguration , "InsertStrategyFactory must not be null" );
93
98
94
99
this .sqlGeneratorSource = sqlGeneratorSource ;
95
100
this .context = context ;
96
101
this .converter = converter ;
97
102
this .operations = operations ;
98
103
this .sqlParametersFactory = sqlParametersFactory ;
99
104
this .insertStrategyFactory = insertStrategyFactory ;
105
+ this .queryMappingConfiguration = queryMappingConfiguration ;
100
106
}
101
107
102
108
@ Override
@@ -265,15 +271,15 @@ public <T> T findById(Object id, Class<T> domainType) {
265
271
SqlIdentifierParameterSource parameter = sqlParametersFactory .forQueryById (id , domainType , ID_SQL_PARAMETER );
266
272
267
273
try {
268
- return operations .queryForObject (findOneSql , parameter , getEntityRowMapper (domainType ));
274
+ return operations .queryForObject (findOneSql , parameter , getRowMapper (domainType ));
269
275
} catch (EmptyResultDataAccessException e ) {
270
276
return null ;
271
277
}
272
278
}
273
279
274
280
@ Override
275
281
public <T > Iterable <T > findAll (Class <T > domainType ) {
276
- return operations .query (sql (domainType ).getFindAll (), getEntityRowMapper (domainType ));
282
+ return operations .query (sql (domainType ).getFindAll (), getRowMapper (domainType ));
277
283
}
278
284
279
285
@ Override
@@ -285,7 +291,7 @@ public <T> Iterable<T> findAllById(Iterable<?> ids, Class<T> domainType) {
285
291
286
292
SqlParameterSource parameterSource = sqlParametersFactory .forQueryByIds (ids , domainType );
287
293
String findAllInListSql = sql (domainType ).getFindAllInList ();
288
- return operations .query (findAllInListSql , parameterSource , getEntityRowMapper (domainType ));
294
+ return operations .query (findAllInListSql , parameterSource , getRowMapper (domainType ));
289
295
}
290
296
291
297
@ Override
@@ -339,12 +345,12 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
339
345
340
346
@ Override
341
347
public <T > Iterable <T > findAll (Class <T > domainType , Sort sort ) {
342
- return operations .query (sql (domainType ).getFindAll (sort ), getEntityRowMapper (domainType ));
348
+ return operations .query (sql (domainType ).getFindAll (sort ), getRowMapper (domainType ));
343
349
}
344
350
345
351
@ Override
346
352
public <T > Iterable <T > findAll (Class <T > domainType , Pageable pageable ) {
347
- return operations .query (sql (domainType ).getFindAll (pageable ), getEntityRowMapper (domainType ));
353
+ return operations .query (sql (domainType ).getFindAll (pageable ), getRowMapper (domainType ));
348
354
}
349
355
350
356
@ Override
@@ -354,7 +360,7 @@ public <T> Optional<T> findOne(Query query, Class<T> domainType) {
354
360
String sqlQuery = sql (domainType ).selectByQuery (query , parameterSource );
355
361
356
362
try {
357
- return Optional .ofNullable (operations .queryForObject (sqlQuery , parameterSource , getEntityRowMapper (domainType )));
363
+ return Optional .ofNullable (operations .queryForObject (sqlQuery , parameterSource , getRowMapper (domainType )));
358
364
} catch (EmptyResultDataAccessException e ) {
359
365
return Optional .empty ();
360
366
}
@@ -366,7 +372,7 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType) {
366
372
MapSqlParameterSource parameterSource = new MapSqlParameterSource ();
367
373
String sqlQuery = sql (domainType ).selectByQuery (query , parameterSource );
368
374
369
- return operations .query (sqlQuery , parameterSource , getEntityRowMapper (domainType ));
375
+ return operations .query (sqlQuery , parameterSource , getRowMapper (domainType ));
370
376
}
371
377
372
378
@ Override
@@ -375,7 +381,7 @@ public <T> Iterable<T> findAll(Query query, Class<T> domainType, Pageable pageab
375
381
MapSqlParameterSource parameterSource = new MapSqlParameterSource ();
376
382
String sqlQuery = sql (domainType ).selectByQuery (query , parameterSource , pageable );
377
383
378
- return operations .query (sqlQuery , parameterSource , getEntityRowMapper (domainType ));
384
+ return operations .query (sqlQuery , parameterSource , getRowMapper (domainType ));
379
385
}
380
386
381
387
@ Override
@@ -404,7 +410,13 @@ public <T> long count(Query query, Class<T> domainType) {
404
410
return result ;
405
411
}
406
412
407
- private <T > EntityRowMapper <T > getEntityRowMapper (Class <T > domainType ) {
413
+ private <T > RowMapper <? extends T > getRowMapper (Class <T > domainType ) {
414
+ RowMapper <? extends T > targetRowMapper ;
415
+
416
+ if ((targetRowMapper = queryMappingConfiguration .getRowMapper (domainType )) != null ) {
417
+ return targetRowMapper ;
418
+ }
419
+
408
420
return new EntityRowMapper <>(getRequiredPersistentEntity (domainType ), converter );
409
421
}
410
422
0 commit comments