Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 92 additions & 24 deletions sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
import itertools
import logging
import opcode
import os
import pickle
from pickle import _getattribute as _pickle_getattribute
import platform
Expand Down Expand Up @@ -108,9 +109,28 @@ def uuid_generator(_):

@dataclasses.dataclass
class CloudPickleConfig:
"""Configuration for cloudpickle behavior."""
"""Configuration for cloudpickle behavior.

This class controls various aspects of how cloudpickle serializes objects.

Attributes:
id_generator: Callable that generates unique identifiers for dynamic
types. Controls isinstance semantics preservation. If None,
disables type tracking and isinstance relationships are not
preserved across pickle/unpickle cycles. If callable, generates
unique IDs to maintain object identity.
Default: uuid_generator (generates UUID hex strings).

skip_reset_dynamic_type_state: Whether to skip resetting state when
reconstructing dynamic types. If True, skips state reset for
already-reconstructed types.

filepath_interceptor: Used to modify filepaths in `co_filename` and
function.__globals__['__file__'].
"""
id_generator: typing.Optional[callable] = uuid_generator
skip_reset_dynamic_type_state: bool = False
filepath_interceptor: typing.Optional[callable] = None


DEFAULT_CONFIG = CloudPickleConfig()
Expand Down Expand Up @@ -396,6 +416,27 @@ def func():
return subimports


def get_relative_path(path):
"""Returns the path of a filename relative to the longest matching directory
in sys.path.
Args:
path: The path to the file.
"""
abs_path = os.path.abspath(path)
longest_match = ""

for dir_path in sys.path:
if not dir_path.endswith(os.path.sep):
dir_path += os.path.sep

if abs_path.startswith(dir_path) and len(dir_path) > len(longest_match):
longest_match = dir_path

if not longest_match:
return path
return os.path.relpath(abs_path, longest_match)


# relevant opcodes
STORE_GLOBAL = opcode.opmap["STORE_GLOBAL"]
DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"]
Expand Down Expand Up @@ -608,7 +649,7 @@ def _make_typevar(
return _lookup_class_or_track(class_tracker_id, tv)


def _decompose_typevar(obj, config):
def _decompose_typevar(obj, config: CloudPickleConfig):
return (
obj.__name__,
obj.__bound__,
Expand All @@ -619,7 +660,7 @@ def _decompose_typevar(obj, config):
)


def _typevar_reduce(obj, config):
def _typevar_reduce(obj, config: CloudPickleConfig):
# TypeVar instances require the module information hence why we
# are not using the _should_pickle_by_reference directly
module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__)
Expand Down Expand Up @@ -671,7 +712,7 @@ def _make_dict_items(obj, is_ordered=False):
# -------------------------------------------------


def _class_getnewargs(obj, config):
def _class_getnewargs(obj, config: CloudPickleConfig):
type_kwargs = {}
if "__module__" in obj.__dict__:
type_kwargs["__module__"] = obj.__module__
Expand All @@ -690,7 +731,7 @@ def _class_getnewargs(obj, config):
)


