62
62
* @author Chirag Tailor
63
63
* @author Diego Krupitza
64
64
* @author Sergey Korotaev
65
+ * @author Mikhail Polivakha
65
66
* @since 1.1
66
67
*/
67
68
public class DefaultDataAccessStrategy implements DataAccessStrategy {
@@ -73,6 +74,8 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
73
74
private final SqlParametersFactory sqlParametersFactory ;
74
75
private final InsertStrategyFactory insertStrategyFactory ;
75
76
77
+ private final QueryMappingConfiguration queryMappingConfiguration ;
78
+
76
79
/**
77
80
* Creates a {@link DefaultDataAccessStrategy}
78
81
*
@@ -84,21 +87,23 @@ public class DefaultDataAccessStrategy implements DataAccessStrategy {
84
87
*/
85
88
public DefaultDataAccessStrategy (SqlGeneratorSource sqlGeneratorSource , RelationalMappingContext context ,
86
89
JdbcConverter converter , NamedParameterJdbcOperations operations , SqlParametersFactory sqlParametersFactory ,
87
- InsertStrategyFactory insertStrategyFactory ) {
90
+ InsertStrategyFactory insertStrategyFactory , QueryMappingConfiguration queryMappingConfiguration ) {
88
91
89
92
Assert .notNull (sqlGeneratorSource , "SqlGeneratorSource must not be null" );
90
93
Assert .notNull (context , "RelationalMappingContext must not be null" );
91
94
Assert .notNull (converter , "JdbcConverter must not be null" );
92
95
Assert .notNull (operations , "NamedParameterJdbcOperations must not be null" );
93
96
Assert .notNull (sqlParametersFactory , "SqlParametersFactory must not be null" );
94
97
Assert .notNull (insertStrategyFactory , "InsertStrategyFactory must not be null" );
98
+ Assert .notNull (queryMappingConfiguration , "InsertStrategyFactory must not be null" );
95
99
96
100
this .sqlGeneratorSource = sqlGeneratorSource ;
97
101
this .context = context ;
98
102
this .converter = converter ;
99
103
this .operations = operations ;
100
104
this .sqlParametersFactory = sqlParametersFactory ;
101
105
this .insertStrategyFactory = insertStrategyFactory ;
106
+ this .queryMappingConfiguration = queryMappingConfiguration ;
102
107
}
103
108
104
109
@ Override
@@ -272,15 +277,15 @@ public <T> T findById(Object id, Class<T> domainType) {
272
277
SqlIdentifierParameterSource parameter = sqlParametersFactory .forQueryById (id , domainType , ID_SQL_PARAMETER );
273
278
274
279
try {
275
- return operations .queryForObject (findOneSql , parameter , getEntityRowMapper (domainType ));
280
+ return operations .queryForObject (findOneSql , parameter , getRowMapper (domainType ));
276
281
} catch (EmptyResultDataAccessException e ) {
277
282
return null ;
278
283
}
279
284
}
280
285
281
286
@ Override
282
287
public <T > List <T > findAll (Class <T > domainType ) {
283
- return operations .query (sql (domainType ).getFindAll (), getEntityRowMapper (domainType ));
288
+ return operations .query (sql (domainType ).getFindAll (), getRowMapper (domainType ));
284
289
}
285
290
286
291
@ Override
@@ -298,7 +303,7 @@ public <T> List<T> findAllById(Iterable<?> ids, Class<T> domainType) {
298
303
299
304
SqlParameterSource parameterSource = sqlParametersFactory .forQueryByIds (ids , domainType );
300
305
String findAllInListSql = sql (domainType ).getFindAllInList ();
301
- return operations .query (findAllInListSql , parameterSource , getEntityRowMapper (domainType ));
306
+ return operations .query (findAllInListSql , parameterSource , getRowMapper (domainType ));
302
307
}
303
308
304
309
@ Override
@@ -365,7 +370,7 @@ public <T> boolean existsById(Object id, Class<T> domainType) {
365
370
366
371
@ Override
367
372
public <T > List <T > findAll (Class <T > domainType , Sort sort ) {
368
- return operations .query (sql (domainType ).getFindAll (sort ), getEntityRowMapper (domainType ));
373
+ return operations .query (sql (domainType ).getFindAll (sort ), getRowMapper (domainType ));
369
374
}
370
375
371
376
@ Override
@@ -376,7 +381,7 @@ public <T> Stream<T> streamAll(Class<T> domainType, Sort sort) {
376
381
377
382
@ Override
378
383
public <T > List <T > findAll (Class <T > domainType , Pageable pageable ) {
379
- return operations .query (sql (domainType ).getFindAll (pageable ), getEntityRowMapper (domainType ));
384
+ return operations .query (sql (domainType ).getFindAll (pageable ), getRowMapper (domainType ));
380
385
}
381
386
382
387
@ Override
@@ -386,7 +391,7 @@ public <T> Optional<T> findOne(Query query, Class<T> domainType) {
386
391
String sqlQuery = sql (domainType ).selectByQuery (query , parameterSource );
387
392
388
393
try {
389
- return Optional .ofNullable (operations .queryForObject (sqlQuery , parameterSource , getEntityRowMapper (domainType )));
394
+ return Optional .ofNullable (operations .queryForObject (sqlQuery , parameterSource , getRowMapper (domainType )));
390
395
} catch (EmptyResultDataAccessException e ) {
391
396
return Optional .empty ();
392
397
}
@@ -398,7 +403,7 @@ public <T> List<T> findAll(Query query, Class<T> domainType) {
398
403
MapSqlParameterSource parameterSource = new MapSqlParameterSource ();
399
404
String sqlQuery = sql (domainType ).selectByQuery (query , parameterSource );
400
405
401
- return operations .query (sqlQuery , parameterSource , getEntityRowMapper (domainType ));
406
+ return operations .query (sqlQuery , parameterSource , getRowMapper (domainType ));
402
407
}
403
408
404
409
@ Override
@@ -416,7 +421,7 @@ public <T> List<T> findAll(Query query, Class<T> domainType, Pageable pageable)
416
421
MapSqlParameterSource parameterSource = new MapSqlParameterSource ();
417
422
String sqlQuery = sql (domainType ).selectByQuery (query , parameterSource , pageable );
418
423
419
- return operations .query (sqlQuery , parameterSource , getEntityRowMapper (domainType ));
424
+ return operations .query (sqlQuery , parameterSource , getRowMapper (domainType ));
420
425
}
421
426
422
427
@ Override
@@ -445,7 +450,13 @@ public <T> long count(Query query, Class<T> domainType) {
445
450
return result ;
446
451
}
447
452
448
- private <T > EntityRowMapper <T > getEntityRowMapper (Class <T > domainType ) {
453
+ private <T > RowMapper <? extends T > getRowMapper (Class <T > domainType ) {
454
+ RowMapper <? extends T > targetRowMapper ;
455
+
456
+ if ((targetRowMapper = queryMappingConfiguration .getRowMapper (domainType )) != null ) {
457
+ return targetRowMapper ;
458
+ }
459
+
449
460
return new EntityRowMapper <>(getRequiredPersistentEntity (domainType ), converter );
450
461
}
451
462
0 commit comments