-
Notifications
You must be signed in to change notification settings - Fork 791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Initial commit for spin steps #2036
Open
talsperre
wants to merge
17
commits into
master
Choose a base branch
from
dev/baby-steps
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
2b78ef4
WIP: Initial commit for spin steps
talsperre 2006c5b
Update input spin datastores
talsperre 936f0d8
Update step execution code
talsperre e5b3e1e
Working POC for linear and foreach steps
talsperre 11c3354
Remove debug logs, fix list attribute access issue
talsperre 1a23dfc
Update current singleton stubs for spin steps
talsperre 75f9e99
Fix bugs in spin for join steps
talsperre 726e76d
Rename baby steps to spin steps
talsperre 3a51eff
Persis output artifacts
talsperre 33b2059
Support spin command in metaflow runner
talsperre 0b883b0
Pass in correct values for deco hooks in spin steps
talsperre 4dc662d
Refactor spin steps, move spin logic to start subcommand
talsperre 2512909
Fix bug in join inputs datastore
talsperre ee1e09d
initial commit for spin runtime
talsperre 60eba72
Test commit
talsperre 517203d
SpinRuntime kind of works
talsperre 8f3b071
Set _namespace_check to False in spin datastore
talsperre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
class SpinDatastore(object): | ||
def __init__(self, spin_parser_validator): | ||
self.spin_parser_validator = spin_parser_validator | ||
|
||
def get_task_for_step(self, step_name): | ||
""" | ||
Returns an instance of the task corresponding to the step name and run id. | ||
""" | ||
if self.run_id is None: | ||
raise ValueError("No run_id provided") | ||
|
||
from metaflow import Step | ||
|
||
step = Step( | ||
f"{self.flow_name}/{self.run_id}/{step_name}", _namespace_check=False | ||
) | ||
return next(iter(step.tasks())) | ||
|
||
@property | ||
def flow_name(self): | ||
return self.spin_parser_validator.flow_name | ||
|
||
@property | ||
def parsed_ancestor_tasks(self): | ||
return self.spin_parser_validator.parsed_ancestor_tasks | ||
|
||
@property | ||
def foreach_var(self): | ||
return self.spin_parser_validator.foreach_var | ||
|
||
@property | ||
def foreach_index(self): | ||
return self.spin_parser_validator.foreach_index | ||
|
||
@property | ||
def foreach_value(self): | ||
return self.spin_parser_validator.foreach_value | ||
|
||
@property | ||
def foreach_stack(self): | ||
return self.spin_parser_validator.foreach_stack | ||
|
||
@property | ||
def artifacts(self): | ||
return self.spin_parser_validator.artifacts | ||
|
||
@property | ||
def run_id(self): | ||
return self.spin_parser_validator.run_id | ||
|
||
@property | ||
def step_name(self): | ||
return self.spin_parser_validator.step_name | ||
|
||
@property | ||
def task(self): | ||
return self.spin_parser_validator.task | ||
|
||
@property | ||
def step_type(self): | ||
return self.spin_parser_validator.step_type | ||
|
||
@property | ||
def previous_steps(self): | ||
return self.spin_parser_validator.previous_steps | ||
|
||
@property | ||
def required_ancestor_tasks(self): | ||
return self.spin_parser_validator.required_ancestor_tasks | ||
|
||
@property | ||
def is_foreach_step(self): | ||
return len(self.spin_parser_validator.foreach_stack) > 0 | ||
|
||
def get_all_previous_tasks(self, prev_step_name): | ||
# We go through all the required ancestors for the current step | ||
# and filter the tasks from the previous step whose foreach stack | ||
# entries match the required ancestors | ||
from metaflow import Step | ||
|
||
def _parse_foreach_stack(foreach_stack): | ||
return { | ||
entry.step: { | ||
"task_val": entry.value, | ||
"task_index": entry.index, | ||
} | ||
for entry in foreach_stack.data | ||
} | ||
|
||
def _parse_required_ancestor_tasks(required_ancestor_tasks): | ||
result = {} | ||
for required_ancestor in required_ancestor_tasks: | ||
step_name = required_ancestor.step_name | ||
task_specifier = required_ancestor.task_specifier | ||
task_val = required_ancestor.value | ||
if task_specifier == "task_id": | ||
from metaflow import Task | ||
|
||
task = Task( | ||
f"{self.flow_name}/{self.run_id}/{step_name}/{task_val}" | ||
) | ||
return { | ||
**_parse_foreach_stack(task["_foreach_stack"]), | ||
} | ||
elif task_specifier == "task_index": | ||
result[step_name] = { | ||
"task_index": task_val, | ||
} | ||
elif task_specifier == "task_val": | ||
result[step_name] = { | ||
"task_val": str(task_val), | ||
} | ||
else: | ||
raise ValueError("Invalid task specifier") | ||
return result | ||
|
||
def _is_ancestor(foreach_stack, required_ancestors): | ||
for step_name, required_ancestor in required_ancestors.items(): | ||
if step_name not in foreach_stack: | ||
return False | ||
if ( | ||
"task_val" in required_ancestor | ||
and str(required_ancestor["task_val"]) | ||
!= foreach_stack[step_name]["task_val"] | ||
): | ||
return False | ||
if ( | ||
"task_index" in required_ancestor | ||
and required_ancestor["task_index"] | ||
!= foreach_stack[step_name]["task_index"] | ||
): | ||
return False | ||
return True | ||
|
||
previous_tasks = [] | ||
prev_step = Step( | ||
f"{self.flow_name}/{self.run_id}/{prev_step_name}", _namespace_check=False | ||
) | ||
for task in prev_step.tasks(): | ||
foreach_stack = task["_foreach_stack"] | ||
required_ancestor_tasks_parsed = _parse_required_ancestor_tasks( | ||
self.required_ancestor_tasks | ||
) | ||
foreach_stack_parsed = _parse_foreach_stack(foreach_stack) | ||
if _is_ancestor(foreach_stack_parsed, required_ancestor_tasks_parsed): | ||
previous_tasks.append(task) | ||
|
||
if len(previous_tasks) == 0: | ||
raise ValueError( | ||
"No previous tasks found for the current step. Please check if the values provided for " | ||
"ancestor tasks are correct." | ||
) | ||
return previous_tasks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from . import SpinDatastore | ||
|
||
|
||
class SpinInput(object): | ||
def __init__(self, artifacts, task=None): | ||
self.artifacts = artifacts | ||
self.task = task | ||
|
||
def __getattr__(self, name): | ||
# We always look for any artifacts provided by the user first | ||
if self.artifacts is not None and name in self.artifacts: | ||
return self.artifacts[name] | ||
|
||
if self.task is None: | ||
raise AttributeError( | ||
f"Attribute '{name}' not provided by the user and no `task` was provided." | ||
) | ||
|
||
try: | ||
return getattr(self.task.artifacts, name).data | ||
except AttributeError: | ||
raise AttributeError( | ||
f"Attribute '{name}' not found in the previous execution of the task for " | ||
f"`{self.step_name}`." | ||
) | ||
|
||
raise AttributeError( | ||
f"Attribute '{name}' not found in the artifacts provided by the user or in the" | ||
f"the previous execution of the task for `{self.step_name}`" | ||
) | ||
|
||
|
||
class StaticSpinInputsDatastore(SpinDatastore): | ||
def __init__(self, spin_parser_validator): | ||
super(StaticSpinInputsDatastore, self).__init__(spin_parser_validator) | ||
self._previous_tasks = {} | ||
|
||
def __getattr__(self, name): | ||
if name not in self.previous_steps: | ||
raise AttributeError( | ||
f"Attribute '{name}' not found in the previous execution of the task for " | ||
f"`{self.step_name}`." | ||
) | ||
|
||
input_step = SpinInput( | ||
self.spin_parser_validator.artifacts["join"][name], | ||
self.get_previous_tasks[name], | ||
) | ||
setattr(self, name, input_step) | ||
return input_step | ||
|
||
def __iter__(self): | ||
for prev_step_name in self.previous_steps: | ||
yield self[prev_step_name] | ||
|
||
def __len__(self): | ||
return len(self.get_previous_tasks) | ||
|
||
@property | ||
def get_previous_tasks(self): | ||
if self._previous_tasks: | ||
return self._previous_tasks | ||
|
||
for prev_step_name in self.previous_steps: | ||
previous_task = self.get_all_previous_tasks(prev_step_name) | ||
self._previous_tasks[prev_step_name] = previous_task | ||
return self._previous_tasks | ||
|
||
|
||
class SpinInputsDatastore(SpinDatastore): | ||
def __init__(self, spin_parser_validator): | ||
super(SpinInputsDatastore, self).__init__(spin_parser_validator) | ||
self._previous_tasks = None | ||
|
||
def __len__(self): | ||
return len(self.get_previous_tasks) | ||
|
||
def __getitem__(self, idx): | ||
_item_task = self.get_previous_tasks[idx] | ||
_item_artifacts = self.spin_parser_validator.artifacts | ||
# _item_artifacts = self.spin_parser_validator.artifacts[idx] | ||
return SpinInput(_item_artifacts, _item_task) | ||
|
||
def __iter__(self): | ||
for idx in range(len(self.get_previous_tasks)): | ||
yield self[idx] | ||
|
||
@property | ||
def get_previous_tasks(self): | ||
if self._previous_tasks: | ||
return self._previous_tasks | ||
|
||
# This a join step for a foreach split, so only has one previous step | ||
prev_step_name = self.previous_steps[0] | ||
self._previous_tasks = self.get_all_previous_tasks(prev_step_name) | ||
# Sort the tasks by index | ||
self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index) | ||
return self._previous_tasks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from . import SpinDatastore | ||
|
||
|
||
class SpinStepDatastore(SpinDatastore): | ||
def __init__(self, spin_parser_validator): | ||
super(SpinStepDatastore, self).__init__(spin_parser_validator) | ||
self._previous_task = None | ||
self._data = {} | ||
|
||
# Set them to empty dictionaries in order to persist artifacts | ||
# See `persist` method in `TaskDatastore` for more details | ||
self._objects = {} | ||
self._info = {} | ||
|
||
def __contains__(self, name): | ||
try: | ||
_ = self.__getattr__(name) | ||
except AttributeError: | ||
return False | ||
return True | ||
|
||
def __getitem__(self, name): | ||
return self.__getattr__(name) | ||
|
||
def __setitem__(self, name, value): | ||
self._data[name] = value | ||
|
||
def __getattr__(self, name): | ||
# Check internal data first | ||
if name in self._data: | ||
return self._data[name] | ||
|
||
# We always look for any artifacts provided by the user first | ||
if name in self.artifacts: | ||
return self.artifacts[name] | ||
|
||
if self.run_id is None: | ||
raise AttributeError( | ||
f"Attribute '{name}' not provided by the user and no `run_id` was provided. " | ||
) | ||
|
||
# If the linear step is part of a foreach step, we need to set the input attribute | ||
if name == "input": | ||
if self.foreach_index: | ||
_foreach_var = self.foreach_var | ||
_foreach_index = self.foreach_index | ||
_foreach_step_name = self.step_name | ||
elif len(self.foreach_stack) > 0: | ||
_foreach_stack = self.previous_task["_foreach_stack"].data | ||
cur_foreach_step_var = _foreach_stack[-1].var | ||
cur_foreach_step_index = _foreach_stack[-1].index | ||
cur_foreach_step_name = _foreach_stack[-1].step | ||
foreach_task = self.get_task_for_step(cur_foreach_step_name) | ||
foreach_value = foreach_task[cur_foreach_step_var].data[ | ||
cur_foreach_step_index | ||
] | ||
setattr(self, name, foreach_value) | ||
return foreach_value | ||
|
||
foreach_task = self.get_task_for_step(_foreach_step_name) | ||
foreach_value = foreach_task[_foreach_var].data[_foreach_index] | ||
setattr(self, name, foreach_value) | ||
return foreach_value | ||
|
||
# If the linear step is part of a foreach step, we need to set the index attribute | ||
if name == "index": | ||
if self.foreach_index: | ||
setattr(self, name, self.foreach_index) | ||
return self.foreach_index | ||
if len(self.foreach_stack) > 0: | ||
cur_foreach_step_index = ( | ||
self.previous_task["_foreach_stack"].data[-1].index | ||
) | ||
setattr(self, name, cur_foreach_step_index) | ||
return cur_foreach_step_index | ||
raise AttributeError( | ||
f"Attribute index does not exist for step `{self.step_name}` as it is not part of a foreach step." | ||
) | ||
|
||
# If the user has not provided the artifact, we look for it in the | ||
# task using the client API | ||
try: | ||
return getattr(self.previous_task.artifacts, name).data | ||
except AttributeError: | ||
raise AttributeError( | ||
f"Attribute '{name}' not found in the previous execution of the task for " | ||
f"`{self.step_name}`." | ||
) | ||
|
||
@property | ||
def previous_task(self): | ||
# Since this is not a join step, we can safely assume that there is only one | ||
# previous step and one corresponding previous task | ||
if self.spin_parser_validator.previous_steps_task: | ||
return self.spin_parser_validator.previous_steps_task | ||
|
||
if self._previous_task: | ||
return self._previous_task | ||
|
||
prev_step_name = self.previous_steps[0] | ||
self._previous_task = self.get_all_previous_tasks(prev_step_name)[0] | ||
return self._previous_task | ||
|
||
def get(self, key, default=None): | ||
try: | ||
return self.__getattr__(key) | ||
except AttributeError: | ||
return default |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this ever reached?