1212# or implied. See the License for the specific language governing
1313# permissions and limitations under the License.
1414
15- from tfx .orchestration import data_types
15+ import os
16+ from typing import Union
17+
18+ from absl import logging
19+ from tfx .dsl .compiler import compiler
20+ from tfx .dsl .compiler import constants
1621from tfx .orchestration import metadata
17- from tfx .orchestration import pipeline
18- from tfx .orchestration .config import config_utils
22+ from tfx .orchestration import pipeline as pipeline_py
23+ from tfx .orchestration .local import runner_utils
1924from tfx .orchestration .local .local_dag_runner import LocalDagRunner
25+ from tfx .orchestration .portable import launcher
26+ from tfx .orchestration .portable import runtime_parameter_utils
27+ from tfx .proto .orchestration import pipeline_pb2
28+ from tfx .utils import telemetry_utils
2029
2130from zenml .logger import get_logger
2231
@@ -32,24 +41,58 @@ class ZenMLLocalDagRunner(LocalDagRunner):
3241 https://github.com/tensorflow/tfx/blob/master/tfx/orchestration/local/
3342 """
3443
35- def run (self , tfx_pipeline : pipeline .Pipeline ) -> None :
36- for component in tfx_pipeline .components :
37- (component_launcher_class , component_config ) = (
38- config_utils .find_component_launch_info (self ._config ,
39- component ))
40- driver_args = data_types .DriverArgs (
41- enable_cache = tfx_pipeline .enable_cache )
42- metadata_connection = metadata .Metadata (
43- tfx_pipeline .metadata_connection_config )
44- component_launcher = component_launcher_class .create (
45- component = component ,
46- pipeline_info = tfx_pipeline .pipeline_info ,
47- driver_args = driver_args ,
48- metadata_connection = metadata_connection ,
49- beam_pipeline_args = tfx_pipeline .beam_pipeline_args ,
50- additional_pipeline_args = tfx_pipeline
51- .additional_pipeline_args ,
52- component_config = component_config )
53- logger .info ('Component %s is running.' , component .id )
54- component_launcher .launch ()
55- logger .info ('Component %s is finished.' , component .id )
44+ def run (self , pipeline : Union [pipeline_pb2 .Pipeline ,
45+ pipeline_py .Pipeline ]) -> None :
46+ """Runs given logical pipeline locally.
47+
48+ Args:
49+ pipeline: Logical pipeline containing pipeline args and components.
50+ """
51+ # For CLI, while creating or updating pipeline, pipeline_args are extracted
52+ # and hence we avoid executing the pipeline.
53+ if 'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH' in os .environ :
54+ return
55+ run_id = pipeline .pipeline_info .run_id
56+
57+ if isinstance (pipeline , pipeline_py .Pipeline ):
58+ c = compiler .Compiler ()
59+ pipeline = c .compile (pipeline )
60+
61+ # Substitute the runtime parameter to be a concrete run_id
62+ runtime_parameter_utils .substitute_runtime_parameter (
63+ pipeline , {
64+ constants .PIPELINE_RUN_ID_PARAMETER_NAME : run_id
65+ })
66+
67+ deployment_config = runner_utils .extract_local_deployment_config (
68+ pipeline )
69+ connection_config = deployment_config .metadata_connection_config
70+
71+ logging .info ('Running pipeline:\n %s' , pipeline )
72+ logging .info ('Using deployment config:\n %s' , deployment_config )
73+ logging .info ('Using connection config:\n %s' , connection_config )
74+
75+ with telemetry_utils .scoped_labels (
76+ {telemetry_utils .LABEL_TFX_RUNNER : 'local' }):
77+ # Run each component. Note that the pipeline.components list is in
78+ # topological order.
79+ # TODO(b/171319478): After IR-based execution is used, used multi-threaded
80+ # execution so that independent components can be run in parallel.
81+ for node in pipeline .nodes :
82+ pipeline_node = node .pipeline_node
83+ node_id = pipeline_node .node_info .id
84+ executor_spec = runner_utils .extract_executor_spec (
85+ deployment_config , node_id )
86+ custom_driver_spec = runner_utils .extract_custom_driver_spec (
87+ deployment_config , node_id )
88+
89+ component_launcher = launcher .Launcher (
90+ pipeline_node = pipeline_node ,
91+ mlmd_connection = metadata .Metadata (connection_config ),
92+ pipeline_info = pipeline .pipeline_info ,
93+ pipeline_runtime_spec = pipeline .runtime_spec ,
94+ executor_spec = executor_spec ,
95+ custom_driver_spec = custom_driver_spec )
96+ logging .info ('Component %s is running.' , node_id )
97+ component_launcher .launch ()
98+ logging .info ('Component %s is finished.' , node_id )
0 commit comments