diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java index 531e715a7b..7d98140cc3 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java @@ -49,6 +49,9 @@ public class TaskMetric implements StateMetric { private long shuffleReadTime = -1; private long shuffleWriteBytes = -1; private long shuffleWriteTime = -1; + private int currentIteratorIndex = -1; + private int totalIteratorNumber = -1; + private long taskPreparationTime = -1; private static final Logger LOG = LoggerFactory.getLogger(TaskMetric.class.getName()); @@ -252,6 +255,30 @@ private void setShuffleWriteTime(final long shuffleWriteTime) { this.shuffleWriteTime = shuffleWriteTime; } + public final int getCurrentIteratorIndex() { + return this.currentIteratorIndex; + } + + private void setCurrentIteratorIndex(final int currentIteratorIndex) { + this.currentIteratorIndex = currentIteratorIndex; + } + + public final int getTotalIteratorNumber() { + return this.totalIteratorNumber; + } + + private void setTotalIteratorNumber(final int totalIteratorNumber) { + this.totalIteratorNumber = totalIteratorNumber; + } + + public final long getTaskPreparationTime() { + return this.taskPreparationTime; + } + + private void setTaskPreparationTime(final long taskPreparationTime) { + this.taskPreparationTime = taskPreparationTime; + } + @Override public final String getId() { return id; @@ -317,6 +344,14 @@ public final boolean processMetricMessage(final String metricField, final byte[] case "shuffleWriteTime": setShuffleWriteTime(SerializationUtils.deserialize(metricValue)); break; + case "currentIteratorIndex": + setCurrentIteratorIndex(SerializationUtils.deserialize(metricValue)); + break; + case "totalIteratorNumber": + setTotalIteratorNumber(SerializationUtils.deserialize(metricValue)); + break; + case "taskPreparationTime": + setTaskPreparationTime(SerializationUtils.deserialize(metricValue)); default: LOG.warn("metricField {} is not supported.", metricField); return false; diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 97e30fb4e7..53adea3d3a 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -86,6 +86,8 @@ enum MessageType { PipeInit = 13; RequestPipeLoc = 14; PipeLocInfo = 15; + ParentTaskDataCollected = 16; + CurrentlyProcessedBytesCollected = 17; } message Message { @@ -107,6 +109,8 @@ message Message { optional PipeInitMessage pipeInitMsg = 16; optional RequestPipeLocationMessage requestPipeLocMsg = 17; optional PipeLocationInfoMessage pipeLocInfoMsg = 18; + optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; + optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20; } // Messages from Master to Executors @@ -256,3 +260,13 @@ message PipeLocationInfoMessage { required int64 requestId = 1; // To find the matching request msg required string executorId = 2; } + +message ParentTaskDataCollectMsg { + required string taskId = 1; + required bytes partitionSizeMap = 2; +} + +message CurrentlyProcessedBytesCollectMsg { + required string taskId = 1; + required int64 processedDataBytes = 2; +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java index 97cde037c6..b9cf9ef129 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter { private long writtenBytes; + private Optional> partitionSizeMap; + /** * Constructor. * @@ -109,7 +111,7 @@ public void close() { final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge .getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new); - final Optional> partitionSizeMap = blockToWrite.commit(); + partitionSizeMap = blockToWrite.commit(); // Return the total size of the committed block. if (partitionSizeMap.isPresent()) { long blockSizeTotal = 0; @@ -123,6 +125,16 @@ public void close() { blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence); } + @Override + public Optional> getPartitionSizeMap() { + if (partitionSizeMap.isPresent()) { + return partitionSizeMap; + } else { + return Optional.empty(); + } + } + + @Override public Optional getWrittenBytes() { if (writtenBytes == -1) { return Optional.empty(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java index bf6ff84e69..a1862f5f2d 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.punctuation.Watermark; +import java.util.Map; import java.util.Optional; /** @@ -45,5 +46,10 @@ public interface OutputWriter { */ Optional getWrittenBytes(); + /** + * @return the map of hashed key to partition size. + */ + Optional> getPartitionSizeMap(); + void close(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java index 544d64d921..d0025428aa 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -113,6 +114,11 @@ public Optional getWrittenBytes() { return Optional.empty(); } + @Override + public Optional> getPartitionSizeMap() { + return Optional.empty(); + } + @Override public void close() { if (!initialized) { diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java index 7af08852eb..b1a828c13c 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.io.IOException; @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable { */ abstract Object fetchDataElement() throws IOException; + /** + * Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore + * on every iterator advance. + * This method is for WorkStealing implementation in Nemo. + * + * @param taskId task id + * @param metricMessageSender metricMessageSender + * + * @return data element + * @throws IOException upon I/O error + * @throws java.util.NoSuchElementException if no more element is available + */ + abstract Object fetchDataElementWithTrace(String taskId, + MetricMessageSender metricMessageSender) throws IOException; + OutputCollector getOutputCollector() { return outputCollector; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index 797818ce44..d7947e8c78 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -22,6 +22,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.*; import org.slf4j.Logger; @@ -100,6 +101,12 @@ Object fetchDataElement() throws IOException { } } + @Override + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { + return fetchDataElement(); + } + private void fetchDataLazily() { final List> futures = readersForParentTask.read(); numOfIterators = futures.size(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index a8ae4a9306..3a92cbc8a9 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -18,10 +18,12 @@ */ package org.apache.nemo.runtime.executor.task; +import org.apache.commons.lang3.SerializationUtils; import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.InputReader; import org.slf4j.Logger; @@ -100,6 +102,49 @@ Object fetchDataElement() throws IOException { return Finishmark.getInstance(); } + @Override + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { + try { + if (firstFetch) { + fetchDataLazily(); + advanceIterator(); + firstFetch = false; + } + + while (true) { + // This iterator has the element + if (this.currentIterator.hasNext()) { + return this.currentIterator.next(); + } + + // This iterator does not have the element + if (currentIteratorIndex < expectedNumOfIterators) { + // Next iterator has the element + countBytes(currentIterator); + // Send the cumulative serBytes to MetricStore + metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes", + SerializationUtils.serialize(serBytes)); + advanceIterator(); + continue; + } else { + // We've consumed all the iterators + break; + } + + } + } catch (final Throwable e) { + // Any failure is caught and thrown as an IOException, so that the task is retried. + // In particular, we catch unchecked exceptions like RuntimeException thrown by DataUtil.IteratorWithNumBytes + // when remote data fetching fails for whatever reason. + // Note that we rely on unchecked exceptions because the Iterator interface does not provide the standard + // "throw Exception" that the TaskExecutor thread can catch and handle. + throw new IOException(e); + } + + return Finishmark.getInstance(); + } + private void advanceIterator() throws IOException { // Take from iteratorQueue final Object iteratorOrThrowable; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 2d82898d7a..68a3362d27 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -23,6 +23,7 @@ import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -74,6 +75,11 @@ Object fetchDataElement() { } } + @Override + Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) { + return fetchDataElement(); + } + final long getBoundedSourceReadTime() { return boundedSourceReadTime; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 2bf574d396..91e8212640 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -19,6 +19,7 @@ package org.apache.nemo.runtime.executor.task; import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; import org.apache.commons.lang3.SerializationUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.nemo.common.Pair; @@ -458,7 +459,7 @@ private boolean handleDataFetchers(final List fetchers) { while (availableIterator.hasNext()) { final DataFetcher dataFetcher = availableIterator.next(); try { - final Object element = dataFetcher.fetchDataElement(); + final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender); onEventFromDataFetcher(element, dataFetcher); if (element instanceof Finishmark) { availableIterator.remove(); @@ -688,12 +689,21 @@ public void setIRVertexPutOnHold(final IRVertex irVertex) { */ private void finalizeOutputWriters(final VertexHarness vertexHarness) { final List writtenBytesList = new ArrayList<>(); + final HashMap partitionSizeMap = new HashMap<>(); // finalize OutputWriters for main children vertexHarness.getWritersToMainChildrenTasks().forEach(outputWriter -> { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }); // finalize OutputWriters for additional tagged children @@ -702,6 +712,14 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }) ); @@ -713,5 +731,57 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { // TODO #236: Decouple metric collection and sending logic metricMessageSender.send(TASK_METRIC_ID, taskId, "taskOutputBytes", SerializationUtils.serialize(totalWrittenBytes)); + + if (!partitionSizeMap.isEmpty()) { + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.ParentTaskDataCollected) + .setParentTaskDataCollected(ControlMessage.ParentTaskDataCollectMsg.newBuilder() + .setTaskId(taskId) + .setPartitionSizeMap(ByteString.copyFrom(SerializationUtils.serialize(partitionSizeMap))) + .build()) + .build()); + } + } + + // Methods for work stealing + /** + * Gather the KV statistics of processed data when execution is completed. + * This method is for work stealing implementation: the accumulated statistics will be used to + * detect skewed tasks of the child stage. + * + * @param totalPartitionSizeMap accumulated partitionSizeMap of task. + * @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter. + */ + private void computePartitionSizeMap(final Map totalPartitionSizeMap, + final Map singlePartitionSizeMap) { + for (Integer hashedKey : singlePartitionSizeMap.keySet()) { + final Long partitionSize = singlePartitionSizeMap.get(hashedKey); + if (totalPartitionSizeMap.containsKey(hashedKey)) { + totalPartitionSizeMap.compute(hashedKey, (existingKey, existingValue) -> existingValue + partitionSize); + } else { + totalPartitionSizeMap.put(hashedKey, partitionSize); + } + } + } + + /** + * Send the temporally processed bytes of the current task on request from the scheduler. + * This method is for work stealing implementation. + */ + public void onRequestForProcessedData() { + LOG.error("{}, bytes {}, replying for the request", taskId, serializedReadBytes); + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.CurrentlyProcessedBytesCollected) + .setCurrentlyProcessedBytesCollected(ControlMessage.CurrentlyProcessedBytesCollectMsg.newBuilder() + .setTaskId(this.taskId) + .setProcessedDataBytes(serializedReadBytes) + .build()) + .build()); } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java index 53cab57810..65b5306f5b 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java @@ -85,6 +85,10 @@ public final class PlanStateManager { private final Map> stageIdToCompletedTaskTimeMsList = new HashMap<>(); private final Map> stageIdToTaskIndexToNumOfClones = new HashMap<>(); + /** + * Used for work stealing. + */ + private final Map>> stageIdToTaskIdxToWSAttemptStates = new HashMap<>(); /** * Represents the plan to manage. */ @@ -127,7 +131,7 @@ public static PlanStateManager newInstance(final String dagDirectory) { } /** - * @param metricStore set the metric store of the paln state manager. + * @param metricStore set the metric store of the plan state manager. */ public void setMetricStore(final MetricStore metricStore) { this.metricStore = metricStore; @@ -326,16 +330,8 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // Log not-yet-completed tasks for us humans to track progress final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); - final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.values().stream() - .filter(attempts -> { - final List states = attempts - .stream() - .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) - .collect(Collectors.toList()); - return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) // one of them is ON_HOLD - || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); // one of them is COMPLETE - }) - .count(); + final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + if (newTaskState.equals(TaskState.State.COMPLETE)) { LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage", taskId, taskStatesOfThisStage.size() - numOfCompletedTaskIndicesInThisStage, taskStatesOfThisStage.size()); @@ -577,6 +573,59 @@ private List getPeerAttemptsForTheSameTaskIndex(final String ta .collect(Collectors.toList()); } + /** + * Get number of remaining tasks of the stage. + * + * @param stageId stage id. + * @return number of remaining tasks. + */ + public int getNumberOfTasksRemainingInStage(final String stageId) { + final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); + final Map> wsTaskStatesOfThisStage = stageIdToTaskIdxToWSAttemptStates + .getOrDefault(stageId, new HashMap<>()); + final long numOfCompletedTaskIndices = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + if (wsTaskStatesOfThisStage.isEmpty()) { + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices); + } else { + final long numOfCompletedWorkStealingTaskIndices = getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices + + wsTaskStatesOfThisStage.size() - numOfCompletedWorkStealingTaskIndices); + } + } + + /** + * Get tasks which are currently being executed. + * + * @param stageId stage id. + * @return Set of tasksIds in execution. + */ + public Set getOngoingTaskIdsInStage(final String stageId) { + final Map> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); + final Set onGoingTaskIds = new HashSet<>(); + for (final int taskIndex : taskIdToState.keySet()) { + final List attemptStates = taskIdToState.get(taskIndex); + for (int attempt = 0; attempt < attemptStates.size(); attempt++) { + if (attemptStates.get(attempt).getStateMachine().getCurrentState().equals(TaskState.State.EXECUTING)) { + onGoingTaskIds.add(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt)); + } + } + } + return onGoingTaskIds; + } + + private long getNumberOfCompletedTasksInStage(final Map> taskIdxToState) { + return taskIdxToState.values().stream() + .filter(attempts -> { + final List states = attempts + .stream() + .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) + .collect(Collectors.toList()); + return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) + || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); + }) + .count(); + } + /** * @return the physical plan. */ diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..44eb0f7a78 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -58,10 +58,7 @@ import javax.inject.Inject; import java.io.Serializable; import java.nio.file.Paths; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; @@ -85,9 +82,11 @@ public final class RuntimeMaster { private static final int METRIC_ARRIVE_TIMEOUT = 10000; private static final int REST_SERVER_PORT = 10101; private static final int SPECULATION_CHECKING_PERIOD_MS = 100; + private static final int WORK_STEALING_CHECKING_PERIOD_MS = 100; private final ExecutorService runtimeMasterThread; private final ScheduledExecutorService speculativeTaskCloningThread; + private final ScheduledExecutorService workStealingThread; private final Scheduler scheduler; private final ContainerManager containerManager; @@ -160,6 +159,16 @@ private RuntimeMaster(final Scheduler scheduler, SPECULATION_CHECKING_PERIOD_MS, TimeUnit.MILLISECONDS); + // Check for work stealing every second + this.workStealingThread = Executors + .newSingleThreadScheduledExecutor(runnable -> new Thread(runnable, "WorkStealing master thread")); + this.workStealingThread.scheduleWithFixedDelay( + () -> this.runtimeMasterThread.submit(scheduler::onWorkStealingCheck), + WORK_STEALING_CHECKING_PERIOD_MS, + WORK_STEALING_CHECKING_PERIOD_MS, + TimeUnit.MILLISECONDS); + + this.scheduler = scheduler; this.containerManager = containerManager; this.executorRegistry = executorRegistry; @@ -481,6 +490,22 @@ private void handleControlMessage(final ControlMessage.Message message) { .setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(serializedData).build()) .build()); break; + case ParentTaskDataCollected: + if (scheduler instanceof BatchScheduler) { + final ControlMessage.ParentTaskDataCollectMsg workStealingMsg = message.getParentTaskDataCollected(); + final String taskId = workStealingMsg.getTaskId(); + final Map partitionSizeMap = SerializationUtils + .deserialize(workStealingMsg.getPartitionSizeMap().toByteArray()); + ((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap); + } + break; + case CurrentlyProcessedBytesCollected: + if (scheduler instanceof BatchScheduler) { + ((BatchScheduler) scheduler).aggregateTaskIdToProcessedBytes( + message.getCurrentlyProcessedBytesCollected().getTaskId(), + message.getCurrentlyProcessedBytesCollected().getProcessedDataBytes() + ); + } case MetricFlushed: metricCountDownLatch.countDown(); break; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java index 16c9a70db9..ebec804132 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java @@ -170,6 +170,11 @@ public void onTaskExecutionFailed(final String taskId) { failedTasks.add(failedTask); } + @Override + public boolean isExecutorSlotAvailable() { + return getExecutorCapacity() - getNumOfRunningTasks() > 0; + } + /** * @return how many Tasks can this executor simultaneously run */ diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java index 26649a81db..dcfb53eb1c 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java @@ -108,4 +108,9 @@ public interface ExecutorRepresenter { * @param taskId id of the Task */ void onTaskExecutionFailed(String taskId); + + /** + * @return true if this executor has an available slot. + */ + boolean isExecutorSlotAvailable(); } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 086c9d08bd..3caa4118ce 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -20,16 +20,19 @@ import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.nemo.common.Pair; +import org.apache.nemo.common.dag.Vertex; import org.apache.nemo.common.exception.UnknownExecutionStateException; import org.apache.nemo.common.exception.UnrecoverableFailureException; import org.apache.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty; import org.apache.nemo.runtime.common.RuntimeIdManager; +import org.apache.nemo.runtime.common.metric.TaskMetric; import org.apache.nemo.runtime.common.plan.*; import org.apache.nemo.runtime.common.state.StageState; import org.apache.nemo.runtime.common.state.TaskState; import org.apache.nemo.runtime.master.BlockManagerMaster; import org.apache.nemo.runtime.master.PlanAppender; import org.apache.nemo.runtime.master.PlanStateManager; +import org.apache.nemo.runtime.master.metric.MetricStore; import org.apache.nemo.runtime.master.resource.ExecutorRepresenter; import org.apache.reef.annotations.audience.DriverSide; import org.slf4j.Logger; @@ -66,6 +69,7 @@ public final class BatchScheduler implements Scheduler { private final PendingTaskCollectionPointer pendingTaskCollectionPointer; // A 'pointer' to the list of pending tasks. private final ExecutorRegistry executorRegistry; // A registry for executors available for the job. private final PlanStateManager planStateManager; // A component that manages the state of the plan. + private final MetricStore metricStore = MetricStore.getStore(); /** * Other necessary components of this {@link org.apache.nemo.runtime.master.RuntimeMaster}. @@ -77,6 +81,14 @@ public final class BatchScheduler implements Scheduler { */ private List> sortedScheduleGroups; // Stages, sorted in the order to be scheduled. + /** + * Data Structures for work stealing. + */ + private final Set workStealingCandidates = new HashSet<>(); + private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); + private final Map taskIdToProcessedBytes = new HashMap<>(); + private final Map stageIdToWorkStealingExecuted = new HashMap<>(); + @Inject private BatchScheduler(final PlanRewriter planRewriter, final TaskDispatcher taskDispatcher, @@ -111,6 +123,11 @@ public void updatePlan(final PhysicalPlan newPhysicalPlan) { private void updatePlan(final PhysicalPlan newPhysicalPlan, final int maxScheduleAttempt) { planStateManager.updatePlan(newPhysicalPlan, maxScheduleAttempt); + + for (Stage stage : planStateManager.getPhysicalPlan().getStageDAG().getVertices()) { + stageIdToWorkStealingExecuted.putIfAbsent(stage.getId(), false); + } + this.sortedScheduleGroups = newPhysicalPlan.getStageDAG().getVertices().stream() .collect(Collectors.groupingBy(Stage::getScheduleGroup)) .entrySet().stream() @@ -258,6 +275,25 @@ public void onSpeculativeExecutionCheck() { } } + @Override + public void onWorkStealingCheck() { + MutableBoolean isWorkStealingConditionSatisfied = new MutableBoolean(false); + List scheduleGroup = BatchSchedulerUtils + .selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager).orElse(new ArrayList<>()); + List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); + isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); + + if (!isWorkStealingConditionSatisfied.booleanValue()) { + return; + } + taskIdToProcessedBytes.clear(); + final List skewedTasks = detectSkew(scheduleGroupInId); + + // TODO #469 Split tasks using iterator interface. + + return; + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName()); @@ -304,6 +340,9 @@ public void terminate() { * - We make {@link TaskDispatcher} dispatch only the tasks that are READY. */ private void doSchedule() { + taskIdToProcessedBytes.clear(); + workStealingCandidates.clear(); + final Optional> earliest = BatchSchedulerUtils.selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager); @@ -383,4 +422,232 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, return false; } + + ///////////////////////////////////////////////////////////////// Methods for work stealing + + /** + * Accumulate the execution result of each stage in Map[STAGE ID, Map[KEY, SIZE]] format. + * KEY is assumed to be Integer because of the HashPartition. + * + * @param taskId id of task to accumulate. + * @param partitionSizeMap map of (K) - (partition size) of the task. + */ + public void aggregateStageIdToPartitionSizeMap(final String taskId, + final Map partitionSizeMap) { + final Map partitionSizeMapForThisStage = stageIdToOutputPartitionSizeMap + .getOrDefault(RuntimeIdManager.getStageIdFromTaskId(taskId), new HashMap<>()); + for (Integer hashedKey : partitionSizeMap.keySet()) { + final Long partitionSize = partitionSizeMap.get(hashedKey); + if (partitionSizeMapForThisStage.containsKey(hashedKey)) { + partitionSizeMapForThisStage.put(hashedKey, partitionSize + partitionSizeMapForThisStage.get(hashedKey)); + } else { + partitionSizeMapForThisStage.put(hashedKey, partitionSize); + } + } + stageIdToOutputPartitionSizeMap.put(RuntimeIdManager.getStageIdFromTaskId(taskId), partitionSizeMapForThisStage); + } + + /** + * Store the tracked processed bytes per task by the current time. + * + * @param taskId id of task to track. + * @param processedBytes size of the processed bytes till now. + */ + public void aggregateTaskIdToProcessedBytes(final String taskId, + final long processedBytes) { + taskIdToProcessedBytes.put(taskId, processedBytes); + } + + /** + * Check if work stealing can be conducted. + * + * @param scheduleGroup schedule group. + * @return true if work stealing is possible. + */ + private boolean checkForWorkStealingBaseConditions(final List scheduleGroup) { + if (scheduleGroup.isEmpty()) { + return false; + } + + /* If the stage of the given schedule group contains sharded tasks, return false */ + if (scheduleGroup.stream().anyMatch(stageId -> stageIdToWorkStealingExecuted.get(stageId).equals(true))) { + return false; + } + + /* If there are idle executors and the number of remaining tasks are smaller than number of executors, + * return true. */ + final boolean executorStatus = executorRegistry.isExecutorSlotAvailable(); + final int totalNumberOfSlots = executorRegistry.getTotalNumberOfExecutorSlots(); + int remainingTasks = 0; + for (String stage : scheduleGroup) { + remainingTasks += planStateManager.getNumberOfTasksRemainingInStage(stage); // ready + executing? + } + return executorStatus && (totalNumberOfSlots > remainingTasks); + } + + /** + * Get the ids of tasks in execution. + * + * @param scheduleGroup schedule group. + * @return ids of running tasks. + */ + private Set getRunningTaskId(final List scheduleGroup) { + final Set onGoingTasksOfSchedulingGroup = new HashSet<>(); + for (String stageId : scheduleGroup) { + onGoingTasksOfSchedulingGroup.addAll(planStateManager.getOngoingTaskIdsInStage(stageId)); + } + return onGoingTasksOfSchedulingGroup; + } + + /** + * Get parent stages of given schedule group. + * + * @param scheduleGroup schedule group. + * @return Map of stage and set of its parent. + */ + private Map> getParentStages(final List scheduleGroup) { + Map> parentStages = new HashMap<>(); + for (String stageId : scheduleGroup) { + parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId) + .stream() + .map(Vertex::getId) + .collect(Collectors.toSet())); + } + return parentStages; + } + + /** + * Get the input size of running tasks. + * + * @param parentStageIds id of parent stages. + * @param runningTaskIds id of running tasks. + * @return Map of task id to its input size. + */ + private Map getInputSizeOfRunningTasks(final Set parentStageIds, + final Set runningTaskIds) { + Map currentlyRunningTaskIdsToTotalSize = new HashMap<>(); + for (String parent : parentStageIds) { + Map taskIdxToSize = stageIdToOutputPartitionSizeMap.get(parent); + for (String taskId : runningTaskIds) { + if (currentlyRunningTaskIdsToTotalSize.containsKey(taskId)) { + final long existingValue = currentlyRunningTaskIdsToTotalSize.get(taskId); + currentlyRunningTaskIdsToTotalSize.put(taskId, + existingValue + taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } else { + currentlyRunningTaskIdsToTotalSize + .put(taskId, taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } + } + } + return currentlyRunningTaskIdsToTotalSize; + } + + /** + * get current execution time of running tasks in millisecond. + * Note that this is the execution time of incomplete tasks. + * + * @param scheduleGroup schedule group. + * @return Map of task id to its execution time. + */ + private Map getCurrentExecutionTimeMsOfRunningTasks(final List scheduleGroup) { + final Map taskToExecutionTime = new HashMap<>(); + for (String stageId : scheduleGroup) { + taskToExecutionTime.putAll(planStateManager.getExecutingTaskToRunningTimeMs(stageId)); + } + return taskToExecutionTime; + } + + private List getScheduleGroupByStage(final String stageId) { + return sortedScheduleGroups.get( + planStateManager.getPhysicalPlan().getStageDAG().getVertexById(stageId).getScheduleGroup()) + .stream() + .map(Vertex::getId) + .collect(Collectors.toList()); + } + + /** + * Detect skewed tasks. + * + * @param scheduleGroup current schedule group. + * @return List of skewed tasks. + */ + private List detectSkew(final List scheduleGroup) { + final Map> taskIdToIteratorInformation = new HashMap<>(); + final Map taskIdToInitializationOverhead = new HashMap<>(); + final Map inputSizeOfCandidateTasks = new HashMap<>(); + final Map> parentStageId = getParentStages(scheduleGroup); + + + /* if this schedule group contains a source stage, return empty list */ + if (scheduleGroup.stream().anyMatch(stage -> + planStateManager.getPhysicalPlan().getStageDAG().getParents(stage).isEmpty())) { + return new ArrayList<>(); + } + + workStealingCandidates.addAll(getRunningTaskId(scheduleGroup)); + + /* Gather statistics of work stealing candidates */ + + /* get size of running tasks */ + for (String stage : scheduleGroup) { + inputSizeOfCandidateTasks.putAll( + getInputSizeOfRunningTasks(parentStageId.get(stage), workStealingCandidates)); + } + + /* get elapsed time */ + Map taskIdToElapsedTime = getCurrentExecutionTimeMsOfRunningTasks(scheduleGroup); + + /* gather task metric */ + for (String taskId : workStealingCandidates) { + TaskMetric taskMetric = metricStore.getMetricWithId(TaskMetric.class, taskId); + + taskIdToProcessedBytes.put(taskId, taskMetric.getSerializedReadBytes()); + taskIdToIteratorInformation.put(taskId, Pair.of( + taskMetric.getCurrentIteratorIndex(), taskMetric.getTotalIteratorNumber())); + taskIdToInitializationOverhead.put(taskId, taskMetric.getTaskPreparationTime()); + } + + /* If gathered statistic is not sufficient for skew detection, return empty list. */ + if (taskIdToProcessedBytes.size() <= workStealingCandidates.size() / 2) { + return new ArrayList<>(); + } + + /* estimate the remaining time */ + List> estimatedTimeToFinishPerTask = new ArrayList<>(taskIdToElapsedTime.size()); + + for (String taskId : taskIdToProcessedBytes.keySet()) { + // if processed bytes are not available, do not detect skew. + if (taskIdToProcessedBytes.get(taskId) <= 0) { + return new ArrayList<>(); + } + + // if this task is almost finished, ignore it. + Pair iteratorInformation = taskIdToIteratorInformation.get(taskId); + if (iteratorInformation.right() - iteratorInformation.left() <= 2) { + continue; + } + + long timeToFinishExecute = taskIdToElapsedTime.get(taskId) * inputSizeOfCandidateTasks.get(taskId) + / taskIdToProcessedBytes.get(taskId); + + // if the estimated left time is shorter than the initialization overhead, stop! + if (timeToFinishExecute < taskIdToInitializationOverhead.get(taskId) * 2) { + continue; + } + + estimatedTimeToFinishPerTask.add(Pair.of(taskId, timeToFinishExecute)); + } + + /* detect skew */ + Collections.sort(estimatedTimeToFinishPerTask, new Comparator>() { + @Override + public int compare(final Pair o1, final Pair o2) { + return o2.right().compareTo(o1.right()); + } + }); + + return estimatedTimeToFinishPerTask + .subList(0, estimatedTimeToFinishPerTask.size() / 2) + .stream().map(Pair::left).collect(Collectors.toList()); + } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java index 11d40c73b8..5cead6e290 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java @@ -126,6 +126,14 @@ private Set getRunningExecutors() { .collect(Collectors.toSet()); } + public int getTotalNumberOfExecutorSlots() { + return getRunningExecutors().stream().mapToInt(ExecutorRepresenter::getExecutorCapacity).sum(); + } + + public boolean isExecutorSlotAvailable() { + return getRunningExecutors().stream().anyMatch(ExecutorRepresenter::isExecutorSlotAvailable); + } + @Override public String toString() { return executors.toString(); diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java index cc4661df64..afe30f6e73 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java @@ -86,6 +86,11 @@ void onTaskStateReportFromExecutor(String executorId, */ void onSpeculativeExecutionCheck(); + /** + * Called to check for work stealing condition. + */ + void onWorkStealingCheck(); + /** * To be called when a job should be terminated. * Any clean up code should be implemented in this method. diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java index 42870f609e..5885aa0ada 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java @@ -451,6 +451,12 @@ public void onSpeculativeExecutionCheck() { return; } + @Override + public void onWorkStealingCheck() { + // we don't simulate work stealing yet. + return; + } + @Override public void terminate() { this.taskDispatcher.terminate(); diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java index 24e30bec87..ffa2c586da 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java @@ -149,6 +149,11 @@ public void onSpeculativeExecutionCheck() { throw new UnsupportedOperationException(); } + @Override + public void onWorkStealingCheck() { + throw new UnsupportedOperationException(); + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName());