Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][NEMO-468]Enable scheduler to detect skewed tasks periodically #310

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public class TaskMetric implements StateMetric<TaskState.State> {
private long shuffleReadTime = -1;
private long shuffleWriteBytes = -1;
private long shuffleWriteTime = -1;
private int currentIteratorIndex = -1;
private int totalIteratorNumber = -1;
private long taskPreparationTime = -1;

private static final Logger LOG = LoggerFactory.getLogger(TaskMetric.class.getName());

Expand Down Expand Up @@ -252,6 +255,30 @@ private void setShuffleWriteTime(final long shuffleWriteTime) {
this.shuffleWriteTime = shuffleWriteTime;
}

public final int getCurrentIteratorIndex() {
return this.currentIteratorIndex;
}

private void setCurrentIteratorIndex(final int currentIteratorIndex) {
this.currentIteratorIndex = currentIteratorIndex;
}

public final int getTotalIteratorNumber() {
return this.totalIteratorNumber;
}

private void setTotalIteratorNumber(final int totalIteratorNumber) {
this.totalIteratorNumber = totalIteratorNumber;
}

public final long getTaskPreparationTime() {
return this.taskPreparationTime;
}

private void setTaskPreparationTime(final long taskPreparationTime) {
this.taskPreparationTime = taskPreparationTime;
}

@Override
public final String getId() {
return id;
Expand Down Expand Up @@ -317,6 +344,14 @@ public final boolean processMetricMessage(final String metricField, final byte[]
case "shuffleWriteTime":
setShuffleWriteTime(SerializationUtils.deserialize(metricValue));
break;
case "currentIteratorIndex":
setCurrentIteratorIndex(SerializationUtils.deserialize(metricValue));
break;
case "totalIteratorNumber":
setTotalIteratorNumber(SerializationUtils.deserialize(metricValue));
break;
case "taskPreparationTime":
setTaskPreparationTime(SerializationUtils.deserialize(metricValue));
default:
LOG.warn("metricField {} is not supported.", metricField);
return false;
Expand Down
14 changes: 14 additions & 0 deletions runtime/common/src/main/proto/ControlMessage.proto
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ enum MessageType {
PipeInit = 13;
RequestPipeLoc = 14;
PipeLocInfo = 15;
ParentTaskDataCollected = 16;
CurrentlyProcessedBytesCollected = 17;
}

message Message {
Expand All @@ -107,6 +109,8 @@ message Message {
optional PipeInitMessage pipeInitMsg = 16;
optional RequestPipeLocationMessage requestPipeLocMsg = 17;
optional PipeLocationInfoMessage pipeLocInfoMsg = 18;
optional ParentTaskDataCollectMsg ParentTaskDataCollected = 19;
optional CurrentlyProcessedBytesCollectMsg currentlyProcessedBytesCollected = 20;
}

// Messages from Master to Executors
Expand Down Expand Up @@ -256,3 +260,13 @@ message PipeLocationInfoMessage {
required int64 requestId = 1; // To find the matching request msg
required string executorId = 2;
}

message ParentTaskDataCollectMsg {
required string taskId = 1;
required bytes partitionSizeMap = 2;
}

message CurrentlyProcessedBytesCollectMsg {
required string taskId = 1;
required int64 processedDataBytes = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public final class BlockOutputWriter implements OutputWriter {

private long writtenBytes;

private Optional<Map<Integer, Long>> partitionSizeMap;

/**
* Constructor.
*
Expand Down Expand Up @@ -109,7 +111,7 @@ public void close() {
final DataPersistenceProperty.Value persistence = (DataPersistenceProperty.Value) runtimeEdge
.getPropertyValue(DataPersistenceProperty.class).orElseThrow(IllegalStateException::new);

final Optional<Map<Integer, Long>> partitionSizeMap = blockToWrite.commit();
partitionSizeMap = blockToWrite.commit();
// Return the total size of the committed block.
if (partitionSizeMap.isPresent()) {
long blockSizeTotal = 0;
Expand All @@ -123,6 +125,16 @@ public void close() {
blockManagerWorker.writeBlock(blockToWrite, blockStoreValue, getExpectedRead(), persistence);
}

@Override
public Optional<Map<Integer, Long>> getPartitionSizeMap() {
if (partitionSizeMap.isPresent()) {
return partitionSizeMap;
} else {
return Optional.empty();
}
}

@Override
public Optional<Long> getWrittenBytes() {
if (writtenBytes == -1) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.nemo.common.punctuation.Watermark;

import java.util.Map;
import java.util.Optional;

/**
Expand All @@ -45,5 +46,10 @@ public interface OutputWriter {
*/
Optional<Long> getWrittenBytes();

/**
* @return the map of hashed key to partition size.
*/
Optional<Map<Integer, Long>> getPartitionSizeMap();

void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
Expand Down Expand Up @@ -113,6 +114,11 @@ public Optional<Long> getWrittenBytes() {
return Optional.empty();
}

@Override
public Optional<Map<Integer, Long>> getPartitionSizeMap() {
return Optional.empty();
}

@Override
public void close() {
if (!initialized) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.nemo.common.ir.OutputCollector;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.runtime.executor.MetricMessageSender;

import java.io.IOException;

Expand Down Expand Up @@ -49,6 +50,21 @@ abstract class DataFetcher implements AutoCloseable {
*/
abstract Object fetchDataElement() throws IOException;

/**
* Identical with fetchDataElement(), except it sends intermediate serializedReadBytes to MetricStore
* on every iterator advance.
* This method is for WorkStealing implementation in Nemo.
*
* @param taskId task id
* @param metricMessageSender metricMessageSender
*
* @return data element
* @throws IOException upon I/O error
* @throws java.util.NoSuchElementException if no more element is available
*/
abstract Object fetchDataElementWithTrace(String taskId,
MetricMessageSender metricMessageSender) throws IOException;

OutputCollector getOutputCollector() {
return outputCollector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.punctuation.Finishmark;
import org.apache.nemo.common.punctuation.Watermark;
import org.apache.nemo.runtime.executor.MetricMessageSender;
import org.apache.nemo.runtime.executor.data.DataUtil;
import org.apache.nemo.runtime.executor.datatransfer.*;
import org.slf4j.Logger;
Expand Down Expand Up @@ -100,6 +101,12 @@ Object fetchDataElement() throws IOException {
}
}

@Override
Object fetchDataElementWithTrace(final String taskId,
final MetricMessageSender metricMessageSender) throws IOException {
return fetchDataElement();
}

private void fetchDataLazily() {
final List<CompletableFuture<DataUtil.IteratorWithNumBytes>> futures = readersForParentTask.read();
numOfIterators = futures.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
*/
package org.apache.nemo.runtime.executor.task;

import org.apache.commons.lang3.SerializationUtils;
import org.apache.nemo.common.ir.OutputCollector;
import org.apache.nemo.common.ir.edge.executionproperty.BlockFetchFailureProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.punctuation.Finishmark;
import org.apache.nemo.runtime.executor.MetricMessageSender;
import org.apache.nemo.runtime.executor.data.DataUtil;
import org.apache.nemo.runtime.executor.datatransfer.InputReader;
import org.slf4j.Logger;
Expand Down Expand Up @@ -100,6 +102,49 @@ Object fetchDataElement() throws IOException {
return Finishmark.getInstance();
}

@Override
Object fetchDataElementWithTrace(final String taskId,
final MetricMessageSender metricMessageSender) throws IOException {
try {
if (firstFetch) {
fetchDataLazily();
advanceIterator();
firstFetch = false;
}

while (true) {
// This iterator has the element
if (this.currentIterator.hasNext()) {
return this.currentIterator.next();
}

// This iterator does not have the element
if (currentIteratorIndex < expectedNumOfIterators) {
// Next iterator has the element
countBytes(currentIterator);
// Send the cumulative serBytes to MetricStore
metricMessageSender.send("TaskMetric", taskId, "serializedReadBytes",
SerializationUtils.serialize(serBytes));
advanceIterator();
continue;
} else {
// We've consumed all the iterators
break;
}

}
} catch (final Throwable e) {
// Any failure is caught and thrown as an IOException, so that the task is retried.
// In particular, we catch unchecked exceptions like RuntimeException thrown by DataUtil.IteratorWithNumBytes
// when remote data fetching fails for whatever reason.
// Note that we rely on unchecked exceptions because the Iterator interface does not provide the standard
// "throw Exception" that the TaskExecutor thread can catch and handle.
throw new IOException(e);
}

return Finishmark.getInstance();
}

private void advanceIterator() throws IOException {
// Take from iteratorQueue
final Object iteratorOrThrowable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.punctuation.Finishmark;
import org.apache.nemo.common.punctuation.Watermark;
import org.apache.nemo.runtime.executor.MetricMessageSender;

import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
Expand Down Expand Up @@ -74,6 +75,11 @@ Object fetchDataElement() {
}
}

@Override
Object fetchDataElementWithTrace(final String taskId, final MetricMessageSender metricMessageSender) {
return fetchDataElement();
}

final long getBoundedSourceReadTime() {
return boundedSourceReadTime;
}
Expand Down
Loading