Skip to content

Commit 0ccb1fd

Browse files
authored
Fix input resolution for steps with dynamic artifact names (#3228)
* Fix input resolution for steps with dynamic artifact names * Improve logic * Linting * Add test * Fix variable access * Really fix test * Rename
1 parent ee48d1a commit 0ccb1fd

File tree

6 files changed

+75
-15
lines changed

6 files changed

+75
-15
lines changed

src/zenml/models/v2/core/pipeline_run.py

+9
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,15 @@ def is_templatable(self) -> bool:
550550
"""
551551
return self.get_metadata().is_templatable
552552

553+
@property
554+
def step_substitutions(self) -> Dict[str, Dict[str, str]]:
555+
"""The `step_substitutions` property.
556+
557+
Returns:
558+
the value of the property.
559+
"""
560+
return self.get_metadata().step_substitutions
561+
553562
@property
554563
def model_version(self) -> Optional[ModelVersionResponse]:
555564
"""The `model_version` property.

src/zenml/orchestrators/input_utils.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from zenml.config.step_configurations import Step
2121
from zenml.enums import ArtifactSaveType, StepRunInputArtifactType
2222
from zenml.exceptions import InputResolutionError
23-
from zenml.utils import pagination_utils
23+
from zenml.utils import pagination_utils, string_utils
2424

2525
if TYPE_CHECKING:
2626
from zenml.models import PipelineRunResponse
@@ -53,7 +53,8 @@ def resolve_step_inputs(
5353
current_run_steps = {
5454
run_step.name: run_step
5555
for run_step in pagination_utils.depaginate(
56-
Client().list_run_steps, pipeline_run_id=pipeline_run.id
56+
Client().list_run_steps,
57+
pipeline_run_id=pipeline_run.id,
5758
)
5859
}
5960

@@ -66,11 +67,23 @@ def resolve_step_inputs(
6667
f"No step `{input_.step_name}` found in current run."
6768
)
6869

70+
# Try to get the substitutions from the pipeline run first, as we
71+
# already have a hydrated version of that. In the unlikely case
72+
# that the pipeline run is outdated, we fetch it from the step
73+
# run instead which will costs us one hydration call.
74+
substitutions = (
75+
pipeline_run.step_substitutions.get(step_run.name)
76+
or step_run.config.substitutions
77+
)
78+
output_name = string_utils.format_name_template(
79+
input_.output_name, substitutions=substitutions
80+
)
81+
6982
try:
70-
outputs = step_run.outputs[input_.output_name]
83+
outputs = step_run.outputs[output_name]
7184
except KeyError:
7285
raise InputResolutionError(
73-
f"No step output `{input_.output_name}` found for step "
86+
f"No step output `{output_name}` found for step "
7487
f"`{input_.step_name}`."
7588
)
7689

@@ -83,12 +96,12 @@ def resolve_step_inputs(
8396
# This should never happen, there can only be a single regular step
8497
# output for a name
8598
raise InputResolutionError(
86-
f"Too many step outputs for output `{input_.output_name}` of "
99+
f"Too many step outputs for output `{output_name}` of "
87100
f"step `{input_.step_name}`."
88101
)
89102
elif len(step_outputs) == 0:
90103
raise InputResolutionError(
91-
f"No step output `{input_.output_name}` found for step "
104+
f"No step output `{output_name}` found for step "
92105
f"`{input_.step_name}`."
93106
)
94107

src/zenml/orchestrators/step_run_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,9 @@ def create_cached_step_runs(
309309
for invocation_id in cache_candidates:
310310
visited_invocations.add(invocation_id)
311311

312+
# Make sure the request factory has the most up to date pipeline
313+
# run to avoid hydration calls
314+
request_factory.pipeline_run = pipeline_run
312315
try:
313316
step_run_request = request_factory.create_request(
314317
invocation_id

tests/integration/functional/steps/test_step_naming.py renamed to tests/integration/functional/steps/test_dynamic_artifact_names.py

+20
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
1414

15+
from contextlib import ExitStack as does_not_raise
1516
from typing import Callable, Tuple
1617

1718
import pytest
@@ -122,6 +123,11 @@ def mixed_with_unannotated_returns() -> (
122123
)
123124

124125

126+
@step
127+
def step_with_string_input(input_: str) -> None:
128+
pass
129+
130+
125131
@pytest.mark.parametrize(
126132
"step",
127133
[
@@ -362,3 +368,17 @@ def _inner(pass_to_step: str = ""):
362368
assert p2_step_subs["date"] == "step_level"
363369
assert p1_step_subs["funny_name"] == "pipeline_level"
364370
assert p2_step_subs["funny_name"] == "step_level"
371+
372+
373+
def test_dynamically_named_artifacts_in_downstream_steps(
374+
clean_client: "Client",
375+
):
376+
"""Test that dynamically named artifacts can be used in downstream steps."""
377+
378+
@pipeline(enable_cache=False)
379+
def _inner(ret: str):
380+
artifact = dynamic_single_string_standard()
381+
step_with_string_input(artifact)
382+
383+
with does_not_raise():
384+
_inner("output_1")

tests/unit/conftest.py

+13
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
14+
from collections import defaultdict
1415
from datetime import datetime
1516
from typing import Any, Callable, Dict, List, Optional
1617
from uuid import uuid4
@@ -416,6 +417,12 @@ def sample_pipeline_run(
416417
sample_workspace_model: WorkspaceResponse,
417418
) -> PipelineRunResponse:
418419
"""Return sample pipeline run view for testing purposes."""
420+
now = datetime.utcnow()
421+
substitutions = {
422+
"date": now.strftime("%Y_%m_%d"),
423+
"time": now.strftime("%H_%M_%S_%f"),
424+
}
425+
419426
return PipelineRunResponse(
420427
id=uuid4(),
421428
name="sample_run_name",
@@ -430,6 +437,7 @@ def sample_pipeline_run(
430437
workspace=sample_workspace_model,
431438
config=PipelineConfiguration(name="aria_pipeline"),
432439
is_templatable=False,
440+
steps_substitutions=defaultdict(lambda: substitutions.copy()),
433441
),
434442
resources=PipelineRunResponseResources(tags=[]),
435443
)
@@ -543,10 +551,15 @@ def f(
543551
spec = StepSpec.model_validate(
544552
{"source": "module.step_class", "upstream_steps": []}
545553
)
554+
now = datetime.utcnow()
546555
config = StepConfiguration.model_validate(
547556
{
548557
"name": step_name,
549558
"outputs": outputs or {},
559+
"substitutions": {
560+
"date": now.strftime("%Y_%m_%d"),
561+
"time": now.strftime("%H_%M_%S_%f"),
562+
},
550563
}
551564
)
552565
return StepRunResponse(

tests/unit/orchestrators/test_input_utils.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
1414

15-
from uuid import uuid4
1615

1716
import pytest
1817

1918
from zenml.config.step_configurations import Step
2019
from zenml.enums import StepRunInputArtifactType
2120
from zenml.exceptions import InputResolutionError
22-
from zenml.models import Page, PipelineRunResponse
21+
from zenml.models import Page
2322
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
2423
from zenml.models.v2.core.step_run import StepRunInputResponse
2524
from zenml.orchestrators import input_utils
@@ -29,6 +28,7 @@ def test_input_resolution(
2928
mocker,
3029
sample_artifact_version_model: ArtifactVersionResponse,
3130
create_step_run,
31+
sample_pipeline_run,
3232
):
3333
"""Tests that input resolution works if the correct models exist in the
3434
zen store."""
@@ -60,7 +60,7 @@ def test_input_resolution(
6060
)
6161

6262
input_artifacts, parent_ids = input_utils.resolve_step_inputs(
63-
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
63+
step=step, pipeline_run=sample_pipeline_run
6464
)
6565
assert input_artifacts == {
6666
"input_name": StepRunInputResponse(
@@ -71,7 +71,7 @@ def test_input_resolution(
7171
assert parent_ids == [step_run.id]
7272

7373

74-
def test_input_resolution_with_missing_step_run(mocker):
74+
def test_input_resolution_with_missing_step_run(mocker, sample_pipeline_run):
7575
"""Tests that input resolution fails if the upstream step run is missing."""
7676
mocker.patch(
7777
"zenml.zen_stores.sql_zen_store.SqlZenStore.list_run_steps",
@@ -97,11 +97,13 @@ def test_input_resolution_with_missing_step_run(mocker):
9797

9898
with pytest.raises(InputResolutionError):
9999
input_utils.resolve_step_inputs(
100-
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
100+
step=step, pipeline_run=sample_pipeline_run
101101
)
102102

103103

104-
def test_input_resolution_with_missing_artifact(mocker, create_step_run):
104+
def test_input_resolution_with_missing_artifact(
105+
mocker, create_step_run, sample_pipeline_run
106+
):
105107
"""Tests that input resolution fails if the upstream step run output
106108
artifact is missing."""
107109
step_run = create_step_run(
@@ -132,12 +134,12 @@ def test_input_resolution_with_missing_artifact(mocker, create_step_run):
132134

133135
with pytest.raises(InputResolutionError):
134136
input_utils.resolve_step_inputs(
135-
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
137+
step=step, pipeline_run=sample_pipeline_run
136138
)
137139

138140

139141
def test_input_resolution_fetches_all_run_steps(
140-
mocker, sample_artifact_version_model, create_step_run
142+
mocker, sample_artifact_version_model, create_step_run, sample_pipeline_run
141143
):
142144
"""Tests that input resolution fetches all step runs of the pipeline run."""
143145
step_run = create_step_run(
@@ -178,7 +180,7 @@ def test_input_resolution_fetches_all_run_steps(
178180
)
179181

180182
input_utils.resolve_step_inputs(
181-
step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo")
183+
step=step, pipeline_run=sample_pipeline_run
182184
)
183185

184186
# `resolve_step_inputs(...)` depaginates the run steps so we fetch all

0 commit comments

Comments
 (0)