From cf39d486dddf49e29baa5dea78ca8a539ac74e23 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 16:58:02 +0900 Subject: [PATCH 01/12] calculate processed block size in outputwriter --- .../executor/datatransfer/BlockOutputWriter.java | 14 +++++++++++++- .../executor/datatransfer/OutputWriter.java | 6 ++++++ .../executor/datatransfer/PipeOutputWriter.java | 6 ++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java index 97cde037c6..b9cf9ef129 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/BlockOutputWriter.java @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter { private long writtenBytes; + private Optional> partitionSizeMap; + /** * Constructor. * @@ -109,7 +111,7 @@ public void close() { final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge .getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new); - final Optional> partitionSizeMap = blockToWrite.commit(); + partitionSizeMap = blockToWrite.commit(); // Return the total size of the committed block. if (partitionSizeMap.isPresent()) { long blockSizeTotal = 0; @@ -123,6 +125,16 @@ public void close() { blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence); } + @Override + public Optional> getPartitionSizeMap() { + if (partitionSizeMap.isPresent()) { + return partitionSizeMap; + } else { + return Optional.empty(); + } + } + + @Override public Optional getWrittenBytes() { if (writtenBytes == -1) { return Optional.empty(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java index bf6ff84e69..a1862f5f2d 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/OutputWriter.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.punctuation.Watermark; +import java.util.Map; import java.util.Optional; /** @@ -45,5 +46,10 @@ public interface OutputWriter { */ Optional getWrittenBytes(); + /** + * @return the map of hashed key to partition size. + */ + Optional> getPartitionSizeMap(); + void close(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java index 544d64d921..d0025428aa 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; /** @@ -113,6 +114,11 @@ public Optional getWrittenBytes() { return Optional.empty(); } + @Override + public Optional> getPartitionSizeMap() { + return Optional.empty(); + } + @Override public void close() { if (!initialized) { From 1b94e85c262f329ef54a045baa0adf40ba82f488 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:15:15 +0900 Subject: [PATCH 02/12] modify datafetcher to gather trace of the serialized read bytes --- .../runtime/executor/task/DataFetcher.java | 16 +++++++ .../MultiThreadParentTaskDataFetcher.java | 6 +++ .../executor/task/ParentTaskDataFetcher.java | 45 +++++++++++++++++++ .../task/SourceVertexDataFetcher.java | 7 +++ 4 files changed, 74 insertions(+) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java index 7af08852eb..b1a828c13c 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/DataFetcher.java @@ -20,6 +20,7 @@ import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.vertex.IRVertex; +import org.apache.nemo.runtime.executor.MetricMessageSender; import java.io.IOException; @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable { */ abstract Object fetchDataElement() throws IOException; + /** + * Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore + * on every iterator advance. + * This method is for WorkStealing implementation in Nemo. + * + * @param taskId task id + * @param metricMessageSender metricMessageSender + * + * @return data element + * @throws IOException upon I/O error + * @throws java.util.NoSuchElementException if no more element is available + */ + abstract Object fetchDataElementWithTrace(String taskId, + MetricMessageSender metricMessageSender) throws IOException; + OutputCollector getOutputCollector() { return outputCollector; } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index 797818ce44..c1361cb125 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -22,6 +22,7 @@ import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.*; import org.slf4j.Logger; @@ -100,6 +101,11 @@ Object fetchDataElement() throws IOException { } } + @Override + Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) throws IOException { + return fetchDataElement(); + } + private void fetchDataLazily() { final List> futures = readersForParentTask.read(); numOfIterators = futures.size(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index a8ae4a9306..4c376ff6b9 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -18,10 +18,12 @@ */ package org.apache.nemo.runtime.executor.task; +import org.apache.commons.lang3.SerializationUtils; import org.apache.nemo.common.ir.OutputCollector; import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.punctuation.Finishmark; +import org.apache.nemo.runtime.executor.MetricMessageSender; import org.apache.nemo.runtime.executor.data.DataUtil; import org.apache.nemo.runtime.executor.datatransfer.InputReader; import org.slf4j.Logger; @@ -100,6 +102,49 @@ Object fetchDataElement() throws IOException { return Finishmark.getInstance(); } + @Override + Object fetchDataElementWithTrace(String taskId, + MetricMessageSender metricMessageSender) throws IOException { + try { + if (firstFetch) { + fetchDataLazily(); + advanceIterator(); + firstFetch = false; + } + + while (true) { + // This iterator has the element + if (this.currentIterator.hasNext()) { + return this.currentIterator.next(); + } + + // This iterator does not have the element + if (currentIteratorIndex < expectedNumOfIterators) { + // Next iterator has the element + countBytes(currentIterator); + // Send the cumulative serBytes to MetricStore + metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes", + SerializationUtils.serialize(serBytes)); + advanceIterator(); + continue; + } else { + // We've consumed all the iterators + break; + } + + } + } catch (final Throwable e) { + // Any failure is caught and thrown as an IOException, so that the task is retried. + // In particular, we catch unchecked exceptions like RuntimeException thrown by DataUtil.IteratorWithNumBytes + // when remote data fetching fails for whatever reason. + // Note that we rely on unchecked exceptions because the Iterator interface does not provide the standard + // "throw Exception" that the TaskExecutor thread can catch and handle. + throw new IOException(e); + } + + return Finishmark.getInstance(); + } + private void advanceIterator() throws IOException { // Take from iteratorQueue final Object iteratorOrThrowable; diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 2d82898d7a..8ac8c27eee 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -23,7 +23,9 @@ import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.punctuation.Finishmark; import org.apache.nemo.common.punctuation.Watermark; +import org.apache.nemo.runtime.executor.MetricMessageSender; +import java.io.IOException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -74,6 +76,11 @@ Object fetchDataElement() { } } + @Override + Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) { + return fetchDataElement(); + } + final long getBoundedSourceReadTime() { return boundedSourceReadTime; } From 5f53288749d3383eae2a6045c3513eeef4f10379 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:28:23 +0900 Subject: [PATCH 03/12] handle checkstyle --- .../executor/task/MultiThreadParentTaskDataFetcher.java | 3 ++- .../nemo/runtime/executor/task/ParentTaskDataFetcher.java | 4 ++-- .../nemo/runtime/executor/task/SourceVertexDataFetcher.java | 3 +-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java index c1361cb125..d7947e8c78 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/MultiThreadParentTaskDataFetcher.java @@ -102,7 +102,8 @@ Object fetchDataElement() throws IOException { } @Override - Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) throws IOException { + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { return fetchDataElement(); } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java index 4c376ff6b9..3a92cbc8a9 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/ParentTaskDataFetcher.java @@ -103,8 +103,8 @@ Object fetchDataElement() throws IOException { } @Override - Object fetchDataElementWithTrace(String taskId, - MetricMessageSender metricMessageSender) throws IOException { + Object fetchDataElementWithTrace(final String taskId, + final MetricMessageSender metricMessageSender) throws IOException { try { if (firstFetch) { fetchDataLazily(); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java index 8ac8c27eee..68a3362d27 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/SourceVertexDataFetcher.java @@ -25,7 +25,6 @@ import org.apache.nemo.common.punctuation.Watermark; import org.apache.nemo.runtime.executor.MetricMessageSender; -import java.io.IOException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -77,7 +76,7 @@ Object fetchDataElement() { } @Override - Object fetchDataElementWithTrace(String taskId, MetricMessageSender metricMessageSender) { + Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) { return fetchDataElement(); } From aed78c66f1a75d0756fe39b211d772351841df8f Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:33:37 +0900 Subject: [PATCH 04/12] replace fetchDataElement with fetchDataElementWithTrace in TaskExecutor --- .../org/apache/nemo/runtime/executor/task/TaskExecutor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 2bf574d396..fc3cc4e8b8 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -458,7 +458,7 @@ private boolean handleDataFetchers(final List fetchers) { while (availableIterator.hasNext()) { final DataFetcher dataFetcher = availableIterator.next(); try { - final Object element = dataFetcher.fetchDataElement(); + final Object element = dataFetcher.fetchDataElementWithTrace(taskId, metricMessageSender); onEventFromDataFetcher(element, dataFetcher); if (element instanceof Finishmark) { availableIterator.remove(); From e4401f05291527d6ff4848ca3d0976b926667264 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 17:50:54 +0900 Subject: [PATCH 05/12] add work stealing thread in runtime master --- .../apache/nemo/runtime/master/RuntimeMaster.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..4c4b5bb35e 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -85,9 +85,11 @@ public final class RuntimeMaster { private static final int METRIC_ARRIVE_TIMEOUT = 10000; private static final int REST_SERVER_PORT = 10101; private static final int SPECULATION_CHECKING_PERIOD_MS = 100; + private static final int WORK_STEALING_CHECKING_PERIOD_MS = 100; private final ExecutorService runtimeMasterThread; private final ScheduledExecutorService speculativeTaskCloningThread; + private final ScheduledExecutorService workStealingThread; private final Scheduler scheduler; private final ContainerManager containerManager; @@ -160,6 +162,16 @@ private RuntimeMaster(final Scheduler scheduler, SPECULATION_CHECKING_PERIOD_MS, TimeUnit.MILLISECONDS); + // Check for work stealing every second + this.workStealingThread = Executors + .newSingleThreadScheduledExecutor(runnable -> new Thread(runnable, "WorkStealing master thread")); + this.workStealingThread.scheduleWithFixedDelay( + () -> this.runtimeMasterThread.submit(scheduler::onWorkStealingCheck), + WORK_STEALING_CHECKING_PERIOD_MS, + WORK_STEALING_CHECKING_PERIOD_MS, + TimeUnit.MILLISECONDS); + + this.scheduler = scheduler; this.containerManager = containerManager; this.executorRegistry = executorRegistry; From 398ff94c23e052381193dfa7cb9e0ef9b5e3b653 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 18:14:03 +0900 Subject: [PATCH 06/12] send accumulated KV statistics to nemo driver --- .../src/main/proto/ControlMessage.proto | 7 +++ .../runtime/executor/task/TaskExecutor.java | 50 +++++++++++++++++++ .../nemo/runtime/master/RuntimeMaster.java | 14 ++++-- .../master/scheduler/BatchScheduler.java | 21 ++++++++ 4 files changed, 88 insertions(+), 4 deletions(-) diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 97e30fb4e7..3d43f10c64 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -86,6 +86,7 @@ enum MessageType { PipeInit = 13; RequestPipeLoc = 14; PipeLocInfo = 15; + ParentTaskDataCollected = 16; } message Message { @@ -107,6 +108,7 @@ message Message { optional PipeInitMessage pipeInitMsg = 16; optional RequestPipeLocationMessage requestPipeLocMsg = 17; optional PipeLocationInfoMessage pipeLocInfoMsg = 18; + optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; } // Messages from Master to Executors @@ -256,3 +258,8 @@ message PipeLocationInfoMessage { required int64 requestId = 1; // To find the matching request msg required string executorId = 2; } + +message ParentTaskDataCollectMsg { + required string taskId = 1; + required bytes partitionSizeMap = 2; +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index fc3cc4e8b8..758e32212e 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -19,6 +19,7 @@ package org.apache.nemo.runtime.executor.task; import com.google.common.collect.Lists; +import com.google.protobuf.ByteString; import org.apache.commons.lang3.SerializationUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.nemo.common.Pair; @@ -688,12 +689,21 @@ public void setIRVertexPutOnHold(final IRVertex irVertex) { */ private void finalizeOutputWriters(final VertexHarness vertexHarness) { final List writtenBytesList = new ArrayList<>(); + final HashMap partitionSizeMap = new HashMap<>(); // finalize OutputWriters for main children vertexHarness.getWritersToMainChildrenTasks().forEach(outputWriter -> { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }); // finalize OutputWriters for additional tagged children @@ -702,6 +712,14 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { outputWriter.close(); final Optional writtenBytes = outputWriter.getWrittenBytes(); writtenBytes.ifPresent(writtenBytesList::add); + + // Send partitionSizeMap to Scheduler + if (true) { + final Optional> partitionSizes = outputWriter.getPartitionSizeMap(); + if (partitionSizes.isPresent()) { + computePartitionSizeMap(partitionSizeMap, partitionSizes.get()); + } + } }) ); @@ -713,5 +731,37 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { // TODO #236: Decouple metric collection and sending logic metricMessageSender.send(TASK_METRIC_ID, taskId, "taskOutputBytes", SerializationUtils.serialize(totalWrittenBytes)); + + if (!partitionSizeMap.isEmpty()) { + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.ParentTaskDataCollected) + .setParentTaskDataCollected(ControlMessage.ParentTaskDataCollectMsg.newBuilder() + .setTaskId(taskId) + .setPartitionSizeMap(ByteString.copyFrom(SerializationUtils.serialize(partitionSizeMap))) + .build()) + .build()); + } + } + + /** + * Gather the KV statistics of processed data. + * This method is for work stealing implementation. + * + * @param totalPartitionSizeMap accumulated partitionSizeMap of task. + * @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter. + */ + private void computePartitionSizeMap(final Map totalPartitionSizeMap, + final Map singlePartitionSizeMap) { + for (Integer hashedKey : singlePartitionSizeMap.keySet()) { + final Long partitionSize = singlePartitionSizeMap.get(hashedKey); + if (totalPartitionSizeMap.containsKey(hashedKey)) { + totalPartitionSizeMap.compute(hashedKey, (existingKey, existingValue) -> existingValue + partitionSize); + } else { + totalPartitionSizeMap.put(hashedKey, partitionSize); + } + } } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index d3b48f266a..dd6b144e2a 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -58,10 +58,7 @@ import javax.inject.Inject; import java.io.Serializable; import java.nio.file.Paths; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; @@ -481,6 +478,15 @@ private void handleControlMessage(final ControlMessage.Message message) { .setDataCollected(ControlMessage.DataCollectMessage.newBuilder().setData(serializedData).build()) .build()); break; + case ParentTaskDataCollected: + if (scheduler instanceof BatchScheduler) { + final ControlMessage.ParentTaskDataCollectMsg workStealingMsg = message.getParentTaskDataCollected(); + final String taskId = workStealingMsg.getTaskId(); + final Map partitionSizeMap = SerializationUtils + .deserialize(workStealingMsg.getPartitionSizeMap().toByteArray()); + ((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap); + } + break; case MetricFlushed: metricCountDownLatch.countDown(); break; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 086c9d08bd..23ed8326e7 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -77,6 +77,11 @@ public final class BatchScheduler implements Scheduler { */ private List> sortedScheduleGroups; // Stages, sorted in the order to be scheduled. + /** + * Data Structures for work stealing. + */ + private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); + @Inject private BatchScheduler(final PlanRewriter planRewriter, final TaskDispatcher taskDispatcher, @@ -383,4 +388,20 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, return false; } + + // Methods for work stealing + public void aggregateStageIdToPartitionSizeMap(final String taskId, + final Map partitionSizeMap) { + final Map partitionSizeMapForThisStage = stageIdToOutputPartitionSizeMap + .getOrDefault(RuntimeIdManager.getStageIdFromTaskId(taskId), new HashMap<>()); + for (Integer hashedKey : partitionSizeMap.keySet()) { + final Long partitionSize = partitionSizeMap.get(hashedKey); + if (partitionSizeMapForThisStage.containsKey(hashedKey)) { + partitionSizeMapForThisStage.put(hashedKey, partitionSize + partitionSizeMapForThisStage.get(hashedKey)); + } else { + partitionSizeMapForThisStage.put(hashedKey, partitionSize); + } + } + stageIdToOutputPartitionSizeMap.put(RuntimeIdManager.getStageIdFromTaskId(taskId), partitionSizeMapForThisStage); + } } From 9e2d04775f8c072be7ecbc334e64745740756766 Mon Sep 17 00:00:00 2001 From: hwarim Date: Thu, 15 Jul 2021 18:35:42 +0900 Subject: [PATCH 07/12] track the processed bytes of the current stage: send it to driver --- .../src/main/proto/ControlMessage.proto | 7 ++++++ .../runtime/executor/task/TaskExecutor.java | 24 +++++++++++++++++-- .../nemo/runtime/master/RuntimeMaster.java | 7 ++++++ .../master/scheduler/BatchScheduler.java | 20 ++++++++++++++++ 4 files changed, 56 insertions(+), 2 deletions(-) diff --git a/runtime/common/src/main/proto/ControlMessage.proto b/runtime/common/src/main/proto/ControlMessage.proto index 3d43f10c64..53adea3d3a 100644 --- a/runtime/common/src/main/proto/ControlMessage.proto +++ b/runtime/common/src/main/proto/ControlMessage.proto @@ -87,6 +87,7 @@ enum MessageType { RequestPipeLoc = 14; PipeLocInfo = 15; ParentTaskDataCollected = 16; + CurrentlyProcessedBytesCollected = 17; } message Message { @@ -109,6 +110,7 @@ message Message { optional RequestPipeLocationMessage requestPipeLocMsg = 17; optional PipeLocationInfoMessage pipeLocInfoMsg = 18; optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19; + optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20; } // Messages from Master to Executors @@ -263,3 +265,8 @@ message ParentTaskDataCollectMsg { required string taskId = 1; required bytes partitionSizeMap = 2; } + +message CurrentlyProcessedBytesCollectMsg { + required string taskId = 1; + required int64 processedDataBytes = 2; +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java index 758e32212e..91e8212640 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/task/TaskExecutor.java @@ -746,9 +746,11 @@ private void finalizeOutputWriters(final VertexHarness vertexHarness) { } } + // Methods for work stealing /** - * Gather the KV statistics of processed data. - * This method is for work stealing implementation. + * Gather the KV statistics of processed data when execution is completed. + * This method is for work stealing implementation: the accumulated statistics will be used to + * detect skewed tasks of the child stage. * * @param totalPartitionSizeMap accumulated partitionSizeMap of task. * @param singlePartitionSizeMap partitionSizeMap gained from single OutputWriter. @@ -764,4 +766,22 @@ private void computePartitionSizeMap(final Map totalPartitionSize } } } + + /** + * Send the temporally processed bytes of the current task on request from the scheduler. + * This method is for work stealing implementation. + */ + public void onRequestForProcessedData() { + LOG.error("{}, bytes {}, replying for the request", taskId, serializedReadBytes); + persistentConnectionToMasterMap.getMessageSender(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID).send( + ControlMessage.Message.newBuilder() + .setId(RuntimeIdManager.generateMessageId()) + .setListenerId(MessageEnvironment.RUNTIME_MASTER_MESSAGE_LISTENER_ID) + .setType(ControlMessage.MessageType.CurrentlyProcessedBytesCollected) + .setCurrentlyProcessedBytesCollected(ControlMessage.CurrentlyProcessedBytesCollectMsg.newBuilder() + .setTaskId(this.taskId) + .setProcessedDataBytes(serializedReadBytes) + .build()) + .build()); + } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java index dd6b144e2a..40fb5e86fc 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/RuntimeMaster.java @@ -487,6 +487,13 @@ private void handleControlMessage(final ControlMessage.Message message) { ((BatchScheduler) scheduler).aggregateStageIdToPartitionSizeMap(taskId, partitionSizeMap); } break; + case CurrentlyProcessedBytesCollected: + if (scheduler instanceof BatchScheduler) { + ((BatchScheduler) scheduler).aggregateTaskIdToProcessedBytes( + message.getCurrentlyProcessedBytesCollected().getTaskId(), + message.getCurrentlyProcessedBytesCollected().getProcessedDataBytes() + ); + } case MetricFlushed: metricCountDownLatch.countDown(); break; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 23ed8326e7..8941aa093c 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -81,6 +81,7 @@ public final class BatchScheduler implements Scheduler { * Data Structures for work stealing. */ private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); + private final Map taskIdToProcessedBytes = new HashMap<>(); @Inject private BatchScheduler(final PlanRewriter planRewriter, @@ -390,6 +391,14 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, } // Methods for work stealing + + /** + * Accumulate the execution result of each stage in Map[STAGE ID, Map[KEY, SIZE]] format. + * KEY is assumed to be Integer because of the HashPartition. + * + * @param taskId id of task to accumulate. + * @param partitionSizeMap map of (K) - (partition size) of the task. + */ public void aggregateStageIdToPartitionSizeMap(final String taskId, final Map partitionSizeMap) { final Map partitionSizeMapForThisStage = stageIdToOutputPartitionSizeMap @@ -404,4 +413,15 @@ public void aggregateStageIdToPartitionSizeMap(final String taskId, } stageIdToOutputPartitionSizeMap.put(RuntimeIdManager.getStageIdFromTaskId(taskId), partitionSizeMapForThisStage); } + + /** + * Store the tracked processed bytes per task by the current time. + * + * @param taskId id of task to track. + * @param processedBytes size of the processed bytes till now. + */ + public void aggregateTaskIdToProcessedBytes(final String taskId, + final long processedBytes) { + taskIdToProcessedBytes.put(taskId, processedBytes); + } } From f88cd30adde973b6e310258a300809542818f6b3 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 14:28:49 +0900 Subject: [PATCH 08/12] check work stealing on scheduler --- .../master/scheduler/BatchScheduler.java | 200 +++++++++++++++++- .../runtime/master/scheduler/Scheduler.java | 5 + .../master/scheduler/SimulationScheduler.java | 6 + .../master/scheduler/StreamingScheduler.java | 5 + 4 files changed, 215 insertions(+), 1 deletion(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index 8941aa093c..f41b359f7b 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -20,16 +20,19 @@ import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.nemo.common.Pair; +import org.apache.nemo.common.dag.Vertex; import org.apache.nemo.common.exception.UnknownExecutionStateException; import org.apache.nemo.common.exception.UnrecoverableFailureException; import org.apache.nemo.common.ir.vertex.executionproperty.ClonedSchedulingProperty; import org.apache.nemo.runtime.common.RuntimeIdManager; +import org.apache.nemo.runtime.common.metric.TaskMetric; import org.apache.nemo.runtime.common.plan.*; import org.apache.nemo.runtime.common.state.StageState; import org.apache.nemo.runtime.common.state.TaskState; import org.apache.nemo.runtime.master.BlockManagerMaster; import org.apache.nemo.runtime.master.PlanAppender; import org.apache.nemo.runtime.master.PlanStateManager; +import org.apache.nemo.runtime.master.metric.MetricStore; import org.apache.nemo.runtime.master.resource.ExecutorRepresenter; import org.apache.reef.annotations.audience.DriverSide; import org.slf4j.Logger; @@ -66,6 +69,7 @@ public final class BatchScheduler implements Scheduler { private final PendingTaskCollectionPointer pendingTaskCollectionPointer; // A 'pointer' to the list of pending tasks. private final ExecutorRegistry executorRegistry; // A registry for executors available for the job. private final PlanStateManager planStateManager; // A component that manages the state of the plan. + private final MetricStore metricStore = MetricStore.getStore(); /** * Other necessary components of this {@link org.apache.nemo.runtime.master.RuntimeMaster}. @@ -80,8 +84,10 @@ public final class BatchScheduler implements Scheduler { /** * Data Structures for work stealing. */ + private final Set workStealingCandidates = new HashSet<>(); private final Map> stageIdToOutputPartitionSizeMap = new HashMap<>(); private final Map taskIdToProcessedBytes = new HashMap<>(); + private final Map stageIdToWorkStealingExecuted = new HashMap<>(); @Inject private BatchScheduler(final PlanRewriter planRewriter, @@ -117,6 +123,11 @@ public void updatePlan(final PhysicalPlan newPhysicalPlan) { private void updatePlan(final PhysicalPlan newPhysicalPlan, final int maxScheduleAttempt) { planStateManager.updatePlan(newPhysicalPlan, maxScheduleAttempt); + + for (Stage stage : planStateManager.getPhysicalPlan().getStageDAG().getVertices()) { + stageIdToWorkStealingExecuted.putIfAbsent(stage.getId(), false); + } + this.sortedScheduleGroups = newPhysicalPlan.getStageDAG().getVertices().stream() .collect(Collectors.groupingBy(Stage::getScheduleGroup)) .entrySet().stream() @@ -264,6 +275,24 @@ public void onSpeculativeExecutionCheck() { } } + @Override + public void onWorkStealingCheck() { + MutableBoolean isWorkStealingConditionSatisfied = new MutableBoolean(false); + List scheduleGroup = BatchSchedulerUtils + .selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager).orElse(new ArrayList<>()); + List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); + isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); + + if (isWorkStealingConditionSatisfied.booleanValue()) { + taskIdToProcessedBytes.clear(); + final List skewedTasks = detectSkew(scheduleGroupInId); + } + + // TODO #469 Split tasks using iterator interface. + + return; + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName()); @@ -310,6 +339,9 @@ public void terminate() { * - We make {@link TaskDispatcher} dispatch only the tasks that are READY. */ private void doSchedule() { + taskIdToProcessedBytes.clear(); + workStealingCandidates.clear(); + final Optional> earliest = BatchSchedulerUtils.selectEarliestSchedulableGroup(sortedScheduleGroups, planStateManager); @@ -390,7 +422,7 @@ private boolean modifyStageNumCloneUsingMedianTime(final String stageId, return false; } - // Methods for work stealing + ///////////////////////////////////////////////////////////////// Methods for work stealing /** * Accumulate the execution result of each stage in Map[STAGE ID, Map[KEY, SIZE]] format. @@ -424,4 +456,170 @@ public void aggregateTaskIdToProcessedBytes(final String taskId, final long processedBytes) { taskIdToProcessedBytes.put(taskId, processedBytes); } + + /** + * Check if work stealing can be conducted. + * + * @param scheduleGroup schedule group. + */ + private boolean checkForWorkStealingBaseConditions(final List scheduleGroup) { + if (scheduleGroup.isEmpty()) { + return false; + } + + /* If the stage of the given schedule group contains sharded tasks, return false */ + if (scheduleGroup.stream().anyMatch(stageId -> stageIdToWorkStealingExecuted.get(stageId).equals(true))) { + return false; + } + + /* If there are idle executors and the number of remaining tasks are smaller than number of executors, + * return true. + */ + final boolean executorStatus = executorRegistry.isExecutorSlotAvailable(); + final int totalNumberOfSlots = executorRegistry.getTotalNumberOfExecutorSlots(); + int remainingTasks = 0; + for (String stage : scheduleGroup) { + remainingTasks += planStateManager.getNumberOfTasksRemainingInStage(stage); // ready + executing? + } + return executorStatus && (totalNumberOfSlots > remainingTasks); + } + + private Set getCurrentlyRunningTaskId(final List scheduleGroup) { + final Set onGoingTasksOfSchedulingGroup = new HashSet<>(); + for (String stageId : scheduleGroup) { + onGoingTasksOfSchedulingGroup.addAll(planStateManager.getOngoingTaskIdsInStage(stageId)); + } + return onGoingTasksOfSchedulingGroup; + } + + private Map> getParentStages(final List scheduleGroup) { + Map> parentStages = new HashMap<>(); + for (String stageId : scheduleGroup) { + parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId).stream() + .map(Vertex::getId) + .collect(Collectors.toSet())); + } + return parentStages; + } + + private Map getInputSizesOfRunningTaskIds(final Set parentStageIds, + final Set currentlyRunningTaskIds) { + Map currentlyRunningTaskIdsToTotalSize = new HashMap<>(); + for (String parent : parentStageIds) { + Map taskIdxToSize = stageIdToOutputPartitionSizeMap.get(parent); + for (String taskId : currentlyRunningTaskIds) { + if (currentlyRunningTaskIdsToTotalSize.containsKey(taskId)) { + final long existingValue = currentlyRunningTaskIdsToTotalSize.get(taskId); + currentlyRunningTaskIdsToTotalSize.put(taskId, + existingValue + taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } else { + currentlyRunningTaskIdsToTotalSize + .put(taskId, taskIdxToSize.get(RuntimeIdManager.getIndexFromTaskId(taskId))); + } + } + } + return currentlyRunningTaskIdsToTotalSize; + } + + private Map getCurrentExecutionTimeMsOfRunningTasks(final List scheduleGroup) { + final Map taskToExecutionTime = new HashMap<>(); + for (String stageId : scheduleGroup) { + taskToExecutionTime.putAll(planStateManager.getExecutingTaskToRunningTimeMs(stageId)); + } + return taskToExecutionTime; + } + + private List getScheduleGroupByStage(final String stageId) { + return sortedScheduleGroups.get( + planStateManager.getPhysicalPlan().getStageDAG().getVertexById(stageId).getScheduleGroup()) + .stream() + .map(Vertex::getId) + .collect(Collectors.toList()); + } + + /** + * Detect skewed tasks. + * + * @param scheduleGroup current schedule group. + * @return List of skewed tasks. + */ + private List detectSkew(final List scheduleGroup) { + final Map> taskIdToIteratorInformation = new HashMap<>(); + final Map taskIdToInitializationOverhead = new HashMap<>(); + final Map inputSizeOfCandidateTasks = new HashMap<>(); + final Map> parentStageId = getParentStages(scheduleGroup); + + + /* if this schedule group contains a source stage, return empty list */ + if (scheduleGroup.stream().anyMatch(stage -> + planStateManager.getPhysicalPlan().getStageDAG().getParents(stage).isEmpty())) { + return new ArrayList<>(); + } + + workStealingCandidates.addAll(getCurrentlyRunningTaskId(scheduleGroup)); + + /* Gather statistics of work stealing candidates */ + + /* get size of running tasks */ + for (String stage : scheduleGroup) { + inputSizeOfCandidateTasks.putAll( + getInputSizesOfRunningTaskIds(parentStageId.get(stage), workStealingCandidates)); + } + + /* get elapsed time */ + Map taskIdToElapsedTime = getCurrentExecutionTimeMsOfRunningTasks(scheduleGroup); + + /* gather task metric */ + for (String taskId : workStealingCandidates) { + TaskMetric taskMetric = metricStore.getMetricWithId(TaskMetric.class, taskId); + + taskIdToProcessedBytes.put(taskId, taskMetric.getSerializedReadBytes()); + taskIdToIteratorInformation.put(taskId, Pair.of( + taskMetric.getCurrentIteratorIndex(), taskMetric.getTotalIteratorNumber())); + taskIdToInitializationOverhead.put(taskId, taskMetric.getTaskPreparationTime()); + } + + /* If gathered statistic is not sufficient for skew detection, return empty list. */ + if (taskIdToProcessedBytes.size() <= workStealingCandidates.size() / 2) { + return new ArrayList<>(); + } + + /* estimate the remaining time */ + List> estimatedTimeToFinishPerTask = new ArrayList<>(taskIdToElapsedTime.size()); + + for (String taskId : taskIdToProcessedBytes.keySet()) { + // if processed bytes are not available, do not detect skew. + if (taskIdToProcessedBytes.get(taskId) <= 0) { + return new ArrayList<>(); + } + + // if this task is almost finished, ignore it. + Pair iteratorInformation = taskIdToIteratorInformation.get(taskId); + if (iteratorInformation.right() - iteratorInformation.left() <= 2) { + continue; + } + + long timeToFinishExecute = taskIdToElapsedTime.get(taskId) * inputSizeOfCandidateTasks.get(taskId) + / taskIdToProcessedBytes.get(taskId); + + // if the estimated left time is shorter than the initialization overhead, stop! + if (timeToFinishExecute < taskIdToInitializationOverhead.get(taskId) * 2) { + continue; + } + + estimatedTimeToFinishPerTask.add(Pair.of(taskId, timeToFinishExecute)); + } + + // detect skew + Collections.sort(estimatedTimeToFinishPerTask, new Comparator>() { + @Override + public int compare(final Pair o1, final Pair o2) { + return o2.right().compareTo(o1.right()); + } + }); + + return estimatedTimeToFinishPerTask + .subList(0, estimatedTimeToFinishPerTask.size() / 2) + .stream().map(Pair::left).collect(Collectors.toList()); + } } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java index cc4661df64..afe30f6e73 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/Scheduler.java @@ -86,6 +86,11 @@ void onTaskStateReportFromExecutor(String executorId, */ void onSpeculativeExecutionCheck(); + /** + * Called to check for work stealing condition. + */ + void onWorkStealingCheck(); + /** * To be called when a job should be terminated. * Any clean up code should be implemented in this method. diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java index 42870f609e..5885aa0ada 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SimulationScheduler.java @@ -451,6 +451,12 @@ public void onSpeculativeExecutionCheck() { return; } + @Override + public void onWorkStealingCheck() { + // we don't simulate work stealing yet. + return; + } + @Override public void terminate() { this.taskDispatcher.terminate(); diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java index 24e30bec87..ffa2c586da 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/StreamingScheduler.java @@ -149,6 +149,11 @@ public void onSpeculativeExecutionCheck() { throw new UnsupportedOperationException(); } + @Override + public void onWorkStealingCheck() { + throw new UnsupportedOperationException(); + } + @Override public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) { LOG.info("{} added (node: {})", executorRepresenter.getExecutorId(), executorRepresenter.getNodeName()); From 3ba3a5edfcd5abf6dbd2f56f192f3070e1f33981 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:32:37 +0900 Subject: [PATCH 09/12] cleanup skew detection code --- .../master/scheduler/BatchScheduler.java | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index f41b359f7b..c9305494dc 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -461,6 +461,7 @@ public void aggregateTaskIdToProcessedBytes(final String taskId, * Check if work stealing can be conducted. * * @param scheduleGroup schedule group. + * @return true if work stealing is possible. */ private boolean checkForWorkStealingBaseConditions(final List scheduleGroup) { if (scheduleGroup.isEmpty()) { @@ -473,8 +474,7 @@ private boolean checkForWorkStealingBaseConditions(final List scheduleGr } /* If there are idle executors and the number of remaining tasks are smaller than number of executors, - * return true. - */ + * return true. */ final boolean executorStatus = executorRegistry.isExecutorSlotAvailable(); final int totalNumberOfSlots = executorRegistry.getTotalNumberOfExecutorSlots(); int remainingTasks = 0; @@ -484,7 +484,13 @@ private boolean checkForWorkStealingBaseConditions(final List scheduleGr return executorStatus && (totalNumberOfSlots > remainingTasks); } - private Set getCurrentlyRunningTaskId(final List scheduleGroup) { + /** + * Get the ids of tasks in execution. + * + * @param scheduleGroup schedule group. + * @return ids of running tasks. + */ + private Set getRunningTaskId(final List scheduleGroup) { final Set onGoingTasksOfSchedulingGroup = new HashSet<>(); for (String stageId : scheduleGroup) { onGoingTasksOfSchedulingGroup.addAll(planStateManager.getOngoingTaskIdsInStage(stageId)); @@ -492,22 +498,36 @@ private Set getCurrentlyRunningTaskId(final List scheduleGroup) return onGoingTasksOfSchedulingGroup; } + /** + * Get parent stages of given schedule group. + * + * @param scheduleGroup schedule group. + * @return Map of stage and set of its parent. + */ private Map> getParentStages(final List scheduleGroup) { Map> parentStages = new HashMap<>(); for (String stageId : scheduleGroup) { - parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId).stream() + parentStages.put(stageId, planStateManager.getPhysicalPlan().getStageDAG().getParents(stageId) + .stream() .map(Vertex::getId) .collect(Collectors.toSet())); } return parentStages; } - private Map getInputSizesOfRunningTaskIds(final Set parentStageIds, - final Set currentlyRunningTaskIds) { + /** + * Get the input size of running tasks. + * + * @param parentStageIds id of parent stages. + * @param runningTaskIds id of running tasks. + * @return Map of task id to its input size. + */ + private Map getInputSizeOfRunningTasks(final Set parentStageIds, + final Set runningTaskIds) { Map currentlyRunningTaskIdsToTotalSize = new HashMap<>(); for (String parent : parentStageIds) { Map taskIdxToSize = stageIdToOutputPartitionSizeMap.get(parent); - for (String taskId : currentlyRunningTaskIds) { + for (String taskId : runningTaskIds) { if (currentlyRunningTaskIdsToTotalSize.containsKey(taskId)) { final long existingValue = currentlyRunningTaskIdsToTotalSize.get(taskId); currentlyRunningTaskIdsToTotalSize.put(taskId, @@ -521,6 +541,13 @@ private Map getInputSizesOfRunningTaskIds(final Set parent return currentlyRunningTaskIdsToTotalSize; } + /** + * get current execution time of running tasks in millisecond. + * Note that this is the execution time of incomplete tasks. + * + * @param scheduleGroup schedule group. + * @return Map of task id to its execution time. + */ private Map getCurrentExecutionTimeMsOfRunningTasks(final List scheduleGroup) { final Map taskToExecutionTime = new HashMap<>(); for (String stageId : scheduleGroup) { @@ -556,14 +583,14 @@ private List detectSkew(final List scheduleGroup) { return new ArrayList<>(); } - workStealingCandidates.addAll(getCurrentlyRunningTaskId(scheduleGroup)); + workStealingCandidates.addAll(getRunningTaskId(scheduleGroup)); /* Gather statistics of work stealing candidates */ /* get size of running tasks */ for (String stage : scheduleGroup) { inputSizeOfCandidateTasks.putAll( - getInputSizesOfRunningTaskIds(parentStageId.get(stage), workStealingCandidates)); + getInputSizeOfRunningTasks(parentStageId.get(stage), workStealingCandidates)); } /* get elapsed time */ @@ -610,7 +637,7 @@ private List detectSkew(final List scheduleGroup) { estimatedTimeToFinishPerTask.add(Pair.of(taskId, timeToFinishExecute)); } - // detect skew + /* detect skew */ Collections.sort(estimatedTimeToFinishPerTask, new Comparator>() { @Override public int compare(final Pair o1, final Pair o2) { From 94cdf152bca3cf642f4617868857ebc936e78763 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:36:35 +0900 Subject: [PATCH 10/12] get executor vacancy information from executor registry --- .../master/resource/DefaultExecutorRepresenter.java | 5 +++++ .../nemo/runtime/master/resource/ExecutorRepresenter.java | 5 +++++ .../nemo/runtime/master/scheduler/ExecutorRegistry.java | 8 ++++++++ 3 files changed, 18 insertions(+) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java index 16c9a70db9..ebec804132 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/DefaultExecutorRepresenter.java @@ -170,6 +170,11 @@ public void onTaskExecutionFailed(final String taskId) { failedTasks.add(failedTask); } + @Override + public boolean isExecutorSlotAvailable() { + return getExecutorCapacity() - getNumOfRunningTasks() > 0; + } + /** * @return how many Tasks can this executor simultaneously run */ diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java index 26649a81db..dcfb53eb1c 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/resource/ExecutorRepresenter.java @@ -108,4 +108,9 @@ public interface ExecutorRepresenter { * @param taskId id of the Task */ void onTaskExecutionFailed(String taskId); + + /** + * @return true if this executor has an available slot. + */ + boolean isExecutorSlotAvailable(); } diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java index 11d40c73b8..5cead6e290 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/ExecutorRegistry.java @@ -126,6 +126,14 @@ private Set getRunningExecutors() { .collect(Collectors.toSet()); } + public int getTotalNumberOfExecutorSlots() { + return getRunningExecutors().stream().mapToInt(ExecutorRepresenter::getExecutorCapacity).sum(); + } + + public boolean isExecutorSlotAvailable() { + return getRunningExecutors().stream().anyMatch(ExecutorRepresenter::isExecutorSlotAvailable); + } + @Override public String toString() { return executors.toString(); From d269ca144ec25bca26272f846ed0aceb21838385 Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:52:18 +0900 Subject: [PATCH 11/12] add helper methods in plan state manager --- .../nemo/runtime/master/PlanStateManager.java | 71 ++++++++++++++++--- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java index 53cab57810..65b5306f5b 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/PlanStateManager.java @@ -85,6 +85,10 @@ public final class PlanStateManager { private final Map> stageIdToCompletedTaskTimeMsList = new HashMap<>(); private final Map> stageIdToTaskIndexToNumOfClones = new HashMap<>(); + /** + * Used for work stealing. + */ + private final Map>> stageIdToTaskIdxToWSAttemptStates = new HashMap<>(); /** * Represents the plan to manage. */ @@ -127,7 +131,7 @@ public static PlanStateManager newInstance(final String dagDirectory) { } /** - * @param metricStore set the metric store of the paln state manager. + * @param metricStore set the metric store of the plan state manager. */ public void setMetricStore(final MetricStore metricStore) { this.metricStore = metricStore; @@ -326,16 +330,8 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState // Log not-yet-completed tasks for us humans to track progress final String stageId = RuntimeIdManager.getStageIdFromTaskId(taskId); final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); - final long numOfCompletedTaskIndicesInThisStage = taskStatesOfThisStage.values().stream() - .filter(attempts -> { - final List states = attempts - .stream() - .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) - .collect(Collectors.toList()); - return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) // one of them is ON_HOLD - || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); // one of them is COMPLETE - }) - .count(); + final long numOfCompletedTaskIndicesInThisStage = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + if (newTaskState.equals(TaskState.State.COMPLETE)) { LOG.info("{} completed: {} Task(s) out of {} are remaining in this stage", taskId, taskStatesOfThisStage.size() - numOfCompletedTaskIndicesInThisStage, taskStatesOfThisStage.size()); @@ -577,6 +573,59 @@ private List getPeerAttemptsForTheSameTaskIndex(final String ta .collect(Collectors.toList()); } + /** + * Get number of remaining tasks of the stage. + * + * @param stageId stage id. + * @return number of remaining tasks. + */ + public int getNumberOfTasksRemainingInStage(final String stageId) { + final Map> taskStatesOfThisStage = stageIdToTaskIdxToAttemptStates.get(stageId); + final Map> wsTaskStatesOfThisStage = stageIdToTaskIdxToWSAttemptStates + .getOrDefault(stageId, new HashMap<>()); + final long numOfCompletedTaskIndices = getNumberOfCompletedTasksInStage(taskStatesOfThisStage); + if (wsTaskStatesOfThisStage.isEmpty()) { + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices); + } else { + final long numOfCompletedWorkStealingTaskIndices = getNumberOfCompletedTasksInStage(wsTaskStatesOfThisStage); + return (int) (taskStatesOfThisStage.size() - numOfCompletedTaskIndices + + wsTaskStatesOfThisStage.size() - numOfCompletedWorkStealingTaskIndices); + } + } + + /** + * Get tasks which are currently being executed. + * + * @param stageId stage id. + * @return Set of tasksIds in execution. + */ + public Set getOngoingTaskIdsInStage(final String stageId) { + final Map> taskIdToState = stageIdToTaskIdxToAttemptStates.get(stageId); + final Set onGoingTaskIds = new HashSet<>(); + for (final int taskIndex : taskIdToState.keySet()) { + final List attemptStates = taskIdToState.get(taskIndex); + for (int attempt = 0; attempt < attemptStates.size(); attempt++) { + if (attemptStates.get(attempt).getStateMachine().getCurrentState().equals(TaskState.State.EXECUTING)) { + onGoingTaskIds.add(RuntimeIdManager.generateTaskId(stageId, taskIndex, attempt)); + } + } + } + return onGoingTaskIds; + } + + private long getNumberOfCompletedTasksInStage(final Map> taskIdxToState) { + return taskIdxToState.values().stream() + .filter(attempts -> { + final List states = attempts + .stream() + .map(state -> (TaskState.State) state.getStateMachine().getCurrentState()) + .collect(Collectors.toList()); + return states.stream().anyMatch(curState -> curState.equals(TaskState.State.ON_HOLD)) + || states.stream().anyMatch(curState -> curState.equals(TaskState.State.COMPLETE)); + }) + .count(); + } + /** * @return the physical plan. */ From e23b012d00c861df2839db1c4d35304a2c6cdb5b Mon Sep 17 00:00:00 2001 From: hwarim Date: Fri, 16 Jul 2021 17:58:01 +0900 Subject: [PATCH 12/12] add task metrics needed for determining work stealing condition --- .../runtime/common/metric/TaskMetric.java | 35 +++++++++++++++++++ .../master/scheduler/BatchScheduler.java | 7 ++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java index 531e715a7b..7d98140cc3 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/metric/TaskMetric.java @@ -49,6 +49,9 @@ public class TaskMetric implements StateMetric { private long shuffleReadTime = -1; private long shuffleWriteBytes = -1; private long shuffleWriteTime = -1; + private int currentIteratorIndex = -1; + private int totalIteratorNumber = -1; + private long taskPreparationTime = -1; private static final Logger LOG = LoggerFactory.getLogger(TaskMetric.class.getName()); @@ -252,6 +255,30 @@ private void setShuffleWriteTime(final long shuffleWriteTime) { this.shuffleWriteTime = shuffleWriteTime; } + public final int getCurrentIteratorIndex() { + return this.currentIteratorIndex; + } + + private void setCurrentIteratorIndex(final int currentIteratorIndex) { + this.currentIteratorIndex = currentIteratorIndex; + } + + public final int getTotalIteratorNumber() { + return this.totalIteratorNumber; + } + + private void setTotalIteratorNumber(final int totalIteratorNumber) { + this.totalIteratorNumber = totalIteratorNumber; + } + + public final long getTaskPreparationTime() { + return this.taskPreparationTime; + } + + private void setTaskPreparationTime(final long taskPreparationTime) { + this.taskPreparationTime = taskPreparationTime; + } + @Override public final String getId() { return id; @@ -317,6 +344,14 @@ public final boolean processMetricMessage(final String metricField, final byte[] case "shuffleWriteTime": setShuffleWriteTime(SerializationUtils.deserialize(metricValue)); break; + case "currentIteratorIndex": + setCurrentIteratorIndex(SerializationUtils.deserialize(metricValue)); + break; + case "totalIteratorNumber": + setTotalIteratorNumber(SerializationUtils.deserialize(metricValue)); + break; + case "taskPreparationTime": + setTaskPreparationTime(SerializationUtils.deserialize(metricValue)); default: LOG.warn("metricField {} is not supported.", metricField); return false; diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java index c9305494dc..3caa4118ce 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/BatchScheduler.java @@ -283,10 +283,11 @@ public void onWorkStealingCheck() { List scheduleGroupInId = scheduleGroup.stream().map(Stage::getId).collect(Collectors.toList()); isWorkStealingConditionSatisfied.setValue(checkForWorkStealingBaseConditions(scheduleGroupInId)); - if (isWorkStealingConditionSatisfied.booleanValue()) { - taskIdToProcessedBytes.clear(); - final List skewedTasks = detectSkew(scheduleGroupInId); + if (!isWorkStealingConditionSatisfied.booleanValue()) { + return; } + taskIdToProcessedBytes.clear(); + final List skewedTasks = detectSkew(scheduleGroupInId); // TODO #469 Split tasks using iterator interface.