From 28adec8187f6be0898e7a9375c2293df47fb1848 Mon Sep 17 00:00:00 2001 From: Kangji Date: Sun, 15 Aug 2021 14:29:45 +0900 Subject: [PATCH 01/10] [add] unit test for intermediate combine --- .../beam/transform/CombineFnTest.java | 57 +++ ...ormTest.java => CombineTransformTest.java} | 367 +++++++++++++----- 2 files changed, 325 insertions(+), 99 deletions(-) rename compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/{GBKTransformTest.java => CombineTransformTest.java} (57%) diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java index acf7cc4d1b..af6721e245 100644 --- a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java +++ b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java @@ -34,6 +34,14 @@ public class CombineFnTest extends TestCase { public static final class CountFn extends Combine.CombineFn { public static final class Accum { int sum = 0; + + @Override + public boolean equals(Object o) { + if (Accum.class != o.getClass()) { + return false; + } + return (sum == ((Accum) o).sum); + } } @Override @@ -120,6 +128,55 @@ public void testPartialCombineFn() { } } + @Test + public void testIntermediateCombineFn() { + // Initialize intermediate combine function. + final IntermediateCombineFn intermediateCombineFn = + new IntermediateCombineFn<>(combineFn, accumCoder); + + // Create accumulator. + final CountFn.Accum accum1 = intermediateCombineFn.createAccumulator(); + final CountFn.Accum accum2 = intermediateCombineFn.createAccumulator(); + final CountFn.Accum accum3 = intermediateCombineFn.createAccumulator(); + + final CountFn.Accum expectedMergedAccum = intermediateCombineFn.createAccumulator(); + expectedMergedAccum.sum = 6; + + // Check whether accumulators are initialized correctly. + assertEquals(0, accum1.sum); + assertEquals(0, accum2.sum); + assertEquals(0, accum3.sum); + + // Change the parameter for the sake of unit testing. + accum1.sum = 1; + accum2.sum = 2; + accum3.sum = 3; + + // Add input. Intermediate combineFn's addInput method takes accumulators as input + // and merges them into a single accumulator. + final CountFn.Accum addedAccum = intermediateCombineFn.addInput(accum1, accum2); + + // Check whether inputs are added correctly. + assertEquals(3, addedAccum.sum); + + // Merge accumulators. + CountFn.Accum mergedAccum = intermediateCombineFn.mergeAccumulators(Arrays.asList(accum1, accum2, accum3)); + + // Check whether accumulators are merged correctly. + assertEquals(expectedMergedAccum, mergedAccum); + + // Extract output. + assertEquals(expectedMergedAccum, intermediateCombineFn.extractOutput(mergedAccum)); + + // Get accumulator coder. Check if the accumulator coder from intermediate combineFn is equal + // to the one from original combineFn. + try { + assertEquals(accumCoder, intermediateCombineFn.getAccumulatorCoder(CoderRegistry.createDefault(), INTEGER_CODER)); + } catch (CannotProvideCoderException e) { + throw new RuntimeException("Failed to provide an accumulator coder"); + } + } + @Test public void testFinalCombineFn() { // Initialize final combine function. diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformTest.java similarity index 57% rename from compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java rename to compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformTest.java index 3c08c50fb2..0f67164f3d 100644 --- a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java +++ b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformTest.java @@ -18,54 +18,57 @@ */ package org.apache.nemo.compiler.frontend.beam.transform; -import com.google.common.collect.Iterables; -import junit.framework.TestCase; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.sdk.coders.*; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineFnBase; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.*; import org.apache.beam.sdk.util.AppliedCombineFn; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.*; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.nemo.common.ir.vertex.transform.Transform; import org.apache.nemo.common.punctuation.Watermark; import org.apache.nemo.compiler.frontend.beam.NemoPipelineOptions; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import java.util.*; import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.*; import static org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES; import static org.mockito.Mockito.mock; -public class GBKTransformTest extends TestCase { - private static final Logger LOG = LoggerFactory.getLogger(GBKTransformTest.class.getName()); +public class CombineTransformTest { + private static final Logger LOG = LoggerFactory.getLogger(CombineTransformTest.class.getName()); private final static Coder STRING_CODER = StringUtf8Coder.of(); private final static Coder INTEGER_CODER = BigEndianIntegerCoder.of(); private void checkOutput(final KV expected, final KV result) { // check key - assertEquals(expected.getKey(), result.getKey()); + Assert.assertEquals(expected.getKey(), result.getKey()); // check value - assertEquals(expected.getValue(), result.getValue()); + Assert.assertEquals(expected.getValue(), result.getValue()); } private void checkOutput2(final KV> expected, final KV> result) { // check key - assertEquals(expected.getKey(), result.getKey()); + Assert.assertEquals(expected.getKey(), result.getKey()); // check value final List resultValue = new ArrayList<>(); final List expectedValue = new ArrayList<>(expected.getValue()); result.getValue().iterator().forEachRemaining(resultValue::add); Collections.sort(resultValue); Collections.sort(expectedValue); - assertEquals(expectedValue, resultValue); + Assert.assertEquals(expectedValue, resultValue); } @@ -123,7 +126,7 @@ public Coder getAccumulatorCoder(CoderRegistry registry, Coder outputTag = new TupleTag<>("main-output"); final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(10)) .every(Duration.standardSeconds(5)); @@ -142,7 +145,7 @@ public void test_combine() { final Watermark watermark3 = new Watermark(18000); final Watermark watermark4 = new Watermark(21000); - AppliedCombineFn applied_combine_fn = + AppliedCombineFn appliedCombineFn = AppliedCombineFn.withInputCoder( combine_fn, CoderRegistry.createDefault(), @@ -151,14 +154,14 @@ public void test_combine() { WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES) ); - final GBKTransform combine_transform = - new GBKTransform( + final CombineTransform combineTransform = + new CombineTransform( KvCoder.of(STRING_CODER, INTEGER_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)), outputTag, WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES), PipelineOptionsFactory.as(NemoPipelineOptions.class), - SystemReduceFn.combining(STRING_CODER, applied_combine_fn), + SystemReduceFn.combining(STRING_CODER, appliedCombineFn), DoFnSchemaInformation.create(), DisplayData.none(), false); @@ -168,69 +171,69 @@ public void test_combine() { // window3 : [5000, 15000) // window4 : [10000, 20000) List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window1 = sortedWindows.get(0); final IntervalWindow window2 = sortedWindows.get(1); sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts5)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window3 = sortedWindows.get(0); final IntervalWindow window4 = sortedWindows.get(1); // Prepare to test CombineStreamTransform final Transform.Context context = mock(Transform.Context.class); final TestOutputCollector> oc = new TestOutputCollector(); - combine_transform.prepare(context, oc); + combineTransform.prepare(context, oc); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("c", 1), ts2, slidingWindows.assignWindows(ts2), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts3, slidingWindows.assignWindows(ts3), PaneInfo.NO_FIRING)); // Emit outputs of window1 - combine_transform.onWatermark(watermark1); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark1); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); - assertEquals(2, oc.outputs.size()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(2, oc.outputs.size()); checkOutput(KV.of("a", 1), oc.outputs.get(0).getValue()); checkOutput(KV.of("c", 1), oc.outputs.get(1).getValue()); oc.outputs.clear(); oc.watermarks.clear(); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts4, slidingWindows.assignWindows(ts4), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("c", 1), ts5, slidingWindows.assignWindows(ts5), PaneInfo.NO_FIRING)); // Emit outputs of window2 - combine_transform.onWatermark(watermark2); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark2); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window2), oc.outputs.get(0).getWindows()); - assertEquals(3, oc.outputs.size()); + Assert.assertEquals(Collections.singletonList(window2), oc.outputs.get(0).getWindows()); + Assert.assertEquals(3, oc.outputs.size()); checkOutput(KV.of("a", 2), oc.outputs.get(0).getValue()); checkOutput(KV.of("b", 1), oc.outputs.get(1).getValue()); checkOutput(KV.of("c", 1), oc.outputs.get(2).getValue()); oc.outputs.clear(); oc.watermarks.clear(); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts6, slidingWindows.assignWindows(ts6), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts7, slidingWindows.assignWindows(ts7), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts8, slidingWindows.assignWindows(ts8), PaneInfo.NO_FIRING)); // Emit outputs of window3 - combine_transform.onWatermark(watermark3); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark3); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window3), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window3), oc.outputs.get(0).getWindows()); checkOutput(KV.of("a", 1), oc.outputs.get(0).getValue()); checkOutput(KV.of("b", 2), oc.outputs.get(1).getValue()); checkOutput(KV.of("c", 1), oc.outputs.get(2).getValue()); @@ -238,15 +241,15 @@ public void test_combine() { oc.watermarks.clear(); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("c", 3), ts9, slidingWindows.assignWindows(ts9), PaneInfo.NO_FIRING)); // Emit outputs of window3 - combine_transform.onWatermark(watermark4); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark4); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window4), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window4), oc.outputs.get(0).getWindows()); checkOutput(KV.of("a", 1), oc.outputs.get(0).getValue()); checkOutput(KV.of("b", 2), oc.outputs.get(1).getValue()); checkOutput(KV.of("c", 4), oc.outputs.get(2).getValue()); @@ -255,10 +258,176 @@ public void test_combine() { oc.watermarks.clear(); } + private void clearOutputCollectors(final List outputCollectors) { + for (TestOutputCollector oc : outputCollectors) { + oc.outputs.clear(); + oc.watermarks.clear(); + } + } + + @SuppressWarnings("unchecked") + private void processIntermediateCombineElementsPerWindow(final List elements, + final CombineTransform partialCombineTransform, + final CombineTransform intermediateCombineTransform, + final CombineTransform finalCombineTransform, + final Watermark watermark, + final TestOutputCollector ocPartial, + final TestOutputCollector ocIntermediate) { + for (WindowedValue element : elements) { + partialCombineTransform.onData(element); + } + partialCombineTransform.onWatermark(watermark); + for (WindowedValue output: ocPartial.outputs) { + intermediateCombineTransform.onData(output); + } + intermediateCombineTransform.onWatermark(watermark); + for (WindowedValue output: ocIntermediate.outputs) { + finalCombineTransform.onData(output); + } + finalCombineTransform.onWatermark(watermark); + } + + /** + * Test intermediate combine. + */ + @Test + @SuppressWarnings("unchecked") + public void testIntermediateCombine() { + final TupleTag outputTag = new TupleTag<>("main-output"); + final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(10)) + .every(Duration.standardSeconds(5)); + final WindowingStrategy windowingStrategy = WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES); + + final Instant ts1 = new Instant(1000); + final Instant ts2 = new Instant(2000); + final Instant ts3 = new Instant(6000); + final Instant ts4 = new Instant(8000); + final Instant ts5 = new Instant(11000); + final Instant ts6 = new Instant(14000); + final Instant ts7 = new Instant(16000); + final Instant ts8 = new Instant(17000); + final Instant ts9 = new Instant(19000); + final Watermark watermark1 = new Watermark(7000); + final Watermark watermark2 = new Watermark(12000); + final Watermark watermark3 = new Watermark(18000); + final Watermark watermark4 = new Watermark(21000); + + final KvCoder inputCoder = KvCoder.of(STRING_CODER, INTEGER_CODER); + final Coder accumulatorCoder; + try { + accumulatorCoder = combine_fn.getAccumulatorCoder(CoderRegistry.createDefault(), INTEGER_CODER); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + + final CombineFnBase.GlobalCombineFn partialCombineFn = new PartialCombineFn(combine_fn, accumulatorCoder); + final CombineFnBase.GlobalCombineFn intermediateCombineFn = new IntermediateCombineFn(combine_fn, accumulatorCoder); + final CombineFnBase.GlobalCombineFn finalCombineFn = new FinalCombineFn(combine_fn, accumulatorCoder); + + final SystemReduceFn partialSystemReduceFn = SystemReduceFn.combining(STRING_CODER, + AppliedCombineFn.withInputCoder(partialCombineFn, CoderRegistry.createDefault(), + inputCoder, null, windowingStrategy)); + final SystemReduceFn intermediateSystemReduceFn = SystemReduceFn.combining(STRING_CODER, + AppliedCombineFn.withInputCoder(intermediateCombineFn, CoderRegistry.createDefault(), + KvCoder.of(STRING_CODER, accumulatorCoder), null, windowingStrategy)); + final SystemReduceFn finalSystemReduceFn = SystemReduceFn.combining(STRING_CODER, + AppliedCombineFn.withInputCoder(finalCombineFn, CoderRegistry.createDefault(), + KvCoder.of(STRING_CODER, accumulatorCoder), null, windowingStrategy)); + + final CombineTransformFactory combineTransformFactory = new CombineTransformFactory( + inputCoder, new TupleTag<>(), KvCoder.of(STRING_CODER, accumulatorCoder), + Collections.singletonMap(outputTag, inputCoder), outputTag, windowingStrategy, + PipelineOptionsFactory.as(NemoPipelineOptions.class), partialSystemReduceFn, intermediateSystemReduceFn, + finalSystemReduceFn, DoFnSchemaInformation.create(), DisplayData.none()); + + final CombineTransform partialCombineTransform = combineTransformFactory.getPartialCombineTransform(); + final CombineTransform intermediateCombineTransform = combineTransformFactory.getIntermediateCombineTransform(); + final CombineTransform finalCombineTransform = combineTransformFactory.getFinalCombineTransform(); + + // window1 : [-5000, 5000) in millisecond + // window2 : [0, 10000) + // window3 : [5000, 15000) + // window4 : [10000, 20000) + List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); + sortedWindows.sort(IntervalWindow::compareTo); + final IntervalWindow window1 = sortedWindows.get(0); + final IntervalWindow window2 = sortedWindows.get(1); + sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts5)); + sortedWindows.sort(IntervalWindow::compareTo); + final IntervalWindow window3 = sortedWindows.get(0); + final IntervalWindow window4 = sortedWindows.get(1); + + // Prepare to test CombineStreamTransform + final Transform.Context context = mock(Transform.Context.class); + final TestOutputCollector> ocPartial = new TestOutputCollector<>(); + final TestOutputCollector> ocIntermediate = new TestOutputCollector<>(); + final TestOutputCollector> ocFinal = new TestOutputCollector<>(); + partialCombineTransform.prepare(context, ocPartial); + intermediateCombineTransform.prepare(context, ocIntermediate); + finalCombineTransform.prepare(context, ocFinal); + + processIntermediateCombineElementsPerWindow( + Arrays.asList( + WindowedValue.of(KV.of("a", 1), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("c", 1), ts2, slidingWindows.assignWindows(ts2), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("b", 1), ts3, slidingWindows.assignWindows(ts3), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark1, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window1), ocFinal.outputs.get(0).getWindows()); + Assert.assertEquals(2, ocFinal.outputs.size()); + checkOutput(KV.of("a", 1), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("c", 1), ocFinal.outputs.get(1).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + + processIntermediateCombineElementsPerWindow( + Arrays.asList( + WindowedValue.of(KV.of("a", 1), ts4, slidingWindows.assignWindows(ts4), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("c", 1), ts5, slidingWindows.assignWindows(ts5), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark2, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window2), ocFinal.outputs.get(0).getWindows()); + Assert.assertEquals(3, ocFinal.outputs.size()); + checkOutput(KV.of("a", 2), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("b", 1), ocFinal.outputs.get(1).getValue()); + checkOutput(KV.of("c", 1), ocFinal.outputs.get(2).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + + processIntermediateCombineElementsPerWindow( + Arrays.asList( + WindowedValue.of(KV.of("b", 1), ts6, slidingWindows.assignWindows(ts6), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("b", 1), ts7, slidingWindows.assignWindows(ts7), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("a", 1), ts8, slidingWindows.assignWindows(ts8), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark3, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window3), ocFinal.outputs.get(0).getWindows()); + checkOutput(KV.of("a", 1), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("b", 2), ocFinal.outputs.get(1).getValue()); + checkOutput(KV.of("c", 1), ocFinal.outputs.get(2).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + + processIntermediateCombineElementsPerWindow( + Arrays.asList(WindowedValue.of(KV.of("c", 3), ts9, slidingWindows.assignWindows(ts9), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark4, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window4), ocFinal.outputs.get(0).getWindows()); + checkOutput(KV.of("a", 1), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("b", 2), ocFinal.outputs.get(1).getValue()); + checkOutput(KV.of("c", 4), ocFinal.outputs.get(2).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + } + // Test with late data @Test @SuppressWarnings("unchecked") - public void test_combine_lateData() { + public void testCombineLateData() { final TupleTag outputTag = new TupleTag<>("main-output"); final Duration lateness = Duration.standardSeconds(2); final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(10)) @@ -271,7 +440,7 @@ public void test_combine_lateData() { final Watermark watermark1 = new Watermark(6500); final Watermark watermark2 = new Watermark(8000); - AppliedCombineFn applied_combine_fn = + AppliedCombineFn appliedCombineFn = AppliedCombineFn.withInputCoder( combine_fn, CoderRegistry.createDefault(), @@ -280,14 +449,14 @@ public void test_combine_lateData() { WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES).withAllowedLateness(lateness) ); - final GBKTransform combine_transform = - new GBKTransform( + final CombineTransform combineTransform = + new CombineTransform( KvCoder.of(STRING_CODER, INTEGER_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)), outputTag, WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES).withAllowedLateness(lateness), PipelineOptionsFactory.as(NemoPipelineOptions.class), - SystemReduceFn.combining(STRING_CODER, applied_combine_fn), + SystemReduceFn.combining(STRING_CODER, appliedCombineFn), DoFnSchemaInformation.create(), DisplayData.none(), false); @@ -297,52 +466,52 @@ public void test_combine_lateData() { // window3 : [5000, 15000) // window4 : [10000, 20000) List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window1 = sortedWindows.get(0); final IntervalWindow window2 = sortedWindows.get(1); sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts4)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window3 = sortedWindows.get(0); final IntervalWindow window4 = sortedWindows.get(1); // Prepare to test final Transform.Context context = mock(Transform.Context.class); final TestOutputCollector> oc = new TestOutputCollector(); - combine_transform.prepare(context, oc); + combineTransform.prepare(context, oc); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts2, slidingWindows.assignWindows(ts2), PaneInfo.NO_FIRING)); // On-time firing of window1. Skipping checking outputs since test1 checks output from non-late data - combine_transform.onWatermark(watermark1); + combineTransform.onWatermark(watermark1); oc.outputs.clear(); // Late data in window 1. Should be accumulated since EOW + allowed lateness > current Watermark - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 5), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING)); // Check outputs - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); - assertEquals(1,oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); checkOutput(KV.of("a", 6), oc.outputs.get(0).getValue()); oc.outputs.clear(); oc.watermarks.clear(); // Late data in window 1. Should NOT be accumulated to outputs of window1 since EOW + allowed lateness > current Watermark - combine_transform.onWatermark(watermark2); - combine_transform.onData(WindowedValue.of( + combineTransform.onWatermark(watermark2); + combineTransform.onData(WindowedValue.of( KV.of("a", 10), ts3, slidingWindows.assignWindows(ts3), PaneInfo.NO_FIRING)); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); - assertEquals(1, oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); checkOutput(KV.of("a", 10), oc.outputs.get(0).getValue()); oc.outputs.clear(); oc.watermarks.clear(); @@ -369,14 +538,14 @@ public void test_combine_lateData() { @Test @SuppressWarnings("unchecked") - public void test_gbk() { + public void testGBK() { final TupleTag outputTag = new TupleTag<>("main-output"); final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(2)) .every(Duration.standardSeconds(1)); - final GBKTransform> doFnTransform = - new GBKTransform( + final CombineTransform> doFnTransform = + new CombineTransform( KvCoder.of(STRING_CODER, STRING_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))), outputTag, @@ -403,7 +572,7 @@ public void test_gbk() { List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); // [0---1000) final IntervalWindow window0 = sortedWindows.get(0); @@ -412,7 +581,7 @@ public void test_gbk() { sortedWindows.clear(); sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts4)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); // [1000--3000) final IntervalWindow window2 = sortedWindows.get(1); @@ -436,21 +605,21 @@ public void test_gbk() { // output // 1: ["hello", "world"] // 2: ["hello"] - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // windowed result for key 1 - assertEquals(Arrays.asList(window0), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window0), oc.outputs.get(0).getWindows()); checkOutput2(KV.of("1", Arrays.asList("hello", "world")), oc.outputs.get(0).getValue()); // windowed result for key 2 - assertEquals(Arrays.asList(window0), oc.outputs.get(1).getWindows()); - checkOutput2(KV.of("2", Arrays.asList("hello")), oc.outputs.get(1).getValue()); + Assert.assertEquals(Collections.singletonList(window0), oc.outputs.get(1).getWindows()); + checkOutput2(KV.of("2", Collections.singletonList("hello")), oc.outputs.get(1).getValue()); - assertEquals(2, oc.outputs.size()); - assertEquals(2, oc.watermarks.size()); + Assert.assertEquals(2, oc.outputs.size()); + Assert.assertEquals(2, oc.watermarks.size()); // check output watermark - assertEquals(1000, + Assert.assertEquals(1000, oc.watermarks.get(0).getTimestamp()); oc.outputs.clear(); @@ -462,8 +631,8 @@ public void test_gbk() { doFnTransform.onWatermark(watermark2); - assertEquals(0, oc.outputs.size()); // do not emit anything - assertEquals(0, oc.watermarks.size()); + Assert.assertEquals(0, oc.outputs.size()); // do not emit anything + Assert.assertEquals(0, oc.watermarks.size()); doFnTransform.onData(WindowedValue.of( KV.of("3", "a"), ts5, slidingWindows.assignWindows(ts5), PaneInfo.NO_FIRING)); @@ -481,23 +650,23 @@ public void test_gbk() { // 1: ["hello", "world", "a"] // 2: ["hello"] // 3: ["a", "a", "b"] - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // windowed result for key 1 - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); checkOutput2(KV.of("1", Arrays.asList("hello", "world", "a")), oc.outputs.get(0).getValue()); // windowed result for key 2 - assertEquals(Arrays.asList(window1), oc.outputs.get(1).getWindows()); - checkOutput2(KV.of("2", Arrays.asList("hello")), oc.outputs.get(1).getValue()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(1).getWindows()); + checkOutput2(KV.of("2", Collections.singletonList("hello")), oc.outputs.get(1).getValue()); // windowed result for key 3 - assertEquals(Arrays.asList(window1), oc.outputs.get(2).getWindows()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(2).getWindows()); checkOutput2(KV.of("3", Arrays.asList("a", "a", "b")), oc.outputs.get(2).getValue()); // check output watermark - assertEquals(2000, + Assert.assertEquals(2000, oc.watermarks.get(0).getTimestamp()); oc.outputs.clear(); @@ -516,20 +685,20 @@ public void test_gbk() { // output // 1: ["a", "a"] // 3: ["a", "a", "b", "b"] - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); - assertEquals(2, oc.outputs.size()); + Assert.assertEquals(2, oc.outputs.size()); // windowed result for key 1 - assertEquals(Arrays.asList(window2), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window2), oc.outputs.get(0).getWindows()); checkOutput2(KV.of("1", Arrays.asList("a", "a")), oc.outputs.get(0).getValue()); // windowed result for key 3 - assertEquals(Arrays.asList(window2), oc.outputs.get(1).getWindows()); + Assert.assertEquals(Collections.singletonList(window2), oc.outputs.get(1).getWindows()); checkOutput2(KV.of("3", Arrays.asList("a", "a", "b", "b")), oc.outputs.get(1).getValue()); // check output watermark - assertEquals(3000, + Assert.assertEquals(3000, oc.watermarks.get(0).getTimestamp()); doFnTransform.close(); @@ -539,7 +708,7 @@ public void test_gbk() { * Test complex triggers that emit early and late firing. */ @Test - public void test_gbk_eventTimeTrigger() { + public void testGBKEventTimeTrigger() { final Duration lateness = Duration.standardSeconds(1); final AfterWatermark.AfterWatermarkEarlyAndLate trigger = AfterWatermark.pastEndOfWindow() // early firing @@ -561,8 +730,8 @@ public void test_gbk_eventTimeTrigger() { final TupleTag outputTag = new TupleTag<>("main-output"); - final GBKTransform> doFnTransform = - new GBKTransform( + final CombineTransform> doFnTransform = + new CombineTransform( KvCoder.of(STRING_CODER, STRING_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))), outputTag, @@ -591,8 +760,8 @@ public void test_gbk_eventTimeTrigger() { // early firing is not related to the watermark progress doFnTransform.onWatermark(new Watermark(2)); - assertEquals(1, oc.outputs.size()); - assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); oc.outputs.clear(); doFnTransform.onData(WindowedValue.of( @@ -607,8 +776,8 @@ public void test_gbk_eventTimeTrigger() { // GBKTransform emits data when receiving watermark // TODO #250: element-wise processing doFnTransform.onWatermark(new Watermark(5)); - assertEquals(1, oc.outputs.size()); - assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); // ACCUMULATION MODE checkOutput2(KV.of("1", Arrays.asList("hello", "world")), oc.outputs.get(0).getValue()); oc.outputs.clear(); @@ -617,8 +786,8 @@ public void test_gbk_eventTimeTrigger() { doFnTransform.onData(WindowedValue.of( KV.of("1", "!!"), new Instant(3), window.assignWindow(new Instant(3)), PaneInfo.NO_FIRING)); doFnTransform.onWatermark(new Watermark(5001)); - assertEquals(1, oc.outputs.size()); - assertEquals(ON_TIME, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(ON_TIME, oc.outputs.get(0).getPane().getTiming()); // ACCUMULATION MODE checkOutput2(KV.of("1", Arrays.asList("hello", "world", "!!")), oc.outputs.get(0).getValue()); oc.outputs.clear(); @@ -634,8 +803,8 @@ public void test_gbk_eventTimeTrigger() { KV.of("1", "bye!"), new Instant(1000), window.assignWindow(new Instant(1000)), PaneInfo.NO_FIRING)); doFnTransform.onWatermark(new Watermark(6000)); - assertEquals(1, oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); // The data should be accumulated to the previous window because it allows 1 second lateness checkOutput2(KV.of("1", Arrays.asList("hello", "world", "!!", "bye!")), oc.outputs.get(0).getValue()); oc.outputs.clear(); @@ -651,8 +820,8 @@ public void test_gbk_eventTimeTrigger() { KV.of("1", "hello again!"), new Instant(4800), window.assignWindow(new Instant(4800)), PaneInfo.NO_FIRING)); doFnTransform.onWatermark(new Watermark(6300)); - assertEquals(1, oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); checkOutput2(KV.of("1", Arrays.asList("hello again!")), oc.outputs.get(0).getValue()); oc.outputs.clear(); doFnTransform.close(); From fb1ca4f7c2cb157e7338162d2f6b68da9e170da3 Mon Sep 17 00:00:00 2001 From: Kangji Date: Sun, 15 Aug 2021 14:30:10 +0900 Subject: [PATCH 02/10] [add] implemented intermediate combine --- .../beam/PipelineTranslationContext.java | 2 +- .../frontend/beam/PipelineTranslator.java | 52 ++++--- ...BKTransform.java => CombineTransform.java} | 45 ++++-- .../transform/CombineTransformFactory.java | 136 ++++++++++++++++++ .../beam/transform/IntermediateCombineFn.java | 74 ++++++++++ 5 files changed, 274 insertions(+), 35 deletions(-) rename compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/{GBKTransform.java => CombineTransform.java} (86%) create mode 100644 compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformFactory.java create mode 100644 compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/IntermediateCombineFn.java diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java index 28aafb2c90..3ba6f0e7fe 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java @@ -277,7 +277,7 @@ private CommunicationPatternProperty.Value getCommPattern(final IRVertex src, fi } // If GBKTransform represents a partial CombinePerKey transformation, we do NOT need to shuffle its input, // since its output will be shuffled before going through a final CombinePerKey transformation. - if ((dstTransform instanceof GBKTransform && !((GBKTransform) dstTransform).getIsPartialCombining()) + if ((dstTransform instanceof CombineTransform && !((CombineTransform) dstTransform).getIsPartialCombining()) || dstTransform instanceof GroupByKeyTransform) { return CommunicationPatternProperty.Value.SHUFFLE; } diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java index cd9d7ad223..3794c3687f 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java @@ -358,10 +358,12 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato } final CombineFnBase.GlobalCombineFn combineFn = perKey.getFn(); + final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline()); + final PCollection mainInput = (PCollection) Iterables.getOnlyElement( - TransformInputs.nonAdditionalInputs(beamNode.toAppliedPTransform(ctx.getPipeline()))); + TransformInputs.nonAdditionalInputs(pTransform)); final PCollection inputs = (PCollection) Iterables.getOnlyElement( - TransformInputs.nonAdditionalInputs(beamNode.toAppliedPTransform(ctx.getPipeline()))); + TransformInputs.nonAdditionalInputs(pTransform)); final KvCoder inputCoder = (KvCoder) inputs.getCoder(); final Coder accumulatorCoder; @@ -386,48 +388,52 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato finalCombine = new OperatorVertex(new CombineFnFinalTransform<>(combineFn)); } else { // Stream data processing, using GBKTransform - final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline()); final CombineFnBase.GlobalCombineFn partialCombineFn = new PartialCombineFn( (Combine.CombineFn) combineFn, accumulatorCoder); + final CombineFnBase.GlobalCombineFn intermediateCombineFn = new IntermediateCombineFn( + (Combine.CombineFn) combineFn, accumulatorCoder); final CombineFnBase.GlobalCombineFn finalCombineFn = new FinalCombineFn( (Combine.CombineFn) combineFn, accumulatorCoder); + final SystemReduceFn partialSystemReduceFn = SystemReduceFn.combining( inputCoder.getKeyCoder(), AppliedCombineFn.withInputCoder(partialCombineFn, - ctx.getPipeline().getCoderRegistry(), inputCoder, - null, - mainInput.getWindowingStrategy())); + ctx.getPipeline().getCoderRegistry(), + inputCoder, + null, mainInput.getWindowingStrategy())); + final SystemReduceFn intermediateSystemReduceFn = + SystemReduceFn.combining( + inputCoder.getKeyCoder(), + AppliedCombineFn.withInputCoder(intermediateCombineFn, + ctx.getPipeline().getCoderRegistry(), + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), + null, mainInput.getWindowingStrategy())); final SystemReduceFn finalSystemReduceFn = SystemReduceFn.combining( inputCoder.getKeyCoder(), AppliedCombineFn.withInputCoder(finalCombineFn, ctx.getPipeline().getCoderRegistry(), - KvCoder.of(inputCoder.getKeyCoder(), - accumulatorCoder), + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), null, mainInput.getWindowingStrategy())); final TupleTag partialMainOutputTag = new TupleTag<>(); - final GBKTransform partialCombineStreamTransform = - new GBKTransform(inputCoder, - Collections.singletonMap(partialMainOutputTag, KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)), - partialMainOutputTag, - mainInput.getWindowingStrategy(), - ctx.getPipelineOptions(), - partialSystemReduceFn, - DoFnSchemaInformation.create(), - DisplayData.from(beamNode.getTransform()), - true); - final GBKTransform finalCombineStreamTransform = - new GBKTransform(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), + final CombineTransformFactory combineTransformFactory = + new CombineTransformFactory(inputCoder, + partialMainOutputTag, + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), getOutputCoders(pTransform), Iterables.getOnlyElement(beamNode.getOutputs().keySet()), mainInput.getWindowingStrategy(), ctx.getPipelineOptions(), + partialSystemReduceFn, + intermediateSystemReduceFn, finalSystemReduceFn, DoFnSchemaInformation.create(), - DisplayData.from(beamNode.getTransform()), - false); + DisplayData.from(beamNode.getTransform())); + + final CombineTransform partialCombineStreamTransform = combineTransformFactory.getPartialCombineTransform(); + final CombineTransform finalCombineStreamTransform = combineTransformFactory.getFinalCombineTransform(); partialCombine = new OperatorVertex(partialCombineStreamTransform); finalCombine = new OperatorVertex(finalCombineStreamTransform); @@ -564,7 +570,7 @@ private static Transform createGBKTransform( return new GroupByKeyTransform(); } else { // GroupByKey Transform when using a non-global windowing strategy. - return new GBKTransform<>( + return new CombineTransform<>( (KvCoder) mainInput.getCoder(), getOutputCoders(pTransform), mainOutputTag, diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransform.java similarity index 86% rename from compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java rename to compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransform.java index 4aa366f456..6ab8037903 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransform.java @@ -45,9 +45,9 @@ * @param input type * @param output type */ -public final class GBKTransform +public final class CombineTransform extends AbstractDoFnTransform, KeyedWorkItem, KV> { - private static final Logger LOG = LoggerFactory.getLogger(GBKTransform.class.getName()); + private static final Logger LOG = LoggerFactory.getLogger(CombineTransform.class.getName()); private final SystemReduceFn reduceFn; private transient InMemoryTimerInternalsFactory inMemoryTimerInternalsFactory; private transient InMemoryStateInternalsFactory inMemoryStateInternalsFactory; @@ -57,16 +57,31 @@ public final class GBKTransform private boolean dataReceived = false; private transient OutputCollector originOc; private final boolean isPartialCombining; + private final CombineTransform intermediateCombine; - public GBKTransform(final Coder> inputCoder, - final Map, Coder> outputCoders, - final TupleTag> mainOutputTag, - final WindowingStrategy windowingStrategy, - final PipelineOptions options, - final SystemReduceFn reduceFn, - final DoFnSchemaInformation doFnSchemaInformation, - final DisplayData displayData, - final boolean isPartialCombining) { + public CombineTransform(final Coder> inputCoder, + final Map, Coder> outputCoders, + final TupleTag> mainOutputTag, + final WindowingStrategy windowingStrategy, + final PipelineOptions options, + final SystemReduceFn reduceFn, + final DoFnSchemaInformation doFnSchemaInformation, + final DisplayData displayData, + final boolean isPartialCombining) { + this(inputCoder, outputCoders, mainOutputTag, windowingStrategy, options, reduceFn, + doFnSchemaInformation, displayData, isPartialCombining, null); + } + + public CombineTransform(final Coder> inputCoder, + final Map, Coder> outputCoders, + final TupleTag> mainOutputTag, + final WindowingStrategy windowingStrategy, + final PipelineOptions options, + final SystemReduceFn reduceFn, + final DoFnSchemaInformation doFnSchemaInformation, + final DisplayData displayData, + final boolean isPartialCombining, + final CombineTransform intermediateCombine) { super(null, inputCoder, outputCoders, @@ -80,6 +95,7 @@ public GBKTransform(final Coder> inputCoder, Collections.emptyMap()); /* does not have side inputs */ this.reduceFn = reduceFn; this.isPartialCombining = isPartialCombining; + this.intermediateCombine = intermediateCombine; } /** @@ -272,6 +288,13 @@ public boolean getIsPartialCombining() { return isPartialCombining; } + /** + * Get the intermediate combine transform if possible. + * @return the intermediate transform if possible. + */ + public Optional getIntermediateCombine() { + return Optional.ofNullable(intermediateCombine); + } /** Wrapper class for {@link OutputCollector}. */ public class GBKOutputCollector implements OutputCollector>> { diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformFactory.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformFactory.java new file mode 100644 index 0000000000..86916cbbcf --- /dev/null +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformFactory.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.compiler.frontend.beam.transform; + +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.Map; + +/** + * Factory for the combine transform that combines the results during the group by key. + * @param the type of the key. + * @param the type of the input values. + * @param the type of the accumulators. + * @param the type of the output. + */ +public class CombineTransformFactory { + private static final Logger LOG = LoggerFactory.getLogger(CombineTransformFactory.class.getName()); + + private final SystemReduceFn combineFn; + private final SystemReduceFn intermediateCombineFn; + private final SystemReduceFn finalReduceFn; + + private final Coder> inputCoder; + private final TupleTag> partialMainOutputTag; + private final Coder> accumulatorCoder; + private final Map, Coder> outputCoders; + + private final TupleTag> mainOutputTag; + private final WindowingStrategy windowingStrategy; + private final PipelineOptions options; + + private final DoFnSchemaInformation doFnSchemaInformation; + private final DisplayData displayData; + + public CombineTransformFactory(final Coder> inputCoder, + final TupleTag> partialMainOutputTag, + final Coder> accumulatorCoder, + final Map, Coder> outputCoders, + final TupleTag> mainOutputTag, + final WindowingStrategy windowingStrategy, + final PipelineOptions options, + final SystemReduceFn combineFn, + final SystemReduceFn intermediateCombineFn, + final SystemReduceFn finalReduceFn, + final DoFnSchemaInformation doFnSchemaInformation, + final DisplayData displayData) { + this.combineFn = combineFn; + this.intermediateCombineFn = intermediateCombineFn; + this.finalReduceFn = finalReduceFn; + + this.inputCoder = inputCoder; + this.partialMainOutputTag = partialMainOutputTag; + this.accumulatorCoder = accumulatorCoder; + this.outputCoders = outputCoders; + + this.mainOutputTag = mainOutputTag; + this.windowingStrategy = windowingStrategy; + this.options = options; + + this.doFnSchemaInformation = doFnSchemaInformation; + this.displayData = displayData; + } + + + /** + * Get the partial combine transform of the combine transform. + * @return the partial combine transform for the combine transform. + */ + public CombineTransform getPartialCombineTransform() { + return new CombineTransform<>(inputCoder, + Collections.singletonMap(partialMainOutputTag, accumulatorCoder), + partialMainOutputTag, + windowingStrategy, + options, + combineFn, + doFnSchemaInformation, + displayData, true); + } + + /** + * Get the intermediate combine transform of the combine transform. + * @return the intermediate combine transform for the combine transform. + */ + public CombineTransform getIntermediateCombineTransform() { + return new CombineTransform<>(accumulatorCoder, + Collections.singletonMap(partialMainOutputTag, accumulatorCoder), + partialMainOutputTag, + windowingStrategy, + options, + intermediateCombineFn, + doFnSchemaInformation, + displayData, false); + } + + /** + * Get the final combine transform of the combine transform. + * @return the final combine transform for the combine transform. + */ + public CombineTransform getFinalCombineTransform() { + return new CombineTransform<>(accumulatorCoder, + outputCoders, + mainOutputTag, + windowingStrategy, + options, + finalReduceFn, + doFnSchemaInformation, + displayData, false, this.getIntermediateCombineTransform()); + } +} diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/IntermediateCombineFn.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/IntermediateCombineFn.java new file mode 100644 index 0000000000..73a272b0a0 --- /dev/null +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/IntermediateCombineFn.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.compiler.frontend.beam.transform; + +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.transforms.Combine; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; + +/** + * Wrapper class for {@link Combine.CombineFn}. + * When adding input, it merges its accumulator and input accumulator into a single accumulator. + * After then, it returns the accumulator for it to be merged later on by the {@link FinalCombineFn}. + * @param accumulator type. + */ +public final class IntermediateCombineFn extends Combine.CombineFn { + private static final Logger LOG = LoggerFactory.getLogger(IntermediateCombineFn.class.getName()); + private final Combine.CombineFn originFn; + private final Coder accumCoder; + + public IntermediateCombineFn(final Combine.CombineFn originFn, + final Coder accumCoder) { + this.originFn = originFn; + this.accumCoder = accumCoder; + } + + @Override + public Coder getAccumulatorCoder(final CoderRegistry registry, final Coder inputCoder) + throws CannotProvideCoderException { + return accumCoder; + } + + @Override + public AccumT createAccumulator() { + return originFn.createAccumulator(); + } + + @Override + public AccumT addInput(final AccumT mutableAccumulator, final AccumT input) { + final AccumT result = originFn.mergeAccumulators(Arrays.asList(mutableAccumulator, input)); + return result; + } + + @Override + public AccumT mergeAccumulators(final Iterable accumulators) { + return originFn.mergeAccumulators(accumulators); + } + + @Override + public AccumT extractOutput(final AccumT accumulator) { + return accumulator; + } +} From c97de0557fb1fe0eca291770418d15fed78ee9f9 Mon Sep 17 00:00:00 2001 From: Kangji Date: Sun, 15 Aug 2021 15:55:03 +0900 Subject: [PATCH 03/10] [add] new type of data comm channel --- .../edge/executionproperty/CommunicationPatternProperty.java | 3 ++- .../nemo/common/ir/executionproperty/ExecutionPropertyMap.java | 1 + .../nemo/runtime/executor/datatransfer/PipeInputReader.java | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java index 23909b43bd..accb8ac054 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java +++ b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java @@ -52,6 +52,7 @@ public static CommunicationPatternProperty of(final Value value) { public enum Value { ONE_TO_ONE, BROADCAST, - SHUFFLE + SHUFFLE, + PARTIAL_SHUFFLE } } diff --git a/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java b/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java index d0ef23549d..19f52432a1 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java +++ b/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java @@ -73,6 +73,7 @@ public static ExecutionPropertyMap of( map.put(EncoderProperty.of(EncoderFactory.DUMMY_ENCODER_FACTORY)); map.put(DecoderProperty.of(DecoderFactory.DUMMY_DECODER_FACTORY)); switch (commPattern) { + case PARTIAL_SHUFFLE: case SHUFFLE: map.put(DataFlowProperty.of(DataFlowProperty.Value.PULL)); map.put(PartitionerProperty.of(PartitionerProperty.Type.HASH)); diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java index cab2ed2f43..194fc258c3 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java @@ -71,7 +71,8 @@ public List> read() { if (comValue.equals(CommunicationPatternProperty.Value.ONE_TO_ONE)) { return Collections.singletonList(pipeManagerWorker.read(dstTaskIndex, runtimeEdge, dstTaskIndex)); } else if (comValue.equals(CommunicationPatternProperty.Value.BROADCAST) - || comValue.equals(CommunicationPatternProperty.Value.SHUFFLE)) { + || comValue.equals(CommunicationPatternProperty.Value.SHUFFLE) + || comValue.equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE)) { final int numSrcTasks = InputReader.getSourceParallelism(this); final List> futures = new ArrayList<>(); for (int srcTaskIdx = 0; srcTaskIdx < numSrcTasks; srcTaskIdx++) { From f5b139d6c4e5a82bead52bd3fea53e7358bb4332 Mon Sep 17 00:00:00 2001 From: Kangji Date: Sun, 15 Aug 2021 15:58:16 +0900 Subject: [PATCH 04/10] [add] new vertex property about network hierarchy --- .../ShuffleExecutorSetProperty.java | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ShuffleExecutorSetProperty.java diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ShuffleExecutorSetProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ShuffleExecutorSetProperty.java new file mode 100644 index 0000000000..2f75ac52a6 --- /dev/null +++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ShuffleExecutorSetProperty.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.common.ir.vertex.executionproperty; + +import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty; + +import java.util.ArrayList; +import java.util.HashSet; + +/** + * List of set of node names to limit the scheduling of the tasks of the vertex to while shuffling. + */ +public final class ShuffleExecutorSetProperty extends VertexExecutionProperty>> { + + /** + * Default constructor. + * @param value value of the execution property. + */ + private ShuffleExecutorSetProperty(final ArrayList> value) { + super(value); + } + + /** + * Static method for constructing {@link ShuffleExecutorSetProperty}. + * + * @param setsOfExecutors the list of executors to schedule the tasks of the vertex on. + * Leave empty to make it effectless. + * @return the new execution property + */ + public static ShuffleExecutorSetProperty of(final HashSet> setsOfExecutors) { + return new ShuffleExecutorSetProperty(new ArrayList<>(setsOfExecutors)); + } +} From 60537badaa6c0721793265d25a140879d7c98173 Mon Sep 17 00:00:00 2001 From: Kangji Date: Sun, 15 Aug 2021 16:00:33 +0900 Subject: [PATCH 05/10] [add] accumulator vertex insertion logic --- .../java/org/apache/nemo/common/ir/IRDAG.java | 34 +++++++++++++++++++ .../apache/nemo/common/ir/IRDAGChecker.java | 31 +++++++++++++++-- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java index c619b563b0..d6a929c6b7 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java +++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java @@ -35,6 +35,7 @@ import org.apache.nemo.common.ir.executionproperty.ResourceSpecification; import org.apache.nemo.common.ir.vertex.IRVertex; import org.apache.nemo.common.ir.vertex.LoopVertex; +import org.apache.nemo.common.ir.vertex.OperatorVertex; import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty; import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty; @@ -799,6 +800,39 @@ public void insert(final TaskSizeSplitterVertex toInsert) { modifiedDAG = builder.build(); } + public void insert(final OperatorVertex accumulatorVertex, final IREdge targetEdge) { + // Create a completely new DAG with the vertex inserted. + final DAGBuilder builder = new DAGBuilder<>(); + + builder.addVertex(accumulatorVertex); + modifiedDAG.topologicalDo(v -> { + builder.addVertex(v); + + modifiedDAG.getIncomingEdgesOf(v).forEach(e -> { + if (e == targetEdge) { + // Edge to the accumulatorVertex + final IREdge toAV = new IREdge(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE, + e.getSrc(), accumulatorVertex); + e.copyExecutionPropertiesTo(toAV); + toAV.setProperty(CommunicationPatternProperty.of(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE)); + + // Edge from the accumulatorVertex + final IREdge fromAV = new IREdge(CommunicationPatternProperty.Value.SHUFFLE, accumulatorVertex, e.getDst()); + e.copyExecutionPropertiesTo(fromAV); + + // Connect the new edges + builder.connectVertices(toAV); + builder.connectVertices(fromAV); + } else { + // Simply connect vertices as before + builder.connectVertices(e); + } + }); + }); + + modifiedDAG = builder.build(); + } + /** * Reshape unsafely, without guarantees on preserving application semantics. * TODO #330: Refactor Unsafe Reshaping Passes diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java index ae8a8b3889..8eb33324e3 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java +++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java @@ -79,6 +79,7 @@ private IRDAGChecker() { addLoopVertexCheckers(); addScheduleGroupCheckers(); addCacheCheckers(); + addIntermediateAccumulatorVertexCheckers(); } /** @@ -284,23 +285,25 @@ void addShuffleEdgeCheckers() { final NeighborChecker shuffleChecker = ((v, inEdges, outEdges) -> { for (final IREdge inEdge : inEdges) { if (CommunicationPatternProperty.Value.SHUFFLE + .equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get()) + || CommunicationPatternProperty.Value.PARTIAL_SHUFFLE .equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get())) { // Shuffle edges must have the following properties if (!inEdge.getPropertyValue(KeyExtractorProperty.class).isPresent() || !inEdge.getPropertyValue(KeyEncoderProperty.class).isPresent() || !inEdge.getPropertyValue(KeyDecoderProperty.class).isPresent()) { - return failure("Shuffle edge does not have a Key-related property: " + inEdge.getId()); + return failure("(Partial)Shuffle edge does not have a Key-related property: " + inEdge.getId()); } } else { // Non-shuffle edges must not have the following properties final Optional> partitioner = inEdge.getPropertyValue(PartitionerProperty.class); if (partitioner.isPresent() && partitioner.get().left().equals(PartitionerProperty.Type.HASH)) { - return failure("Only shuffle can have the hash partitioner", + return failure("Only (partial)shuffle can have the hash partitioner", inEdge, CommunicationPatternProperty.class, PartitionerProperty.class); } if (inEdge.getPropertyValue(PartitionSetProperty.class).isPresent()) { - return failure("Only shuffle can select partition sets", + return failure("Only (partial)shuffle can select partition sets", inEdge, CommunicationPatternProperty.class, PartitionSetProperty.class); } } @@ -486,6 +489,28 @@ void addEncodingCompressionCheckers() { singleEdgeCheckerList.add(compressAndDecompress); } + void addIntermediateAccumulatorVertexCheckers() { + final NeighborChecker shuffleExecutorSet = ((v, inEdges, outEdges) -> { + if (v.getPropertyValue(ShuffleExecutorSetProperty.class).isPresent()) { + if (inEdges.size() != 1 || outEdges.size() != 1 || inEdges.stream().anyMatch(e -> + !e.getPropertyValue(CommunicationPatternProperty.class).get() + .equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE))) { + return failure("Only intermediate accumulator vertex can have shuffle executor set property", v); + } else if (v.getPropertyValue(ParallelismProperty.class).get() + < v.getPropertyValue(ShuffleExecutorSetProperty.class).get().size()) { + return failure("Parallelism must be greater or equal to the number of shuffle executor set", v); + } + } else { + if (inEdges.stream().anyMatch(e -> e.getPropertyValue(CommunicationPatternProperty.class).get() + .equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE))) { + return failure("Intermediate accumulator vertex must have shuffle executor set property", v); + } + } + return success(); + }); + neighborCheckerList.add(shuffleExecutorSet); + } + /** * Group outgoing edges by the additional output tag property. * @param outEdges the outedges to group. From 2e871b21ab3883627de16b03c307b58f78620a4c Mon Sep 17 00:00:00 2001 From: Kangji Date: Sun, 15 Aug 2021 16:01:09 +0900 Subject: [PATCH 06/10] [add] unit test for insertion --- .../test/java/org/apache/nemo/common/ir/IRDAGTest.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java index 6113c85266..7a0a6741be 100644 --- a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java +++ b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java @@ -327,6 +327,15 @@ public void testSplitterVertex() { mustPass(); } + @Test + public void testAccumulatorVertex() { + final OperatorVertex cv = new OperatorVertex(new EmptyComponents.EmptyTransform("iav")); + cv.setProperty(ShuffleExecutorSetProperty.of(new HashSet<>())); + cv.setProperty(ParallelismProperty.of(5)); + irdag.insert(cv, shuffleEdge); + mustPass(); + } + private MessageAggregatorVertex insertNewTriggerVertex(final IRDAG dag, final IREdge edgeToGetStatisticsOf) { final MessageGeneratorVertex mb = new MessageGeneratorVertex<>((l, r) -> null); final MessageAggregatorVertex ma = new MessageAggregatorVertex<>(() -> new Object(), (l, r) -> null); From eb1feeb403eb24ad9a7df48af6f2e2245696bf48 Mon Sep 17 00:00:00 2001 From: Kangji Date: Mon, 16 Aug 2021 18:56:07 +0900 Subject: [PATCH 07/10] [add] new stage property about allocated executor --- .../TaskIndexToExecutorIDProperty.java | 51 +++++++++++++++++++ .../nemo/runtime/common/plan/Stage.java | 3 ++ .../apache/nemo/runtime/common/plan/Task.java | 8 +++ .../master/scheduler/TaskDispatcher.java | 5 ++ 4 files changed, 67 insertions(+) create mode 100644 common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/TaskIndexToExecutorIDProperty.java diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/TaskIndexToExecutorIDProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/TaskIndexToExecutorIDProperty.java new file mode 100644 index 0000000000..d341475f1b --- /dev/null +++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/TaskIndexToExecutorIDProperty.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.common.ir.vertex.executionproperty; + +import org.apache.nemo.common.Pair; +import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty; + +import java.util.HashMap; +import java.util.List; + +/** + * Keep track of where the tasks are located by its executor ID. + */ +public final class TaskIndexToExecutorIDProperty + extends VertexExecutionProperty>>> { + /** + * Default constructor. + * @param taskIDToExecutorIDsMap value of the execution property. + */ + private TaskIndexToExecutorIDProperty(final HashMap>> taskIDToExecutorIDsMap) { + super(taskIDToExecutorIDsMap); + } + + /** + * Static method for constructing {@link TaskIndexToExecutorIDProperty}. + * + * @param taskIndexToExecutorIDsMap the map indicating the executor IDs where the tasks are located on. + * @return the new execution property + */ + public static TaskIndexToExecutorIDProperty of( + final HashMap>> taskIndexToExecutorIDsMap) { + return new TaskIndexToExecutorIDProperty(taskIndexToExecutorIDsMap); + } +} diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java index a7f472c0da..c2c0a2e182 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Stage.java @@ -31,8 +31,10 @@ import org.apache.nemo.common.ir.vertex.executionproperty.EnableDynamicTaskSizingProperty; import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty; import org.apache.nemo.common.ir.vertex.executionproperty.ScheduleGroupProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.TaskIndexToExecutorIDProperty; import java.io.Serializable; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,6 +69,7 @@ public Stage(final String stageId, this.irDag = irDag; this.serializedIRDag = SerializationUtils.serialize(irDag); this.executionProperties = executionProperties; + this.executionProperties.put(TaskIndexToExecutorIDProperty.of(new HashMap<>())); this.vertexIdToReadables = vertexIdToReadables; } diff --git a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java index 719075b456..5dc280345c 100644 --- a/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java +++ b/runtime/common/src/main/java/org/apache/nemo/runtime/common/plan/Task.java @@ -109,6 +109,14 @@ public List getTaskOutgoingEdges() { return taskOutgoingEdges; } + /** + * + * @return the task index. + */ + public int getTaskIdx() { + return RuntimeIdManager.getIndexFromTaskId(taskId); + } + /** * @return the attempt index. */ diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/TaskDispatcher.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/TaskDispatcher.java index e15d8ea13e..4e19d367bb 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/TaskDispatcher.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/TaskDispatcher.java @@ -19,6 +19,8 @@ package org.apache.nemo.runtime.master.scheduler; import org.apache.commons.lang3.mutable.MutableObject; +import org.apache.nemo.common.Pair; +import org.apache.nemo.common.ir.vertex.executionproperty.TaskIndexToExecutorIDProperty; import org.apache.nemo.runtime.common.plan.Task; import org.apache.nemo.runtime.common.state.TaskState; import org.apache.nemo.runtime.master.PlanStateManager; @@ -150,6 +152,9 @@ private void doScheduleTaskList() { planStateManager.onTaskStateChanged(task.getTaskId(), TaskState.State.EXECUTING); LOG.info("{} scheduled to {}", task.getTaskId(), selectedExecutor.getExecutorId()); + task.getPropertyValue(TaskIndexToExecutorIDProperty.class).get() + .computeIfAbsent(task.getTaskIdx(), i -> new ArrayList<>()) + .add(task.getAttemptIdx(), Pair.of(selectedExecutor.getExecutorId(), selectedExecutor.getNodeName())); // send the task selectedExecutor.onTaskScheduled(task); } else { From 52bd37dd6ea5e6682a5c9cfc057c11b57babebb4 Mon Sep 17 00:00:00 2001 From: Kangji Date: Mon, 16 Aug 2021 19:07:31 +0900 Subject: [PATCH 08/10] [add] intermediate combine scheduling constraint --- ...ediateAccumulatorSchedulingConstraint.java | 51 +++++++++++++++++++ .../SchedulingConstraintRegistry.java | 4 +- 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraint.java diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraint.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraint.java new file mode 100644 index 0000000000..7ecf87694f --- /dev/null +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraint.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.runtime.master.scheduler; + +import org.apache.nemo.common.ir.executionproperty.AssociatedProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.ShuffleExecutorSetProperty; +import org.apache.nemo.runtime.common.plan.Task; +import org.apache.nemo.runtime.master.resource.ExecutorRepresenter; + +import javax.inject.Inject; +import java.util.ArrayList; +import java.util.HashSet; + +/** + * Compare shuffle executor set and the executor. + */ +@AssociatedProperty(ShuffleExecutorSetProperty.class) +public final class IntermediateAccumulatorSchedulingConstraint implements SchedulingConstraint { + + @Inject + private IntermediateAccumulatorSchedulingConstraint() { + } + + @Override + public boolean testSchedulability(final ExecutorRepresenter executor, final Task task) { + if (!task.getPropertyValue(ShuffleExecutorSetProperty.class).isPresent()) { + return true; + } + + final ArrayList> setsOfExecutors = task.getPropertyValue(ShuffleExecutorSetProperty.class).get(); + final int numOfSets = setsOfExecutors.size(); + final int idx = task.getTaskIdx(); + return setsOfExecutors.get(idx % numOfSets).contains(executor.getNodeName()); + } +} diff --git a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java index 277dd5056c..a6c4c9ad7a 100644 --- a/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java +++ b/runtime/master/src/main/java/org/apache/nemo/runtime/master/scheduler/SchedulingConstraintRegistry.java @@ -44,12 +44,14 @@ private SchedulingConstraintRegistry( final FreeSlotSchedulingConstraint freeSlotSchedulingConstraint, final LocalitySchedulingConstraint localitySchedulingConstraint, final AntiAffinitySchedulingConstraint antiAffinitySchedulingConstraint, - final NodeShareSchedulingConstraint nodeShareSchedulingConstraint) { + final NodeShareSchedulingConstraint nodeShareSchedulingConstraint, + final IntermediateAccumulatorSchedulingConstraint intermediateAccumulatorSchedulingConstraint) { registerSchedulingConstraint(containerTypeAwareSchedulingConstraint); registerSchedulingConstraint(freeSlotSchedulingConstraint); registerSchedulingConstraint(localitySchedulingConstraint); registerSchedulingConstraint(antiAffinitySchedulingConstraint); registerSchedulingConstraint(nodeShareSchedulingConstraint); + registerSchedulingConstraint(intermediateAccumulatorSchedulingConstraint); } /** From 5a0f2fe801e0f4fe9b9a436af9432df9701503d3 Mon Sep 17 00:00:00 2001 From: Kangji Date: Mon, 16 Aug 2021 19:08:35 +0900 Subject: [PATCH 09/10] [add] scheduling constraint unit test --- ...teAccumulatorSchedulingConstraintTest.java | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 runtime/master/src/test/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraintTest.java diff --git a/runtime/master/src/test/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraintTest.java b/runtime/master/src/test/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraintTest.java new file mode 100644 index 0000000000..a82dbf72a1 --- /dev/null +++ b/runtime/master/src/test/java/org/apache/nemo/runtime/master/scheduler/IntermediateAccumulatorSchedulingConstraintTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.runtime.master.scheduler; + +import org.apache.nemo.common.ir.vertex.executionproperty.ShuffleExecutorSetProperty; +import org.apache.nemo.runtime.common.plan.Task; +import org.apache.nemo.runtime.master.resource.ExecutorRepresenter; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.exceptions.InjectionException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.*; +import java.util.concurrent.Executor; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test cases for {@link IntermediateAccumulatorSchedulingConstraint}. + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({ExecutorRepresenter.class, Task.class}) +public class IntermediateAccumulatorSchedulingConstraintTest { + + private static ExecutorRepresenter mockExecutorRepresenter(final String nodeName) { + final ExecutorRepresenter executorRepresenter = mock(ExecutorRepresenter.class); + when(executorRepresenter.getNodeName()).thenReturn(nodeName); + return executorRepresenter; + } + + @Test + public void testIntermediateAccumulator() throws InjectionException { + final SchedulingConstraint schedulingConstraint = Tang.Factory.getTang().newInjector() + .getInstance(IntermediateAccumulatorSchedulingConstraint.class); + + final ExecutorRepresenter e0 = mockExecutorRepresenter("mulan-0"); + final ExecutorRepresenter e1 = mockExecutorRepresenter("mulan-0"); + final ExecutorRepresenter e2 = mockExecutorRepresenter("mulan-1"); + final ExecutorRepresenter e3 = mockExecutorRepresenter("mulan-1"); + final ExecutorRepresenter e4 = mockExecutorRepresenter("mulan-2"); + final ExecutorRepresenter e5 = mockExecutorRepresenter("mulan-2"); + final Set executorRepresenters = new HashSet<>(Arrays.asList(e0, e1, e2, e3, e4, e5)); + final Set expectedExecutors1 = new HashSet<>(Arrays.asList(e0, e1, e2, e3)); + final Set expectedExecutors2 = new HashSet<>(Arrays.asList(e4, e5)); + + final Task task1 = mock(Task.class); + final Task task2 = mock(Task.class); + final Task task3 = mock(Task.class); + ArrayList> setsOfExecutors = new ArrayList<>(Arrays.asList( + new HashSet<>(Arrays.asList("mulan-0", "mulan-1")), + new HashSet<>(Arrays.asList("mulan-2")) + )); + when(task1.getPropertyValue(ShuffleExecutorSetProperty.class)).thenReturn(Optional.of(setsOfExecutors)); + when(task1.getTaskIdx()).thenReturn(0); + when(task2.getPropertyValue(ShuffleExecutorSetProperty.class)).thenReturn(Optional.of(setsOfExecutors)); + when(task2.getTaskIdx()).thenReturn(3); + when(task3.getPropertyValue(ShuffleExecutorSetProperty.class)).thenReturn(Optional.of(setsOfExecutors)); + when(task3.getTaskIdx()).thenReturn(6); + + final Set candidateExecutors1 = executorRepresenters.stream() + .filter(e -> schedulingConstraint.testSchedulability(e, task1)).collect(Collectors.toSet()); + assertEquals(expectedExecutors1, candidateExecutors1); + + final Set candidateExecutors2 = executorRepresenters.stream() + .filter(e -> schedulingConstraint.testSchedulability(e, task2)).collect(Collectors.toSet()); + assertEquals(expectedExecutors2, candidateExecutors2); + + final Set candidateExecutors3 = executorRepresenters.stream() + .filter(e -> schedulingConstraint.testSchedulability(e, task3)).collect(Collectors.toSet()); + assertEquals(expectedExecutors1, candidateExecutors3); + } +} From be6f3a2e7202767b53e0f0fbcb30af23fe9f7d75 Mon Sep 17 00:00:00 2001 From: Kangji Date: Mon, 16 Aug 2021 19:13:57 +0900 Subject: [PATCH 10/10] [add] data transfer on partial shuffle --- .../common/partitioner/HashPartitioner.java | 4 +++ .../datatransfer/PipeOutputWriter.java | 26 ++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/apache/nemo/common/partitioner/HashPartitioner.java b/common/src/main/java/org/apache/nemo/common/partitioner/HashPartitioner.java index 241f0564f7..02b4aaddb6 100644 --- a/common/src/main/java/org/apache/nemo/common/partitioner/HashPartitioner.java +++ b/common/src/main/java/org/apache/nemo/common/partitioner/HashPartitioner.java @@ -45,4 +45,8 @@ public HashPartitioner(final int numOfPartitions, public Integer partition(final Object element) { return Math.abs(keyExtractor.extractKey(element).hashCode() % numOfPartitions); } + + public Integer partition(final Object element, final int numOfSubPartitions) { + return Math.abs(keyExtractor.extractKey(element).hashCode() % numOfSubPartitions); + } } diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java index 544d64d921..419aa99b69 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeOutputWriter.java @@ -18,7 +18,11 @@ */ package org.apache.nemo.runtime.executor.datatransfer; +import org.apache.nemo.common.Pair; import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.ShuffleExecutorSetProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.TaskIndexToExecutorIDProperty; +import org.apache.nemo.common.partitioner.HashPartitioner; import org.apache.nemo.common.partitioner.Partitioner; import org.apache.nemo.common.punctuation.Watermark; import org.apache.nemo.runtime.common.RuntimeIdManager; @@ -32,9 +36,9 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Optional; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; /** * Represents the output data transfer from a task. @@ -147,6 +151,22 @@ private List getPipeToWrite(final Object element) { return Collections.singletonList(pipes.get(0)); case BROADCAST: return pipes; + case PARTIAL_SHUFFLE: + final List> listOfSrcNodeNames = ((StageEdge) runtimeEdge).getSrc() + .getPropertyValue(TaskIndexToExecutorIDProperty.class).get().get(srcTaskIndex); + final String nodeName = listOfSrcNodeNames.get(listOfSrcNodeNames.size() - 1).right(); + + final ArrayList> setsOfExecutors = ((StageEdge) runtimeEdge).getDst() + .getPropertyValue(ShuffleExecutorSetProperty.class).get(); + final int numOfSets = setsOfExecutors.size(); + final int dstParallelism = ((StageEdge) runtimeEdge).getDst().getParallelism(); + final List listOfDstTaskIdx = IntStream.range(0, dstParallelism) + .filter(i -> setsOfExecutors.get(i % numOfSets).contains(nodeName)) + .boxed().collect(Collectors.toList()); + + final int numOfPartitions = listOfDstTaskIdx.size(); + final int pipeIndex = listOfDstTaskIdx.get(((HashPartitioner) partitioner).partition(element, numOfPartitions)); + return Collections.singletonList(pipes.get(pipeIndex)); default: return Collections.singletonList(pipes.get((int) partitioner.partition(element))); }