diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 7f63b19a7..fe429e425 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -539,9 +539,10 @@ Configuration flags ------------------- OmegaConf support several configuration flags. -Configuration flags can be set on any configuration node (Sequence or Mapping). if a configuration flag is not set +Configuration flags can be set on any configuration node (Sequence or Mapping). If a configuration flag is not set it inherits the value from the parent of the node. The default value inherited from the root node is always false. +To avoid this inhertence and set a flag explicitly to all children nodes, specify the option `recursive=True`. .. _read-only-flag: @@ -570,6 +571,17 @@ You can temporarily remove the read only flag from a config object: >>> conf.a.b 20 +Example using `recursive=True`: + +.. doctest:: loaded + + >>> conf = OmegaConf.create({"a": {"b": 10}}) + >>> OmegaConf.set_readonly(conf.a, True) + >>> OmegaConf.set_readonly(conf, False, recursive = True) + >>> conf.a.b = 20 + >>> conf.a.b + 20 + .. _struct-flag: Struct flag diff --git a/news/1124.feature b/news/1124.feature new file mode 100644 index 000000000..474d4cabd --- /dev/null +++ b/news/1124.feature @@ -0,0 +1 @@ +OmegaConf.set_readonly() and OmegaConf.set_struct() have an option recursive, that explicitly sets the flag to all children nodes \ No newline at end of file diff --git a/omegaconf/base.py b/omegaconf/base.py index 77e951058..1e28f4128 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -152,6 +152,7 @@ def _set_flag( self, flags: Union[List[str], str], values: Union[List[Optional[bool]], Optional[bool]], + recursive: bool = False, ) -> "Node": if isinstance(flags, str): flags = [flags] @@ -175,6 +176,21 @@ def _set_flag( assert self._metadata.flags is not None self._metadata.flags[flag] = value self._invalidate_flags_cache() + + if recursive: + from . import DictConfig, ListConfig + + if isinstance(self, DictConfig): + for key in self.keys(): + child = self._get_child(key) + if child is not None: + child._set_flag(flags, values, recursive=recursive) # type: ignore + elif isinstance(self, ListConfig): + for index in range(len(self)): + child = self._get_child(index) + if child is not None: + child._set_flag(flags, values, recursive=recursive) # type: ignore + return self def _get_node_flag(self, flag: str) -> Optional[bool]: diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 8b130895b..c794d0f2f 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -518,9 +518,11 @@ def copy_cache(from_config: BaseContainer, to_config: BaseContainer) -> None: OmegaConf.set_cache(to_config, OmegaConf.get_cache(from_config)) @staticmethod - def set_readonly(conf: Node, value: Optional[bool]) -> None: + def set_readonly( + conf: Node, value: Optional[bool], recursive: bool = False + ) -> None: # noinspection PyProtectedMember - conf._set_flag("readonly", value) + conf._set_flag("readonly", value, recursive=recursive) @staticmethod def is_readonly(conf: Node) -> Optional[bool]: @@ -528,9 +530,11 @@ def is_readonly(conf: Node) -> Optional[bool]: return conf._get_flag("readonly") @staticmethod - def set_struct(conf: Container, value: Optional[bool]) -> None: + def set_struct( + conf: Container, value: Optional[bool], recursive: bool = False + ) -> None: # noinspection PyProtectedMember - conf._set_flag("struct", value) + conf._set_flag("struct", value, recursive=recursive) @staticmethod def is_struct(conf: Container) -> Optional[bool]: diff --git a/tests/test_base_config.py b/tests/test_base_config.py index 6129120c1..833ca0d4c 100644 --- a/tests/test_base_config.py +++ b/tests/test_base_config.py @@ -184,6 +184,25 @@ def test_set_flags() -> None: c._set_flag(["readonly", "struct"], [True, False, False]) +@mark.parametrize("config", [{"a": {"b": 1}}, {"a": [1, 2, 3]}]) +def test_set_flag_recursively(config: Any) -> None: + c = OmegaConf.create(config) + assert not c._get_flag("readonly") + assert not c._get_flag("struct") + c.a._set_flag(["readonly", "struct"], [True, True], recursive=False) + assert c.a._get_flag("readonly") + assert c.a._get_flag("struct") + assert not c._get_flag("readonly") + assert not c._get_flag("struct") + + c._set_flag("readonly", False, recursive=True) + assert not c.a._get_flag("readonly") + assert c.a._get_flag("struct") + + c._set_flag("struct", False, recursive=True) + assert not c.a._get_flag("struct") + + @mark.parametrize("no_deepcopy_set_nodes", [True, False]) @mark.parametrize("node", [20, {"b": 10}, [1, 2]]) def test_get_flag_after_dict_assignment(no_deepcopy_set_nodes: bool, node: Any) -> None: