Skip to content

Commit 686b224

Browse files
author
Nikita Sokolov
committed
Support persistent function.__globals__
1 parent cdc704d commit 686b224

File tree

2 files changed

+92
-61
lines changed

2 files changed

+92
-61
lines changed

cloudpickle/cloudpickle_fast.py

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -134,47 +134,6 @@ def _file_reconstructor(retval):
134134
return retval
135135

136136

137-
# COLLECTION OF OBJECTS STATE GETTERS
138-
# -----------------------------------
139-
def _function_getstate(func):
140-
# - Put func's dynamic attributes (stored in func.__dict__) in state. These
141-
# attributes will be restored at unpickling time using
142-
# f.__dict__.update(state)
143-
# - Put func's members into slotstate. Such attributes will be restored at
144-
# unpickling time by iterating over slotstate and calling setattr(func,
145-
# slotname, slotvalue)
146-
slotstate = {
147-
"__name__": func.__name__,
148-
"__qualname__": func.__qualname__,
149-
"__annotations__": func.__annotations__,
150-
"__kwdefaults__": func.__kwdefaults__,
151-
"__defaults__": func.__defaults__,
152-
"__module__": func.__module__,
153-
"__doc__": func.__doc__,
154-
"__closure__": func.__closure__,
155-
}
156-
157-
f_globals_ref = _extract_code_globals(func.__code__)
158-
f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in
159-
func.__globals__}
160-
161-
closure_values = (
162-
list(map(_get_cell_contents, func.__closure__))
163-
if func.__closure__ is not None else ()
164-
)
165-
166-
# Extract currently-imported submodules used by func. Storing these modules
167-
# in a smoke _cloudpickle_subimports attribute of the object's state will
168-
# trigger the side effect of importing these modules at unpickling time
169-
# (which is necessary for func to work correctly once depickled)
170-
slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
171-
func.__code__, itertools.chain(f_globals.values(), closure_values))
172-
slotstate["__globals__"] = f_globals
173-
174-
state = func.__dict__
175-
return state, slotstate
176-
177-
178137
def _class_getstate(obj):
179138
clsdict = _extract_class_dict(obj)
180139
clsdict.pop('__weakref__', None)
@@ -504,7 +463,7 @@ class CloudPickler(Pickler):
504463
def _dynamic_function_reduce(self, func):
505464
"""Reduce a function that is not pickleable via attribute lookup."""
506465
newargs = self._function_getnewargs(func)
507-
state = _function_getstate(func)
466+
state = self._function_getstate(func)
508467
return (types.FunctionType, newargs, state, None, None,
509468
_function_setstate)
510469

@@ -528,25 +487,28 @@ def _function_reduce(self, obj):
528487
def _function_getnewargs(self, func):
529488
code = func.__code__
530489

531-
# base_globals represents the future global namespace of func at
532-
# unpickling time. Looking it up and storing it in
533-
# CloudpiPickler.globals_ref allow functions sharing the same globals
534-
# at pickling time to also share them once unpickled, at one condition:
535-
# since globals_ref is an attribute of a CloudPickler instance, and
536-
# that a new CloudPickler is created each time pickle.dump or
537-
# pickle.dumps is called, functions also need to be saved within the
538-
# same invocation of cloudpickle.dump/cloudpickle.dumps (for example:
539-
# cloudpickle.dumps([f1, f2])). There is no such limitation when using
540-
# CloudPickler.dump, as long as the multiple invocations are bound to
541-
# the same CloudPickler.
542-
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
543-
544-
if base_globals == {}:
545-
# Add module attributes used to resolve relative imports
546-
# instructions inside func.
547-
for k in ["__package__", "__name__", "__path__", "__file__"]:
548-
if k in func.__globals__:
549-
base_globals[k] = func.__globals__[k]
490+
if self.persistent_id(func.__globals__) is None:
491+
# base_globals represents the future global namespace of func at
492+
# unpickling time. Looking it up and storing it in
493+
# CloudpiPickler.globals_ref allow functions sharing the same globals
494+
# at pickling time to also share them once unpickled, at one condition:
495+
# since globals_ref is an attribute of a CloudPickler instance, and
496+
# that a new CloudPickler is created each time pickle.dump or
497+
# pickle.dumps is called, functions also need to be saved within the
498+
# same invocation of cloudpickle.dump/cloudpickle.dumps (for example:
499+
# cloudpickle.dumps([f1, f2])). There is no such limitation when using
500+
# CloudPickler.dump, as long as the multiple invocations are bound to
501+
# the same CloudPickler.
502+
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
503+
504+
if base_globals == {}:
505+
# Add module attributes used to resolve relative imports
506+
# instructions inside func.
507+
for k in ["__package__", "__name__", "__path__", "__file__"]:
508+
if k in func.__globals__:
509+
base_globals[k] = func.__globals__[k]
510+
else:
511+
base_globals = func.__globals__
550512

