Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
76a09d4
increase in-memory caching retention time for repeated calls
Mesh-ach Oct 14, 2025
e23d6a0
feat: model version retrieval testing
Mesh-ach Oct 23, 2025
5763b79
feat: model version retrieval testing
Mesh-ach Oct 23, 2025
de8b250
feat: model version retrieval testing
Mesh-ach Oct 23, 2025
c0f7250
fix: formatting style
Mesh-ach Oct 23, 2025
5a97168
fix: formatting style
Mesh-ach Oct 23, 2025
2b51021
fix: formatting style
Mesh-ach Oct 23, 2025
7b8a777
fix: formatting style
Mesh-ach Oct 23, 2025
167112a
fix: formatting style
Mesh-ach Oct 23, 2025
bd44916
fix: formatting style
Mesh-ach Oct 23, 2025
3e39347
fix: formatting style
Mesh-ach Oct 23, 2025
87d2854
fix: formatting style
Mesh-ach Oct 23, 2025
d213a8c
fix: formatting style
Mesh-ach Oct 23, 2025
1969f73
fix: linting
Mesh-ach Oct 23, 2025
3d83f7f
Feat: Added backfill endpoint
Mesh-ach Oct 23, 2025
e1a687c
Feat: Added backfill endpoint
Mesh-ach Oct 23, 2025
1fe495d
Feat: Added backfill endpoint
Mesh-ach Oct 23, 2025
d92cea1
Fix: linting
Mesh-ach Oct 23, 2025
a176d62
added func description
Mesh-ach Oct 23, 2025
03f0275
Merge pull request #178 from datakind/BackfillEndpoint
Mesh-ach Oct 23, 2025
bdf1d47
added func description
Mesh-ach Oct 24, 2025
280df44
Merge pull request #179 from datakind/BackfillEndpoint
Mesh-ach Oct 24, 2025
ca3c4e5
added func description
Mesh-ach Oct 24, 2025
a391bcb
added func description
Mesh-ach Oct 24, 2025
36ec01e
added func description
Mesh-ach Oct 24, 2025
903e9d8
added func description
Mesh-ach Oct 24, 2025
d400f26
added func description
Mesh-ach Oct 24, 2025
9dc2513
feat: adjusted run output endpointto return model_run_id
Mesh-ach Oct 27, 2025
f389b7d
Delete .DS_Store
Mesh-ach Oct 27, 2025
cd8189f
Delete src/.DS_Store
Mesh-ach Oct 27, 2025
20bd5f5
Delete terraform/.DS_Store
Mesh-ach Oct 27, 2025
dbb00ff
Merge pull request #180 from datakind/AdjustModelRunOutput
vishpillai123 Oct 27, 2025
94824d1
feat: added model deletion endpoint
Mesh-ach Nov 4, 2025
bd6cafe
feat: added model deletion endpoint
Mesh-ach Nov 4, 2025
8305181
feat: added model deletion endpoint
Mesh-ach Nov 4, 2025
9b9d8cd
fix: linting
Mesh-ach Nov 4, 2025
9792046
fix: linting
Mesh-ach Nov 4, 2025
5cfad35
fix: linting
Mesh-ach Nov 4, 2025
6d7682e
fix: linting
Mesh-ach Nov 4, 2025
1feabc7
fix: linting
Mesh-ach Nov 4, 2025
d2130b3
fix: linting
Mesh-ach Nov 4, 2025
b0f69a9
fix: linting
Mesh-ach Nov 4, 2025
3e0cb4b
Merge pull request #181 from datakind/ModelDeletionEndpoint
Mesh-ach Nov 5, 2025
5b0d590
fixed model name malformation
Mesh-ach Nov 5, 2025
82a2452
Merge pull request #182 from datakind/ModelDeletionEndpoint
Mesh-ach Nov 5, 2025
314ef2c
fix: removed databricks deletion functionality
Mesh-ach Nov 5, 2025
5647200
Merge pull request #183 from datakind/ModelDeletionEndpoint
Mesh-ach Nov 5, 2025
2319d7a
fix: removed query results not needed
Mesh-ach Nov 5, 2025
a41b71b
fix: removed query results not needed
Mesh-ach Nov 5, 2025
809d1db
fix: added status
Mesh-ach Nov 5, 2025
ae5fe8f
fix: added status
Mesh-ach Nov 5, 2025
b10ed71
fix: formatting fix
Mesh-ach Nov 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"pandera~=0.13",
"mlflow~=2.15.0",
"cachetools",
"types-cachetools",
]

