Skip to content

Commit 45a04de

Browse files
committed
Properly escape interpolation-like strings in resolved configs
Fixes #1112 Fixes #1081
1 parent b155a60 commit 45a04de

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

omegaconf/_impl.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_get_value,
1111
is_primitive_container,
1212
is_structured_config,
13+
maybe_escape,
1314
)
1415

1516

@@ -33,7 +34,7 @@ def _resolve_container_value(cfg: Container, key: Any) -> None:
3334
if isinstance(resolved, Container) and isinstance(node, ValueNode):
3435
cfg[key] = resolved
3536
else:
36-
node._set_value(_get_value(resolved))
37+
node._set_value(maybe_escape(_get_value(resolved)))
3738
else:
3839
_resolve(node)
3940

@@ -46,7 +47,7 @@ def _resolve(cfg: Node) -> Node:
4647
except InterpolationToMissingValueError:
4748
cfg._set_value(MISSING)
4849
else:
49-
cfg._set_value(resolved._value())
50+
cfg._set_value(maybe_escape(resolved._value()))
5051

5152
if isinstance(cfg, DictConfig):
5253
for k in cfg.keys():

omegaconf/_utils.py

+37
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,43 @@ def is_primitive_container(obj: Any) -> bool:
683683
return is_primitive_list(obj) or is_primitive_dict(obj)
684684

685685

686+
def maybe_escape(value: Any) -> Any:
687+
"""Escape interpolation strings and return other values unchanged.
688+
689+
When the input value is an interpolation string, the returned value is such that
690+
it yields the original input string when resolved.
691+
"""
692+
if not isinstance(value, str) or not _is_interpolation_string(
693+
value, strict_interpolation_validation=False
694+
):
695+
return value
696+
start = 0
697+
tokens = []
698+
while True:
699+
# Find next ${ that needs escaping.
700+
first_inter = value.find("${", start)
701+
if first_inter < 0:
702+
tokens.append(value[start:]) # ensure we keep the end of the string
703+
break
704+
# Any backslash that comes before ${ will need to be escaped as well.
705+
count_esc = 0
706+
while (
707+
first_inter - count_esc - 1 >= 0
708+
and value[first_inter - count_esc - 1] == "\\"
709+
):
710+
count_esc += 1
711+
tokens += [
712+
# Characters that need not be changed.
713+
value[start : first_inter - count_esc],
714+
# Escaped backslashes before the interpolation.
715+
"\\" * (count_esc * 2),
716+
# Escaped interpolation.
717+
"\\${",
718+
]
719+
start = first_inter + 2
720+
return "".join(tokens)
721+
722+
686723
def get_list_element_type(ref_type: Optional[Type[Any]]) -> Any:
687724
args = getattr(ref_type, "__args__", None)
688725
if ref_type is not List and args is not None and args[0]:

0 commit comments

Comments
 (0)