Skip to content

Commit abd7f5c

Browse files
authored
Merge pull request #87 from maiot-io/baris/upgrade
Baris/upgrade
2 parents 0d0fb8e + 2be3fa1 commit abd7f5c

39 files changed

+353
-624
lines changed

requirements.txt

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
absl-py==0.10.0
22
pip-check-reqs>=2.0.1,<3
33
click>=7.0,<8
4-
setuptools>=38.4.0
4+
setuptools==46.4.0
55
nbformat>=5.0.4
66
panel==0.8.3
77
plotly==4.0.0
88
tabulate==0.8.7
9-
numpy==1.18.0
9+
numpy==1.19.2
1010
httplib2==0.17.0
11-
tfx==0.26.1
11+
six==1.15.0
12+
tfx==0.30.0
13+
tensorflow_datasets==4.3.0
1214
fire==0.3.1
1315
gitpython==3.1.11
1416
analytics-python==1.2.9
1517
distro==1.5.0
16-
tensorflow>=2.3.0,<2.4.0
17-
tensorflow-serving-api==2.3.0
18-
18+
tensorflow==2.4.1
19+
grpcio==1.32.0
20+
dill==0.3.1.1
21+
google-cloud-bigquery==1.28.0
1922

2023
# docs
2124
jupyter-book==0.9.1
@@ -32,8 +35,8 @@ sphinxext-opengraph==0.3.1
3235
cortex==0.29.0
3336

3437
# gcp
35-
apache-beam[gcp]==2.27.0
36-
apache-beam==2.27.0
38+
apache-beam[gcp]==2.28.0
39+
apache-beam==2.28.0
3740
google-apitools==0.5.31
3841

3942
# pytorch

zenml/backends/orchestrator/base/zenml_local_orchestrator.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,20 @@
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
1414

15-
from tfx.orchestration import data_types
15+
import os
16+
from typing import Union
17+
18+
from absl import logging
19+
from tfx.dsl.compiler import compiler
20+
from tfx.dsl.compiler import constants
1621
from tfx.orchestration import metadata
17-
from tfx.orchestration import pipeline
18-
from tfx.orchestration.config import config_utils
22+
from tfx.orchestration import pipeline as pipeline_py
23+
from tfx.orchestration.local import runner_utils
1924
from tfx.orchestration.local.local_dag_runner import LocalDagRunner
25+
from tfx.orchestration.portable import launcher
26+
from tfx.orchestration.portable import runtime_parameter_utils
27+
from tfx.proto.orchestration import pipeline_pb2
28+
from tfx.utils import telemetry_utils
2029

2130
from zenml.logger import get_logger
2231