[project.urls]
Expand Down
7 changes: 5 additions & 2 deletions src/webapp/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,11 @@ class JobTable(Base):
String(VAR_CHAR_STANDARD_LENGTH), nullable=True
)
completed: Mapped[bool] = mapped_column(nullable=True)
framework: Mapped[str | None] = mapped_column(
String(VAR_CHAR_STANDARD_LENGTH), nullable=False, default="sklearn"
model_version: Mapped[str | None] = mapped_column(
String(VAR_CHAR_STANDARD_LENGTH), nullable=True
)
model_run_id: Mapped[str | None] = mapped_column(
String(VAR_CHAR_STANDARD_LENGTH), nullable=True
)


Expand Down
76 changes: 65 additions & 11 deletions src/webapp/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class DatabricksInferenceRunRequest(BaseModel):
# Note that the following should be the filepath.
filepath_to_type: dict[str, list[SchemaType]]
model_name: str
model_type: str
# The email where notifications will get sent.
email: str
gcp_external_bucket_name: str
Expand Down Expand Up @@ -89,10 +88,10 @@ def _sha256_json(obj: Any) -> str:
).hexdigest()


L1_RESP_CACHE_TTL = int("120") # seconds
L1_VER_CACHE_TTL = int("60") # seconds
L1_RESP_CACHE = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL)
L1_VER_CACHE = TTLCache(maxsize=256, ttl=L1_VER_CACHE_TTL)
L1_RESP_CACHE_TTL = int("600") # seconds
L1_VER_CACHE_TTL = int("3600") # seconds
L1_RESP_CACHE: Any = TTLCache(maxsize=128, ttl=L1_RESP_CACHE_TTL)
L1_VER_CACHE: Any = TTLCache(maxsize=256, ttl=L1_VER_CACHE_TTL)
_L1_LOCK = threading.RLock()


