Skip to content

Commit c42f8ef

Browse files
Pipeline run not tracked in cached artifact version (#2713)
* add `pipeline_run_id` to ArtifactVersionResponse * rename * add model test * improve producer run id retrieval * tiny fix * linting * remove unneeded properties * remove unneeded properties * Auto-update of LLM Finetuning template * lint * fix test signature --------- Co-authored-by: GitHub Actions <[email protected]>
1 parent 574dcd0 commit c42f8ef

File tree

4 files changed

+188
-4
lines changed

4 files changed

+188
-4
lines changed

src/zenml/zen_stores/schemas/artifact_schemas.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,17 @@ def to_model(
315315

316316
producer_step_run_id, producer_pipeline_run_id = None, None
317317
if self.output_of_step_runs:
318-
step_run = self.output_of_step_runs[0].step_run
319-
if step_run.status == ExecutionStatus.COMPLETED:
318+
original_step_runs = [
319+
sr
320+
for sr in self.output_of_step_runs
321+
if sr.step_run.status == ExecutionStatus.COMPLETED
322+
]
323+
if len(original_step_runs) == 1:
324+
step_run = original_step_runs[0].step_run
320325
producer_step_run_id = step_run.id
321326
producer_pipeline_run_id = step_run.pipeline_run_id
322327
else:
328+
step_run = self.output_of_step_runs[0].step_run
323329
producer_step_run_id = step_run.original_step_run_id
324330

325331
# Create the body of the model
@@ -348,10 +354,13 @@ def to_model(
348354
run_metadata={m.key: m.to_model() for m in self.run_metadata},
349355
)
350356

357+
resources = None
358+
351359
return ArtifactVersionResponse(
352360
id=self.id,
353361
body=body,
354362
metadata=metadata,
363+
resources=resources,
355364
)
356365

357366
def update(

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def to_model(
208208
}
209209

210210
output_artifacts = {
211-
artifact.name: artifact.artifact_version.to_model()
211+
artifact.name: artifact.artifact_version.to_model(
212+
pipeline_run_id_in_context=self.pipeline_run_id
213+
)
212214
for artifact in self.output_artifacts
213215
}
214216

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2462,7 +2462,9 @@ def get_artifact_version(
24622462
f"{artifact_version_id}: No artifact version with this ID "
24632463
f"found."
24642464
)
2465-
return artifact_version.to_model(include_metadata=hydrate)
2465+
return artifact_version.to_model(
2466+
include_metadata=hydrate, include_resources=hydrate
2467+
)
24662468

24672469
def list_artifact_versions(
24682470
self,
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
15+
from typing import Callable
16+
from uuid import UUID
17+
18+
import pytest
19+
from typing_extensions import Annotated
20+
21+
from zenml import pipeline, step
22+
from zenml.client import Client
23+
from zenml.enums import ModelStages
24+
from zenml.model.model import Model
25+
26+
27+
@step(enable_cache=True)
28+
def simple_producer_step() -> Annotated[int, "trackable_artifact"]:
29+
return 42
30+
31+
32+
@step(enable_cache=False)
33+
def keep_pipeline_alive() -> None:
34+
pass
35+
36+
37+
@pipeline
38+
def cacheable_pipeline_which_always_run():
39+
simple_producer_step()
40+
keep_pipeline_alive()
41+
42+
43+
@pipeline
44+
def cacheable_pipeline_which_can_be_fully_cached():
45+
simple_producer_step()
46+
47+
48+
@pipeline
49+
def cacheable_pipeline_where_second_step_is_cached():
50+
simple_producer_step(id="simple_producer_step_1")
51+
simple_producer_step(id="simple_producer_step_2")
52+
53+
54+
def _validate_artifacts_state(
55+
clean_client: Client,
56+
pr_id: UUID,
57+
producer_pr_id: UUID,
58+
expected_version: int,
59+
step_name: str = "simple_producer_step",
60+
artifact_name: str = "trackable_artifact",
61+
):
62+
pr = clean_client.get_pipeline_run(pr_id)
63+
outputs_1 = pr.steps[step_name].outputs
64+
step = clean_client.get_run_step(pr.steps[step_name].id)
65+
outputs_2 = step.outputs
66+
for outputs in [outputs_1, outputs_2]:
67+
assert len(outputs) == 1
68+
assert int(outputs[artifact_name].version) == expected_version
69+
# producer ID is always the original PR
70+
assert (
71+
outputs[artifact_name].producer_pipeline_run_id == producer_pr_id
72+
)
73+
74+
artifact = clean_client.get_artifact_version(artifact_name)
75+
assert artifact.name == artifact_name
76+
assert int(artifact.version) == expected_version
77+
# producer ID is always the original PR
78+
assert artifact.producer_pipeline_run_id == producer_pr_id
79+
80+
81+
# TODO: remove clean client, ones clean env for REST is available
82+
@pytest.mark.parametrize(
83+
"pipeline",
84+
[
85+
cacheable_pipeline_which_always_run,
86+
cacheable_pipeline_which_can_be_fully_cached,
87+
],
88+
)
89+
def test_that_cached_artifact_versions_are_created_properly(
90+
pipeline: Callable, clean_client: Client
91+
):
92+
pr_orig = pipeline()
93+
94+
_validate_artifacts_state(
95+
clean_client=clean_client,
96+
pr_id=pr_orig.id,
97+
producer_pr_id=pr_orig.id,
98+
expected_version=1,
99+
)
100+
101+
pr = pipeline()
102+
103+
pr = clean_client.get_pipeline_run(pr.id)
104+
_validate_artifacts_state(
105+
clean_client=clean_client,
106+
pr_id=pr.id,
107+
producer_pr_id=pr_orig.id,
108+
expected_version=1, # cached artifact doesn't produce new version
109+
)
110+
111+
112+
# TODO: remove clean client, ones clean env for REST is available
113+
def test_that_cached_artifact_versions_are_created_properly_for_second_step(
114+
clean_client: Client,
115+
):
116+
pr_orig = cacheable_pipeline_where_second_step_is_cached()
117+
118+
_validate_artifacts_state(
119+
clean_client=clean_client,
120+
pr_id=pr_orig.id,
121+
producer_pr_id=pr_orig.id,
122+
step_name="simple_producer_step_1",
123+
expected_version=1,
124+
)
125+
_validate_artifacts_state(
126+
clean_client=clean_client,
127+
pr_id=pr_orig.id,
128+
producer_pr_id=pr_orig.id,
129+
step_name="simple_producer_step_2",
130+
expected_version=1,
131+
)
132+
133+
pr = cacheable_pipeline_where_second_step_is_cached()
134+
135+
pr = clean_client.get_pipeline_run(pr.id)
136+
_validate_artifacts_state(
137+
clean_client=clean_client,
138+
pr_id=pr.id,
139+
producer_pr_id=pr_orig.id,
140+
step_name="simple_producer_step_1",
141+
expected_version=1, # cached artifact doesn't produce new version
142+
)
143+
_validate_artifacts_state(
144+
clean_client=clean_client,
145+
pr_id=pr.id,
146+
producer_pr_id=pr_orig.id,
147+
step_name="simple_producer_step_2",
148+
expected_version=1, # cached artifact doesn't produce new version
149+
)
150+
151+
152+
def test_that_cached_artifact_versions_are_created_properly_for_model_version(
153+
clean_client: Client,
154+
):
155+
pr_orig = cacheable_pipeline_which_always_run.with_options(
156+
model=Model(name="foo")
157+
)()
158+
159+
mv = clean_client.get_model_version("foo", ModelStages.LATEST)
160+
assert (
161+
mv.data_artifacts["trackable_artifact"]["1"].producer_pipeline_run_id
162+
== pr_orig.id
163+
)
164+
165+
cacheable_pipeline_which_always_run.with_options(model=Model(name="foo"))()
166+
167+
mv = clean_client.get_model_version("foo", ModelStages.LATEST)
168+
assert (
169+
mv.data_artifacts["trackable_artifact"]["1"].producer_pipeline_run_id
170+
== pr_orig.id
171+
)

0 commit comments

Comments
 (0)