Skip to content
127 changes: 126 additions & 1 deletion metaflow/flowspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,35 @@ def _is_primitive_type(item):
value = item if _is_primitive_type(item) else reprlib.Repr().repr(item)
return basestring(value)[:MAXIMUM_FOREACH_VALUE_CHARS]

def _validate_switch_cases(self, switch_cases, step):
resolved_cases = {}
for case_key, step_method in switch_cases.items():
if isinstance(case_key, str) and case_key.startswith("config:"):
full_path = case_key[len("config:") :]
parts = full_path.split(".", 1)
if len(parts) == 2:
config_var_name, config_key_name = parts
try:
config_obj = getattr(self, config_var_name)
resolved_key = str(getattr(config_obj, config_key_name))
except AttributeError:
msg = (
"Step *{step}* references unknown config '{path}' "
"in switch case.".format(step=step, path=full_path)
)
raise InvalidNextException(msg)
else:
raise MetaflowInternalError(
"Invalid config path format in switch case."
)
else:
resolved_key = case_key

func_name = step_method.__func__.__name__
resolved_cases[resolved_key] = func_name

return resolved_cases

def next(self, *dsts: Callable[..., None], **kwargs) -> None:
"""
Indicates the next step to execute after this step has completed.
Expand All @@ -780,6 +809,15 @@ def next(self, *dsts: Callable[..., None], **kwargs) -> None:
evaluates to an iterator. A task will be launched for each value in the iterator and
each task will execute the code specified by the step `foreach_step`.

- Switch statement:
```
self.next({"case1": self.step_a, "case2": self.step_b}, condition='condition_variable')
```
In this situation, `step_a` and `step_b` are methods in the current class decorated
with the `@step` decorator and `condition_variable` is a variable name in the current
class. The value of the condition variable determines which step to execute. If the
value doesn't match any of the dictionary keys, a RuntimeError is raised.

Parameters
----------
dsts : Callable[..., None]
Expand All @@ -794,6 +832,7 @@ def next(self, *dsts: Callable[..., None], **kwargs) -> None:
step = self._current_step

foreach = kwargs.pop("foreach", None)
condition = kwargs.pop("condition", None)
num_parallel = kwargs.pop("num_parallel", None)
if kwargs:
kw = next(iter(kwargs))
Expand All @@ -811,6 +850,92 @@ def next(self, *dsts: Callable[..., None], **kwargs) -> None:
)
raise InvalidNextException(msg)

# check: switch case using condition
if condition is not None:
if len(dsts) != 1 or not isinstance(dsts[0], dict):
msg = (
"Step *{step}* has an invalid self.next() transition. "
"When using 'condition', provide exactly one dictionary argument "
"mapping condition values to step methods.".format(step=step)
)
raise InvalidNextException(msg)

if not isinstance(condition, basestring):
msg = (
"Step *{step}* has an invalid self.next() transition. "
"The argument to 'condition' must be a string.".format(step=step)
)
raise InvalidNextException(msg)

if foreach is not None or num_parallel is not None:
msg = (
"Step *{step}* has an invalid self.next() transition. "
"Switch statements cannot be combined with foreach or num_parallel.".format(
step=step
)
)
raise InvalidNextException(msg)

switch_cases = dsts[0]
if not switch_cases:
msg = (
"Step *{step}* has an invalid self.next() transition. "
"Switch statement cannot have an empty dictionary.".format(
step=step
)
)
raise InvalidNextException(msg)

# Validate that condition variable exists
try:
getattr(self, condition)
except AttributeError:
msg = (
"Condition variable *self.{var}* in step *{step}* "
"does not exist. Make sure you set self.{var} in this step.".format(
step=step, var=condition
)
)
raise InvalidNextException(msg)

# Validate that all switch case values are step methods
funcs = []
for case_value, step_method in switch_cases.items():
try:
func_name = step_method.__func__.__name__
except:
msg = (
"In step *{step}* switch case '{case}', the value is "
"not a function. Make sure all values in the switch dictionary "
"are methods of the Flow class.".format(
step=step, case=case_value
)
)
raise InvalidNextException(msg)
if not hasattr(self, func_name):
msg = (
"Step *{step}* specifies a switch case to an "
"unknown step, *{name}*.".format(step=step, name=func_name)
)
raise InvalidNextException(msg)
funcs.append(func_name)

resolved_switch_cases = self._validate_switch_cases(switch_cases, step)

# Store switch information for runtime evaluation
self._switch_cases = resolved_switch_cases
self._switch_condition = condition
self._transition = (funcs, None)
return

# Convert switch cases back to individual destinations for regular processing
if len(dsts) == 1 and isinstance(dsts[0], dict):
msg = (
"Step *{step}* has an invalid self.next() transition. "
"Dictionary argument requires 'condition' parameter.".format(step=step)
)
raise InvalidNextException(msg)

# check: all destinations are methods of this object
funcs = []
for i, dst in enumerate(dsts):
Expand Down Expand Up @@ -901,7 +1026,7 @@ def next(self, *dsts: Callable[..., None], **kwargs) -> None:
self._foreach_var = foreach

# check: non-keyword transitions are valid
if foreach is None:
if foreach is None and condition is None:
if len(dsts) < 1:
msg = (
"Step *{step}* has an invalid self.next() transition. "
Expand Down
128 changes: 120 additions & 8 deletions metaflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(self, func_ast, decos, doc, source_file, lineno):
self.has_tail_next = False
self.invalid_tail_next = False
self.num_args = 0
self.switch_cases = {}
self.condition = None
self.foreach_param = None
self.num_parallel = 0
self.parallel_foreach = False
Expand All @@ -77,6 +79,55 @@ def __init__(self, func_ast, decos, doc, source_file, lineno):
def _expr_str(self, expr):
return "%s.%s" % (expr.value.id, expr.attr)

def _parse_switch_dict(self, dict_node):
switch_cases = {}
if isinstance(dict_node, ast.Dict):
for key, value in zip(dict_node.keys, dict_node.values):
case_key = None

# Handle string literals
if isinstance(key, ast.Str):
case_key = key.s
elif isinstance(key, ast.Constant) and isinstance(key.value, str):
case_key = key.value
elif isinstance(key, ast.Attribute):
if isinstance(key.value, ast.Attribute) and isinstance(
key.value.value, ast.Name
):
# This handles self.config.some_key
if key.value.value.id == "self":
config_var = key.value.attr
config_key = key.attr
case_key = f"config:{config_var}.{config_key}"
else:
return None
else:
return None

# Handle variables or other dynamic expressions - not allowed
elif isinstance(key, ast.Name):
return None
else:
# Can't statically analyze this key
return None

if case_key is None:
return None

# Extract the step name from the value
if isinstance(value, ast.Attribute) and isinstance(
value.value, ast.Name
):
if value.value.id == "self":
step_name = value.attr
switch_cases[case_key] = step_name
else:
return None
else:
return None

return switch_cases if switch_cases else None

def _parse(self, func_ast, lineno):
self.num_args = len(func_ast.args.args)
tail = func_ast.body[-1]
Expand All @@ -98,7 +149,38 @@ def _parse(self, func_ast, lineno):
self.has_tail_next = True
self.invalid_tail_next = True
self.tail_next_lineno = lineno + tail.lineno - 1
self.out_funcs = [e.attr for e in tail.value.args]

# Check if first argument is a dictionary (switch case)
if (
len(tail.value.args) == 1
and isinstance(tail.value.args[0], ast.Dict)
and any(k.arg == "condition" for k in tail.value.keywords)
):

# This is a switch statement
switch_cases = self._parse_switch_dict(tail.value.args[0])
condition_name = None

# Get condition parameter
for keyword in tail.value.keywords:
if keyword.arg == "condition":
if isinstance(keyword.value, ast.Str):
condition_name = keyword.value.s
elif isinstance(keyword.value, ast.Constant) and isinstance(
keyword.value.value, str
):
condition_name = keyword.value.value
break

if switch_cases and condition_name:
self.type = "split-switch"
self.condition = condition_name
self.switch_cases = switch_cases
self.out_funcs = list(switch_cases.values())
self.invalid_tail_next = False
return
else:
self.out_funcs = [e.attr for e in tail.value.args]

keywords = dict(
(k.arg, getattr(k.value, "s", None)) for k in tail.value.keywords
Expand Down Expand Up @@ -145,6 +227,7 @@ def __str__(self):
has_tail_next={0.has_tail_next} (line {0.tail_next_lineno})
invalid_tail_next={0.invalid_tail_next}
foreach_param={0.foreach_param}
condition={0.condition}
parallel_step={0.parallel_step}
parallel_foreach={0.parallel_foreach}
-> {out}""".format(
Expand Down Expand Up @@ -207,6 +290,8 @@ def traverse(node, seen, split_parents):
if node.type in ("split", "foreach"):
node.split_parents = split_parents
split_parents = split_parents + [node.name]
elif node.type == "split-switch":
node.split_parents = split_parents
elif node.type == "join":
# ignore joins without splits
if split_parents:
Expand Down Expand Up @@ -247,15 +332,37 @@ def __str__(self):
def output_dot(self):
def edge_specs():
for node in self.nodes.values():
for edge in node.out_funcs:
yield "%s -> %s;" % (node.name, edge)
if node.type == "split-switch":
# Label edges for switch cases
for case_value, step_name in node.switch_cases.items():
yield (
'{0} -> {1} [label="{2}" color="blue" fontcolor="blue"];'.format(
node.name, step_name, case_value
)
)
else:
for edge in node.out_funcs:
yield "%s -> %s;" % (node.name, edge)

def node_specs():
for node in self.nodes.values():
nodetype = "join" if node.num_args > 1 else node.type
yield '"{0.name}"' '[ label = <<b>{0.name}</b> | <font point-size="10">{type}</font>> ' ' fontname = "Helvetica" ' ' shape = "record" ];'.format(
node, type=nodetype
)
if node.type == "split-switch":
# Hexagon shape for switch nodes
condition_label = (
f"switch: {node.condition}" if node.condition else "switch"
)
yield (
'"{0.name}" '
'[ label = <<b>{0.name}</b><br/><font point-size="9">{condition}</font>> '
' fontname = "Helvetica" '
' shape = "hexagon" '
' style = "filled" fillcolor = "lightgreen" ];'
).format(node, condition=condition_label)
else:
nodetype = "join" if node.num_args > 1 else node.type
yield '"{0.name}"' '[ label = <<b>{0.name}</b> | <font point-size="10">{type}</font>> ' ' fontname = "Helvetica" ' ' shape = "record" ];'.format(
node, type=nodetype
)

return (
"digraph {0.name} {{\n"
Expand All @@ -279,6 +386,8 @@ def node_to_type(node):
if node.parallel_foreach:
return "split-parallel"
return "split-foreach"
elif node.type == "split-switch":
return "split-switch"
return "unknown" # Should never happen

def node_to_dict(name, node):
Expand All @@ -303,6 +412,9 @@ def node_to_dict(name, node):
d["foreach_artifact"] = node.foreach_param
elif d["type"] == "split-parallel":
d["num_parallel"] = node.num_parallel
elif d["type"] == "split-switch":
d["condition"] = node.condition
d["switch_cases"] = node.switch_cases
if node.matching_join:
d["matching_join"] = node.matching_join
return d
Expand All @@ -317,7 +429,7 @@ def populate_block(start_name, end_name):
steps_info[cur_name] = node_dict
resulting_list.append(cur_name)

if cur_node.type not in ("start", "linear", "join"):
if cur_node.type not in ("start", "linear", "join", "split-switch"):
# We need to look at the different branches for this
resulting_list.append(
[
Expand Down
Loading
Loading