Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Initial commit for spin steps #2036

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
516 changes: 488 additions & 28 deletions metaflow/cli.py

Large diffs are not rendered by default.

153 changes: 153 additions & 0 deletions metaflow/datastore/spin_datastore/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
class SpinDatastore(object):
def __init__(self, spin_parser_validator):
self.spin_parser_validator = spin_parser_validator

def get_task_for_step(self, step_name):
"""
Returns an instance of the task corresponding to the step name and run id.
"""
if self.run_id is None:
raise ValueError("No run_id provided")

from metaflow import Step

step = Step(
f"{self.flow_name}/{self.run_id}/{step_name}", _namespace_check=False
)
return next(iter(step.tasks()))

@property
def flow_name(self):
return self.spin_parser_validator.flow_name

@property
def parsed_ancestor_tasks(self):
return self.spin_parser_validator.parsed_ancestor_tasks

@property
def foreach_var(self):
return self.spin_parser_validator.foreach_var

@property
def foreach_index(self):
return self.spin_parser_validator.foreach_index

@property
def foreach_value(self):
return self.spin_parser_validator.foreach_value

@property
def foreach_stack(self):
return self.spin_parser_validator.foreach_stack

@property
def artifacts(self):
return self.spin_parser_validator.artifacts

@property
def run_id(self):
return self.spin_parser_validator.run_id

@property
def step_name(self):
return self.spin_parser_validator.step_name

@property
def task(self):
return self.spin_parser_validator.task

@property
def step_type(self):
return self.spin_parser_validator.step_type

@property
def previous_steps(self):
return self.spin_parser_validator.previous_steps

@property
def required_ancestor_tasks(self):
return self.spin_parser_validator.required_ancestor_tasks

@property
def is_foreach_step(self):
return len(self.spin_parser_validator.foreach_stack) > 0

def get_all_previous_tasks(self, prev_step_name):
# We go through all the required ancestors for the current step
# and filter the tasks from the previous step whose foreach stack
# entries match the required ancestors
from metaflow import Step

def _parse_foreach_stack(foreach_stack):
return {
entry.step: {
"task_val": entry.value,
"task_index": entry.index,
}
for entry in foreach_stack.data
}

def _parse_required_ancestor_tasks(required_ancestor_tasks):
result = {}
for required_ancestor in required_ancestor_tasks:
step_name = required_ancestor.step_name
task_specifier = required_ancestor.task_specifier
task_val = required_ancestor.value
if task_specifier == "task_id":
from metaflow import Task

task = Task(
f"{self.flow_name}/{self.run_id}/{step_name}/{task_val}"
)
return {
**_parse_foreach_stack(task["_foreach_stack"]),
}
elif task_specifier == "task_index":
result[step_name] = {
"task_index": task_val,
}
elif task_specifier == "task_val":
result[step_name] = {
"task_val": str(task_val),
}
else:
raise ValueError("Invalid task specifier")
return result

def _is_ancestor(foreach_stack, required_ancestors):
for step_name, required_ancestor in required_ancestors.items():
if step_name not in foreach_stack:
return False
if (
"task_val" in required_ancestor
and str(required_ancestor["task_val"])
!= foreach_stack[step_name]["task_val"]
):
return False
if (
"task_index" in required_ancestor
and required_ancestor["task_index"]
!= foreach_stack[step_name]["task_index"]
):
return False
return True

previous_tasks = []
prev_step = Step(
f"{self.flow_name}/{self.run_id}/{prev_step_name}", _namespace_check=False
)
for task in prev_step.tasks():
foreach_stack = task["_foreach_stack"]
required_ancestor_tasks_parsed = _parse_required_ancestor_tasks(
self.required_ancestor_tasks
)
foreach_stack_parsed = _parse_foreach_stack(foreach_stack)
if _is_ancestor(foreach_stack_parsed, required_ancestor_tasks_parsed):
previous_tasks.append(task)

if len(previous_tasks) == 0:
raise ValueError(
"No previous tasks found for the current step. Please check if the values provided for "
"ancestor tasks are correct."
)
return previous_tasks
98 changes: 98 additions & 0 deletions metaflow/datastore/spin_datastore/inputs_datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from . import SpinDatastore


class SpinInput(object):
def __init__(self, artifacts, task=None):
self.artifacts = artifacts
self.task = task

def __getattr__(self, name):
# We always look for any artifacts provided by the user first
if self.artifacts is not None and name in self.artifacts:
return self.artifacts[name]

if self.task is None:
raise AttributeError(
f"Attribute '{name}' not provided by the user and no `task` was provided."
)

try:
return getattr(self.task.artifacts, name).data
except AttributeError:
raise AttributeError(
f"Attribute '{name}' not found in the previous execution of the task for "
f"`{self.step_name}`."
)

raise AttributeError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this ever reached?

f"Attribute '{name}' not found in the artifacts provided by the user or in the"
f"the previous execution of the task for `{self.step_name}`"
)


class StaticSpinInputsDatastore(SpinDatastore):
def __init__(self, spin_parser_validator):
super(StaticSpinInputsDatastore, self).__init__(spin_parser_validator)
self._previous_tasks = {}

def __getattr__(self, name):
if name not in self.previous_steps:
raise AttributeError(
f"Attribute '{name}' not found in the previous execution of the task for "
f"`{self.step_name}`."
)

