Skip to content

Commit 2e100e3

Browse files
benedikt-bartschermasenf
authored andcommitted
port enum env var support from #4248 (#4251)
* port enum env var support from #4248 * add some tests for interpret env var functions
1 parent cdbe7f8 commit 2e100e3

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

reflex/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from __future__ import annotations
44

55
import dataclasses
6+
import enum
67
import importlib
8+
import inspect
79
import os
810
import sys
911
import urllib.parse
@@ -221,6 +223,28 @@ def interpret_path_env(value: str, field_name: str) -> Path:
221223
return path
222224

223225

226+
def interpret_enum_env(value: str, field_type: GenericType, field_name: str) -> Any:
227+
"""Interpret an enum environment variable value.
228+
229+
Args:
230+
value: The environment variable value.
231+
field_type: The field type.
232+
field_name: The field name.
233+
234+
Returns:
235+
The interpreted value.
236+
237+
Raises:
238+
EnvironmentVarValueError: If the value is invalid.
239+
"""
240+
try:
241+
return field_type(value)
242+
except ValueError as ve:
243+
raise EnvironmentVarValueError(
244+
f"Invalid enum value: {value} for {field_name}"
245+
) from ve
246+
247+
224248
def interpret_env_var_value(
225249
value: str, field_type: GenericType, field_name: str
226250
) -> Any:
@@ -252,6 +276,8 @@ def interpret_env_var_value(
252276
return interpret_int_env(value, field_name)
253277
elif field_type is Path:
254278
return interpret_path_env(value, field_name)
279+
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
280+
return interpret_enum_env(value, field_type, field_name)
255281

256282
else:
257283
raise ValueError(

tests/units/test_config.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77

88
import reflex as rx
99
import reflex.config
10-
from reflex.config import environment
11-
from reflex.constants import Endpoint
10+
from reflex.config import (
11+
environment,
12+
interpret_boolean_env,
13+
interpret_enum_env,
14+
interpret_int_env,
15+
)
16+
from reflex.constants import Endpoint, Env
1217

1318

1419
def test_requires_app_name():
@@ -208,11 +213,11 @@ def test_replace_defaults(
208213
assert getattr(c, key) == value
209214

210215

211-
def reflex_dir_constant():
216+
def reflex_dir_constant() -> Path:
212217
return environment.REFLEX_DIR
213218

214219

215-
def test_reflex_dir_env_var(monkeypatch, tmp_path):
220+
def test_reflex_dir_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
216221
"""Test that the REFLEX_DIR environment variable is used to set the Reflex.DIR constant.
217222
218223
Args:
@@ -224,3 +229,16 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path):
224229
mp_ctx = multiprocessing.get_context(method="spawn")
225230
with mp_ctx.Pool(processes=1) as pool:
226231
assert pool.apply(reflex_dir_constant) == tmp_path
232+
233+
234+
def test_interpret_enum_env() -> None:
235+
assert interpret_enum_env(Env.PROD.value, Env, "REFLEX_ENV") == Env.PROD
236+
237+
238+
def test_interpret_int_env() -> None:
239+
assert interpret_int_env("3001", "FRONTEND_PORT") == 3001
240+
241+
242+
@pytest.mark.parametrize("value, expected", [("true", True), ("false", False)])
243+
def test_interpret_bool_env(value: str, expected: bool) -> None:
244+
assert interpret_boolean_env(value, "TELEMETRY_ENABLED") == expected

0 commit comments

Comments
 (0)