Expand Down Expand Up @@ -252,7 +251,6 @@ def run_pdp_inference(
], # is this value the same PER environ? dev/staging/prod
"gcp_bucket_name": req.gcp_external_bucket_name,
"model_name": req.model_name,
"model_type": req.model_type,
"notification_email": req.email,
},
)
Expand Down Expand Up @@ -334,7 +332,7 @@ def fetch_table_data(
inst_name: str,
table_name: str,
warehouse_id: str,
) -> List[Dict[str, Any]]:
) -> Any:
"""
Execute SELECT * via Databricks SQL Statement Execution API using EXTERNAL_LINKS.
Blocks server-side for up to 30s; if not SUCCEEDED, raises. Downloads presigned
Expand Down Expand Up @@ -367,9 +365,9 @@ def fetch_table_data(

if not ver_resp.status or ver_resp.status.state != StatementState.SUCCEEDED:
raise TimeoutError("DESCRIBE HISTORY did not finish within 30s")
cols = [c.name for c in ver_resp.manifest.schema.columns]
cols = [c.name for c in ver_resp.manifest.schema.columns] # type: ignore
idx = {n: i for i, n in enumerate(cols)}
rows = ver_resp.result.data_array or []
rows = ver_resp.result.data_array or [] # type: ignore
if not rows or "version" not in idx:
raise ValueError("DESCRIBE HISTORY returned no version")
table_version = str(rows[0][idx["version"]])
Expand Down Expand Up @@ -433,13 +431,13 @@ def fetch_table_data(
resp.manifest and resp.manifest.schema and resp.manifest.schema.columns
):
raise ValueError("Schema/columns missing (EXTERNAL_LINKS).")
cols: List[str] = []
cols: List[str] = [] # type: ignore
for c in resp.manifest.schema.columns:
if c.name is None:
raise ValueError("Encountered a column without a name.")
cols.append(c.name)

records: List[Dict[str, Any]] = []
records: Any = []

# Helper: consume one chunk-like object (first result or subsequent chunk)
def _consume_chunk(chunk_obj: Any) -> int | None:
Expand Down Expand Up @@ -505,6 +503,62 @@ def _consume_chunk(chunk_obj: Any) -> int | None:
pass
return records

def fetch_model_version(
self, catalog_name: str, inst_name: str, model_name: str
) -> Any:
schema = databricksify_inst_name(inst_name)
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"

try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars["DATABRICKS_HOST_URL"],
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")

model_versions: Any = list(
w.model_versions.list(
full_name=model_name_path,
)
)

if not model_versions:
raise ValueError(f"No versions found for model: {model_name_path}")

latest_version = max(model_versions, key=lambda v: int(v.version))

return latest_version

def delete_model(self, catalog_name: str, inst_name: str, model_name: str) -> None:
schema = databricksify_inst_name(inst_name)
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"

try:
w = WorkspaceClient(
host=databricks_vars["DATABRICKS_HOST_URL"],
google_service_account=gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
except Exception as e:
LOGGER.exception(
"Failed to create Databricks WorkspaceClient with host: %s and service account: %s",
databricks_vars["DATABRICKS_HOST_URL"],
gcs_vars["GCP_SERVICE_ACCOUNT_EMAIL"],
)
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")

try:
w.registered_models.delete(full_name=model_name_path)
LOGGER.info("Deleted registration model: %s", model_name_path)
except Exception:
LOGGER.exception("Failed to delete registered model: %s", model_name_path)
raise

def get_key_for_file(
self, mapping: Dict[str, Any], file_name: str
) -> Optional[str]:
Expand Down
185 changes: 181 additions & 4 deletions src/webapp/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jsonpickle
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import and_
from sqlalchemy import and_, update, or_
from sqlalchemy.orm import Session
from sqlalchemy.future import select
from ..databricks import DatabricksControl, DatabricksInferenceRunRequest
Expand Down Expand Up @@ -33,6 +33,7 @@
import traceback
import logging
from ..gcsdbutils import update_db_from_bucket
from ..config import env_vars

from ..gcsutil import StorageControl

Expand Down Expand Up @@ -310,6 +311,50 @@ def read_inst_model(
}


@router.delete("/{inst_id}/models/{model_name}")
def delete_model(
inst_id: str,
model_name: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
) -> Any:
transformed_model_name = str(decode_url_piece(model_name)).strip()
has_access_to_inst_or_err(inst_id, current_user)
model_owner_and_higher_or_err(current_user, "modify batch")

local_session.set(sql_session)
sess = local_session.get()

model_list = sess.execute(
select(ModelTable).where(
ModelTable.name == transformed_model_name,
ModelTable.inst_id == str_to_uuid(inst_id),
)
).scalar_one_or_none()
if model_list is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Model not found."
)

# 2) Optionally Delete models from databricks itself
# TODO: Add databricks deletion functionality

try:
sess.delete(model_list)
sess.commit()
except Exception as e:
sess.rollback()
raise HTTPException(
status_code=500, detail=f"DB batch delete failed after file cleanup: {e}"
)

return {
"inst_id": inst_id,
"model_name": transformed_model_name,
"status": "Model deleted",
}


@router.get("/{inst_id}/models/{model_name}/runs", response_model=list[RunInfo])
def read_inst_model_outputs(
inst_id: str,
Expand Down Expand Up @@ -364,6 +409,8 @@ def read_inst_model_outputs(
"inst_id": uuid_to_str(query_result[0][0].inst_id),
"m_name": query_result[0][0].name,
"run_id": elem.id,
"model_run_id": elem.model_run_id,
"model_version": elem.model_version,
"created_by": uuid_to_str(elem.created_by),
"triggered_at": elem.triggered_at,
"batch_name": elem.batch_name,
Expand Down Expand Up @@ -555,7 +602,6 @@ def trigger_inference_run(
gcp_external_bucket_name=get_external_bucket_name(inst_id),
# The institution email to which pipeline success/failure notifications will get sent.
email=cast(str, current_user.email),
model_type=query_result[0][0].framework,
)
try:
res = databricks_control.run_pdp_inference(db_req)
Expand All @@ -567,14 +613,20 @@ def trigger_inference_run(
detail=f"Databricks run_pdp_inference error. Error = {str(e)}",
) from e
triggered_timestamp = datetime.now()
latest_model_version = databricks_control.fetch_model_version(
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=inst_result[0][0].name,
model_name=model_name,
)
job = JobTable(
id=res.job_run_id,
triggered_at=triggered_timestamp,
created_by=str_to_uuid(current_user.user_id),
batch_name=req.batch_name,
model_id=query_result[0][0].id,
output_valid=False,
framework=query_result[0][0].framework,
model_version=latest_model_version.version,
model_run_id=latest_model_version.run_id,
)
local_session.get().add(job)
return {
Expand All @@ -585,5 +637,130 @@ def trigger_inference_run(
"triggered_at": triggered_timestamp,
"batch_name": req.batch_name,
"output_valid": False,
"framework": query_result[0][0].framework,
"model_version": latest_model_version.version,
"model_run_id": latest_model_version.run_id,
}


@router.get("/{inst_id}/models/{model_name}/get-model-versions")
def get_model_versions(
inst_id: str,
model_name: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
) -> Any:
transformed_model_name = str(decode_url_piece(model_name)).strip()
has_access_to_inst_or_err(inst_id, current_user)

local_session.set(sql_session)
query_result = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)
if not query_result or len(query_result) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(query_result) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

print(f"Initial model name = {model_name}")
print(f"Converted model name {transformed_model_name}")

latest_model_version = databricks_control.fetch_model_version(
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=f"{query_result[0][0].name}",
model_name=transformed_model_name,
)

return latest_model_version


@router.post("/{inst_id}/models/{model_name}/backfill-model-runs")
def backfill_model_runs(
inst_id: str,
model_name: str,
current_user: Annotated[BaseUser, Depends(get_current_active_user)],
sql_session: Annotated[Session, Depends(get_session)],
databricks_control: Annotated[DatabricksControl, Depends(DatabricksControl)],
) -> Any:
"""Backfills missing model run metadata and returns the latest model version info.

