12
12
# or implied. See the License for the specific language governing
13
13
# permissions and limitations under the License.
14
14
15
- from uuid import uuid4
16
15
17
16
import pytest
18
17
19
18
from zenml .config .step_configurations import Step
20
19
from zenml .enums import StepRunInputArtifactType
21
20
from zenml .exceptions import InputResolutionError
22
- from zenml .models import Page , PipelineRunResponse
21
+ from zenml .models import Page
23
22
from zenml .models .v2 .core .artifact_version import ArtifactVersionResponse
24
23
from zenml .models .v2 .core .step_run import StepRunInputResponse
25
24
from zenml .orchestrators import input_utils
@@ -29,6 +28,7 @@ def test_input_resolution(
29
28
mocker ,
30
29
sample_artifact_version_model : ArtifactVersionResponse ,
31
30
create_step_run ,
31
+ sample_pipeline_run ,
32
32
):
33
33
"""Tests that input resolution works if the correct models exist in the
34
34
zen store."""
@@ -60,7 +60,7 @@ def test_input_resolution(
60
60
)
61
61
62
62
input_artifacts , parent_ids = input_utils .resolve_step_inputs (
63
- step = step , pipeline_run = PipelineRunResponse ( id = uuid4 (), name = "foo" )
63
+ step = step , pipeline_run = sample_pipeline_run
64
64
)
65
65
assert input_artifacts == {
66
66
"input_name" : StepRunInputResponse (
@@ -71,7 +71,7 @@ def test_input_resolution(
71
71
assert parent_ids == [step_run .id ]
72
72
73
73
74
- def test_input_resolution_with_missing_step_run (mocker ):
74
+ def test_input_resolution_with_missing_step_run (mocker , sample_pipeline_run ):
75
75
"""Tests that input resolution fails if the upstream step run is missing."""
76
76
mocker .patch (
77
77
"zenml.zen_stores.sql_zen_store.SqlZenStore.list_run_steps" ,
@@ -97,11 +97,13 @@ def test_input_resolution_with_missing_step_run(mocker):
97
97
98
98
with pytest .raises (InputResolutionError ):
99
99
input_utils .resolve_step_inputs (
100
- step = step , pipeline_run = PipelineRunResponse ( id = uuid4 (), name = "foo" )
100
+ step = step , pipeline_run = sample_pipeline_run
101
101
)
102
102
103
103
104
- def test_input_resolution_with_missing_artifact (mocker , create_step_run ):
104
+ def test_input_resolution_with_missing_artifact (
105
+ mocker , create_step_run , sample_pipeline_run
106
+ ):
105
107
"""Tests that input resolution fails if the upstream step run output
106
108
artifact is missing."""
107
109
step_run = create_step_run (
@@ -132,12 +134,12 @@ def test_input_resolution_with_missing_artifact(mocker, create_step_run):
132
134
133
135
with pytest .raises (InputResolutionError ):
134
136
input_utils .resolve_step_inputs (
135
- step = step , pipeline_run = PipelineRunResponse ( id = uuid4 (), name = "foo" )
137
+ step = step , pipeline_run = sample_pipeline_run
136
138
)
137
139
138
140
139
141
def test_input_resolution_fetches_all_run_steps (
140
- mocker , sample_artifact_version_model , create_step_run
142
+ mocker , sample_artifact_version_model , create_step_run , sample_pipeline_run
141
143
):
142
144
"""Tests that input resolution fetches all step runs of the pipeline run."""
143
145
step_run = create_step_run (
@@ -178,7 +180,7 @@ def test_input_resolution_fetches_all_run_steps(
178
180
)
179
181
180
182
input_utils .resolve_step_inputs (
181
- step = step , pipeline_run = PipelineRunResponse ( id = uuid4 (), name = "foo" )
183
+ step = step , pipeline_run = sample_pipeline_run
182
184
)
183
185
184
186
# `resolve_step_inputs(...)` depaginates the run steps so we fetch all
0 commit comments