Skip to content

Commit 5411035

Browse files
committed
Add skip decorator; A few clean ups
1 parent 898b0c0 commit 5411035

7 files changed

+102
-8
lines changed

metaflow/plugins/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def get_plugin_cli():
116116
from .frameworks.pytorch import PytorchParallelDecorator
117117
from .aip.aip_decorator import AIPInternalDecorator
118118
from .aip.accelerator_decorator import AcceleratorDecorator
119-
from .aip.interruptible_decorator import interruptibleDecorator
119+
from .aip.interruptible_decorator import InterruptibleDecorator
120+
from .aip.skip_decorator import SkipDecorator
120121

121122

122123
STEP_DECORATORS = [
@@ -134,8 +135,9 @@ def get_plugin_cli():
134135
PytorchParallelDecorator,
135136
InternalTestUnboundedForeachDecorator,
136137
AcceleratorDecorator,
137-
interruptibleDecorator,
138+
InterruptibleDecorator,
138139
AIPInternalDecorator,
140+
SkipDecorator,
139141
]
140142
_merge_lists(STEP_DECORATORS, _ext_plugins["STEP_DECORATORS"], "name")
141143

@@ -159,6 +161,7 @@ def get_plugin_cli():
159161
from .aws.step_functions.schedule_decorator import ScheduleDecorator
160162
from .project_decorator import ProjectDecorator
161163
from .aip.s3_sensor_decorator import S3SensorDecorator
164+
162165
from .aip.exit_handler_decorator import ExitHandlerDecorator
163166

164167
FLOW_DECORATORS = [

metaflow/plugins/aip/aip.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from metaflow.plugins.aip.aip_decorator import AIPException
5959
from .accelerator_decorator import AcceleratorDecorator
6060
from .argo_client import ArgoClient
61-
from .interruptible_decorator import interruptibleDecorator
61+
from .interruptible_decorator import InterruptibleDecorator
6262
from .aip_foreach_splits import graph_to_task_ids
6363
from ..aws.batch.batch_decorator import BatchDecorator
6464
from ..aws.step_functions.schedule_decorator import ScheduleDecorator
@@ -106,7 +106,7 @@ def __init__(
106106
resource_requirements: Dict[str, str],
107107
aip_decorator: AIPInternalDecorator,
108108
accelerator_decorator: AcceleratorDecorator,
109-
interruptible_decorator: interruptibleDecorator,
109+
interruptible_decorator: InterruptibleDecorator,
110110
environment_decorator: EnvironmentDecorator,
111111
total_retries: int,
112112
minutes_between_retries: str,
@@ -741,7 +741,7 @@ def build_aip_component(node: DAGNode, task_id: str) -> AIPComponent:
741741
(
742742
deco
743743
for deco in node.decorators
744-
if isinstance(deco, interruptibleDecorator)
744+
if isinstance(deco, InterruptibleDecorator)
745745
),
746746
None, # default
747747
),

metaflow/plugins/aip/interruptible_decorator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def _get_ec2_metadata(path: str) -> Optional[str]:
1111
return response.text
1212

1313

14-
class interruptibleDecorator(StepDecorator):
14+
class InterruptibleDecorator(StepDecorator):
1515
"""
1616
For AIP orchestrator plugin only.
1717

metaflow/plugins/aip/s3_sensor_decorator.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from types import FunctionType
2-
from typing import Tuple
32
from urllib.parse import urlparse
43

54
from metaflow.decorators import FlowDecorator
+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Skip decorator is a workaround solution to implement conditional branching in metaflow.
2+
# When condition variable is_skipping is evaluated to True,
3+
# it will skip current step and execute the supplied next step.
4+
5+
from functools import wraps
6+
from metaflow.decorators import StepDecorator
7+
8+
9+
class SkipDecorator(StepDecorator):
10+
"""
11+
The @skip decorator is a workaround for conditional branching. The @skip decorator checks an artifact
12+
and if it is false, skips the evaluation of the step function and jumps to the supplied next step.
13+
14+
**The `start` and `end` steps are always expected and should not be skipped.**
15+
16+
Usage:
17+
class SkipFlow(FlowSpec):
18+
19+
condition = Parameter("condition", default=False)
20+
21+
@step
22+
def start(self):
23+
print("Should skip:", self.condition)
24+
self.next(self.middle)
25+
26+
@skip(check='condition', next='end')
27+
@step
28+
def middle(self):
29+
print("Running the middle step - not skipping")
30+
self.next(self.end)
31+
32+
@step
33+
def end(self):
34+
pass
35+
"""
36+
37+
name = "skip"
38+
39+
def __init__(self, check="", next=""):
40+
super().__init__()
41+
self.check = check
42+
self.next = next
43+
44+
def __call__(self, f):
45+
@wraps(f)
46+
def func(step):
47+
if getattr(step, self.check):
48+
step.next(getattr(step, self.next))
49+
else:
50+
return f(step)
51+
52+
return func

metaflow/plugins/aip/tests/flows/resources_flow.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import pprint
33
import subprocess
4-
import time
54
from typing import Dict, List
65
from multiprocessing.shared_memory import SharedMemory
76

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from metaflow import Parameter, FlowSpec, step, skip
2+
3+
4+
class SkipFlow(FlowSpec):
5+
6+
condition_true = Parameter("condition-true", default=True)
7+
8+
@step
9+
def start(self):
10+
print("Should skip:", self.condition)
11+
self.desired_step_executed = False
12+
self.condition_false = False
13+
self.next(self.skipped_step)
14+
15+
@skip(check="condition_true", next="desired_step")
16+
@step
17+
def skipped_step(self):
18+
raise Exception(
19+
"Unexpectedly ran the skipped_step step. This step should have been skipped."
20+
)
21+
self.next(self.unreachable)
22+
23+
def unreachable(self):
24+
raise Exception(
25+
"Unexpectedly ran the unreachable step. This step should have been skipped."
26+
)
27+
self.next(self.end)
28+
29+
@skip(check="condition_false", next="end")
30+
@step
31+
def desired_step(self):
32+
self.desired_step_executed = True
33+
self.next(self.end)
34+
35+
@step
36+
def end(self):
37+
assert self.desired_step_executed, "Desired step was not executed"
38+
39+
40+
if __name__ == "__main__":
41+
SkipFlow()

0 commit comments

Comments
 (0)