input_step = SpinInput(
self.spin_parser_validator.artifacts["join"][name],
self.get_previous_tasks[name],
)
setattr(self, name, input_step)
return input_step

def __iter__(self):
for prev_step_name in self.previous_steps:
yield self[prev_step_name]

def __len__(self):
return len(self.get_previous_tasks)

@property
def get_previous_tasks(self):
if self._previous_tasks:
return self._previous_tasks

for prev_step_name in self.previous_steps:
previous_task = self.get_all_previous_tasks(prev_step_name)
self._previous_tasks[prev_step_name] = previous_task
return self._previous_tasks


class SpinInputsDatastore(SpinDatastore):
def __init__(self, spin_parser_validator):
super(SpinInputsDatastore, self).__init__(spin_parser_validator)
self._previous_tasks = None

def __len__(self):
return len(self.get_previous_tasks)

def __getitem__(self, idx):
_item_task = self.get_previous_tasks[idx]
_item_artifacts = self.spin_parser_validator.artifacts
# _item_artifacts = self.spin_parser_validator.artifacts[idx]
return SpinInput(_item_artifacts, _item_task)

def __iter__(self):
for idx in range(len(self.get_previous_tasks)):
yield self[idx]

@property
def get_previous_tasks(self):
if self._previous_tasks:
return self._previous_tasks

# This a join step for a foreach split, so only has one previous step
prev_step_name = self.previous_steps[0]
self._previous_tasks = self.get_all_previous_tasks(prev_step_name)
# Sort the tasks by index
self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index)
return self._previous_tasks
108 changes: 108 additions & 0 deletions metaflow/datastore/spin_datastore/step_datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from . import SpinDatastore


class SpinStepDatastore(SpinDatastore):
def __init__(self, spin_parser_validator):
super(SpinStepDatastore, self).__init__(spin_parser_validator)
self._previous_task = None
self._data = {}

# Set them to empty dictionaries in order to persist artifacts
# See `persist` method in `TaskDatastore` for more details
self._objects = {}
self._info = {}

def __contains__(self, name):
try:
_ = self.__getattr__(name)
except AttributeError:
return False
return True

def __getitem__(self, name):
return self.__getattr__(name)

def __setitem__(self, name, value):
self._data[name] = value

def __getattr__(self, name):
# Check internal data first
if name in self._data:
return self._data[name]

# We always look for any artifacts provided by the user first
if name in self.artifacts:
return self.artifacts[name]

if self.run_id is None:
raise AttributeError(
f"Attribute '{name}' not provided by the user and no `run_id` was provided. "
)

# If the linear step is part of a foreach step, we need to set the input attribute
if name == "input":
if self.foreach_index:
_foreach_var = self.foreach_var
_foreach_index = self.foreach_index
_foreach_step_name = self.step_name
elif len(self.foreach_stack) > 0:
_foreach_stack = self.previous_task["_foreach_stack"].data
cur_foreach_step_var = _foreach_stack[-1].var
cur_foreach_step_index = _foreach_stack[-1].index
cur_foreach_step_name = _foreach_stack[-1].step
foreach_task = self.get_task_for_step(cur_foreach_step_name)
foreach_value = foreach_task[cur_foreach_step_var].data[
cur_foreach_step_index
]
setattr(self, name, foreach_value)
return foreach_value

foreach_task = self.get_task_for_step(_foreach_step_name)
foreach_value = foreach_task[_foreach_var].data[_foreach_index]
setattr(self, name, foreach_value)
return foreach_value

# If the linear step is part of a foreach step, we need to set the index attribute
if name == "index":
if self.foreach_index:
setattr(self, name, self.foreach_index)
return self.foreach_index
if len(self.foreach_stack) > 0:
cur_foreach_step_index = (
self.previous_task["_foreach_stack"].data[-1].index
)
setattr(self, name, cur_foreach_step_index)
return cur_foreach_step_index
raise AttributeError(
f"Attribute index does not exist for step `{self.step_name}` as it is not part of a foreach step."
)

# If the user has not provided the artifact, we look for it in the
# task using the client API
try:
return getattr(self.previous_task.artifacts, name).data
except AttributeError:
raise AttributeError(
f"Attribute '{name}' not found in the previous execution of the task for "
f"`{self.step_name}`."
)

@property
def previous_task(self):
# Since this is not a join step, we can safely assume that there is only one
# previous step and one corresponding previous task
if self.spin_parser_validator.previous_steps_task:
return self.spin_parser_validator.previous_steps_task

if self._previous_task:
return self._previous_task

prev_step_name = self.previous_steps[0]
self._previous_task = self.get_all_previous_tasks(prev_step_name)[0]
return self._previous_task

def get(self, key, default=None):
try:
return self.__getattr__(key)
except AttributeError:
return default
15 changes: 15 additions & 0 deletions metaflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,21 @@ def _init_step_decorators(flow, graph, environment, flow_datastore, logger):
)


def _init_step_decorator(flow, graph, environment, flow_datastore, logger, step_name):
for step in flow:
if step.__name__ == step_name:
for deco in step.decorators:
deco.step_init(
flow,
graph,
step.__name__,
step.decorators,
environment,
flow_datastore,
logger,
)


FlowSpecDerived = TypeVar("FlowSpecDerived", bound=FlowSpec)

# The StepFlag is a "fake" input item to be able to distinguish
Expand Down
Loading
Loading