551513
# Do not bind the free variables before the function is created to
552514
# avoid infinite recursion.
@@ -558,6 +520,50 @@ def _function_getnewargs(self, func):
558520

559521
return code, base_globals, None, None, closure
560522

523+
# COLLECTION OF OBJECTS STATE GETTERS
524+
# -----------------------------------
525+
def _function_getstate(self, func):
526+
# - Put func's dynamic attributes (stored in func.__dict__) in state. These
527+
# attributes will be restored at unpickling time using
528+
# f.__dict__.update(state)
529+
# - Put func's members into slotstate. Such attributes will be restored at
530+
# unpickling time by iterating over slotstate and calling setattr(func,
531+
# slotname, slotvalue)
532+
slotstate = {
533+
"__name__": func.__name__,
534+
"__qualname__": func.__qualname__,
535+
"__annotations__": func.__annotations__,
536+
"__kwdefaults__": func.__kwdefaults__,
537+
"__defaults__": func.__defaults__,
538+
"__module__": func.__module__,
539+
"__doc__": func.__doc__,
540+
"__closure__": func.__closure__,
541+
}
542+
543+
if self.persistent_id(func.__globals__) is None:
544+
f_globals_ref = _extract_code_globals(func.__code__)
545+
f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in
546+
func.__globals__}
547+
else:
548+
f_globals = func.__globals__
549+
550+
closure_values = (
551+
list(map(_get_cell_contents, func.__closure__))
552+
if func.__closure__ is not None else ()
553+
)
554+
555+
# Extract currently-imported submodules used by func. Storing these modules
556+
# in a smoke _cloudpickle_subimports attribute of the object's state will
557+
# trigger the side effect of importing these modules at unpickling time
558+
# (which is necessary for func to work correctly once depickled)
559+
slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
560+
func.__code__, itertools.chain(f_globals.values(), closure_values))
561+
slotstate["__globals__"] = f_globals
562+
563+
state = func.__dict__
564+
return state, slotstate
565+
566+
561567
def dump(self, obj):
562568
try:
563569
return Pickler.dump(self, obj)

tests/cloudpickle_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,31 @@ def f1():
15911591
finally:
15921592
_TEST_GLOBAL_VARIABLE = orig_value
15931593

1594+
def test_persistent_function_globals(self):
1595+
__globals__ = {"a": "foo"}
1596+
1597+
class Pickler(cloudpickle.CloudPickler):
1598+
@staticmethod
1599+
def persistent_id(obj):
1600+
if id(obj) == id(__globals__):
1601+
return "__globals__"
1602+
1603+
class Unpickler(pickle.Unpickler):
1604+
@staticmethod
1605+
def persistent_load(pid):
1606+
return {"__globals__": __globals__}[pid]
1607+
1608+
get = eval('lambda: a', __globals__)
1609+
file = io.BytesIO()
1610+
Pickler(file).dump(get)
1611+
dumped = file.getvalue()
1612+
self.assertNotIn(b'foo', dumped)
1613+
get = Unpickler(io.BytesIO(dumped)).load()
1614+
self.assertEqual(id(__globals__), id(get.__globals__))
1615+
self.assertEqual('foo', get())
1616+
__globals__['a'] = 'bar'
1617+
self.assertEqual('bar', get())
1618+
15941619
def test_interactive_remote_function_calls(self):
15951620
code = """if __name__ == "__main__":
15961621
from testutils import subprocess_worker

0 commit comments

Comments
 (0)