@@ -32,24 +41,58 @@ class ZenMLLocalDagRunner(LocalDagRunner):
3241
https://github.com/tensorflow/tfx/blob/master/tfx/orchestration/local/
3342
"""
3443

35-
def run(self, tfx_pipeline: pipeline.Pipeline) -> None:
36-
for component in tfx_pipeline.components:
37-
(component_launcher_class, component_config) = (
38-
config_utils.find_component_launch_info(self._config,
39-
component))
40-
driver_args = data_types.DriverArgs(
41-
enable_cache=tfx_pipeline.enable_cache)
42-
metadata_connection = metadata.Metadata(
43-
tfx_pipeline.metadata_connection_config)
44-
component_launcher = component_launcher_class.create(
45-
component=component,
46-
pipeline_info=tfx_pipeline.pipeline_info,
47-
driver_args=driver_args,
48-
metadata_connection=metadata_connection,
49-
beam_pipeline_args=tfx_pipeline.beam_pipeline_args,
50-
additional_pipeline_args=tfx_pipeline
51-
.additional_pipeline_args,
52-
component_config=component_config)
53-
logger.info('Component %s is running.', component.id)
54-
component_launcher.launch()
55-
logger.info('Component %s is finished.', component.id)
44+
def run(self, pipeline: Union[pipeline_pb2.Pipeline,
45+
pipeline_py.Pipeline]) -> None:
46+
"""Runs given logical pipeline locally.
47+
48+
Args:
49+
pipeline: Logical pipeline containing pipeline args and components.
50+
"""
51+
# For CLI, while creating or updating pipeline, pipeline_args are extracted
52+
# and hence we avoid executing the pipeline.
53+
if 'TFX_JSON_EXPORT_PIPELINE_ARGS_PATH' in os.environ:
54+
return
55+
run_id = pipeline.pipeline_info.run_id
56+
57+
if isinstance(pipeline, pipeline_py.Pipeline):
58+
c = compiler.Compiler()
59+
pipeline = c.compile(pipeline)
60+
61+
# Substitute the runtime parameter to be a concrete run_id
62+
runtime_parameter_utils.substitute_runtime_parameter(
63+
pipeline, {
64+
constants.PIPELINE_RUN_ID_PARAMETER_NAME: run_id
65+
})
66+
67+
deployment_config = runner_utils.extract_local_deployment_config(
68+
pipeline)
69+
connection_config = deployment_config.metadata_connection_config
70+
71+
logging.info('Running pipeline:\n %s', pipeline)
72+
logging.info('Using deployment config:\n %s', deployment_config)
73+
logging.info('Using connection config:\n %s', connection_config)
74+
75+
with telemetry_utils.scoped_labels(
76+
{telemetry_utils.LABEL_TFX_RUNNER: 'local'}):
77+
# Run each component. Note that the pipeline.components list is in
78+
# topological order.
79+
# TODO(b/171319478): After IR-based execution is used, used multi-threaded
80+
# execution so that independent components can be run in parallel.
81+
for node in pipeline.nodes:
82+
pipeline_node = node.pipeline_node
83+
node_id = pipeline_node.node_info.id
84+
executor_spec = runner_utils.extract_executor_spec(
85+
deployment_config, node_id)
86+
custom_driver_spec = runner_utils.extract_custom_driver_spec(
87+
deployment_config, node_id)
88+
89+
component_launcher = launcher.Launcher(
90+
pipeline_node=pipeline_node,
91+
mlmd_connection=metadata.Metadata(connection_config),
92+
pipeline_info=pipeline.pipeline_info,
93+
pipeline_runtime_spec=pipeline.runtime_spec,
94+
executor_spec=executor_spec,
95+
custom_driver_spec=custom_driver_spec)
96+
logging.info('Component %s is running.', node_id)
97+
component_launcher.launch()
98+
logging.info('Component %s is finished.', node_id)

zenml/backends/processing/processing_spark_backend.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# permissions and limitations under the License.
1414
"""Definition of the Spark Processing Backend"""
1515

16+
import multiprocessing
17+
from typing import Text, Optional, List
18+
1619
from zenml.backends.processing import ProcessingBaseBackend
1720

1821

@@ -30,6 +33,39 @@ class ProcessingSparkBackend(ProcessingBaseBackend):
3033
This backend is not implemented yet.
3134
"""
3235

33-
def __init__(self, **kwargs):
34-
super().__init__(**kwargs)
35-
raise NotImplementedError('Its coming soon!')
36+
def __init__(self,
37+
spark_rest_url: Text,
38+
environment_type: Text = 'LOOPBACK',
39+
environment_cache_millis: int = 1000000,
40+
spark_submit_uber_jar: bool = True):
41+
42+
self.spark_rest_url = spark_rest_url
43+
self.environment_type = environment_type
44+
self.environment_cache_millis = environment_cache_millis
45+
self.spark_submit_uber_jar = spark_submit_uber_jar
46+
47+
try:
48+
parallelism = multiprocessing.cpu_count()
49+
except NotImplementedError:
50+
parallelism = 1
51+
self.sdk_worker_parallelism = parallelism
52+
53+
super(ProcessingSparkBackend, self).__init__(
54+
environment_type=environment_type,
55+
environment_cache_millis=environment_cache_millis,
56+
spark_submit_uber_jar=spark_submit_uber_jar,
57+
spark_rest_url=self.spark_rest_url)
58+
59+
def get_beam_args(self,
60+
pipeline_name: Text = None,
61+
pipeline_root: Text = None) -> Optional[List[Text]]:
62+
63+
return [
64+
'--runner=SparkRunner',
65+
'--spark_rest_url=' + self.spark_rest_url,
66+
'--environment_type=' + self.environment_type,
67+
'--environment_cache_millis=' + str(self.environment_cache_millis),
68+
'--sdk_worker_parallelism=' + str(self.sdk_worker_parallelism),
69+
'--experiments=use_loopback_process_worker=True',
70+
'--experiments=pre_optimize=all',
71+
'--spark_submit_uber_jar']

