Skip to content

Commit 1e39db2

Browse files
authored
Add decorator syntax (#92)
1 parent 03c06de commit 1e39db2

File tree

4 files changed

+248
-14
lines changed

4 files changed

+248
-14
lines changed

datalayer_core/sdk/datalayer.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def create_runtime(
220220
if name is None:
221221
name = f"runtime-{environment}-{uuid.uuid4()}"
222222

223+
# print(f"Runtime {name}")
224+
223225
if snapshot_name is not None:
224226
snapshots = self.list_snapshots()
225227
for snapshot in snapshots:
@@ -738,7 +740,9 @@ def _start(self) -> None:
738740

739741
if self._kernel_client is None:
740742
self._runtime = self._create_runtime(self._environment_name)
741-
runtime: dict[str, str] = self._runtime.get("runtime") # type: ignore
743+
# print(self._runtime)
744+
runtime: dict[str, str] = self._runtime["runtime"] # type: ignore
745+
# print("runtime", runtime)
742746
self._ingress = runtime["ingress"]
743747
self._kernel_token = runtime["token"]
744748
self._pod_name = runtime["pod_name"]
@@ -820,7 +824,10 @@ def set_variables(self, variables: dict[str, Any]) -> Response:
820824
return Response([])
821825

822826
def execute_file(
823-
self, path: Union[str, Path], variables: Optional[dict[str, Any]] = None
827+
self,
828+
path: Union[str, Path],
829+
variables: Optional[dict[str, Any]] = None,
830+
output: Optional[str] = None,
824831
) -> Response:
825832
"""
826833
Execute a Python file in the runtime.
@@ -831,6 +838,8 @@ def execute_file(
831838
Path to the Python file to execute.
832839
variables: Optional[dict[str, Any]]
833840
Optional variables to set before executing the code.
841+
output: Optional[str]
842+
Optional output variable to return as result.
834843
835844
Returns
836845
-------
@@ -841,18 +850,26 @@ def execute_file(
841850
if variables:
842851
self.set_variables(variables)
843852

844-
for _id, cell in _get_cells(fname):
845-
if self._kernel_client:
853+
if self._kernel_client:
854+
outputs = []
855+
for _id, cell in _get_cells(fname):
846856
reply = self._kernel_client.execute_interactive(
847857
cell,
848858
silent=False,
849859
)
850-
return Response(reply.get("outputs", []))
860+
outputs.append(reply.get("outputs", []))
861+
if output is not None:
862+
return self.get_variable(output)
863+
864+
return Response(outputs)
851865
return Response([])
852866

853867
def execute_code(
854-
self, code: str, variables: Optional[dict[str, Any]] = None
855-
) -> Response:
868+
self,
869+
code: str,
870+
variables: Optional[dict[str, Any]] = None,
871+
output: Optional[str] = None,
872+
) -> Union[Response, Any]:
856873
"""
857874
Execute code in the runtime.
858875
@@ -862,6 +879,8 @@ def execute_code(
862879
The Python code to execute.
863880
variables: Optional[dict[str, Any]]
864881
Optional variables to set before executing the code.
882+
output: Optional[str]
883+
Optional output variable to return as result.
865884
866885
Returns
867886
-------
@@ -874,6 +893,8 @@ def execute_code(
874893
self.set_variables(variables)
875894
reply = self._kernel_client.execute(code)
876895
result = reply.get("outputs", {})
896+
if output is not None:
897+
return self.get_variable(output)
877898
else:
878899
raise RuntimeError(
879900
"Kernel client is not started. Call `start()` first."
@@ -884,8 +905,11 @@ def execute_code(
884905
return Response([])
885906

886907
def execute(
887-
self, code_or_path: Union[str, Path], variables: Optional[dict[str, Any]] = None
888-
) -> Response:
908+
self,
909+
code_or_path: Union[str, Path],
910+
variables: Optional[dict[str, Any]] = None,
911+
output: Optional[str] = None,
912+
) -> Union[Response, Any]:
889913
"""
890914
Execute code in the runtime.
891915
@@ -895,10 +919,13 @@ def execute(
895919
The Python code or path to the file to execute.
896920
variables: Optional[dict[str, Any]]
897921
Optional variables to set before executing the code.
922+
output: Optional[str]
923+
Optional output variable to return as result.
898924
899925
Returns
900926
-------
901-
dict: The result of the code execution.
927+
dict:
928+
The result of the code execution.
902929
903930
904931
{
@@ -916,9 +943,13 @@ def execute(
916943
}
917944
"""
918945
if self._check_file(code_or_path):
919-
return self.execute_file(str(code_or_path), variables)
946+
return self.execute_file(
947+
str(code_or_path), variables=variables, output=output
948+
)
920949
else:
921-
return self.execute_code(str(code_or_path), variables)
950+
return self.execute_code(
951+
str(code_or_path), variables=variables, output=output
952+
)
922953

923954
def terminate(self) -> bool:
924955
"""Terminate the Runtime."""

datalayer_core/sdk/decorators.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2023-2025 Datalayer, Inc.
2+
# Distributed under the terms of the Modified BSD License.
3+
4+
import functools
5+
import inspect
6+
from typing import Any, Callable, Optional, Union
7+
8+
from datalayer_core.sdk.datalayer import DatalayerClient
9+
10+
# TODO:
11+
# - inputs are different from args and kwargs (rename)
12+
# - inputs cannot be kewyword args of the function
13+
# - incorrect number of args
14+
15+
16+
def datalayer(
17+
runtime_name: Union[Callable[..., Any], Optional[str]] = None,
18+
inputs: Optional[list[str]] = None,
19+
output: Optional[str] = None,
20+
snapshot_name: Optional[str] = None,
21+
) -> Any:
22+
"""
23+
Decorator to execute a function in a Datalayer runtime.
24+
25+
Parameters
26+
----------
27+
runtime_name : str, optional
28+
The name of the runtime to use. If not provided, a default runtime will be used.
29+
inputs : list[str], optional
30+
A list of input variable names for the function.
31+
output : str, optional
32+
The name of the output variable for the function
33+
snapshot_name : str, optional
34+
The name of the runtime snapshot to use
35+
36+
Returns
37+
-------
38+
Callable[..., Any]
39+
A decorator that wraps the function to be executed in a Datalayer runtime.
40+
41+
Examples
42+
--------
43+
44+
>>> from datalayer_core.sdk.decorators import datalayer
45+
>>> @datalayer
46+
... def example(x: float, y: float) -> float:
47+
... return x + y
48+
49+
>>> from datalayer_core.sdk.decorators import datalayer
50+
>>> @datalayer(runtime_name="example-runtime", inputs=["x", "y"], output="z")
51+
... def example(x: float, y: float) -> float:
52+
... return x + y
53+
"""
54+
variables = {}
55+
inputs_decorated = inputs
56+
output_decorated = output
57+
snapshot_name_decorated = snapshot_name
58+
59+
if callable(runtime_name):
60+
runtime_name_decorated = None
61+
else:
62+
runtime_name_decorated = runtime_name
63+
64+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
65+
if output_decorated is None:
66+
output = f"DATALAYER_RUNTIME_OUTPUT_{func.__name__}".upper()
67+
68+
sig = inspect.signature(func)
69+
if inputs_decorated is None:
70+
inputs = []
71+
for name, _param in sig.parameters.items():
72+
inputs.append(name)
73+
variables[name] = (
74+
_param.default
75+
if _param.default is not inspect.Parameter.empty
76+
else None
77+
)
78+
else:
79+
if len(sig.parameters) != len(inputs_decorated):
80+
raise ValueError(
81+
f"Function {func.__name__} has {len(sig.parameters)} parameters, "
82+
f"but {len(inputs_decorated)} inputs were provided."
83+
)
84+
85+
@functools.wraps(func)
86+
def wrapper(*args: Any, **kwargs: Any) -> Any:
87+
sig = inspect.signature(func)
88+
mapping = {}
89+
for idx, (name, _param) in enumerate(sig.parameters.items()):
90+
mapping[name] = (inputs_decorated or inputs)[idx]
91+
92+
for kwarg, kwarg_value in kwargs.items():
93+
variables[mapping[kwarg]] = kwarg_value
94+
95+
for idx, (arg_value) in enumerate(args):
96+
kwarg = (inputs_decorated or inputs)[idx]
97+
variables[kwarg] = arg_value
98+
99+
function_call = (
100+
f"{output_decorated or output} = {func.__name__}("
101+
+ ", ".join(inputs_decorated or inputs)
102+
+ ")"
103+
)
104+
105+
start = 0
106+
func_source_lines = inspect.getsource(func).split("\n")
107+
for start, line in enumerate(func_source_lines):
108+
if line.startswith("def "):
109+
break
110+
function_source = "\n".join(func_source_lines[start:])
111+
112+
# print("inputs", inputs_decorated or inputs)
113+
# print("variables", variables)
114+
# print([function_source])
115+
# print([function_call])
116+
117+
client = DatalayerClient()
118+
with client.create_runtime(
119+
name=runtime_name_decorated, snapshot_name=snapshot_name_decorated
120+
) as runtime:
121+
runtime.execute(function_source)
122+
return runtime.execute(
123+
function_call,
124+
variables=variables,
125+
output=output_decorated or output,
126+
)
127+
128+
return wrapper
129+
130+
# print(f"Using runtime: {runtime_name}, inputs: {inputs}, output: {output}")
131+
if callable(runtime_name):
132+
return decorator(runtime_name)
133+
else:
134+
return decorator
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2023-2025 Datalayer, Inc.
2+
# Distributed under the terms of the Modified BSD License.
3+
4+
import os
5+
import time
6+
7+
import pytest
8+
from dotenv import load_dotenv
9+
10+
from datalayer_core.sdk.decorators import datalayer
11+
12+
load_dotenv()
13+
14+
15+
DATALAYER_TEST_TOKEN = os.environ.get("DATALAYER_TEST_TOKEN")
16+
17+
18+
def sum_test(x: float, y: float, z: float = 1) -> float:
19+
return x + y + z
20+
21+
22+
@pytest.mark.parametrize(
23+
"args,expected_output,decorator",
24+
[
25+
([1, 4.5, 2], 7.5, datalayer),
26+
([1, 4.5, 2], 7.5, datalayer(runtime_name="runtime-test")),
27+
([1, 4.5, 2], 7.5, datalayer(output="result")),
28+
([1, 4.5, 2], 7.5, datalayer(inputs=["a", "b", "c"])),
29+
],
30+
)
31+
@pytest.mark.skipif(
32+
not bool(DATALAYER_TEST_TOKEN),
33+
reason="DATALAYER_TEST_TOKEN is not set, skipping secret tests.",
34+
)
35+
def test_decorator(args, expected_output, decorator): # type: ignore
36+
"""
37+
Test the Datalayer decorator.
38+
"""
39+
time.sleep(10)
40+
func = decorator(sum_test)
41+
assert func(*args) == expected_output
42+
time.sleep(10)

examples/sdk.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,43 @@
11
# Copyright (c) 2023-2025 Datalayer, Inc.
22
# Distributed under the terms of the Modified BSD License.
3+
import inspect
34

45
from dotenv import load_dotenv
56

67
from datalayer_core import DatalayerClient
8+
from datalayer_core.sdk.decorators import datalayer
79

810
# Using .env file with DATALAYER_RUN_URL and DATALAYER_TOKEN defined
911
load_dotenv()
1012

11-
client = DatalayerClient()
12-
print(client.list_runtimes())
13+
14+
# @datalayer
15+
# @datalayer()
16+
# @datalayer(runtime_name="example-runtime")
17+
@datalayer(snapshot_name="snapshot-iris-model")
18+
# @datalayer(runtime_name="example-runtime", output="result")
19+
# @datalayer(runtime_name="example-runtime", inputs=["a", "b", "c"])
20+
def sum(x: float, y: float, z: int = 1) -> float:
21+
return x + y
22+
23+
24+
print([sum(1, 4.5, z=2)])
25+
26+
# sig = inspect.signature(example)
27+
# print("\nParameters:")
28+
# for name, param in sig.parameters.items():
29+
# print(f" Name: {name}")
30+
# print(f" Kind: {param.kind}")
31+
# print(f" Default Value: {param.default}")
32+
# print(f" Annotation: {param.annotation}")
33+
# print("---")
34+
35+
# print(client.list_runtimes())
1336
# with client.create_runtime() as runtime:
37+
# runtime.execute('x = 1')
38+
# runtime.execute('y = 4.5')
39+
# runtime.execute('def example(x: float, y: float) -> float:\n return x + y\n')
40+
# runtime.execute('print(example(x, y))')
1441
# response = runtime.execute("import os;print(len(os.environ['MY_SECRET']))")
1542
# print(response.stdout)
1643
# response = runtime.execute(

0 commit comments

Comments
 (0)