Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions cloudpickle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,34 @@

import sys
import pickle
import types
import warnings

import cloudpickle.cloudpickle as cp


class CustomModuleType(types.ModuleType):
def __getattr__(self, name):
if name == 'Pickler':
warnings.warn(
'Pickler will point to Cloudpickler in two releases.',
FutureWarning
)
return self._Pickler
raise AttributeError

def __reduce__(self):
return __import__, ("cloudpickle.cloudpickle",)


cp.__class__ = CustomModuleType

if sys.version_info[:2] >= (3, 7):
def __getattr__(name):
return cp.__class__.__getattr__(cp, name)

from cloudpickle.cloudpickle import *

if sys.version_info[:2] >= (3, 8):
from cloudpickle.cloudpickle_fast import CloudPickler, dumps, dump

Expand Down
16 changes: 8 additions & 8 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from enum import Enum

from typing import Generic, Union, Tuple, Callable
from pickle import _Pickler as Pickler
from pickle import _Pickler
from pickle import _getattribute
from io import BytesIO
from importlib._bootstrap import _find_spec
Expand Down Expand Up @@ -460,21 +460,21 @@ def _create_parametrized_type_hint(origin, args):
return origin[args]


class CloudPickler(Pickler):
class CloudPickler(_Pickler):

dispatch = Pickler.dispatch.copy()
dispatch = _Pickler.dispatch.copy()

def __init__(self, file, protocol=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(self, file, protocol=protocol)
_Pickler.__init__(self, file, protocol=protocol)
# map ids to dictionary. used to ensure that functions can share global env
self.globals_ref = {}

def dump(self, obj):
self.inject_addons()
try:
return Pickler.dump(self, obj)
return _Pickler.dump(self, obj)
except RuntimeError as e:
if 'recursion' in e.args[0]:
msg = """Could not pickle object as excessively deep recursion required."""
Expand Down Expand Up @@ -537,7 +537,7 @@ def save_function(self, obj, name=None):
interactive prompt, etc) and handles the pickling appropriately.
"""
if _is_importable_by_name(obj, name=name):
return Pickler.save_global(self, obj, name=name)
return _Pickler.save_global(self, obj, name=name)
elif PYPY and isinstance(obj.__code__, builtin_code_type):
return self.save_pypy_builtin_func(obj)
else:
Expand Down Expand Up @@ -839,11 +839,11 @@ def save_global(self, obj, name=None, pack=struct.pack):
# dispatch with type-specific savers.
self._save_parametrized_type_hint(obj)
elif name is not None:
Pickler.save_global(self, obj, name=name)
_Pickler.save_global(self, obj, name=name)
elif not _is_importable_by_name(obj, name=name):
self.save_dynamic_class(obj)
else:
Pickler.save_global(self, obj, name=name)
_Pickler.save_global(self, obj, name=name)

dispatch[type] = save_global

Expand Down
15 changes: 15 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import enum
import typing
import warnings
from functools import wraps

import pytest
Expand Down Expand Up @@ -56,6 +57,20 @@
_TEST_GLOBAL_VARIABLE = "default_value"


def test_future_warning_pickler():
# FutureWarning should be raised when accessing Pickler
with warnings.catch_warnings(record=True) as warning:
warnings.simplefilter("always")
pickler = cloudpickle.Pickler
assert len(warning) == 1
assert issubclass(warning[-1].category, FutureWarning)
assert "Pickler will point to Cloudpickler in two releases." \
in str(warning[-1].message)

# cloudpickle.Pickler should still be pointing to pickle._Pickler
assert pickler == pickle._Pickler


class RaiserOnPickle(object):

def __init__(self, exc):
Expand Down