From 8efc26a9feb423ae06953a9671c8b80ec5e58a4c Mon Sep 17 00:00:00 2001 From: Taylor Braun-Jones Date: Tue, 10 Jun 2025 19:45:31 +0000 Subject: [PATCH 1/5] crude initial optimized ECS describe_tasks usage --- dask_cloudprovider/aws/ecs.py | 298 ++++++++++++++++++++++++++++++---- 1 file changed, 265 insertions(+), 33 deletions(-) diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index 79e7bb0d..b74d318d 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -3,7 +3,9 @@ import uuid import warnings import weakref +from collections import defaultdict from typing import List, Optional +from cachetools import TTLCache import dask @@ -36,12 +38,198 @@ raise ImportError(msg) from e logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) DEFAULT_TAGS = { "createdBy": "dask-cloudprovider" } # Package tags to apply to all resources +MAX_TASKS_PER_DESCRIBE_CALL = 100 +DEFAULT_POLL_INTERVAL_S = 1 # How often to check if there are tasks to describe +MAX_RETRY_ATTEMPTS = 5 +INITIAL_BACKOFF_S = 1 +MAX_BACKOFF_S = 20 + +class TaskPoller: + def __init__(self, client_factory, poll_interval_s=DEFAULT_POLL_INTERVAL_S): + self._client_factory = client_factory + self._poll_interval_s = poll_interval_s + self._tasks_to_poll = defaultdict(lambda: {"arns": set(), "futures": {}}) + self._polled_task_details_cache: TTLCache[str, dict] = TTLCache(maxsize=1000, ttl=DEFAULT_POLL_INTERVAL_S) # Cache for task_arn -> task_detail + self._lock = asyncio.Lock() + self._poll_loop_task = None + + async def ensure_running(self): + async with self._lock: + if self._poll_loop_task is None or self._poll_loop_task.done(): + self._poll_loop_task = asyncio.create_task(self._poll_loop()) + logger.info("TaskPoller started.") + + async def stop(self): + async with self._lock: + if self._poll_loop_task: + self._poll_loop_task.cancel() + try: + await self._poll_loop_task + except asyncio.CancelledError: + logger.info("TaskPoller poll loop cancelled.") + except Exception as e: + logger.error(f"Error during TaskPoller stop: {e}", exc_info=True) + self._poll_loop_task = None + logger.info("TaskPoller stopped.") + # Clear futures to prevent tasks from hanging if poller is stopped then restarted + for cluster_data in self._tasks_to_poll.values(): + for future in cluster_data["futures"].values(): + if not future.done(): + future.set_exception(RuntimeError("TaskPoller stopped before task details could be fetched.")) + self._tasks_to_poll.clear() + self._polled_task_details_cache.clear() + + + async def get_task_details(self, cluster_arn, task_arn): + await self.ensure_running() + async with self._lock: + # Check cache first + if (task_details := self._polled_task_details_cache.get(task_arn)) is not None: + # logger.debug(f"Task {task_arn} found in cache for cluster {cluster_arn}.") + return task_details + + # Check if a future already exists for this task_arn + cluster_data = self._tasks_to_poll[cluster_arn] + if task_arn in cluster_data["futures"]: + future = cluster_data["futures"][task_arn] + else: + future = asyncio.Future() + cluster_data["arns"].add(task_arn) + cluster_data["futures"][task_arn] = future + return await future + + async def _poll_loop(self): + while self._poll_loop_task is not None and not self._poll_loop_task.cancelled(): + try: + await asyncio.sleep(self._poll_interval_s) + await self._process_poll_queue() + except asyncio.CancelledError: + logger.info("TaskPoller poll loop gracefully exiting due to cancellation.") + break + except Exception as e: + logger.error(f"Error in TaskPoller poll loop: {e}", exc_info=True) + # Avoid tight loop on persistent errors, wait longer + await asyncio.sleep(self._poll_interval_s * 5) + logger.info("TaskPoller poll loop finished.") + + + async def _process_poll_queue(self): + # logger.debug("Processing task poll queue...") + clusters_to_process = [] + async with self._lock: + if not self._tasks_to_poll: + return + for cluster_arn, data in self._tasks_to_poll.items(): + if data["arns"]: + clusters_to_process.append(cluster_arn) + + # logger.debug(f"{len(clusters_to_process)} clusters to process") + if not clusters_to_process: + return + + async with self._client_factory("ecs") as ecs: + for cluster_arn in clusters_to_process: + task_arns_to_fetch_for_cluster = [] + current_futures_for_cluster = {} + + async with self._lock: + if cluster_arn in self._tasks_to_poll and self._tasks_to_poll[cluster_arn]["arns"]: + task_arns_to_fetch_for_cluster = list(self._tasks_to_poll[cluster_arn]["arns"]) + # Keep track of futures associated with this specific fetch attempt + current_futures_for_cluster = { + arn: self._tasks_to_poll[cluster_arn]["futures"][arn] + for arn in task_arns_to_fetch_for_cluster + if arn in self._tasks_to_poll[cluster_arn]["futures"] and not self._tasks_to_poll[cluster_arn]["futures"][arn].done() + } + self._tasks_to_poll[cluster_arn]["arns"].clear() # Clear ARNs for this poll cycle + else: + continue + + if not task_arns_to_fetch_for_cluster: + continue + + for i in range(0, len(task_arns_to_fetch_for_cluster), MAX_TASKS_PER_DESCRIBE_CALL): + batch_arns = task_arns_to_fetch_for_cluster[i:i + MAX_TASKS_PER_DESCRIBE_CALL] + if not batch_arns: + continue + + wait_duration = INITIAL_BACKOFF_S + success = False + for attempt in range(MAX_RETRY_ATTEMPTS): + try: + response = await ecs.describe_tasks(cluster=cluster_arn, tasks=batch_arns) + fetched_task_details = {task["taskArn"]: task for task in response.get("tasks", [])} + + async with self._lock: + for arn, detail in fetched_task_details.items(): + self._polled_task_details_cache[arn] = detail + if arn in current_futures_for_cluster and not current_futures_for_cluster[arn].done(): + current_futures_for_cluster[arn].set_result(detail) + + # For ARNs in batch but not in response (e.g. task terminated quickly) + for arn_in_batch in batch_arns: + if arn_in_batch not in fetched_task_details and arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): + err = RuntimeError(f"Task {arn_in_batch} not found in describe_tasks response for cluster {cluster_arn}. It might have terminated.") + logger.warning(str(err)) + current_futures_for_cluster[arn_in_batch].set_exception(err) + success = True + break # Break from retry loop on success + except ClientError as e: + if e.response["Error"].get("Code") == "ThrottlingException": + logger.warning( + f"describe_tasks throttled for cluster {cluster_arn}, batch starting with {batch_arns[0]} (attempt {attempt + 1}/{MAX_RETRY_ATTEMPTS}). Retrying in {wait_duration}s." + ) + if attempt == MAX_RETRY_ATTEMPTS - 1: + logger.error(f"Max retries exceeded for describe_tasks on cluster {cluster_arn} (batch starting {batch_arns[0]}) due to throttling.") + async with self._lock: + for arn_in_batch in batch_arns: + if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): + current_futures_for_cluster[arn_in_batch].set_exception(e) + # Do not raise, let other batches/clusters proceed + else: + await asyncio.sleep(wait_duration) + wait_duration = min(wait_duration * 2, MAX_BACKOFF_S) + else: + logger.error(f"ClientError describing tasks for cluster {cluster_arn} (batch starting {batch_arns[0]}): {e}", exc_info=True) + async with self._lock: + for arn_in_batch in batch_arns: + if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): + current_futures_for_cluster[arn_in_batch].set_exception(e) + success = True # Break retry loop for non-throttling ClientErrors + break + except Exception as e: + logger.error(f"Unexpected error describing tasks for cluster {cluster_arn} (batch starting {batch_arns[0]}): {e}", exc_info=True) + async with self._lock: + for arn_in_batch in batch_arns: + if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): + current_futures_for_cluster[arn_in_batch].set_exception(e) + success = True # Break retry loop for other unexpected errors + break + + if not success: # Should only happen if all retries for throttling failed + logger.error(f"Failed to describe tasks for batch starting with {batch_arns[0]} in cluster {cluster_arn} after all retries.") + # Futures for this failed batch were already handled with exceptions in the retry loop. + + # Clean up futures from self._tasks_to_poll that were part of this processing cycle + async with self._lock: + if cluster_arn in self._tasks_to_poll: + cluster_futures = self._tasks_to_poll[cluster_arn]["futures"] + for arn in list(cluster_futures.keys()): # Iterate over copy of keys for safe deletion + if arn in current_futures_for_cluster and cluster_futures[arn].done(): + del cluster_futures[arn] + if not cluster_futures and not self._tasks_to_poll[cluster_arn]["arns"]: # if no pending arns and no pending futures + try: + del self._tasks_to_poll[cluster_arn] + except KeyError: + pass # already deleted by another part of the code + class Task: """A superclass for managing ECS Tasks @@ -120,6 +308,8 @@ def __init__( fargate_use_private_ip=False, fargate_capacity_provider=None, task_kwargs=None, + is_task_long_arn_format_enabled=True, + task_poller=None, **kwargs, ): self.lock = asyncio.Lock() @@ -129,7 +319,7 @@ def __init__( self.task_definition_arn = task_definition_arn self.task = None self.task_arn = None - self.task_type = None + self.task_type = None # Should be set by Scheduler/Worker subclasses self.public_ip = None self.private_ip = None self.connection = None @@ -144,7 +334,9 @@ def __init__( self._fargate_capacity_provider = fargate_capacity_provider self.kwargs = kwargs self.task_kwargs = task_kwargs + self._is_task_long_arn_format_enabled = is_task_long_arn_format_enabled self.status = Status.created + self._task_poller = task_poller def __await__(self): async def _(): @@ -160,36 +352,23 @@ async def _(): def _use_public_ip(self): return self.fargate and not self._fargate_use_private_ip - async def _is_long_arn_format_enabled(self): - async with self._client("ecs") as ecs: - [response] = ( - await ecs.list_account_settings( - name="taskLongArnFormat", effectiveSettings=True - ) - )["settings"] - return response["value"] == "enabled" - async def _update_task(self): - async with self._client("ecs") as ecs: - wait_duration = 1 - while True: - try: - [self.task] = ( - await ecs.describe_tasks( - cluster=self.cluster_arn, tasks=[self.task_arn] - ) - )["tasks"] - except ClientError as e: - if e.response["Error"]["Code"] == "ThrottlingException": - wait_duration = min(wait_duration * 2, 20) - else: - raise - else: - break - await asyncio.sleep(wait_duration) - async def _task_is_running(self): - await self._update_task() + if not self._task_poller: + raise RuntimeError(f"TaskPoller not available to Task {self.task_arn}") + + if not self.task_arn: + raise RuntimeError(f"Task {self.name} (type: {self.task_type}) has no task_arn") + + try: + # logger.debug(f"Task {self.task_arn} requesting details from poller for cluster {self.cluster_arn}.") + self.task = await self._task_poller.get_task_details(self.cluster_arn, self.task_arn) + # logger.debug(f"Task {self.name} updated via poller. Status: {self.task.get('lastStatus')}") + except Exception as e: + logger.error(f"Failed to get task details for {self.task_arn} via poller: {e}", exc_info=True) + raise + + def _task_is_running(self): return self.task["lastStatus"] == "RUNNING" async def start(self): @@ -199,7 +378,7 @@ async def start(self): kwargs = self.task_kwargs.copy() if self.task_kwargs is not None else {} # Tags are only supported if you opt into long arn format so we need to check for that - if await self._is_long_arn_format_enabled(): + if self._is_task_long_arn_format_enabled: kwargs["tags"] = dict_to_aws(self.tags) if self.platform_version and self.fargate: kwargs["platformVersion"] = self.platform_version @@ -254,12 +433,16 @@ async def start(self): break except Exception as e: timeout.set_exception(e) + logger.warning(f"Failed to start {self.task_type} task: {e}") + logger.info("Retrying in 1 second...") await asyncio.sleep(1) self.task_arn = self.task["taskArn"] + await self._update_task() while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]: + await asyncio.sleep(0.5) await self._update_task() - if not await self._task_is_running(): + if not self._task_is_running(): raise RuntimeError("%s failed to start" % type(self).__name__) [eni] = [ attachment @@ -287,7 +470,8 @@ async def close(self, **kwargs): await ecs.stop_task(cluster=self.cluster_arn, task=self.task_arn) await self._update_task() while self.task["lastStatus"] in ["RUNNING"]: - await asyncio.sleep(1) + # logger.debug(f"Waiting for {self.task_type} task {self.task_arn} to close...") + await asyncio.sleep(0.5) await self._update_task() self.status = Status.closed @@ -813,6 +997,8 @@ def __init__( self._platform_version = platform_version self._lock = asyncio.Lock() self.session = get_session() + self._is_task_long_arn_format_enabled = None + self._task_poller = TaskPoller(client_factory=self._client) super().__init__(**kwargs) def _client(self, name: str): @@ -823,6 +1009,34 @@ def _client(self, name: str): region_name=self._region_name, ) + async def __aenter__(self): + await super().__aenter__() + await self._task_poller.ensure_running() + return self + + # async def __aexit__(self, exc_type, exc, tb): + # try: + # await self._task_poller.stop() + # except Exception as e: + # logger.error(f"Failed to stop TaskPoller cleanly: {e}", exc_info=True) + # await super().__aexit__(exc_type, exc, tb) + + # Override new_spec_object to inject the task_poller + # def new_spec_object(self, spec): + # cls = spec["cls"] + # options = spec.get("options", {}) # Ensure options exist + + # # Inject the task_poller + # # Ensure self._task_poller is initialized before this can be called. + # # It is initialized in ECSCluster.__init__ + # options_with_poller = {**options, "task_poller": self._task_poller} + + # # Original SpecCluster.new_spec_object just does: return cls(**options) + # # We need to ensure that the cls (Scheduler or Worker) correctly handles task_poller. + # # Their __init__ methods now accept **kwargs and pass to super().__init__(..., **kwargs) + # # and Task.__init__ explicitly takes task_poller. + # return cls(**options_with_poller) + async def _start( self, ): @@ -950,6 +1164,10 @@ async def _start( self.worker_task_definition_arn = ( await self._create_worker_task_definition_arn() ) + if self._is_task_long_arn_format_enabled is None: + self._is_task_long_arn_format_enabled = ( + await self._get_is_task_long_arn_format_enabled() + ) options = { "client": self._client, @@ -962,6 +1180,8 @@ async def _start( "tags": self.tags, "platform_version": self._platform_version, "fargate_use_private_ip": self._fargate_use_private_ip, + "task_poller": self._task_poller, + "is_task_long_arn_format_enabled": self._is_task_long_arn_format_enabled, } scheduler_options = { "task_definition_arn": self.scheduler_task_definition_arn, @@ -1175,7 +1395,9 @@ async def _delete_security_groups(self): await ec2.delete_security_group( GroupName=self.cluster_name, DryRun=False ) - except Exception: + except Exception as e: + logging.warning(f"Failed to delete security group {self.cluster_name}: {e}") + logger.info("Retrying to delete security in 2 seconds...") await asyncio.sleep(2) break @@ -1319,6 +1541,16 @@ async def _delete_worker_task_definition_arn(self): taskDefinition=self.worker_task_definition_arn ) + async def _get_is_task_long_arn_format_enabled(self): + # TODO: Throttling backoff if this fails + async with self._client("ecs") as ecs: + [response] = ( + await ecs.list_account_settings( + name="taskLongArnFormat", effectiveSettings=True + ) + )["settings"] + return response["value"] == "enabled" + def logs(self): async def get_logs(task): log = "" From 40303452826c61e0098dc20490dc91e4f8ee5070 Mon Sep 17 00:00:00 2001 From: Taylor Braun-Jones Date: Tue, 10 Jun 2025 22:42:13 +0000 Subject: [PATCH 2/5] Use the better throttling backoff built into the Boto3 client --- dask_cloudprovider/aws/ecs.py | 150 +++++++++++++--------------------- 1 file changed, 55 insertions(+), 95 deletions(-) diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index b74d318d..61efa398 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -27,6 +27,7 @@ try: from botocore.exceptions import ClientError + from aiobotocore.config import AioConfig from aiobotocore.session import get_session except ImportError as e: msg = ( @@ -47,9 +48,6 @@ MAX_TASKS_PER_DESCRIBE_CALL = 100 DEFAULT_POLL_INTERVAL_S = 1 # How often to check if there are tasks to describe -MAX_RETRY_ATTEMPTS = 5 -INITIAL_BACKOFF_S = 1 -MAX_BACKOFF_S = 20 class TaskPoller: def __init__(self, client_factory, poll_interval_s=DEFAULT_POLL_INTERVAL_S): @@ -160,62 +158,28 @@ async def _process_poll_queue(self): if not batch_arns: continue - wait_duration = INITIAL_BACKOFF_S - success = False - for attempt in range(MAX_RETRY_ATTEMPTS): - try: - response = await ecs.describe_tasks(cluster=cluster_arn, tasks=batch_arns) - fetched_task_details = {task["taskArn"]: task for task in response.get("tasks", [])} - - async with self._lock: - for arn, detail in fetched_task_details.items(): - self._polled_task_details_cache[arn] = detail - if arn in current_futures_for_cluster and not current_futures_for_cluster[arn].done(): - current_futures_for_cluster[arn].set_result(detail) - - # For ARNs in batch but not in response (e.g. task terminated quickly) - for arn_in_batch in batch_arns: - if arn_in_batch not in fetched_task_details and arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): - err = RuntimeError(f"Task {arn_in_batch} not found in describe_tasks response for cluster {cluster_arn}. It might have terminated.") - logger.warning(str(err)) - current_futures_for_cluster[arn_in_batch].set_exception(err) - success = True - break # Break from retry loop on success - except ClientError as e: - if e.response["Error"].get("Code") == "ThrottlingException": - logger.warning( - f"describe_tasks throttled for cluster {cluster_arn}, batch starting with {batch_arns[0]} (attempt {attempt + 1}/{MAX_RETRY_ATTEMPTS}). Retrying in {wait_duration}s." - ) - if attempt == MAX_RETRY_ATTEMPTS - 1: - logger.error(f"Max retries exceeded for describe_tasks on cluster {cluster_arn} (batch starting {batch_arns[0]}) due to throttling.") - async with self._lock: - for arn_in_batch in batch_arns: - if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): - current_futures_for_cluster[arn_in_batch].set_exception(e) - # Do not raise, let other batches/clusters proceed - else: - await asyncio.sleep(wait_duration) - wait_duration = min(wait_duration * 2, MAX_BACKOFF_S) - else: - logger.error(f"ClientError describing tasks for cluster {cluster_arn} (batch starting {batch_arns[0]}): {e}", exc_info=True) - async with self._lock: - for arn_in_batch in batch_arns: - if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): - current_futures_for_cluster[arn_in_batch].set_exception(e) - success = True # Break retry loop for non-throttling ClientErrors - break - except Exception as e: - logger.error(f"Unexpected error describing tasks for cluster {cluster_arn} (batch starting {batch_arns[0]}): {e}", exc_info=True) - async with self._lock: - for arn_in_batch in batch_arns: - if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): - current_futures_for_cluster[arn_in_batch].set_exception(e) - success = True # Break retry loop for other unexpected errors - break - - if not success: # Should only happen if all retries for throttling failed - logger.error(f"Failed to describe tasks for batch starting with {batch_arns[0]} in cluster {cluster_arn} after all retries.") - # Futures for this failed batch were already handled with exceptions in the retry loop. + try: + response = await ecs.describe_tasks(cluster=cluster_arn, tasks=batch_arns) + fetched_task_details = {task["taskArn"]: task for task in response.get("tasks", [])} + + async with self._lock: + for arn, detail in fetched_task_details.items(): + self._polled_task_details_cache[arn] = detail + if arn in current_futures_for_cluster and not current_futures_for_cluster[arn].done(): + current_futures_for_cluster[arn].set_result(detail) + + # For ARNs in batch but not in response (e.g. task terminated quickly) + for arn_in_batch in batch_arns: + if arn_in_batch not in fetched_task_details and arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): + err = RuntimeError(f"Task {arn_in_batch} not found in describe_tasks response for cluster {cluster_arn}. It might have terminated.") + logger.warning(str(err)) + current_futures_for_cluster[arn_in_batch].set_exception(err) + except Exception as e: + logger.error(f"Unexpected error describing tasks for cluster {cluster_arn} (batch starting {batch_arns[0]}): {e}", exc_info=True) + async with self._lock: + for arn_in_batch in batch_arns: + if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): + current_futures_for_cluster[arn_in_batch].set_exception(e) # Clean up futures from self._tasks_to_poll that were part of this processing cycle async with self._lock: @@ -488,48 +452,35 @@ def _log_stream_name(self): ) async def logs(self, follow=False): - current_try = 0 next_token = None read_from = 0 while True: - try: - async with self._client("logs") as logs: - if next_token: - l = await logs.get_log_events( - logGroupName=self.log_group, - logStreamName=self._log_stream_name, - nextToken=next_token, - ) - else: - l = await logs.get_log_events( - logGroupName=self.log_group, - logStreamName=self._log_stream_name, - startTime=read_from, - ) - if next_token != l["nextForwardToken"]: - next_token = l["nextForwardToken"] + async with self._client("logs") as logs: + if next_token: + l = await logs.get_log_events( + logGroupName=self.log_group, + logStreamName=self._log_stream_name, + nextToken=next_token, + ) else: - next_token = None - if not l["events"]: - if follow: - await asyncio.sleep(1) - else: - break - for event in l["events"]: - read_from = event["timestamp"] - yield event["message"] - except ClientError as e: - if e.response["Error"]["Code"] == "ThrottlingException": - warnings.warn( - "get_log_events rate limit exceeded, retrying after delay.", - RuntimeWarning, + l = await logs.get_log_events( + logGroupName=self.log_group, + logStreamName=self._log_stream_name, + startTime=read_from, ) - backoff_duration = get_sleep_duration(current_try) - await asyncio.sleep(backoff_duration) - current_try += 1 + if next_token != l["nextForwardToken"]: + next_token = l["nextForwardToken"] + else: + next_token = None + if not l["events"]: + if follow: + await asyncio.sleep(1) else: - raise + break + for event in l["events"]: + read_from = event["timestamp"] + yield event["message"] def __repr__(self): return "" % (type(self).__name__, self.status) @@ -1007,6 +958,16 @@ def _client(self, name: str): aws_access_key_id=self._aws_access_key_id, aws_secret_access_key=self._aws_secret_access_key, region_name=self._region_name, + config=AioConfig( + retries={ + # Use Standard retry mode which provides: + # - Jittered exponential backoff in the event of failures + # - Never delays the first request attempt, only the retries + # - Supports circuit-breaking to prevent the SDK from retrying during outages + "mode": "standard", + "total_max_attempts": 5, # Includes the initial request + } + ), ) async def __aenter__(self): @@ -1542,7 +1503,6 @@ async def _delete_worker_task_definition_arn(self): ) async def _get_is_task_long_arn_format_enabled(self): - # TODO: Throttling backoff if this fails async with self._client("ecs") as ecs: [response] = ( await ecs.list_account_settings( From 511e7a45429c542babc07dceefb965a550ce67d2 Mon Sep 17 00:00:00 2001 From: Taylor Braun-Jones Date: Tue, 10 Jun 2025 18:57:38 -0400 Subject: [PATCH 3/5] Revert TaskPoller --- dask_cloudprovider/aws/ecs.py | 212 ++-------------------------------- 1 file changed, 11 insertions(+), 201 deletions(-) diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index 61efa398..3b09e056 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -3,9 +3,7 @@ import uuid import warnings import weakref -from collections import defaultdict from typing import List, Optional -from cachetools import TTLCache import dask @@ -46,154 +44,6 @@ "createdBy": "dask-cloudprovider" } # Package tags to apply to all resources -MAX_TASKS_PER_DESCRIBE_CALL = 100 -DEFAULT_POLL_INTERVAL_S = 1 # How often to check if there are tasks to describe - -class TaskPoller: - def __init__(self, client_factory, poll_interval_s=DEFAULT_POLL_INTERVAL_S): - self._client_factory = client_factory - self._poll_interval_s = poll_interval_s - self._tasks_to_poll = defaultdict(lambda: {"arns": set(), "futures": {}}) - self._polled_task_details_cache: TTLCache[str, dict] = TTLCache(maxsize=1000, ttl=DEFAULT_POLL_INTERVAL_S) # Cache for task_arn -> task_detail - self._lock = asyncio.Lock() - self._poll_loop_task = None - - async def ensure_running(self): - async with self._lock: - if self._poll_loop_task is None or self._poll_loop_task.done(): - self._poll_loop_task = asyncio.create_task(self._poll_loop()) - logger.info("TaskPoller started.") - - async def stop(self): - async with self._lock: - if self._poll_loop_task: - self._poll_loop_task.cancel() - try: - await self._poll_loop_task - except asyncio.CancelledError: - logger.info("TaskPoller poll loop cancelled.") - except Exception as e: - logger.error(f"Error during TaskPoller stop: {e}", exc_info=True) - self._poll_loop_task = None - logger.info("TaskPoller stopped.") - # Clear futures to prevent tasks from hanging if poller is stopped then restarted - for cluster_data in self._tasks_to_poll.values(): - for future in cluster_data["futures"].values(): - if not future.done(): - future.set_exception(RuntimeError("TaskPoller stopped before task details could be fetched.")) - self._tasks_to_poll.clear() - self._polled_task_details_cache.clear() - - - async def get_task_details(self, cluster_arn, task_arn): - await self.ensure_running() - async with self._lock: - # Check cache first - if (task_details := self._polled_task_details_cache.get(task_arn)) is not None: - # logger.debug(f"Task {task_arn} found in cache for cluster {cluster_arn}.") - return task_details - - # Check if a future already exists for this task_arn - cluster_data = self._tasks_to_poll[cluster_arn] - if task_arn in cluster_data["futures"]: - future = cluster_data["futures"][task_arn] - else: - future = asyncio.Future() - cluster_data["arns"].add(task_arn) - cluster_data["futures"][task_arn] = future - return await future - - async def _poll_loop(self): - while self._poll_loop_task is not None and not self._poll_loop_task.cancelled(): - try: - await asyncio.sleep(self._poll_interval_s) - await self._process_poll_queue() - except asyncio.CancelledError: - logger.info("TaskPoller poll loop gracefully exiting due to cancellation.") - break - except Exception as e: - logger.error(f"Error in TaskPoller poll loop: {e}", exc_info=True) - # Avoid tight loop on persistent errors, wait longer - await asyncio.sleep(self._poll_interval_s * 5) - logger.info("TaskPoller poll loop finished.") - - - async def _process_poll_queue(self): - # logger.debug("Processing task poll queue...") - clusters_to_process = [] - async with self._lock: - if not self._tasks_to_poll: - return - for cluster_arn, data in self._tasks_to_poll.items(): - if data["arns"]: - clusters_to_process.append(cluster_arn) - - # logger.debug(f"{len(clusters_to_process)} clusters to process") - if not clusters_to_process: - return - - async with self._client_factory("ecs") as ecs: - for cluster_arn in clusters_to_process: - task_arns_to_fetch_for_cluster = [] - current_futures_for_cluster = {} - - async with self._lock: - if cluster_arn in self._tasks_to_poll and self._tasks_to_poll[cluster_arn]["arns"]: - task_arns_to_fetch_for_cluster = list(self._tasks_to_poll[cluster_arn]["arns"]) - # Keep track of futures associated with this specific fetch attempt - current_futures_for_cluster = { - arn: self._tasks_to_poll[cluster_arn]["futures"][arn] - for arn in task_arns_to_fetch_for_cluster - if arn in self._tasks_to_poll[cluster_arn]["futures"] and not self._tasks_to_poll[cluster_arn]["futures"][arn].done() - } - self._tasks_to_poll[cluster_arn]["arns"].clear() # Clear ARNs for this poll cycle - else: - continue - - if not task_arns_to_fetch_for_cluster: - continue - - for i in range(0, len(task_arns_to_fetch_for_cluster), MAX_TASKS_PER_DESCRIBE_CALL): - batch_arns = task_arns_to_fetch_for_cluster[i:i + MAX_TASKS_PER_DESCRIBE_CALL] - if not batch_arns: - continue - - try: - response = await ecs.describe_tasks(cluster=cluster_arn, tasks=batch_arns) - fetched_task_details = {task["taskArn"]: task for task in response.get("tasks", [])} - - async with self._lock: - for arn, detail in fetched_task_details.items(): - self._polled_task_details_cache[arn] = detail - if arn in current_futures_for_cluster and not current_futures_for_cluster[arn].done(): - current_futures_for_cluster[arn].set_result(detail) - - # For ARNs in batch but not in response (e.g. task terminated quickly) - for arn_in_batch in batch_arns: - if arn_in_batch not in fetched_task_details and arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): - err = RuntimeError(f"Task {arn_in_batch} not found in describe_tasks response for cluster {cluster_arn}. It might have terminated.") - logger.warning(str(err)) - current_futures_for_cluster[arn_in_batch].set_exception(err) - except Exception as e: - logger.error(f"Unexpected error describing tasks for cluster {cluster_arn} (batch starting {batch_arns[0]}): {e}", exc_info=True) - async with self._lock: - for arn_in_batch in batch_arns: - if arn_in_batch in current_futures_for_cluster and not current_futures_for_cluster[arn_in_batch].done(): - current_futures_for_cluster[arn_in_batch].set_exception(e) - - # Clean up futures from self._tasks_to_poll that were part of this processing cycle - async with self._lock: - if cluster_arn in self._tasks_to_poll: - cluster_futures = self._tasks_to_poll[cluster_arn]["futures"] - for arn in list(cluster_futures.keys()): # Iterate over copy of keys for safe deletion - if arn in current_futures_for_cluster and cluster_futures[arn].done(): - del cluster_futures[arn] - if not cluster_futures and not self._tasks_to_poll[cluster_arn]["arns"]: # if no pending arns and no pending futures - try: - del self._tasks_to_poll[cluster_arn] - except KeyError: - pass # already deleted by another part of the code - class Task: """A superclass for managing ECS Tasks @@ -273,7 +123,6 @@ def __init__( fargate_capacity_provider=None, task_kwargs=None, is_task_long_arn_format_enabled=True, - task_poller=None, **kwargs, ): self.lock = asyncio.Lock() @@ -300,7 +149,6 @@ def __init__( self.task_kwargs = task_kwargs self._is_task_long_arn_format_enabled = is_task_long_arn_format_enabled self.status = Status.created - self._task_poller = task_poller def __await__(self): async def _(): @@ -317,22 +165,15 @@ def _use_public_ip(self): return self.fargate and not self._fargate_use_private_ip async def _update_task(self): + async with self._client("ecs") as ecs: + [self.task] = ( + await ecs.describe_tasks( + cluster=self.cluster_arn, tasks=[self.task_arn] + ) + )["tasks"] - if not self._task_poller: - raise RuntimeError(f"TaskPoller not available to Task {self.task_arn}") - - if not self.task_arn: - raise RuntimeError(f"Task {self.name} (type: {self.task_type}) has no task_arn") - - try: - # logger.debug(f"Task {self.task_arn} requesting details from poller for cluster {self.cluster_arn}.") - self.task = await self._task_poller.get_task_details(self.cluster_arn, self.task_arn) - # logger.debug(f"Task {self.name} updated via poller. Status: {self.task.get('lastStatus')}") - except Exception as e: - logger.error(f"Failed to get task details for {self.task_arn} via poller: {e}", exc_info=True) - raise - - def _task_is_running(self): + async def _task_is_running(self): + await self._update_task() return self.task["lastStatus"] == "RUNNING" async def start(self): @@ -402,11 +243,10 @@ async def start(self): await asyncio.sleep(1) self.task_arn = self.task["taskArn"] - await self._update_task() while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]: - await asyncio.sleep(0.5) + await asyncio.sleep(1) await self._update_task() - if not self._task_is_running(): + if not await self._task_is_running(): raise RuntimeError("%s failed to start" % type(self).__name__) [eni] = [ attachment @@ -435,7 +275,7 @@ async def close(self, **kwargs): await self._update_task() while self.task["lastStatus"] in ["RUNNING"]: # logger.debug(f"Waiting for {self.task_type} task {self.task_arn} to close...") - await asyncio.sleep(0.5) + await asyncio.sleep(1) await self._update_task() self.status = Status.closed @@ -949,7 +789,6 @@ def __init__( self._lock = asyncio.Lock() self.session = get_session() self._is_task_long_arn_format_enabled = None - self._task_poller = TaskPoller(client_factory=self._client) super().__init__(**kwargs) def _client(self, name: str): @@ -970,34 +809,6 @@ def _client(self, name: str): ), ) - async def __aenter__(self): - await super().__aenter__() - await self._task_poller.ensure_running() - return self - - # async def __aexit__(self, exc_type, exc, tb): - # try: - # await self._task_poller.stop() - # except Exception as e: - # logger.error(f"Failed to stop TaskPoller cleanly: {e}", exc_info=True) - # await super().__aexit__(exc_type, exc, tb) - - # Override new_spec_object to inject the task_poller - # def new_spec_object(self, spec): - # cls = spec["cls"] - # options = spec.get("options", {}) # Ensure options exist - - # # Inject the task_poller - # # Ensure self._task_poller is initialized before this can be called. - # # It is initialized in ECSCluster.__init__ - # options_with_poller = {**options, "task_poller": self._task_poller} - - # # Original SpecCluster.new_spec_object just does: return cls(**options) - # # We need to ensure that the cls (Scheduler or Worker) correctly handles task_poller. - # # Their __init__ methods now accept **kwargs and pass to super().__init__(..., **kwargs) - # # and Task.__init__ explicitly takes task_poller. - # return cls(**options_with_poller) - async def _start( self, ): @@ -1141,7 +952,6 @@ async def _start( "tags": self.tags, "platform_version": self._platform_version, "fargate_use_private_ip": self._fargate_use_private_ip, - "task_poller": self._task_poller, "is_task_long_arn_format_enabled": self._is_task_long_arn_format_enabled, } scheduler_options = { From 01346ed504012e4e9fc5e0d5e7e55b6f6947c019 Mon Sep 17 00:00:00 2001 From: Taylor Braun-Jones Date: Wed, 11 Jun 2025 14:12:37 +0000 Subject: [PATCH 4/5] (Re)try harder --- dask_cloudprovider/aws/ecs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index 3b09e056..909406f3 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -800,11 +800,11 @@ def _client(self, name: str): config=AioConfig( retries={ # Use Standard retry mode which provides: - # - Jittered exponential backoff in the event of failures + # - Jittered exponential backoff with max of 20s in the event of failures # - Never delays the first request attempt, only the retries # - Supports circuit-breaking to prevent the SDK from retrying during outages "mode": "standard", - "total_max_attempts": 5, # Includes the initial request + "max_attempts": 10, # Not including the initial request } ), ) From da5e36b39b2a7dd9d60a02f3db4716c7b352a44f Mon Sep 17 00:00:00 2001 From: Taylor Braun-Jones Date: Wed, 11 Jun 2025 16:47:56 +0000 Subject: [PATCH 5/5] cleanup --- dask_cloudprovider/aws/ecs.py | 25 +++++++++++-------------- dask_cloudprovider/utils/timeout.py | 9 ++++++++- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index 909406f3..a13f4fcb 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -12,7 +12,6 @@ from dask_cloudprovider.aws.helper import ( dict_to_aws, aws_to_dict, - get_sleep_duration, get_default_vpc, get_vpc_subnets, create_default_security_group, @@ -37,7 +36,6 @@ raise ImportError(msg) from e logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) DEFAULT_TAGS = { @@ -132,7 +130,7 @@ def __init__( self.task_definition_arn = task_definition_arn self.task = None self.task_arn = None - self.task_type = None # Should be set by Scheduler/Worker subclasses + self.task_type = None self.public_ip = None self.private_ip = None self.connection = None @@ -172,8 +170,7 @@ async def _update_task(self): ) )["tasks"] - async def _task_is_running(self): - await self._update_task() + def _task_is_running(self): return self.task["lastStatus"] == "RUNNING" async def start(self): @@ -237,16 +234,19 @@ async def start(self): [self.task] = response["tasks"] break except Exception as e: + # Retries due to throttle errors are handled by the aiobotocore client so this should be an uncommon case timeout.set_exception(e) - logger.warning(f"Failed to start {self.task_type} task: {e}") - logger.info("Retrying in 1 second...") - await asyncio.sleep(1) + logger.debug(f"Failed to start {self.task_type} task after {timeout.elapsed_time:.1f}s, retrying in 1s: {e}") + await asyncio.sleep(2) self.task_arn = self.task["taskArn"] + + # Wait for the task to come up while self.task["lastStatus"] in ["PENDING", "PROVISIONING"]: + # Try to avoid hitting throttling rate limits when bring up a large cluster await asyncio.sleep(1) await self._update_task() - if not await self._task_is_running(): + if not self._task_is_running(): raise RuntimeError("%s failed to start" % type(self).__name__) [eni] = [ attachment @@ -273,8 +273,7 @@ async def close(self, **kwargs): async with self._client("ecs") as ecs: await ecs.stop_task(cluster=self.cluster_arn, task=self.task_arn) await self._update_task() - while self.task["lastStatus"] in ["RUNNING"]: - # logger.debug(f"Waiting for {self.task_type} task {self.task_arn} to close...") + while self._task_is_running(): await asyncio.sleep(1) await self._update_task() self.status = Status.closed @@ -1166,9 +1165,7 @@ async def _delete_security_groups(self): await ec2.delete_security_group( GroupName=self.cluster_name, DryRun=False ) - except Exception as e: - logging.warning(f"Failed to delete security group {self.cluster_name}: {e}") - logger.info("Retrying to delete security in 2 seconds...") + except Exception: await asyncio.sleep(2) break diff --git a/dask_cloudprovider/utils/timeout.py b/dask_cloudprovider/utils/timeout.py index 07c8ffc5..197a8077 100644 --- a/dask_cloudprovider/utils/timeout.py +++ b/dask_cloudprovider/utils/timeout.py @@ -66,7 +66,7 @@ def run(self): self.start = datetime.now() self.running = True - if self.start + timedelta(seconds=self.timeout) < datetime.now(): + if self.elapsed_time >= self.timeout: if self.warn: warnings.warn(self.error_message) return False @@ -82,3 +82,10 @@ def set_exception(self, e): the thing you are trying rather than a TimeoutException. """ self.exception = e + + @property + def elapsed_time(self): + """Return the elapsed time since the timeout started.""" + if self.start is None: + return 0 + return (datetime.now() - self.start).total_seconds()