Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ def _setup_state(self) -> None:
async_mode="asgi",
cors_allowed_origins=(
"*"
if config.cors_allowed_origins == ["*"]
else config.cors_allowed_origins
if config.cors_allowed_origins == ("*",)
else list(config.cors_allowed_origins)
),
cors_credentials=True,
max_http_buffer_size=environment.REFLEX_SOCKET_MAX_HTTP_BUFFER_SIZE.get(),
Expand Down
150 changes: 96 additions & 54 deletions reflex/config.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
"""The Reflex config."""

from __future__ import annotations

import dataclasses
import importlib
import os
import sys
import threading
import urllib.parse
from collections.abc import Sequence
from importlib.util import find_spec
from pathlib import Path
from types import ModuleType
from typing import Any, ClassVar

import pydantic.v1 as pydantic
from typing import TYPE_CHECKING, Any, ClassVar

from reflex import constants
from reflex.base import Base
from reflex.constants.base import LogLevel
from reflex.environment import EnvironmentVariables as EnvironmentVariables
from reflex.environment import EnvVar as EnvVar
Expand All @@ -30,10 +27,10 @@
from reflex.plugins import Plugin
from reflex.utils import console
from reflex.utils.exceptions import ConfigError
from reflex.utils.types import true_type_for_pydantic_field


class DBConfig(Base):
@dataclasses.dataclass(kw_only=True)
class DBConfig:
"""Database config."""

engine: str
Expand All @@ -51,7 +48,7 @@ def postgresql(
password: str | None = None,
host: str | None = None,
port: int | None = 5432,
) -> DBConfig:
) -> "DBConfig":
"""Create an instance with postgresql engine.

Args:
Expand Down Expand Up @@ -81,7 +78,7 @@ def postgresql_psycopg(
password: str | None = None,
host: str | None = None,
port: int | None = 5432,
) -> DBConfig:
) -> "DBConfig":
"""Create an instance with postgresql+psycopg engine.

Args:
Expand All @@ -107,7 +104,7 @@ def postgresql_psycopg(
def sqlite(
cls,
database: str,
) -> DBConfig:
) -> "DBConfig":
"""Create an instance with sqlite engine.

Args:
Expand Down Expand Up @@ -145,32 +142,9 @@ def get_url(self) -> str:
_sensitive_env_vars = {"DB_URL", "ASYNC_DB_URL", "REDIS_URL"}


class Config(Base):
"""The config defines runtime settings for the app.

By default, the config is defined in an `rxconfig.py` file in the root of the app.

```python
# rxconfig.py
import reflex as rx

config = rx.Config(
app_name="myapp",
api_url="http://localhost:8000",
)
```

Every config value can be overridden by an environment variable with the same name in uppercase.
For example, `db_url` can be overridden by setting the `DB_URL` environment variable.

