Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
786 changes: 784 additions & 2 deletions backend_py/primary/poetry.lock

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions backend_py/primary/primary/auth/enforce_logged_in_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,19 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
perf_metrics = PerfMetrics()

path_to_check = scope.get("path", "")

# Look for root_path path as specified when initializing FastAPI
# If there is one, strip it out before comparing paths
root_path = scope.get("root_path", "")
if root_path:
path_to_check = path_to_check.replace(root_path, "")

path_is_protected = True
if path_to_check in ["/login", "/auth-callback"] + self._unprotected_paths:
path_is_protected = False

for unprotected in ["/login", "/auth-callback"] + self._unprotected_paths:
if path_to_check.startswith(unprotected):
path_is_protected = False
break

if path_is_protected:
request = Request(scope, receive)
Expand Down
4 changes: 4 additions & 0 deletions backend_py/primary/primary/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@
DEFAULT_STALE_WHILE_REVALIDATE = 3600 * 24 # 24 hour
REDIS_USER_SESSION_URL = "redis://redis-user-session:6379"
REDIS_CACHE_URL = "redis://redis-cache:6379"

COSMOS_DB_PROD_CONNECTION_STRING = os.environ.get("WEBVIZ_DB_CONNECTION_STRING", None)
COSMOS_DB_EMULATOR_URI = "https://host.docker.internal:8081/"
COSMOS_DB_EMULATOR_KEY = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;"
12 changes: 11 additions & 1 deletion backend_py/primary/primary/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from starsessions.stores.redis import RedisStore
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware

from primary.persistence.setup_local_database import maybe_setup_local_database
from primary.auth.auth_helper import AuthHelper
from primary.auth.enforce_logged_in_middleware import EnforceLoggedInMiddleware
from primary.middleware.add_process_time_to_server_timing_middleware import AddProcessTimeToServerTimingMiddleware
Expand All @@ -34,6 +35,7 @@
from primary.routers.vfp.router import router as vfp_router
from primary.routers.well.router import router as well_router
from primary.routers.well_completions.router import router as well_completions_router
from primary.routers.persistence.router import router as persistence_router
from primary.services.sumo_access.sumo_fingerprinter import SumoFingerprinterFactory
from primary.services.utils.httpx_async_client_wrapper import HTTPX_ASYNC_CLIENT_WRAPPER
from primary.services.utils.task_meta_tracker import TaskMetaTrackerFactory
Expand All @@ -57,12 +59,19 @@
logging.getLogger("primary.routers.grid3d").setLevel(logging.DEBUG)
logging.getLogger("primary.routers.dev").setLevel(logging.DEBUG)
logging.getLogger("primary.routers.surface").setLevel(logging.DEBUG)
logging.getLogger("primary.persistence.cosmosdb").setLevel(logging.DEBUG)
logging.getLogger("primary.persistence.session_store").setLevel(logging.DEBUG)
logging.getLogger("primary.persistence.snapshot_store").setLevel(logging.DEBUG)
logging.getLogger("primary.persistence.tasks").setLevel(logging.DEBUG)
# logging.getLogger("primary.auth").setLevel(logging.DEBUG)
# logging.getLogger("uvicorn.error").setLevel(logging.DEBUG)
# logging.getLogger("uvicorn.access").setLevel(logging.DEBUG)

LOGGER = logging.getLogger(__name__)

# Setup Cosmos DB emulator database if running locally
maybe_setup_local_database()


def custom_generate_unique_id(route: APIRoute) -> str:
return f"{route.name}"
Expand Down Expand Up @@ -115,6 +124,7 @@ async def lifespan_handler_async(_fastapi_app: FastAPI) -> AsyncIterator[None]:
app.include_router(rft_router, prefix="/rft", tags=["rft"])
app.include_router(vfp_router, prefix="/vfp", tags=["vfp"])
app.include_router(dev_router, prefix="/dev", tags=["dev"], include_in_schema=False)
app.include_router(persistence_router, prefix="/persistence", tags=["persistence"])

auth_helper = AuthHelper()
app.include_router(auth_helper.router)
Expand All @@ -129,7 +139,7 @@ async def lifespan_handler_async(_fastapi_app: FastAPI) -> AsyncIterator[None]:

# Add out custom middleware to enforce that user is logged in
# Also redirects to /login endpoint for some select paths
unprotected_paths = ["/logout", "/logged_in_user", "/alive", "/openapi.json"]
unprotected_paths = ["/logout", "/logged_in_user", "/alive", "/openapi.json", "/persistence/snapshot_preview"]
paths_redirected_to_login = ["/", "/alive_protected"]

