Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from copy import deepcopy
import asyncio
from pathlib import Path
from typing import Any, Dict
import tomllib
Expand Down Expand Up @@ -244,6 +245,8 @@ def __init__(self):
self._defaults = {}
self._code_defaults = {}
self._defaults_loaded = False
self._loaded = False
self._load_lock = asyncio.Lock()

def register_defaults(self, defaults: Dict[str, Any]):
"""注册代码中定义的默认值"""
Expand Down Expand Up @@ -330,9 +333,20 @@ async def load(self):
)

self._config = merged
self._loaded = True
except Exception as e:
logger.error(f"Error loading config: {e}")
self._config = {}
self._loaded = False

async def ensure_loaded(self):
"""确保配置至少成功加载一次(按需懒加载,线程安全)"""
if self._loaded:
return
async with self._load_lock:
if self._loaded:
return
await self.load()

def get(self, key: str, default: Any = None) -> Any:
"""
Expand Down
11 changes: 8 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
if env_file.exists():
load_dotenv(env_file)

from fastapi import FastAPI # noqa: E402
from fastapi import FastAPI, Request # noqa: E402
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
from fastapi import Depends # noqa: E402

from app.core.auth import verify_api_key # noqa: E402
from app.core.config import get_config # noqa: E402
from app.core.config import config, get_config # noqa: E402
from app.core.logger import logger, setup_logging # noqa: E402
from app.core.exceptions import register_exception_handlers # noqa: E402
from app.core.response_middleware import ResponseLoggerMiddleware # noqa: E402
Expand Down Expand Up @@ -61,7 +61,7 @@ async def lifespan(app: FastAPI):
register_defaults(get_grok_defaults())

# 2. 加载配置
await config.load()
await config.ensure_loaded()

# 3. 启动服务显示
logger.info("Starting Grok2API...")
Expand Down Expand Up @@ -131,6 +131,11 @@ def create_app() -> FastAPI:
# 请求日志和 ID 中间件
app.add_middleware(ResponseLoggerMiddleware)

@app.middleware("http")
async def ensure_config_loaded(request: Request, call_next):
await config.ensure_loaded()
return await call_next(request)

# 注册异常处理器
register_exception_handlers(app)

Expand Down