diff --git a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py index e4fbf0c72f87..b236949a24c3 100644 --- a/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py +++ b/sdks/python/apache_beam/internal/cloudpickle/cloudpickle.py @@ -66,6 +66,7 @@ import itertools import logging import opcode +import os import pickle from pickle import _getattribute as _pickle_getattribute import platform @@ -108,9 +109,28 @@ def uuid_generator(_): @dataclasses.dataclass class CloudPickleConfig: - """Configuration for cloudpickle behavior.""" + """Configuration for cloudpickle behavior. + + This class controls various aspects of how cloudpickle serializes objects. + + Attributes: + id_generator: Callable that generates unique identifiers for dynamic + types. Controls isinstance semantics preservation. If None, + disables type tracking and isinstance relationships are not + preserved across pickle/unpickle cycles. If callable, generates + unique IDs to maintain object identity. + Default: uuid_generator (generates UUID hex strings). + + skip_reset_dynamic_type_state: Whether to skip resetting state when + reconstructing dynamic types. If True, skips state reset for + already-reconstructed types. + + filepath_interceptor: Used to modify filepaths in `co_filename` and + function.__globals__['__file__']. + """ id_generator: typing.Optional[callable] = uuid_generator skip_reset_dynamic_type_state: bool = False + filepath_interceptor: typing.Optional[callable] = None DEFAULT_CONFIG = CloudPickleConfig() @@ -396,6 +416,27 @@ def func(): return subimports +def get_relative_path(path): + """Returns the path of a filename relative to the longest matching directory + in sys.path. + Args: + path: The path to the file. + """ + abs_path = os.path.abspath(path) + longest_match = "" + + for dir_path in sys.path: + if not dir_path.endswith(os.path.sep): + dir_path += os.path.sep + + if abs_path.startswith(dir_path) and len(dir_path) > len(longest_match): + longest_match = dir_path + + if not longest_match: + return path + return os.path.relpath(abs_path, longest_match) + + # relevant opcodes STORE_GLOBAL = opcode.opmap["STORE_GLOBAL"] DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"] @@ -608,7 +649,7 @@ def _make_typevar( return _lookup_class_or_track(class_tracker_id, tv) -def _decompose_typevar(obj, config): +def _decompose_typevar(obj, config: CloudPickleConfig): return ( obj.__name__, obj.__bound__, @@ -619,7 +660,7 @@ def _decompose_typevar(obj, config): ) -def _typevar_reduce(obj, config): +def _typevar_reduce(obj, config: CloudPickleConfig): # TypeVar instances require the module information hence why we # are not using the _should_pickle_by_reference directly module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__) @@ -671,7 +712,7 @@ def _make_dict_items(obj, is_ordered=False): # ------------------------------------------------- -def _class_getnewargs(obj, config): +def _class_getnewargs(obj, config: CloudPickleConfig): type_kwargs = {} if "__module__" in obj.__dict__: type_kwargs["__module__"] = obj.__module__ @@ -690,7 +731,7 @@ def _class_getnewargs(obj, config): ) -def _enum_getnewargs(obj, config): +def _enum_getnewargs(obj, config: CloudPickleConfig): members = {e.name: e.value for e in obj} return ( obj.__bases__, @@ -831,7 +872,7 @@ def _enum_getstate(obj): # these holes". -def _code_reduce(obj): +def _code_reduce(obj, config: CloudPickleConfig): """code object reducer.""" # If you are not sure about the order of arguments, take a look at help # of the specific type from types, for example: @@ -850,6 +891,11 @@ def _code_reduce(obj): co_varnames = tuple(name for name in obj.co_varnames) co_freevars = tuple(name for name in obj.co_freevars) co_cellvars = tuple(name for name in obj.co_cellvars) + + co_filename = obj.co_filename + if (config and config.filepath_interceptor): + co_filename = config.filepath_interceptor(co_filename) + if hasattr(obj, "co_exceptiontable"): # Python 3.11 and later: there are some new attributes # related to the enhanced exceptions. @@ -864,7 +910,7 @@ def _code_reduce(obj): obj.co_consts, co_names, co_varnames, - obj.co_filename, + co_filename, co_name, obj.co_qualname, obj.co_firstlineno, @@ -887,7 +933,7 @@ def _code_reduce(obj): obj.co_consts, co_names, co_varnames, - obj.co_filename, + co_filename, co_name, obj.co_firstlineno, obj.co_linetable, @@ -908,7 +954,7 @@ def _code_reduce(obj): obj.co_code, obj.co_consts, co_varnames, - obj.co_filename, + co_filename, co_name, obj.co_firstlineno, obj.co_lnotab, @@ -932,7 +978,7 @@ def _code_reduce(obj): obj.co_consts, co_names, co_varnames, - obj.co_filename, + co_filename, co_name, obj.co_firstlineno, obj.co_lnotab, @@ -1043,7 +1089,7 @@ def _weakset_reduce(obj): return weakref.WeakSet, (list(obj), ) -def _dynamic_class_reduce(obj, config): +def _dynamic_class_reduce(obj, config: CloudPickleConfig): """Save a class that can't be referenced as a module attribute. This method is used to serialize classes that are defined inside @@ -1074,7 +1120,7 @@ def _dynamic_class_reduce(obj, config): ) -def _class_reduce(obj, config): +def _class_reduce(obj, config: CloudPickleConfig): """Select the reducer depending on the dynamic nature of the class obj.""" if obj is type(None): # noqa return type, (None, ) @@ -1169,7 +1215,7 @@ def _function_setstate(obj, state): setattr(obj, k, v) -def _class_setstate(obj, state, skip_reset_dynamic_type_state): +def _class_setstate(obj, state, skip_reset_dynamic_type_state=False): # Lock while potentially modifying class state. with _DYNAMIC_CLASS_TRACKER_LOCK: if skip_reset_dynamic_type_state and obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS: @@ -1240,7 +1286,6 @@ class Pickler(pickle.Pickler): _dispatch_table[property] = _property_reduce _dispatch_table[staticmethod] = _classmethod_reduce _dispatch_table[CellType] = _cell_reduce - _dispatch_table[types.CodeType] = _code_reduce _dispatch_table[types.GetSetDescriptorType] = _getset_descriptor_reduce _dispatch_table[types.ModuleType] = _module_reduce _dispatch_table[types.MethodType] = _method_reduce @@ -1300,9 +1345,15 @@ def _function_getnewargs(self, func): base_globals = self.globals_ref.setdefault(id(func.__globals__), {}) if base_globals == {}: + if "__file__" in func.__globals__: + # Apply normalization ONLY to the __file__ attribute + file_path = func.__globals__["__file__"] + if self.config.filepath_interceptor: + file_path = self.config.filepath_interceptor(file_path) + base_globals["__file__"] = file_path # Add module attributes used to resolve relative imports # instructions inside func. - for k in ["__package__", "__name__", "__path__", "__file__"]: + for k in ["__package__", "__name__", "__path__"]: if k in func.__globals__: base_globals[k] = func.__globals__[k] @@ -1318,15 +1369,16 @@ def _function_getnewargs(self, func): def dump(self, obj): try: return super().dump(obj) - except RuntimeError as e: - if len(e.args) > 0 and "recursion" in e.args[0]: - msg = "Could not pickle object as excessively deep recursion required." - raise pickle.PicklingError(msg) from e - else: - raise + except RecursionError as e: + msg = "Could not pickle object as excessively deep recursion required." + raise pickle.PicklingError(msg) from e def __init__( - self, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG): + self, + file, + protocol=None, + buffer_callback=None, + config: CloudPickleConfig = DEFAULT_CONFIG): if protocol is None: protocol = DEFAULT_PROTOCOL super().__init__(file, protocol=protocol, buffer_callback=buffer_callback) @@ -1405,6 +1457,8 @@ def reducer_override(self, obj): return _class_reduce(obj, self.config) elif isinstance(obj, typing.TypeVar): # Add this check return _typevar_reduce(obj, self.config) + elif isinstance(obj, types.CodeType): + return _code_reduce(obj, self.config) elif isinstance(obj, types.FunctionType): return self._function_reduce(obj) else: @@ -1487,6 +1541,11 @@ def save_typevar(self, obj, name=None): dispatch[typing.TypeVar] = save_typevar + def save_code(self, obj, name=None): + return self.save_reduce(*_code_reduce(obj, self.config), obj=obj) + + dispatch[types.CodeType] = save_code + def save_function(self, obj, name=None): """Registered with the dispatch to handle all function types. @@ -1532,7 +1591,12 @@ def save_pypy_builtin_func(self, obj): # Shorthands similar to pickle.dump/pickle.dumps -def dump(obj, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG): +def dump( + obj, + file, + protocol=None, + buffer_callback=None, + config: CloudPickleConfig = DEFAULT_CONFIG): """Serialize obj as bytes streamed into file protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to @@ -1550,7 +1614,11 @@ def dump(obj, file, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG): config=config).dump(obj) -def dumps(obj, protocol=None, buffer_callback=None, config=DEFAULT_CONFIG): +def dumps( + obj, + protocol=None, + buffer_callback=None, + config: CloudPickleConfig = DEFAULT_CONFIG): """Serialize obj as a string of bytes allocated in memory protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to