app.add_middleware(
Expand Down
Empty file.
18 changes: 18 additions & 0 deletions backend_py/primary/primary/persistence/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import hashlib
from typing import Any, cast


# Utility function to hash a JSON string using SHA-256
# This function mimics the behavior of TextEncoder in JavaScript, which encodes strings to
# UTF-8 before hashing. The output is a hexadecimal string representation of the hash.
#
# It is important that this function returns the same hash as the JavaScript version
def hash_sha256(json_string: str) -> str:
data = json_string.encode("utf-8") # Matches TextEncoder behavior
hash_bytes = hashlib.sha256(data).digest()
hash_hex = "".join(f"{b:02x}" for b in hash_bytes)
return hash_hex


def cast_query_params(params: list[dict[str, Any]]) -> list[dict[str, object]]:
return cast(list[dict[str, object]], params)
Empty file.
213 changes: 213 additions & 0 deletions backend_py/primary/primary/persistence/cosmosdb/cosmos_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import logging
from types import TracebackType
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from azure.cosmos.aio import ContainerProxy
from azure.cosmos import exceptions
from pydantic import BaseModel, ValidationError

from .cosmos_database import CosmosDatabase
from .exceptions import (
DatabaseAccessError,
DatabaseAccessNotFoundError,
DatabaseAccessConflictError,
DatabaseAccessPreconditionFailedError,
DatabaseAccessPermissionError,
DatabaseAccessThrottledError,
DatabaseAccessTransportError,
)


logger = logging.getLogger(__name__)

T = TypeVar("T", bound=BaseModel)

"""
CosmosContainer provides access to a specific container in a Cosmos DB database.
It allows for querying, inserting, updating, and deleting items in the container.
It uses a Pydantic model for item validation and serialization.

It is designed to be used with asynchronous context management, ensuring proper resource cleanup.
"""


class CosmosContainer(Generic[T]):
def __init__(
self,
database_name: str,
container_name: str,
database: CosmosDatabase,
container: ContainerProxy,
validation_model: Type[T],
):
self._database_name = database_name
self._container_name = container_name
self._database = database
self._container = container
self._validation_model: Type[T] = validation_model

@classmethod
def create(cls, database_name: str, container_name: str, validation_model: Type[T]) -> "CosmosContainer[T]":
"""Create a CosmosContainer instance."""
database = CosmosDatabase.create(database_name)
container = database.get_container(container_name)
logger.debug("[CosmosContainer] Created for container '%s' in database '%s'", container_name, database_name)
return cls(database_name, container_name, database, container, validation_model)

async def __aenter__(self) -> "CosmosContainer[T]": # pylint: disable=C9001
return self

async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None: # pylint: disable=C9001
await self.close_async()

def _make_exception(self, op: str, exc: exceptions.CosmosHttpResponseError) -> DatabaseAccessError:
"""Map Cosmos error to a data-access exception with rich context and re-raise."""
headers = getattr(exc, "headers", {}) or {}
status = getattr(exc, "status_code", None)
# Cosmos uses x-ms-substatus for more detail (e.g., 1002)
substatus_raw = headers.get("x-ms-substatus")
try:
substatus = int(substatus_raw) if substatus_raw is not None else None
except ValueError:
substatus = None
activity_id = headers.get("x-ms-activity-id")

msg = (
f"[{op}] Cosmos error on {self._database_name}/{self._container_name}: "
f"{getattr(exc, 'message', None) or str(exc)} "
f"(status={status}, substatus={substatus}, activity_id={activity_id})"
)

# Log with stack trace
logger.exception(
"[CosmosContainer] %s",
msg,
extra={
"database": self._database_name,
"container": self._container_name,
"operation": op,
"status_code": status,
"sub_status": substatus,
"activity_id": activity_id,
},
)

if status == 404:
return DatabaseAccessNotFoundError(msg, status_code=status, sub_status=substatus, activity_id=activity_id)
if status == 409:
return DatabaseAccessConflictError(msg, status_code=status, sub_status=substatus, activity_id=activity_id)
if status == 412:
return DatabaseAccessPreconditionFailedError(
msg, status_code=status, sub_status=substatus, activity_id=activity_id
)
if status in (401, 403):
return DatabaseAccessPermissionError(msg, status_code=status, sub_status=substatus, activity_id=activity_id)
if status in (429, 503):
# Typically retryable
return DatabaseAccessThrottledError(msg, status_code=status, sub_status=substatus, activity_id=activity_id)

# Fallback
return DatabaseAccessTransportError(msg, status_code=status, sub_status=substatus, activity_id=activity_id)

async def query_items_async(self, query: str, parameters: Optional[List[Dict[str, object]]] = None) -> List[T]:
try:
items_iterable = self._container.query_items(
query=query,
parameters=parameters or [],
)
items = [item async for item in items_iterable]
return [self._validation_model.model_validate(item) for item in items]
except ValidationError as validation_error:
logger.error("[CosmosContainer] Validation error in '%s': %s", self._container_name, validation_error)
raise
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception("query_items_async", error)

async def get_item_async(self, item_id: str, partition_key: str) -> T:
try:
item = await self._container.read_item(item=item_id, partition_key=partition_key)
return self._validation_model.model_validate(item)
except ValidationError as validation_error:
logger.error("[CosmosContainer] Validation error in '%s': %s", self._container_name, validation_error)
raise
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception("get_item_async", error) from error

async def insert_item_async(self, item: T) -> str:
try:
body: Dict[str, Any] = self._validation_model.model_validate(item).model_dump(by_alias=True, mode="json")
result = await self._container.create_item(body)
return result["id"]
except ValidationError as validation_error:
logger.error("[CosmosContainer] Validation error in '%s': %s", self._container_name, validation_error)
raise
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception("insert_item_async", error) from error

async def delete_item_async(self, item_id: str, partition_key: str) -> None:
try:
await self._container.delete_item(item=item_id, partition_key=partition_key)
logger.debug("[CosmosContainer] Deleted item '%s' from '%s'", item_id, self._container_name)
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception("delete_item_async", error) from error

async def update_item_async(self, item_id: str, partition_key: str, updated_item: T) -> None:
try:
validated = self._validation_model.model_validate(updated_item).model_dump(by_alias=True, mode="json")

if validated.get("id") and validated["id"] != item_id:
raise ValueError(f"id mismatch: payload id {validated['id']} != path id {item_id}")

await self._container.replace_item(item=item_id, body=validated, partition_key=partition_key)

logger.debug("[CosmosContainer] Updated item '%s' in '%s'", item_id, self._container_name)
except ValidationError as validation_error:
logger.error("[CosmosContainer] Validation error in '%s': %s", self._container_name, validation_error)
raise
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception("update_item_async", error) from error

async def patch_item_async(
self,
item_id: str,
partition_key: str,
patch_operations: list[dict],
*,
filter_predicate: str | None = None,
) -> None:
try:
await self._container.patch_item(
item=item_id,
partition_key=partition_key,
patch_operations=patch_operations,
filter_predicate=filter_predicate,
no_response=True,
)
logger.debug("[CosmosContainer] Patched item '%s' in '%s'", item_id, self._container_name)
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception("patch_item_async", error) from error

async def query_projection_async(
self,
query: str,
parameters: Optional[List[dict]] = None,
) -> List[Dict[str, Any]]:
"""
Run a query that returns raw dicts (no Pydantic validation), useful for
projections like SELECT c.id, c.partitionKey.
"""
try:
items_iterable = self._container.query_items(
query=query,
parameters=parameters or [],
)
return [item async for item in items_iterable]
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception("query_projection_async", error) from error

async def close_async(self) -> None:
"""Close the container."""
if self._database:
logger.debug("[CosmosContainer] Closing container '%s/%s'", self._database_name, self._container_name)
await self._database.close_async()
54 changes: 54 additions & 0 deletions backend_py/primary/primary/persistence/cosmosdb/cosmos_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from types import TracebackType
from typing import Optional, Type
from azure.cosmos.aio import CosmosClient, ContainerProxy
from azure.cosmos import exceptions

from primary.config import COSMOS_DB_PROD_CONNECTION_STRING, COSMOS_DB_EMULATOR_URI, COSMOS_DB_EMULATOR_KEY
from primary.services.service_exceptions import Service, ServiceRequestError


class CosmosDatabase:
def __init__(self, database_name: str, client: CosmosClient):
self._database_name = database_name
self._client = client
self._database = self._client.get_database_client(database_name)

@classmethod
def create(cls, database_name: str) -> "CosmosDatabase":
if COSMOS_DB_PROD_CONNECTION_STRING:
client = CosmosClient.from_connection_string(COSMOS_DB_PROD_CONNECTION_STRING)
elif COSMOS_DB_EMULATOR_URI and COSMOS_DB_EMULATOR_KEY:
client = CosmosClient(COSMOS_DB_EMULATOR_URI, COSMOS_DB_EMULATOR_KEY, connection_verify=False)
else:
raise ServiceRequestError(
"No Cosmos DB production connection string or emulator URI/key provided.", Service.DATABASE
)
self = cls(database_name, client)
return self

async def __aenter__(self) -> "CosmosDatabase": # pylint: disable=C9001
return self

async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None: # pylint: disable=C9001
await self.close_async()

def _make_exception(self, message: str) -> ServiceRequestError:
return ServiceRequestError(f"CosmosDatabase ({self._database_name}): {message}", Service.DATABASE)

def get_container(self, container_name: str) -> ContainerProxy:
if not self._client or not self._database:
raise self._make_exception("Database client is not initialized or already closed.")
if not container_name or not isinstance(container_name, str):
raise self._make_exception("Invalid container name.")

try:
container = self._database.get_container_client(container_name)
return container
except exceptions.CosmosHttpResponseError as error:
raise self._make_exception(f"Unable to access container '{container_name}': {error.message}") from error

async def close_async(self) -> None:
if self._client:
await self._client.close()
Loading
Loading