See the [configuration](https://reflex.dev/docs/getting-started/configuration/) docs for more info.
"""

class Config: # pyright: ignore [reportIncompatibleVariableOverride]
"""Pydantic config for the config."""

validate_assignment = True
use_enum_values = False
@dataclasses.dataclass(kw_only=True)
class BaseConfig:
"""Base config for the Reflex app."""

# The name of the app (should match the name of the app directory).
app_name: str
Expand Down Expand Up @@ -218,13 +192,13 @@ class Config: # pyright: ignore [reportIncompatibleVariableOverride]
static_page_generation_timeout: int = 60

# List of origins that are allowed to connect to the backend API.
cors_allowed_origins: list[str] = ["*"]
cors_allowed_origins: Sequence[str] = dataclasses.field(default=("*",))

# Whether to use React strict mode.
react_strict_mode: bool = True

# Additional frontend packages to install.
frontend_packages: list[str] = []
frontend_packages: list[str] = dataclasses.field(default_factory=list)

# Indicate which type of state manager to use
state_manager_mode: constants.StateManagerMode = constants.StateManagerMode.DISK
Expand All @@ -239,7 +213,9 @@ class Config: # pyright: ignore [reportIncompatibleVariableOverride]
redis_token_expiration: int = constants.Expiration.TOKEN

# Attributes that were explicitly set by the user.
_non_default_attributes: set[str] = pydantic.PrivateAttr(set())
_non_default_attributes: set[str] = dataclasses.field(
default_factory=set, init=False
)

# Path to file containing key-values pairs to override in the environment; Dotenv format.
env_file: str | None = None
Expand All @@ -257,21 +233,50 @@ class Config: # pyright: ignore [reportIncompatibleVariableOverride]
extra_overlay_function: str | None = None

# List of plugins to use in the app.
plugins: list[Plugin] = []
plugins: list[Plugin] = dataclasses.field(default_factory=list)

_prefixes: ClassVar[list[str]] = ["REFLEX_"]

def __init__(self, *args, **kwargs):
"""Initialize the config values.

@dataclasses.dataclass(kw_only=True, init=False)
class Config(BaseConfig):
"""The config defines runtime settings for the app.

By default, the config is defined in an `rxconfig.py` file in the root of the app.

```python
# rxconfig.py
import reflex as rx

config = rx.Config(
app_name="myapp",
api_url="http://localhost:8000",
)
```

Every config value can be overridden by an environment variable with the same name in uppercase.
For example, `db_url` can be overridden by setting the `DB_URL` environment variable.

See the [configuration](https://reflex.dev/docs/getting-started/configuration/) docs for more info.
"""

def _post_init(self, **kwargs):
"""Post-initialization method to set up the config.

This method is called after the config is initialized. It sets up the
environment variables, updates the config from the environment, and
replaces default URLs if ports were set.

Args:
*args: The args to pass to the Pydantic init method.
**kwargs: The kwargs to pass to the Pydantic init method.
**kwargs: The kwargs passed to the Pydantic init method.

Raises:
ConfigError: If some values in the config are invalid.
"""
super().__init__(*args, **kwargs)
class_fields = self.class_fields()
for key, value in kwargs.items():
if key not in class_fields:
setattr(self, key, value)

# Clean up this code when we remove plain envvar in 0.8.0
env_loglevel = os.environ.get("REFLEX_LOGLEVEL")
Expand All @@ -287,7 +292,7 @@ def __init__(self, *args, **kwargs):

# Update default URLs if ports were set
kwargs.update(env_kwargs)
self._non_default_attributes.update(kwargs)
self._non_default_attributes = set(kwargs.keys())
self._replace_defaults(**kwargs)

if (
Expand All @@ -297,6 +302,41 @@ def __init__(self, *args, **kwargs):
msg = f"{self._prefixes[0]}REDIS_URL is required when using the redis state manager."
raise ConfigError(msg)

@classmethod
def class_fields(cls) -> set[str]:
"""Get the fields of the config class.

Returns:
The fields of the config class.
"""
return {field.name for field in dataclasses.fields(cls)}

if not TYPE_CHECKING:

def __init__(self, **kwargs):
"""Initialize the config values.

Args:
**kwargs: The kwargs to pass to the Pydantic init method.

# noqa: DAR101 self
"""
class_fields = self.class_fields()
super().__init__(**{k: v for k, v in kwargs.items() if k in class_fields})
self._post_init(**kwargs)

def json(self) -> str:
"""Get the config as a JSON string.

Returns:
The config as a JSON string.
"""
import json

from reflex.utils.serializers import serialize

return json.dumps(self, default=serialize)

@property
def app_module(self) -> ModuleType | None:
"""Return the app module if `app_module_import` is set.
Expand Down Expand Up @@ -333,31 +373,33 @@ def update_from_env(self) -> dict[str, Any]:

updated_values = {}
# Iterate over the fields.
for key, field in self.__fields__.items():
for field in dataclasses.fields(self):
# The env var name is the key in uppercase.
environment_variable = None
for prefix in self._prefixes:
if environment_variable := os.environ.get(f"{prefix}{key.upper()}"):
if environment_variable := os.environ.get(
f"{prefix}{field.name.upper()}"
):
break

# If the env var is set, override the config value.
if environment_variable and environment_variable.strip():
# Interpret the value.
value = interpret_env_var_value(
environment_variable,
true_type_for_pydantic_field(field),
field.type,
field.name,
)

# Set the value.
updated_values[key] = value
updated_values[field.name] = value

if key.upper() in _sensitive_env_vars:
if field.name.upper() in _sensitive_env_vars:
environment_variable = "***"

if value != getattr(self, key):
if value != getattr(self, field.name):
console.debug(
f"Overriding config value {key} with env var {key.upper()}={environment_variable}",
f"Overriding config value {field.name} with env var {field.name.upper()}={environment_variable}",
dedupe=True,
)
return updated_values
Expand Down
4 changes: 2 additions & 2 deletions reflex/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import multiprocessing
import os
import platform
from collections.abc import Callable
from collections.abc import Callable, Sequence
from functools import lru_cache
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -227,7 +227,7 @@ def interpret_env_var_value(
return interpret_existing_path_env(value, field_name)
if field_type is Plugin:
return interpret_plugin_env(value, field_name)
if get_origin(field_type) is list:
if get_origin(field_type) in (list, Sequence):
return [
interpret_env_var_value(
v,
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_tailwind.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def index():
app.add_page(index)
if not tailwind_version:
config = rx.config.get_config()
config.tailwind = None
config.plugins = []
elif tailwind_version == 3:
config = rx.config.get_config()
Expand Down
8 changes: 4 additions & 4 deletions tests/units/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

def test_requires_app_name():
"""Test that a config requires an app_name."""
with pytest.raises(ValueError):
rx.Config()
with pytest.raises(TypeError):
rx.Config() # pyright: ignore[reportCallIssue]


def test_set_app_name(base_config_values):
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_event_namespace(mocker: MockerFixture, kwargs, expected):
(
{"backend_port": 8001, "frontend_port": 3001},
{"REFLEX_BACKEND_PORT": 8002},
{"frontend_port": "3005"},
{"frontend_port": 3005},
{
"api_url": "http://localhost:8002",
"backend_port": 8002,
Expand All @@ -182,7 +182,7 @@ def test_event_namespace(mocker: MockerFixture, kwargs, expected):
(
{"api_url": "http://foo.bar:8900", "deploy_url": "http://foo.bar:3001"},
{"REFLEX_BACKEND_PORT": 8002},
{"frontend_port": "3005"},
{"frontend_port": 3005},
{
"api_url": "http://foo.bar:8900",
"backend_port": 8002,
Expand Down