def _enum_getnewargs(obj, config):
def _enum_getnewargs(obj, config: CloudPickleConfig):
members = {e.name: e.value for e in obj}
return (
obj.__bases__,
Expand Down Expand Up @@ -831,7 +872,7 @@ def _enum_getstate(obj):
# these holes".


def _code_reduce(obj):
def _code_reduce(obj, config: CloudPickleConfig):
"""code object reducer."""
# If you are not sure about the order of arguments, take a look at help
# of the specific type from types, for example:
Expand All @@ -850,6 +891,11 @@ def _code_reduce(obj):
co_varnames = tuple(name for name in obj.co_varnames)
co_freevars = tuple(name for name in obj.co_freevars)
co_cellvars = tuple(name for name in obj.co_cellvars)

co_filename = obj.co_filename
if (config and config.filepath_interceptor):
co_filename = config.filepath_interceptor(co_filename)

if hasattr(obj, "co_exceptiontable"):
# Python 3.11 and later: there are some new attributes
# related to the enhanced exceptions.
Expand All @@ -864,7 +910,7 @@ def _code_reduce(obj):
obj.co_consts,
co_names,
co_varnames,
obj.co_filename,
co_filename,
co_name,
obj.co_qualname,
obj.co_firstlineno,
Expand All @@ -887,7 +933,7 @@ def _code_reduce(obj):
obj.co_consts,
co_names,
co_varnames,
obj.co_filename,
co_filename,
co_name,
obj.co_firstlineno,
obj.co_linetable,
Expand All @@ -908,7 +954,7 @@ def _code_reduce(obj):
obj.co_code,
obj.co_consts,
co_varnames,
obj.co_filename,
co_filename,
co_name,
obj.co_firstlineno,
obj.co_lnotab,
Expand All @@ -932,7 +978,7 @@ def _code_reduce(obj):
obj.co_consts,
co_names,
co_varnames,
obj.co_filename,
co_filename,
co_name,
obj.co_firstlineno,
obj.co_lnotab,
Expand Down Expand Up @@ -1043,7 +1089,7 @@ def _weakset_reduce(obj):
return weakref.WeakSet, (list(obj), )


def _dynamic_class_reduce(obj, config):
def _dynamic_class_reduce(obj, config: CloudPickleConfig):
"""Save a class that can't be referenced as a module attribute.

This method is used to serialize classes that are defined inside
Expand Down Expand Up @@ -1074,7 +1120,7 @@ def _dynamic_class_reduce(obj, config):
)


def _class_reduce(obj, config):
def _class_reduce(obj, config: CloudPickleConfig):
"""Select the reducer depending on the dynamic nature of the class obj."""
if obj is type(None): # noqa
return type, (None, )
Expand Down Expand Up @@ -1169,7 +1215,7 @@ def _function_setstate(obj, state):
setattr(obj, k, v)


def _class_setstate(obj, state, skip_reset_dynamic_type_state):
def _class_setstate(obj, state, skip_reset_dynamic_type_state=False):
# Lock while potentially modifying class state.
with _DYNAMIC_CLASS_TRACKER_LOCK:
if skip_reset_dynamic_type_state and obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS:
Expand Down Expand Up @@ -1240,7 +1286,6 @@ class Pickler(pickle.Pickler):
_dispatch_table[property] = _property_reduce
_dispatch_table[staticmethod] = _classmethod_reduce
_dispatch_table[CellType] = _cell_reduce
_dispatch_table[types.CodeType] = _code_reduce
_dispatch_table[types.GetSetDescriptorType] = _getset_descriptor_reduce
_dispatch_table[types.ModuleType] = _module_reduce
_dispatch_table[types.MethodType] = _method_reduce
Expand Down Expand Up @@ -1300,9 +1345,15 @@ def _function_getnewargs(self, func):
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})

if base_globals == {}:
if "__file__" in func.__globals__:
# Apply normalization ONLY to the __file__ attribute
file_path = func.__globals__["__file__"]
if self.config.filepath_interceptor:
file_path = self.config.filepath_interceptor(file_path)
base_globals["__file__"] = file_path
# Add module attributes used to resolve relative imports
# instructions inside func.
for k in ["__package__", "__name__", "__path__", "__file__"]:
for k in ["__package__", "__name__", "__path__"]:
if k in func.__globals__:
base_globals[k] = func.__globals__[k]

Expand All @@ -1318,15 +1369,16 @@ def _function_getnewargs(self, func):
def dump(self, obj):
try:
return super().dump(obj)
except RuntimeError as e:
if len(e.args) > 0 and "recursion" in e.args[0]:
msg = "Could not pickle object as excessively deep recursion required."
raise pickle.PicklingError(msg) from e
else:
raise
except RecursionError as e:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a change from cloudpickle lib

msg = "Could not pickle object as excessively deep recursion required."
raise pickle.PicklingError(msg) from e

def __init__(
self, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
self,
file,
protocol=None,
buffer_callback=None,
config: CloudPickleConfig = DEFAULT_CONFIG):
if protocol is None:
protocol = DEFAULT_PROTOCOL
super().__init__(file, protocol=protocol, buffer_callback=buffer_callback)
Expand Down Expand Up @@ -1405,6 +1457,8 @@ def reducer_override(self, obj):
return _class_reduce(obj, self.config)
elif isinstance(obj, typing.TypeVar): # Add this check
return _typevar_reduce(obj, self.config)
elif isinstance(obj, types.CodeType):
return _code_reduce(obj, self.config)
elif isinstance(obj, types.FunctionType):
return self._function_reduce(obj)
else:
Expand Down Expand Up @@ -1487,6 +1541,11 @@ def save_typevar(self, obj, name=None):

dispatch[typing.TypeVar] = save_typevar

def save_code(self, obj, name=None):
return self.save_reduce(*_code_reduce(obj, self.config), obj=obj)

dispatch[types.CodeType] = save_code

def save_function(self, obj, name=None):
"""Registered with the dispatch to handle all function types.

Expand Down Expand Up @@ -1532,7 +1591,12 @@ def save_pypy_builtin_func(self, obj):
# Shorthands similar to pickle.dump/pickle.dumps


def dump(obj, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
def dump(
obj,
file,
protocol=None,
buffer_callback=None,
config: CloudPickleConfig = DEFAULT_CONFIG):
"""Serialize obj as bytes streamed into file

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand All @@ -1550,7 +1614,11 @@ def dump(obj, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
config=config).dump(obj)


def dumps(obj, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG):
def dumps(
obj,
protocol=None,
buffer_callback=None,
config: CloudPickleConfig = DEFAULT_CONFIG):
"""Serialize obj as a string of bytes allocated in memory

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
Expand Down
Loading