16
16
17
17
package org .springframework .batch .integration .partition ;
18
18
19
- import java .util .Arrays ;
20
19
import java .util .Collection ;
21
20
import java .util .Collections ;
22
21
import java .util .HashSet ;
22
+ import java .util .Set ;
23
23
import java .util .concurrent .TimeoutException ;
24
+ import java .util .stream .Collectors ;
24
25
25
26
import org .junit .jupiter .api .Test ;
26
27
@@ -175,12 +176,11 @@ void testHandleWithJobRepositoryPolling() throws Exception {
175
176
stepExecutions .add (partition2 );
176
177
stepExecutions .add (partition3 );
177
178
when (stepExecutionSplitter .split (any (StepExecution .class ), eq (1 ))).thenReturn (stepExecutions );
178
- JobExecution runningJobExecution = new JobExecution (5L , new JobParameters ());
179
- runningJobExecution .addStepExecutions (Arrays .asList (partition2 , partition1 , partition3 ));
180
- JobExecution completedJobExecution = new JobExecution (5L , new JobParameters ());
181
- completedJobExecution .addStepExecutions (Arrays .asList (partition2 , partition1 , partition4 ));
182
- when (jobExplorer .getJobExecution (5L )).thenReturn (runningJobExecution , runningJobExecution , runningJobExecution ,
183
- completedJobExecution );
179
+ Set <Long > stepExecutionIds = stepExecutions .stream ().map (StepExecution ::getId ).collect (Collectors .toSet ());
180
+ when (jobExplorer .getStepExecutionCount (stepExecutionIds , BatchStatus .RUNNING_STATUSES )).thenReturn (3L , 2L , 1L ,
181
+ 0L );
182
+ Set <StepExecution > completedStepExecutions = Set .of (partition2 , partition1 , partition4 );
183
+ when (jobExplorer .getStepExecutions (jobExecution .getId (), stepExecutionIds )).thenReturn (completedStepExecutions );
184
184
185
185
// set
186
186
messageChannelPartitionHandler .setMessagingOperations (operations );
@@ -200,6 +200,8 @@ void testHandleWithJobRepositoryPolling() throws Exception {
200
200
assertTrue (executions .contains (partition4 ));
201
201
202
202
// verify
203
+ verify (jobExplorer , times (4 )).getStepExecutionCount (stepExecutionIds , BatchStatus .RUNNING_STATUSES );
204
+ verify (jobExplorer , times (1 )).getStepExecutions (jobExecution .getId (), stepExecutionIds );
203
205
verify (operations , times (3 )).send (any (Message .class ));
204
206
}
205
207
@@ -225,9 +227,8 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception {
225
227
stepExecutions .add (partition2 );
226
228
stepExecutions .add (partition3 );
227
229
when (stepExecutionSplitter .split (any (StepExecution .class ), eq (1 ))).thenReturn (stepExecutions );
228
- JobExecution runningJobExecution = new JobExecution (5L , new JobParameters ());
229
- runningJobExecution .addStepExecutions (Arrays .asList (partition2 , partition1 , partition3 ));
230
- when (jobExplorer .getJobExecution (5L )).thenReturn (runningJobExecution );
230
+ Set <Long > stepExecutionIds = stepExecutions .stream ().map (StepExecution ::getId ).collect (Collectors .toSet ());
231
+ when (jobExplorer .getStepExecutionCount (stepExecutionIds , BatchStatus .RUNNING_STATUSES )).thenReturn (1L );
231
232
232
233
// set
233
234
messageChannelPartitionHandler .setMessagingOperations (operations );
0 commit comments