diff --git a/gfmstudio/celery_worker.py b/gfmstudio/celery_worker.py index 2590979..78edd85 100644 --- a/gfmstudio/celery_worker.py +++ b/gfmstudio/celery_worker.py @@ -11,13 +11,14 @@ from gfmstudio.amo.utils import invoke_model_offboarding_handler from gfmstudio.config import settings from gfmstudio.fine_tuning.core.kubernetes import ( - check_k8s_job_status, + check_tuning_task_status, deploy_hpo_tuning_job, deploy_tuning_job, ) from gfmstudio.fine_tuning.utils.webhook_event_handlers import ( handle_dataset_factory_webhooks, handle_fine_tuning_webhooks, + update_tune_status, ) from gfmstudio.inference.services import ( invoke_cancel_inference_handler, @@ -55,14 +56,14 @@ ) def deploy_tuning_job_celery_task(**kwargs): # Inject the monitoring task into kwargs to avoid circular import - kwargs['_monitor_task'] = monitor_k8_job_completion_task + kwargs["_monitor_task"] = monitor_k8_job_completion_task return asyncio.run(deploy_tuning_job(**kwargs)) @celery_app.task( name="monitor_k8_job_completion_task", queue=FT_SERVICE_NAME, - bind=True, # Bind to get access to self for retry + bind=True, max_retries=30, # Allow many retries default_retry_delay=30, # Start with 30 seconds ) @@ -72,31 +73,50 @@ def monitor_k8_job_completion_task(self, ftune_id: str): max_wait = settings.KJOB_MAX_WAIT_SECONDS or 7200 try: - k8s_job_status, _ = asyncio.run(check_k8s_job_status(ftune_id)) + k8s_job_status, _ = asyncio.run(check_tuning_task_status(ftune_id)) except Exception as exc: if "not found" in str(exc): # Job not found, consider it done (likely already completed and deleted) - logger.debug(f"{ftune_id}: Job not found, assuming completed and cleaned up") + logger.debug( + f"{ftune_id}: Job not found, assuming completed and cleaned up" + ) return "Completed" # Unexpected error, retry with exponential backoff logger.warning(f"{ftune_id}: Error checking job status, will retry: {exc}") - raise self.retry(exc=exc, countdown=min(2 ** self.request.retries * 30, max_wait)) - - # Handle None status (job not found after retries) + raise self.retry(exc=exc, countdown=min(2**self.request.retries * 30, max_wait)) + if k8s_job_status is None: - # Job doesn't exist - either completed and deleted, or never created - logger.debug(f"{ftune_id}: Job status is None, assuming completed and cleaned up") + logger.debug( + f"{ftune_id}: Job status is None, assuming completed and cleaned up" + ) return "Completed" - - if k8s_job_status in ["Complete", "Failed"]: - # Job is done - logger.info(f"{ftune_id}: Job finished with status: {k8s_job_status}") + + if k8s_job_status == "Unknown": + logger.info( + f"{ftune_id}: Job status is Unknown (resources deleted), assuming completed and cleaned up" + ) + return "Completed" + + terminal_statuses = ( + f"{settings.K8S_JOB_SUCCESS_STATUSES},{settings.K8S_JOB_FAILURE_STATUSES}" + ) + terminal_statuses = [s.strip().lower() for s in terminal_statuses.split(",")] + if k8s_job_status in terminal_statuses: + logger.info(f"{ftune_id}: Job finished with status: { k8s_job_status }") return k8s_job_status - + + if k8s_job_status == "Running": + try: + asyncio.run(update_tune_status(ftune_id, "In_progress")) + except Exception as e: + logger.warning(f"{ftune_id}: Failed to update status to In_progress: {e}") + # Job still running, retry with exponential backoff # countdown: 30s, 60s, 120s, 240s, 480s, 960s (max with default 600s setting) - countdown = min(2 ** self.request.retries * 30, max_wait) - logger.info(f"{ftune_id}: Job status={k8s_job_status}, will check again in {countdown}s") + countdown = min(2**self.request.retries * 30, max_wait) + logger.info( + f"{ftune_id}: Job status={k8s_job_status}, will check again in {countdown}s" + ) raise self.retry(countdown=countdown) @@ -105,7 +125,7 @@ def monitor_k8_job_completion_task(self, ftune_id: str): queue=FT_SERVICE_NAME, ) def deploy_hpo_tuning_celery_task(**kwargs): - kwargs['_monitor_task'] = monitor_k8_job_completion_task + kwargs["_monitor_task"] = monitor_k8_job_completion_task return asyncio.run(deploy_hpo_tuning_job(**kwargs)) diff --git a/gfmstudio/config.py b/gfmstudio/config.py index 6e933d3..a56addf 100644 --- a/gfmstudio/config.py +++ b/gfmstudio/config.py @@ -32,10 +32,14 @@ class Settings(BaseSettings): EIS_API_KEY: Optional[str] = "" # Object storage / COS details - OBJECT_STORAGE_ENDPOINT: Optional[str] = Field(description="COS endpoint", default="") + OBJECT_STORAGE_ENDPOINT: Optional[str] = Field( + description="COS endpoint", default="" + ) OBJECT_STORAGE_KEY_ID: str = Field(description="Key ID for COS authentication") OBJECT_STORAGE_SEC_KEY: str = Field(description="Secret Key for COS authentication") - OBJECT_STORAGE_REGION: Optional[str] = Field(description="URL with the region", default="") + OBJECT_STORAGE_REGION: Optional[str] = Field( + description="URL with the region", default="" + ) OBJECT_STORAGE_SIGNATURE_VERSION: Optional[str] = Field(default="s3v4") TEMP_UPLOADS_BUCKET: Optional[str] = Field( @@ -43,9 +47,13 @@ class Settings(BaseSettings): default="geospatial-studio-temporary-uploads", ) # Add pipelines v2 COS credentials - PIPELINES_V2_COS_BUCKET: Optional[str] = Field(default="test-geo-inference-pipelines") + PIPELINES_V2_COS_BUCKET: Optional[str] = Field( + default="test-geo-inference-pipelines" + ) PIPELINES_V2_INFERENCE_ROOT_FOLDER: Optional[str] = Field(default=None) - PIPELINES_V2_INTEGRATION_TYPE: Optional[str] = Field(default="database") # Options: database, kafka, api + PIPELINES_V2_INTEGRATION_TYPE: Optional[str] = Field( + default="database" + ) # Options: database, kafka, api INFERENCE_LOGS_BASE_PATH: Optional[str] = Field(default="/data") DEFAULT_SYSTEM_USER: Optional[str] = "system@ibm.com" @@ -69,8 +77,12 @@ class Settings(BaseSettings): INFERENCE_PIPELINE_BASE_URL: Optional[str] = Field( default="https://pipelines-orchestration-nasageospatial-uat.cash.sl.cloud9.ibm.com/v1" ) - INFERENCE_PIPELINE_ID: Optional[str] = Field(default="23a3e4e9-d81d-4694-a2b2-543581e63c12") - DEPLOY_FOR_INFERENCE_PIPELINE_ID: Optional[str] = Field(default="ad249995-58ed-4d56-9ec4-41021f75ee23") + INFERENCE_PIPELINE_ID: Optional[str] = Field( + default="23a3e4e9-d81d-4694-a2b2-543581e63c12" + ) + DEPLOY_FOR_INFERENCE_PIPELINE_ID: Optional[str] = Field( + default="ad249995-58ed-4d56-9ec4-41021f75ee23" + ) DATA_ADVISOR_ENABLED: Optional[bool] = Field(default=False) DATA_ADVISOR_MAX_CLOUD_COVER: Optional[float] = Field(default=80) @@ -121,7 +133,9 @@ class Settings(BaseSettings): "Task": "WGS-1135", }, ) - JIRA_API_KEY: Optional[str] = Field(description="Jira API Key with write access", default="") + JIRA_API_KEY: Optional[str] = Field( + description="Jira API Key with write access", default="" + ) JIRA_API_URI: str = Field( description="Jira API URI", default="https://jsw.ibm.com/rest/api/2", @@ -135,7 +149,9 @@ class Settings(BaseSettings): #################### # FINE TUNING #################### - FT_IMAGE_PULL_SECRETS: str = Field(default="ris-private-registry", description="Image pull secret to pull images.") + FT_IMAGE_PULL_SECRETS: str = Field( + default="ris-private-registry", description="Image pull secret to pull images." + ) MMSEGMENTATION_GEO_IMAGE: str = Field( default="us.icr.io/gfmaas/mmsegmentation_geospatial:v0.1.0", description="The mmsegmentation docker image to run the fine tune process", @@ -164,7 +180,9 @@ class Settings(BaseSettings): description="Path in the pod where the backbone models PVC is mounted", default="/terratorch/", ) - FILES_PVC: Optional[str] = Field(description="Name of the Persistent Volume ", default="gfm-ft-files-pvc") + FILES_PVC: Optional[str] = Field( + description="Name of the Persistent Volume ", default="gfm-ft-files-pvc" + ) NAMESPACE: str = Field( default="geoft", description="This is the namespace (or OCP project) where the helm upgrade is run", @@ -251,6 +269,13 @@ class Settings(BaseSettings): description="Cut-off data after which terratorch v2 should be in use", ) + # Job Status Configuration - stored as comma-separated strings + K8S_JOB_SUCCESS_STATUSES: str = Field( + default="Complete,Succeeded,SuccessCriteriaMet" + ) + + K8S_JOB_FAILURE_STATUSES: str = Field(default="Failed,Error,FailureTarget") + #################### # DATASET FACTORY #################### @@ -258,10 +283,18 @@ class Settings(BaseSettings): DATA_PIPELINE_BASE_URL: Optional[str] = Field( default="https://geofm-workflow-orchestrator-internal-nasageospatial-dev.cash.sl.cloud9.ibm.com/v1" ) - DATA_ONBOARD_PIPELINE_ID: Optional[str] = Field(default="3148e9aa-ee1e-40ba-a9a7-8e783deac6b7") - DATA_ONBOARD_PIPELINE_V2_ID: Optional[str] = Field(default="335d21ce-269e-47d5-91eb-65f315b5c728") - DATASET_PIPELINE_IMAGE: Optional[str] = Field(default="us.icr.io/gfmaas/geostudio-curated-upload:latest") - model_config = ConfigDict(extra="allow", case_sensitive=True, env_file=os.path.join(BASE_DIR, ".env")) + DATA_ONBOARD_PIPELINE_ID: Optional[str] = Field( + default="3148e9aa-ee1e-40ba-a9a7-8e783deac6b7" + ) + DATA_ONBOARD_PIPELINE_V2_ID: Optional[str] = Field( + default="335d21ce-269e-47d5-91eb-65f315b5c728" + ) + DATASET_PIPELINE_IMAGE: Optional[str] = Field( + default="us.icr.io/gfmaas/geostudio-curated-upload:latest" + ) + model_config = ConfigDict( + extra="allow", case_sensitive=True, env_file=os.path.join(BASE_DIR, ".env") + ) #################### # AMO @@ -288,7 +321,9 @@ class Settings(BaseSettings): description="COS bucket where new model artifacts are stored", default="geodev-amo-input-bucket", ) - AMO_INFERENCE_SHARED_PVC: Optional[str] = Field(description="Inference Shared PVC", default="inference-shared-pvc") + AMO_INFERENCE_SHARED_PVC: Optional[str] = Field( + description="Inference Shared PVC", default="inference-shared-pvc" + ) GENERIC_PROCESSOR_BUCKET: Optional[str] = Field( description="COS bucket for generic processor scripts", default="geodev-generic-processor", diff --git a/gfmstudio/fine_tuning/core/kubernetes.py b/gfmstudio/fine_tuning/core/kubernetes.py index 037894a..8f2a665 100644 --- a/gfmstudio/fine_tuning/core/kubernetes.py +++ b/gfmstudio/fine_tuning/core/kubernetes.py @@ -72,7 +72,7 @@ def get_sa_token(): """ token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token" try: - with open(token_path, 'r') as f: + with open(token_path, "r") as f: return f.read().strip() except FileNotFoundError: raise ValueError(f"Service account token not found at {token_path}") @@ -80,19 +80,19 @@ def get_sa_token(): def get_k8s_server_url(): """Get the Kubernetes API server URL. - + First tries to read from the service account (when running inside a cluster), then falls back to settings. - + Returns ------- str The Kubernetes API server URL """ try: - k8s_host = os.getenv('KUBERNETES_SERVICE_HOST') - k8s_port = os.getenv('KUBERNETES_SERVICE_PORT', '443') - + k8s_host = os.getenv("KUBERNETES_SERVICE_HOST") + k8s_port = os.getenv("KUBERNETES_SERVICE_PORT", "443") + if k8s_host: server_url = f"https://{k8s_host}:{k8s_port}" logging.debug(f"Detected Kubernetes server from environment: {server_url}") @@ -121,7 +121,7 @@ async def ensure_logged_in(command=COMMAND): return except ProcessError: logging.debug("Logging into the cluster") - + # Get server URL and token try: k8s_server = get_k8s_server_url() @@ -129,7 +129,7 @@ async def ensure_logged_in(command=COMMAND): except ValueError as e: logging.error(f"Failed to get cluster credentials: {e}") raise - + # Check if we're using OpenShift (oc) or Kubernetes (kubectl) # Try oc login first (for OpenShift) try: @@ -140,20 +140,20 @@ async def ensure_logged_in(command=COMMAND): "set-cluster", "default-cluster", f"--server={k8s_server}", - "--insecure-skip-tls-verify=true" + "--insecure-skip-tls-verify=true", ] await check_output(*set_cluster_cmd) - + # Set credentials set_credentials_cmd = [ "kubectl", "config", "set-credentials", "default-user", - f"--token={sa_token}" + f"--token={sa_token}", ] await check_output(*set_credentials_cmd) - + # Set context set_context_cmd = [ "kubectl", @@ -161,19 +161,14 @@ async def ensure_logged_in(command=COMMAND): "set-context", "default-context", "--cluster=default-cluster", - "--user=default-user" + "--user=default-user", ] await check_output(*set_context_cmd) - + # Use context - use_context_cmd = [ - "kubectl", - "config", - "use-context", - "default-context" - ] + use_context_cmd = ["kubectl", "config", "use-context", "default-context"] await check_output(*use_context_cmd) - + logging.info("Successfully configured kubectl for Kubernetes") except ProcessError: logging.error("Failed to configure kubectl for Kubernetes") @@ -319,7 +314,9 @@ async def deploy_hpo_tuning_job( if settings.CELERY_TASKS_ENABLED and status == "In_progress": monitor_task = kwargs.get("_monitor_task") # For celery tasks, wait untill the kubernetes job is complete before exiting. - await monitor_k8_job_completion(f"{deployment_id}-hpo",monitor_task=monitor_task) + await monitor_k8_job_completion( + f"{deployment_id}-hpo", monitor_task=monitor_task + ) return deployment_id, status @@ -446,7 +443,7 @@ async def deploy_tuning_job( if settings.CELERY_TASKS_ENABLED and status == "In_progress": # For celery tasks, wait untill the kubernetes job is complete before exiting. # Extract monitor_task from kwargs if provided - monitor_task = kwargs.get('_monitor_task') + monitor_task = kwargs.get("_monitor_task") await monitor_k8_job_completion(ftune_id, monitor_task=monitor_task) return deployment_id, status @@ -454,7 +451,7 @@ async def deploy_tuning_job( async def monitor_k8_job_completion(ftune_id: str, monitor_task=None): """Trigger Celery task to monitor Kubernetes job completion. - + This function schedules a Celery task that will monitor the job with exponential backoff, releasing the worker between checks. @@ -466,24 +463,125 @@ async def monitor_k8_job_completion(ftune_id: str, monitor_task=None): The Celery task to use for monitoring. If None, logs a warning. """ if monitor_task is None: - logger.warning(f"{ftune_id}: No monitoring task provided, job will not be monitored") + logger.warning( + f"{ftune_id}: No monitoring task provided, job will not be monitored" + ) return - + # Schedule the monitoring task asynchronously # This releases the current worker immediately monitor_task.apply_async(args=[ftune_id]) # type: ignore[attr-defined] logger.info(f"{ftune_id}: Scheduled monitoring task for job completion") -async def check_k8s_job_status(tune_id: str, retry_label_lookup=True): +async def get_pod_phase(job_name: str) -> str | None: + """Check the status of a pod associated with a Kubernetes job. + + This function checks if the pod is actually running, not just pending. + Useful for determining if a job is truly in progress or just waiting for resources. + + Parameters + ---------- + job_name : str + The Kubernetes job name + + Returns + ------- + str + The pod phase: 'Running', 'Pending', 'Succeeded', 'Failed', 'Unknown', or None if no pod found + """ + try: + await ensure_logged_in(f"kubectl get job --namespace={settings.NAMESPACE}") + + # Get pod status using the job-name label + command = [ + "kubectl", + "get", + "pods", + "-l", + f"job-name={job_name}", + "-o", + "jsonpath={.items[0].status.phase}", + ] + + result = await run_subprocess_cmds(command=command) + return result[0].strip() if result and result[0] else None + + except Exception as e: + # Handle case where job/pod has been deleted by webhook + logger.debug(f"{job_name}: Error checking pod status (likely deleted): {e}") + return None + + +async def get_job_conditions(job_name: str) -> str | None: + """ + Get the conditions of a Kubernetes job. + Parameters + ---------- + job_name : str + The name of the job to check. + Returns + ------- + str + The conditions of the job. + None + If the job has no conditions. + """ + try: + cmd = [ + "kubectl", + "get", + "job", + job_name, + "-o", + "jsonpath={.status.conditions[0].type}", + ] + result = await run_subprocess_cmds(cmd) + return result[0].strip() if result and result[0] else None + except Exception as e: + logger.debug(f"Error checking job conditions: {e}") + return None + + +async def get_aggregate_job_and_pod_status(job_name: str) -> str: + """Get the status of a Kubernetes job. + + Parameters + ---------- + job_name : str + The name of the job to check. + Returns + str + The status of the job. + """ + condition = await get_job_conditions(job_name) + terminal_statuses = ( + f"{settings.K8S_JOB_SUCCESS_STATUSES},{settings.K8S_JOB_FAILURE_STATUSES}" + ) + terminal_statuses = [s.strip().lower() for s in terminal_statuses.split(",")] + + if condition in terminal_statuses: + return condition + + # Job exists but no terminal condition → check pod + pod_phase = await get_pod_phase(job_name) + if pod_phase: + return pod_phase + return "Unknown" + + +async def check_tuning_task_status(tune_id: str, retry_label_lookup=True): """Function to check Kubernetes job status + This function checks both the job status and optionally the pod phase to determine + if a job is truly running or just waiting for resources (pending). + Parameters ---------- tune_id : str Tune id retry_label_lookup: bool - Wheather to retry lookup with labels. + Whether to retry lookup with labels. Returns ------- @@ -496,22 +594,10 @@ async def check_k8s_job_status(tune_id: str, retry_label_lookup=True): # Log in await ensure_logged_in(f"kubectl get job --namespace={settings.NAMESPACE}") - command = [ - "kubectl", - "get", - "job", - kjob_id, - "-o", - "jsonpath={.status.conditions[*].type}", - ] - - result = await run_subprocess_cmds(command=command) - logger.info(f"kubectl run cmds result: {command} ---> {result}") + # Direct resolution via unified status function + status = await get_aggregate_job_and_pod_status(kjob_id) - if result and result[0] != "": - # Job has completion status (Complete or Failed) - status = result[0].strip() - logger.info(f"{kjob_id}: Status for job {status}") + if status not in ["Running"]: return status, kjob_id else: @@ -535,15 +621,19 @@ async def check_k8s_job_status(tune_id: str, retry_label_lookup=True): job_name = job_name.split("/")[-1] logger.info(f"kubectl retry job_name: {job_name}") if job_name: - result = await check_k8s_job_status(job_name, retry_label_lookup=False) + result = await check_tuning_task_status( + job_name, retry_label_lookup=False + ) logger.info(f"kubectl retry result: {result}") # If still no status after retry, treat as Running if result and result[0] is None: - logger.info(f"{job_name}: Job exists but no status yet, treating as Running") + logger.info( + f"{job_name}: Job exists but no status yet, treating as Running" + ) return "Running", job_name return result if result else ("Running", job_name) - - # Job exists but has no conditions - verify it exists and treat as Running + + # Job exists but has no conditions - verify it exists and check pod status verify_cmd = [ "kubectl", "get", @@ -553,15 +643,15 @@ async def check_k8s_job_status(tune_id: str, retry_label_lookup=True): "name", ] verify_result = await run_subprocess_cmds(command=verify_cmd) - + if verify_result and verify_result[0]: - # Job exists but no status conditions yet - it's running or pending - logger.info(f"{kjob_id}: Job exists but no status conditions yet, treating as Running") + # Job exists but no status conditions yet + # Check if we should verify the pod phase + logger.info(f"{kjob_id}: Job exists but no status yet → Running") return "Running", kjob_id - else: - # Job doesn't exist at all - logger.warning(f"{kjob_id}: Job not found in cluster") - return None, tune_id + # Job doesn't exist at all + logger.warning(f"{kjob_id}: Job not found in cluster") + return None, tune_id async def delete_k8s_job_resources(tune_id: str): @@ -711,7 +801,7 @@ async def collect_pod_logs(tune_id: str, retry_label_lookup=True): job_name = job_name.split("/")[-1] logger.info(f"kubectl retry job_name: {job_name}") if job_name: - result = await check_k8s_job_status( + result = await check_tuning_task_status( job_name, retry_label_lookup=False ) logger.info(f"kubectl retry result: {result}") diff --git a/gfmstudio/fine_tuning/utils/tune_handlers.py b/gfmstudio/fine_tuning/utils/tune_handlers.py index b5a9cf6..ce25125 100644 --- a/gfmstudio/fine_tuning/utils/tune_handlers.py +++ b/gfmstudio/fine_tuning/utils/tune_handlers.py @@ -4,6 +4,7 @@ """Helper functions for tune submission and management.""" +import asyncio import base64 import logging import os @@ -22,7 +23,10 @@ from gfmstudio.config import settings from gfmstudio.fine_tuning import schemas from gfmstudio.fine_tuning.core import object_storage, tunes -from gfmstudio.fine_tuning.core.kubernetes import deploy_tuning_job +from gfmstudio.fine_tuning.core.kubernetes import ( + check_tuning_task_status, + deploy_tuning_job, +) from gfmstudio.fine_tuning.core.schema import TuneTemplateParameters from gfmstudio.fine_tuning.core.tuning_config_utils import ( convert_to_jinja2_compatible_braces, @@ -40,6 +44,9 @@ from gfmstudio.fine_tuning.models import BaseModels, GeoDataset, Tunes, TuneTemplate from gfmstudio.fine_tuning.utils.geoserver_handlers import convert_to_geoserver_sld +tune_crud = crud.ItemCrud(model=Tunes) +from gfmstudio.common.api import crud, utils + logger = logging.getLogger(__name__) tunes_crud = crud.ItemCrud(model=Tunes) @@ -732,8 +739,7 @@ async def submit_tune_job( try: if settings.CELERY_TASKS_ENABLED: - # Submit via Celery - deploy_tuning_job_celery_task.apply_async( + result = deploy_tuning_job_celery_task.apply_async( kwargs={ "ftune_id": tune_id, "ftune_config_file": config_path, @@ -742,17 +748,41 @@ async def submit_tune_job( }, task_id=tune_id, ) - ftune_job_id = f"kjob-{tune_id}".lower() - status = "In_progress" + + try: + job_result = await asyncio.to_thread(result.get, timeout=5) + ftune_job_id, job_status = job_result + + if job_status == "Error": + status = "Failed" + else: + ftune_job_id = f"kjob-{tune_id}".lower() + status = "Pending" + except Exception as e: + logger.debug(f"{tune_id}: Job creation in progress: {e}") + ftune_job_id = f"kjob-{tune_id}".lower() + status = "Pending" else: - # Submit directly ftune_job_id, updated_status = await deploy_tuning_job( ftune_id=tune_id, ftune_config_file=config_path, ftuning_runtime_image=runtime_image, tune_type=schemas.TuneOptionEnum.K8_JOB, ) - status = updated_status or "Submitted" + + if updated_status == "Error": + status = "Failed" + elif updated_status == "In_progress": + k8s_status, _ = await check_tuning_task_status(tune_id) + + if k8s_status == "Running": + status = "In_progress" + elif k8s_status == "Pending": + status = "Pending" + else: + status = updated_status + else: + status = updated_status or "Submitted" logger.info(f"Tune job {ftune_job_id} submitted with status: {status}") diff --git a/gfmstudio/fine_tuning/utils/webhook_event_handlers.py b/gfmstudio/fine_tuning/utils/webhook_event_handlers.py index 2dc288d..2e33068 100644 --- a/gfmstudio/fine_tuning/utils/webhook_event_handlers.py +++ b/gfmstudio/fine_tuning/utils/webhook_event_handlers.py @@ -26,6 +26,45 @@ tune_crud = crud.ItemCrud(model=Tunes) dataset_crud = crud.ItemCrud(model=GeoDataset) +terminal_statuses = ( + f"{settings.K8S_JOB_SUCCESS_STATUSES},{settings.K8S_JOB_FAILURE_STATUSES}" +) +terminal_statuses = [s.strip().lower() for s in terminal_statuses.split(",")] + + +async def update_tune_status(tune_id: str, new_status: str, db: Session = None): + """Update tune status if current status is Pending. + + This is used by the monitoring task to update status when pod starts running. + + Parameters + ---------- + tune_id : str + The tune ID to update + new_status : str + The new status to set (e.g., "In_progress") + db : Session, optional + Database session, by default None + """ + if db is None: + db_gen = utils.get_db() + session = await anext(db_gen) + else: + session = db + + try: + tune = tune_crud.get_by_id(db=session, item_id=tune_id) + if tune and tune.status == "Pending": + tune_crud.update( + db=session, + item_id=tune_id, + item={"status": new_status}, + protected=False, + ) + logger.info(f"{tune_id}: Updated status from Pending to {new_status}") + except Exception as e: + logger.warning(f"{tune_id}: Failed to update status: {e}") + async def free_k8s_resources(tune_id: str, max_wait_seconds: int = 3600): """Function that checks status of job and if in terminal state, deletes the job, pvc, configMap @@ -40,26 +79,28 @@ async def free_k8s_resources(tune_id: str, max_wait_seconds: int = 3600): Unique tune_id """ - k8s_job_status, job_id = await kubernetes.check_k8s_job_status(tune_id) + k8s_job_status, job_id = await kubernetes.check_tuning_task_status(tune_id) logger.info(f"{tune_id} Webhook: Job status: {k8s_job_status}") # delete resources - k8s_job_status = str(k8s_job_status).lower() + k8s_job_status_lower = str(k8s_job_status).lower() start_time = asyncio.get_event_loop().time() - while ("complete" not in k8s_job_status) and ("failed" not in k8s_job_status): + while k8s_job_status_lower not in terminal_statuses: elapsed = asyncio.get_event_loop().time() - start_time if elapsed > max_wait_seconds: - logger.error(f"{tune_id} Job status check timeout after {max_wait_seconds}s") + logger.error( + f"{tune_id} Job status check timeout after {max_wait_seconds}s" + ) break await asyncio.sleep(30) - k8s_job_status, job_id = await kubernetes.check_k8s_job_status(tune_id) + k8s_job_status, job_id = await kubernetes.check_tuning_task_status(tune_id) if k8s_job_status is None: logger.warning(f"{tune_id} Job status is None during poll") break - k8s_job_status = str(k8s_job_status).lower() + k8s_job_status_lower = str(k8s_job_status).lower() # delete resources; job, pvc, ConfigMap try: @@ -82,25 +123,27 @@ async def free_k8s_resources_by_label(tune_id: str, max_wait_seconds: int = 3600 Unique tune_id """ - k8s_job_status, _ = await kubernetes.check_k8s_job_status(tune_id) + k8s_job_status, _ = await kubernetes.check_tuning_task_status(tune_id) logger.info(f"{tune_id} Webhook: Job status: {k8s_job_status}") - k8s_job_status = str(k8s_job_status).lower() + k8s_job_status_lower = str(k8s_job_status).lower() start_time = asyncio.get_event_loop().time() - while ("complete" not in k8s_job_status) and ("failed" not in k8s_job_status): + while k8s_job_status_lower not in terminal_statuses: elapsed = asyncio.get_event_loop().time() - start_time if elapsed > max_wait_seconds: - logger.error(f"{tune_id} Job status check timeout after {max_wait_seconds}s") + logger.error( + f"{tune_id} Job status check timeout after {max_wait_seconds}s" + ) break await asyncio.sleep(30) - k8s_job_status, _ = await kubernetes.check_k8s_job_status(tune_id) + k8s_job_status, _ = await kubernetes.check_tuning_task_status(tune_id) if k8s_job_status is None: logger.warning(f"{tune_id} Job status is None during poll") break - k8s_job_status = str(k8s_job_status).lower() + k8s_job_status_lower = str(k8s_job_status).lower() # append kjob to tune-id label = f"app=kjob-{tune_id}".lower() @@ -137,14 +180,18 @@ async def upload_logs_cos(errored_logs_str: str, full_s3_log_file_path: str): Body=errored_logs_str, Key=full_s3_log_file_path, ) - logger.info(f"Log file successfully uploaded to s3://{settings.TUNES_FILES_BUCKET}/{full_s3_log_file_path}") + logger.info( + f"Log file successfully uploaded to s3://{settings.TUNES_FILES_BUCKET}/{full_s3_log_file_path}" + ) except Exception: detail = "Could not put pod logs to COS bucket." logger.exception(detail) raise HTTPException(status_code=500, detail=detail) from None -async def handle_fine_tuning_webhooks(event: Union[NotificationCreate, dict], user: str, db: Session = None): +async def handle_fine_tuning_webhooks( + event: Union[NotificationCreate, dict], user: str, db: Session = None +): """Handle fine tuning service webhook events. For Tuning tasks, if status is; @@ -194,7 +241,9 @@ async def handle_fine_tuning_webhooks(event: Union[NotificationCreate, dict], us elif event.detail["status"] == "Finished": logger.debug(f"{tune_id}: Tuning Task finished successfully") elif event.detail["status"] == "Error": - logger.debug(f"{tune_id}: Tuning Task Errored and resources already deleted.") + logger.debug( + f"{tune_id}: Tuning Task Errored and resources already deleted." + ) await free_k8s_resources(tune_id) try: @@ -221,7 +270,9 @@ async def handle_fine_tuning_webhooks(event: Union[NotificationCreate, dict], us return notification_id -async def handle_dataset_factory_webhooks(event: Union[NotificationCreate, dict], user: str, db: Session = None): +async def handle_dataset_factory_webhooks( + event: Union[NotificationCreate, dict], user: str, db: Session = None +): """Handle dataset factory service webhook events. For Dataset Factory tasks, if status is; @@ -266,7 +317,9 @@ async def handle_dataset_factory_webhooks(event: Union[NotificationCreate, dict] user = dataset.created_by or user item = { "status": event.detail["status"], - "error": transform_error_message(event.detail["error_code"], event.detail["error_message"]), + "error": transform_error_message( + event.detail["error_code"], event.detail["error_message"] + ), "logs": cos_log_path, } try: @@ -309,44 +362,62 @@ async def handle_dataset_factory_webhooks(event: Union[NotificationCreate, dict] # Job cleanup; when status is Succeeded or Failed if event.detail["status"] != "Onboarding": - k8s_delete_job_command = f"kubectl delete job onboarding-v2-pipeline-{dataset_id}" + k8s_delete_job_command = ( + f"kubectl delete job onboarding-v2-pipeline-{dataset_id}" + ) k8s_delete_secret_command = f"kubectl delete secret dataset-onboarding-v2-pipeline-params-{dataset_id}" # Validate BASE_DIR before using it in rm command if not BASE_DIR or str(BASE_DIR).strip() == "": - logger.info(f"BASE_DIR is empty or invalid: '{BASE_DIR}'. Try using /app as BASE_DIR") - deployment_file_path = f"/app/deployment/jobs/onboarding-v2-pipeline-{dataset_id}.yaml" + logger.info( + f"BASE_DIR is empty or invalid: '{BASE_DIR}'. Try using /app as BASE_DIR" + ) + deployment_file_path = ( + f"/app/deployment/jobs/onboarding-v2-pipeline-{dataset_id}.yaml" + ) remove_job_deployment_file_command = f"rm -f {deployment_file_path}" else: deployment_file_path = f"{BASE_DIR}/deployment/jobs/onboarding-v2-pipeline-{dataset_id}.yaml" remove_job_deployment_file_command = f"rm -f {deployment_file_path}" try: - delete_job_output = subprocess.check_output(k8s_delete_job_command, shell=True) + delete_job_output = subprocess.check_output( + k8s_delete_job_command, shell=True + ) logger.info(delete_job_output) except subprocess.CalledProcessError as exc: error_message = str(exc.output) logger.error("Unable to remove the job. Error - " + error_message) try: - delete_secret_output = subprocess.check_output(k8s_delete_secret_command, shell=True) + delete_secret_output = subprocess.check_output( + k8s_delete_secret_command, shell=True + ) logger.info(delete_secret_output) except subprocess.CalledProcessError as exc: error_message = str(exc.output) - logger.error("Unable to remove secrets from the job. Error - " + error_message) + logger.error( + "Unable to remove secrets from the job. Error - " + error_message + ) try: - delete_deployment_file_output = subprocess.check_output(remove_job_deployment_file_command, shell=True) + delete_deployment_file_output = subprocess.check_output( + remove_job_deployment_file_command, shell=True + ) logger.info(delete_deployment_file_output) except subprocess.CalledProcessError as exc: error_message = str(exc.output) - logger.error("Unable to remove deployment file, Error - " + error_message) + logger.error( + "Unable to remove deployment file, Error - " + error_message + ) except Exception: logger.exception("Dataset status was not updated.") raise HTTPException( status_code=500, - detail={"message": f"Internal server error occurred. Dataset-{dataset_id} not updated."}, + detail={ + "message": f"Internal server error occurred. Dataset-{dataset_id} not updated." + }, ) logger.info(f"Dataset status and details has been updated for {dataset_id}")