zenml/components/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,3 @@
1919
from zenml.components.split_gen.component import SplitGen
2020
from zenml.components.tokenizer.component import Tokenizer
2121
from zenml.components.trainer.component import Trainer
22-
from zenml.components.transform_simple.component import SimpleTransform

zenml/components/bulk_inferrer/component.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
class BulkInferrerSpec(ComponentSpec):
3030
PARAMETERS = {
3131
StepKeys.SOURCE: ExecutionParameter(type=Text),
32-
StepKeys.ARGS: ExecutionParameter(type=Dict[Text, Any]),
32+
StepKeys.ARGS: ExecutionParameter(type=Text),
3333
}
3434
INPUTS = {
3535
MODEL: ChannelParameter(type=standard_artifacts.Model, optional=True),

zenml/components/bulk_inferrer/executor.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""ZenML bulk inferrer executor."""
15-
15+
import json
1616
from typing import Any, Dict, List, Text
1717
from typing import Optional
1818

1919
import apache_beam as beam
2020
from absl import logging
2121
from tfx import types
22+
from tfx.components.bulk_inferrer.executor import _RunInference
2223
from tfx.components.util import model_utils
2324
from tfx.dsl.components.base import base_executor
2425
from tfx.proto import bulk_inferrer_pb2
2526
from tfx.types import artifact_utils
2627
from tfx.utils import path_utils
2728
from tfx_bsl.public.proto import model_spec_pb2
28-
from tfx.components.bulk_inferrer.executor import _RunInference
2929

30-
from zenml.components.bulk_inferrer.utils import convert_to_dict
3130
from zenml.components.bulk_inferrer.constants import MODEL, EXAMPLES, \
3231
MODEL_BLESSING, PREDICTIONS
32+
from zenml.components.bulk_inferrer.utils import convert_to_dict
3333
from zenml.standards.standard_keys import StepKeys
3434
from zenml.steps.inferrer import BaseInferrer
3535
from zenml.utils import source_utils
@@ -58,7 +58,7 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
5858
self._log_startup(input_dict, output_dict, exec_properties)
5959

6060
source = exec_properties[StepKeys.SOURCE]
61-
args = exec_properties[StepKeys.ARGS]
61+
args = json.loads(exec_properties[StepKeys.ARGS])
6262
c = source_utils.load_source_path_class(source)
6363
inferrer_step: BaseInferrer = c(**args)
6464

@@ -148,12 +148,12 @@ def _run_model_inference(
148148
logging.info('Path of output examples split `%s` is %s.',
149149
split, output_examples_split_uri)
150150
_ = (
151-
pipeline
152-
| 'RunInference[{}]'.format(split) >>
153-
_RunInference(example_uri, inference_endpoint)
154-
| 'ConvertToDict[{}]'.format(split) >>
155-
beam.Map(convert_to_dict, output_example_spec)
156-
| 'WriteOutput[{}]'.format(split) >>
157-
inferrer_step.write_inference_results())
151+
pipeline
152+
| 'RunInference[{}]'.format(split) >>
153+
_RunInference(example_uri, inference_endpoint)
154+
| 'ConvertToDict[{}]'.format(split) >>
155+
beam.Map(convert_to_dict, output_example_spec)
156+
| 'WriteOutput[{}]'.format(split) >>
157+
inferrer_step.write_inference_results())
158158

159159
logging.info('Output examples written to %s.', output_examples.uri)

zenml/components/data_gen/component.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class DataGenSpec(ComponentSpec):
1515
PARAMETERS = {
1616
StepKeys.NAME: ExecutionParameter(type=Text),
1717
StepKeys.SOURCE: ExecutionParameter(type=Text),
18-
StepKeys.ARGS: ExecutionParameter(type=Dict[Text, Any]),
18+
StepKeys.ARGS: ExecutionParameter(type=Text),
1919
}
2020
INPUTS = {}
2121
OUTPUTS = {
@@ -31,7 +31,6 @@ def __init__(self,
3131
name: Text,
3232
source: Text,
3333
source_args: Dict[Text, Any],
34-
instance_name: Optional[Text] = None,
3534
examples: Optional[ChannelParameter] = None):
3635
"""
3736
Interface for all DataGen components, the main component responsible
@@ -53,5 +52,4 @@ def __init__(self,
5352
args=source_args,
5453
examples=examples)
5554

56-
super(DataGen, self).__init__(spec=spec,
57-
instance_name=instance_name)
55+
super(DataGen, self).__init__(spec=spec)

zenml/components/data_gen/executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
1212
# or implied. See the License for the specific language governing
1313
# permissions and limitations under the License.
14-
14+
import json
1515
from typing import Dict, Text, Any, List
1616

1717
from tfx import types
@@ -37,7 +37,7 @@ def Do(self,
3737
exec_properties:
3838
"""
3939
source = exec_properties[StepKeys.SOURCE]
40-
args = exec_properties[StepKeys.ARGS]
40+
args = json.loads(exec_properties[StepKeys.ARGS])
4141
name = exec_properties[StepKeys.NAME]
4242

4343
c = source_utils.load_source_path_class(source)

zenml/components/evaluator/component.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class ZenMLEvaluatorSpec(ComponentSpec):
1818
PARAMETERS = {constants.SOURCE: ExecutionParameter(type=Text),
19-
constants.ARGS: ExecutionParameter(type=Dict[Text, Any])}
19+
constants.ARGS: ExecutionParameter(Text)}
2020

2121
INPUTS = {constants.EXAMPLES: ChannelParameter(type=Examples),
2222
constants.MODEL: ChannelParameter(type=Model, optional=True),
@@ -42,7 +42,6 @@ def __init__(
4242
examples: types.Channel = None,
4343
model: types.Channel = None,
4444
output: Optional[types.Channel] = None,
45-
instance_name: Optional[Text] = None,
4645
schema: Optional[types.Channel] = None):
4746

4847
# Create the output artifact if not provided
@@ -55,4 +54,4 @@ def __init__(
5554
model=model,
5655
schema=schema,
5756
evaluation=evaluation)
58-
super(Evaluator, self).__init__(spec=spec, instance_name=instance_name)
57+
super(Evaluator, self).__init__(spec=spec)

zenml/components/evaluator/executor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, Dict, List, Text, Callable, Optional
23

34
import apache_beam as beam
@@ -49,7 +50,7 @@ def Do(self,
4950

5051
# Create the step with the schema attached if provided
5152
source = exec_properties[StepKeys.SOURCE]
52-
args = exec_properties[StepKeys.ARGS]
53+
args = json.loads(exec_properties[StepKeys.ARGS])
5354
c = source_utils.load_source_path_class(source)
5455
evaluator_step: BaseEvaluatorStep = c(**args)
5556

@@ -160,7 +161,7 @@ def Do(self,
160161
)
161162
examples_list.append(data)
162163
# Resolve custom extractors
163-
custom_extractors = try_get_fn(evaluator_step.CUSTOM_MODULE,
164+
custom_extractors = try_get_fn(evaluator_step.CUSTOM_MODULE or '',
164165
'custom_extractors')
165166
extractors = None
166167
if custom_extractors:
@@ -170,7 +171,7 @@ def Do(self,
170171
tensor_adapter_config=tensor_adapter_config)
171172

172173
# Resolve custom evaluators
173-
custom_evaluators = try_get_fn(evaluator_step.CUSTOM_MODULE,
174+
custom_evaluators = try_get_fn(evaluator_step.CUSTOM_MODULE or '',
174175
'custom_evaluators')
175176
evaluators = None
176177
if custom_evaluators:

0 commit comments

Comments
 (0)