Temporary endpoint to populate model_run_id and model_version on existing jobs for this model.
Use only when backfilling historical job runs, not for regular operation.
"""
model_name = str(decode_url_piece(model_name)).strip()
has_access_to_inst_or_err(inst_id, current_user)

# Load institution
local_session.set(sql_session)
inst_row = (
local_session.get()
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
.all()
)

model_id = (
local_session.get()
.execute(
select(ModelTable).where(
and_(
ModelTable.inst_id == str_to_uuid(inst_id),
ModelTable.name == model_name,
)
)
)
.all()
)

if not inst_row or len(inst_row) == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Institution not found.",
)
if len(inst_row) > 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Institution duplicates found.",
)

latest_mv = databricks_control.fetch_model_version(
catalog_name=str(env_vars["CATALOG_NAME"]),
inst_name=f"{inst_row[0][0].name}",
model_name=model_name,
)

mv_version = str(latest_mv.version)
mv_run_id = str(latest_mv.run_id)

# UPDATE existing jobs for this model (only those missing values)
stmt = (
update(JobTable)
.where(JobTable.model_id == model_id[0][0].id)
.where(
or_(
JobTable.model_run_id.is_(None),
JobTable.model_run_id == "",
JobTable.model_version.is_(None),
JobTable.model_version == "",
)
)
.values(model_run_id=mv_run_id, model_version=mv_version)
)
result = local_session.get().execute(stmt)
updated_count = result.rowcount or 0 # type: ignore
local_session.get().commit()

return {
"inst_id": str(inst_id),
"model_id": str(model_id[0][0].id),
"model_name": model_name,
"latest_model_version": {"version": mv_version, "run_id": mv_run_id},
"updated_count": updated_count,
}
Loading
Loading