diff --git a/README.md b/README.md index 2adcbba1..d46dd70d 100644 --- a/README.md +++ b/README.md @@ -716,14 +716,34 @@ cp .env.oauth21 .env | `get_gmail_messages_content_batch` | **Core** | Batch retrieve message content | | `send_gmail_message` | **Core** | Send emails | | `get_gmail_thread_content` | Extended | Get full thread content | +| `get_gmail_thread_metadata` | Extended | Get thread metadata (headers, labels, snippet) | | `modify_gmail_message_labels` | Extended | Modify message labels | +| `modify_gmail_thread_labels` | Extended | Modify labels for all messages in a thread | | `list_gmail_labels` | Extended | List available labels | | `manage_gmail_label` | Extended | Create/update/delete labels | | `draft_gmail_message` | Extended | Create drafts | | `get_gmail_threads_content_batch` | Complete | Batch retrieve thread content | +| `get_gmail_threads_metadata_batch` | Complete | Batch retrieve thread metadata | | `batch_modify_gmail_message_labels` | Complete | Batch modify labels | +| `batch_modify_gmail_thread_labels` | Complete | Batch modify labels for all messages in threads | | `start_google_auth` | Complete | Initialize authentication | +**Gmail Message Format Options:** + +Both `get_gmail_thread_content` and `get_gmail_threads_content_batch` support an optional `format` parameter (defaults to `"full"`): + +| Format | Description | +|--------|-------------| +| `"minimal"` | Returns only the `id` and `threadId` of each message. | +| `"metadata"` | Returns message metadata such as headers, labels, and snippet, but *not* the full body. | +| `"full"` | Returns the full email message data, including headers and body (Base64 encoded). This is the default and most commonly used for reading messages. | + +**Convenience Metadata Functions:** + +For quick access to metadata only, use the dedicated metadata functions: +- `get_gmail_thread_metadata` - Equivalent to `get_gmail_thread_content` with `format="metadata"` +- `get_gmail_threads_metadata_batch` - Equivalent to `get_gmail_threads_content_batch` with `format="metadata"` + diff --git a/auth/credential_store.py b/auth/credential_store.py index 24dda70b..385c77ba 100644 --- a/auth/credential_store.py +++ b/auth/credential_store.py @@ -151,9 +151,21 @@ def store_credential(self, user_email: str, credentials: Credentials) -> bool: """Store credentials to local JSON file.""" creds_path = self._get_credential_path(user_email) + # Preserve existing refresh token if new credentials don't have one + # This prevents losing refresh tokens during re-authorization flows + refresh_token_to_store = credentials.refresh_token + if not refresh_token_to_store: + try: + existing_creds = self.get_credential(user_email) + if existing_creds and existing_creds.refresh_token: + refresh_token_to_store = existing_creds.refresh_token + logger.info(f"Preserved existing refresh token for {user_email} in credential store") + except Exception as e: + logger.debug(f"Could not check existing credentials to preserve refresh token: {e}") + creds_data = { "token": credentials.token, - "refresh_token": credentials.refresh_token, + "refresh_token": refresh_token_to_store, "token_uri": credentials.token_uri, "client_id": credentials.client_id, "client_secret": credentials.client_secret, diff --git a/auth/google_auth.py b/auth/google_auth.py index aef73a4d..22549170 100644 --- a/auth/google_auth.py +++ b/auth/google_auth.py @@ -353,7 +353,7 @@ async def start_auth_flow( state=oauth_state, ) - auth_url, _ = flow.authorization_url(access_type="offline", prompt="consent") + auth_url, _ = flow.authorization_url(access_type="offline") session_id = None try: @@ -486,16 +486,41 @@ def handle_auth_callback( user_google_email = user_info["email"] logger.info(f"Identified user_google_email: {user_google_email}") - # Save the credentials + # Preserve existing refresh token if new credentials don't have one + # Google often doesn't return refresh_token on re-authorization if one already exists + existing_credentials = None credential_store = get_credential_store() + try: + existing_credentials = credential_store.get_credential(user_google_email) + except Exception as e: + logger.debug(f"Could not load existing credentials to preserve refresh token: {e}") + + # Also check OAuth21SessionStore for existing refresh token + store = get_oauth21_session_store() + existing_session_creds = store.get_credentials(user_google_email) + + # Preserve refresh token from existing credentials if new one is missing + preserved_refresh_token = credentials.refresh_token + if not preserved_refresh_token: + if existing_credentials and existing_credentials.refresh_token: + preserved_refresh_token = existing_credentials.refresh_token + logger.info(f"Preserved existing refresh token from credential store for {user_google_email}") + elif existing_session_creds and existing_session_creds.refresh_token: + preserved_refresh_token = existing_session_creds.refresh_token + logger.info(f"Preserved existing refresh token from session store for {user_google_email}") + + # Update credentials object with preserved refresh token if we found one + if preserved_refresh_token and not credentials.refresh_token: + credentials.refresh_token = preserved_refresh_token + + # Save the credentials credential_store.store_credential(user_google_email, credentials) # Always save to OAuth21SessionStore for centralized management - store = get_oauth21_session_store() store.store_session( user_email=user_google_email, access_token=credentials.token, - refresh_token=credentials.refresh_token, + refresh_token=preserved_refresh_token or credentials.refresh_token, token_uri=credentials.token_uri, client_id=credentials.client_id, client_secret=credentials.client_secret, @@ -505,9 +530,8 @@ def handle_auth_callback( issuer="https://accounts.google.com" # Add issuer for Google tokens ) - # If session_id is provided, also save to session cache for compatibility - if session_id: - save_credentials_to_session(session_id, credentials) + # Note: No need to call save_credentials_to_session() here as we've already + # saved to the OAuth21SessionStore above with the correct issuer and mcp_session_id return user_google_email, credentials @@ -570,6 +594,9 @@ def get_credentials( user_email=user_email, access_token=credentials.token, refresh_token=credentials.refresh_token, + token_uri=credentials.token_uri, + client_id=credentials.client_id, + client_secret=credentials.client_secret, scopes=credentials.scopes, expiry=credentials.expiry, mcp_session_id=session_id diff --git a/auth/oauth21_session_store.py b/auth/oauth21_session_store.py index 15f7d8b3..695cd7d3 100644 --- a/auth/oauth21_session_store.py +++ b/auth/oauth21_session_store.py @@ -311,10 +311,19 @@ def store_session( issuer: Token issuer (e.g., "https://accounts.google.com") """ with self._lock: + # Preserve existing refresh token if new one is not provided + # This prevents losing refresh tokens during re-authorization flows + preserved_refresh_token = refresh_token + if not preserved_refresh_token: + existing_session = self._sessions.get(user_email) + if existing_session and existing_session.get("refresh_token"): + preserved_refresh_token = existing_session["refresh_token"] + logger.info(f"Preserved existing refresh token for {user_email} in session store") + normalized_expiry = _normalize_expiry_to_naive_utc(expiry) session_info = { "access_token": access_token, - "refresh_token": refresh_token, + "refresh_token": preserved_refresh_token, "token_uri": token_uri, "client_id": client_id, "client_secret": client_secret, diff --git a/core/tool_tiers.yaml b/core/tool_tiers.yaml index ab2dcfda..0ec9e6b5 100644 --- a/core/tool_tiers.yaml +++ b/core/tool_tiers.yaml @@ -7,14 +7,18 @@ gmail: extended: - get_gmail_thread_content + - get_gmail_thread_metadata - modify_gmail_message_labels + - modify_gmail_thread_labels - list_gmail_labels - manage_gmail_label - draft_gmail_message complete: - get_gmail_threads_content_batch + - get_gmail_threads_metadata_batch - batch_modify_gmail_message_labels + - batch_modify_gmail_thread_labels - start_google_auth drive: diff --git a/gmail/gmail_tools.py b/gmail/gmail_tools.py index a1bad350..f0d116d5 100644 --- a/gmail/gmail_tools.py +++ b/gmail/gmail_tools.py @@ -781,23 +781,38 @@ async def draft_gmail_message( return f"Draft created! Draft ID: {draft_id}" -def _format_thread_content(thread_data: dict, thread_id: str) -> str: +def _format_thread_content(thread_data: dict, thread_id: str, format: str = "full") -> str: """ Helper function to format thread content from Gmail API response. Args: thread_data (dict): Thread data from Gmail API thread_id (str): Thread ID for display + format (str): Message format - "minimal", "metadata", or "full" Returns: - str: Formatted thread content + str: Formatted thread content based on the specified format """ messages = thread_data.get("messages", []) if not messages: return f"No messages found in thread '{thread_id}'." - # Extract thread subject from the first message + # Handle minimal format - only return message IDs (threadId is included in the Thread ID header) + if format == "minimal": + content_lines = [ + f"Thread ID: {thread_id}", + f"Messages: {len(messages)}", + "", + ] + for i, message in enumerate(messages, 1): + message_id = message.get("id", "unknown") + content_lines.append(f" {i}. Message ID: {message_id}") + return "\n".join(content_lines) + + # For other formats, extract thread subject from the first message first_message = messages[0] + + # For metadata and full formats, extract headers first_headers = { h["name"]: h["value"] for h in first_message.get("payload", {}).get("headers", []) @@ -814,7 +829,9 @@ def _format_thread_content(thread_data: dict, thread_id: str) -> str: # Process each message in the thread for i, message in enumerate(messages, 1): - # Extract headers + message_id = message.get("id", "unknown") + + # Extract headers for metadata and full formats headers = { h["name"]: h["value"] for h in message.get("payload", {}).get("headers", []) } @@ -823,19 +840,11 @@ def _format_thread_content(thread_data: dict, thread_id: str) -> str: date = headers.get("Date", "(unknown date)") subject = headers.get("Subject", "(no subject)") - # Extract both text and HTML bodies - payload = message.get("payload", {}) - bodies = _extract_message_bodies(payload) - text_body = bodies.get("text", "") - html_body = bodies.get("html", "") - - # Format body content with HTML fallback - body_data = _format_body_content(text_body, html_body) - - # Add message to content + # Add message header info content_lines.extend( [ f"=== Message {i} ===", + f"Message ID: {message_id}", f"From: {sender}", f"Date: {date}", ] @@ -845,22 +854,67 @@ def _format_thread_content(thread_data: dict, thread_id: str) -> str: if subject != thread_subject: content_lines.append(f"Subject: {subject}") - content_lines.extend( - [ - "", - body_data, - "", - ] - ) + # Handle metadata format - include labels and snippet + if format == "metadata": + label_ids = message.get("labelIds", []) + snippet = message.get("snippet", "") + size_estimate = message.get("sizeEstimate", 0) + + if label_ids: + content_lines.append(f"Labels: {', '.join(label_ids)}") + if snippet: + content_lines.append(f"Snippet: {snippet}") + if size_estimate: + content_lines.append(f"Size: {size_estimate} bytes") + content_lines.append("") + + # Handle full format - include body + elif format == "full": + # Extract both text and HTML bodies + payload = message.get("payload", {}) + bodies = _extract_message_bodies(payload) + text_body = bodies.get("text", "") + html_body = bodies.get("html", "") + + # Format body content with HTML fallback + body_data = _format_body_content(text_body, html_body) + + content_lines.extend( + [ + "", + body_data, + "", + ] + ) return "\n".join(content_lines) +def _validate_format_parameter(format: str) -> str: + """ + Validate and normalize the format parameter for Gmail API calls. + + Args: + format (str): The format parameter to validate + + Returns: + str: The normalized format value (lowercase) or "full" if invalid + """ + valid_formats = ["minimal", "metadata", "full"] + if not format or not isinstance(format, str): + return "full" + normalized_format = format.lower() + if normalized_format not in valid_formats: + logger.warning(f"Invalid format '{format}' provided, defaulting to 'full'") + normalized_format = "full" + return normalized_format + + @server.tool() @require_google_service("gmail", "gmail_read") @handle_http_errors("get_gmail_thread_content", is_read_only=True, service_type="gmail") async def get_gmail_thread_content( - service, thread_id: str, user_google_email: str + service, thread_id: str, user_google_email: str, format: str = "full" ) -> str: """ Retrieves the complete content of a Gmail conversation thread, including all messages. @@ -868,20 +922,26 @@ async def get_gmail_thread_content( Args: thread_id (str): The unique ID of the Gmail thread to retrieve. user_google_email (str): The user's Google email address. Required. + format (str): Message format. Options: "minimal" (returns thread info and message IDs), + "metadata" (returns headers, labels, and snippet without body), + "full" (returns complete message with body - default). Returns: str: The complete thread content with all messages formatted for reading. """ logger.info( - f"[get_gmail_thread_content] Invoked. Thread ID: '{thread_id}', Email: '{user_google_email}'" + f"[get_gmail_thread_content] Invoked. Thread ID: '{thread_id}', Email: '{user_google_email}', Format: '{format}'" ) + # Validate and normalize format parameter + normalized_format = _validate_format_parameter(format) + # Fetch the complete thread with all messages thread_response = await asyncio.to_thread( - service.users().threads().get(userId="me", id=thread_id, format="full").execute + service.users().threads().get(userId="me", id=thread_id, format=normalized_format).execute ) - return _format_thread_content(thread_response, thread_id) + return _format_thread_content(thread_response, thread_id, normalized_format) @server.tool() @@ -891,6 +951,7 @@ async def get_gmail_threads_content_batch( service, thread_ids: List[str], user_google_email: str, + format: str = "full", ) -> str: """ Retrieves the content of multiple Gmail threads in a single batch request. @@ -899,14 +960,20 @@ async def get_gmail_threads_content_batch( Args: thread_ids (List[str]): A list of Gmail thread IDs to retrieve. The function will automatically batch requests in chunks of 25. user_google_email (str): The user's Google email address. Required. + format (str): Message format. Options: "minimal" (returns thread info and message IDs), + "metadata" (returns headers, labels, and snippet without body), + "full" (returns complete message with body - default). Returns: str: A formatted list of thread contents with separators. """ logger.info( - f"[get_gmail_threads_content_batch] Invoked. Thread count: {len(thread_ids)}, Email: '{user_google_email}'" + f"[get_gmail_threads_content_batch] Invoked. Thread count: {len(thread_ids)}, Email: '{user_google_email}', Format: '{format}'" ) + # Validate and normalize format parameter + normalized_format = _validate_format_parameter(format) + if not thread_ids: raise ValueError("No thread IDs provided") @@ -926,7 +993,7 @@ def _batch_callback(request_id, response, exception): batch = service.new_batch_http_request(callback=_batch_callback) for tid in chunk_ids: - req = service.users().threads().get(userId="me", id=tid, format="full") + req = service.users().threads().get(userId="me", id=tid, format=normalized_format) batch.add(req, request_id=tid) # Execute batch request @@ -945,7 +1012,7 @@ async def fetch_thread_with_retry(tid: str, max_retries: int = 3): thread = await asyncio.to_thread( service.users() .threads() - .get(userId="me", id=tid, format="full") + .get(userId="me", id=tid, format=normalized_format) .execute ) return tid, thread, None @@ -984,13 +1051,150 @@ async def fetch_thread_with_retry(tid: str, max_retries: int = 3): output_threads.append(f"⚠️ Thread {tid}: No data returned\n") continue - output_threads.append(_format_thread_content(thread, tid)) + output_threads.append(_format_thread_content(thread, tid, normalized_format)) # Combine all threads with separators header = f"Retrieved {len(thread_ids)} threads:" return header + "\n\n" + "\n---\n\n".join(output_threads) +@server.tool() +@require_google_service("gmail", "gmail_read") +@handle_http_errors("get_gmail_thread_metadata", is_read_only=True, service_type="gmail") +async def get_gmail_thread_metadata( + service, thread_id: str, user_google_email: str +) -> str: + """ + Retrieves the metadata of a Gmail conversation thread, including headers, labels, and snippet for all messages. + This is equivalent to calling get_gmail_thread_content with format='metadata'. + + Args: + thread_id (str): The unique ID of the Gmail thread to retrieve. + user_google_email (str): The user's Google email address. Required. + + Returns: + str: The thread metadata with all messages formatted for reading (headers, labels, snippet - no body). + """ + logger.info( + f"[get_gmail_thread_metadata] Invoked. Thread ID: '{thread_id}', Email: '{user_google_email}'" + ) + + # Fetch the thread with metadata format + thread_response = await asyncio.to_thread( + service.users().threads().get(userId="me", id=thread_id, format="metadata").execute + ) + + return _format_thread_content(thread_response, thread_id, "metadata") + + +@server.tool() +@require_google_service("gmail", "gmail_read") +@handle_http_errors("get_gmail_threads_metadata_batch", is_read_only=True, service_type="gmail") +async def get_gmail_threads_metadata_batch( + service, + thread_ids: List[str], + user_google_email: str, +) -> str: + """ + Retrieves the metadata of multiple Gmail threads in a single batch request. + This is equivalent to calling get_gmail_threads_content_batch with format='metadata'. + Supports up to 25 threads per batch to prevent SSL connection exhaustion. + + Args: + thread_ids (List[str]): A list of Gmail thread IDs to retrieve. The function will automatically batch requests in chunks of 25. + user_google_email (str): The user's Google email address. Required. + + Returns: + str: A formatted list of thread metadata (headers, labels, snippet - no body) with separators. + """ + logger.info( + f"[get_gmail_threads_metadata_batch] Invoked. Thread count: {len(thread_ids)}, Email: '{user_google_email}'" + ) + + if not thread_ids: + raise ValueError("No thread IDs provided") + + output_threads = [] + + def _batch_callback(request_id, response, exception): + """Callback for batch requests""" + results[request_id] = {"data": response, "error": exception} + + # Process in smaller chunks to prevent SSL connection exhaustion + for chunk_start in range(0, len(thread_ids), GMAIL_BATCH_SIZE): + chunk_ids = thread_ids[chunk_start : chunk_start + GMAIL_BATCH_SIZE] + results: Dict[str, Dict] = {} + + # Try to use batch API + try: + batch = service.new_batch_http_request(callback=_batch_callback) + + for tid in chunk_ids: + req = service.users().threads().get(userId="me", id=tid, format="metadata") + batch.add(req, request_id=tid) + + # Execute batch request + await asyncio.to_thread(batch.execute) + + except Exception as batch_error: + # Fallback to sequential processing instead of parallel to prevent SSL exhaustion + logger.warning( + f"[get_gmail_threads_metadata_batch] Batch API failed, falling back to sequential processing: {batch_error}" + ) + + async def fetch_thread_with_retry(tid: str, max_retries: int = 3): + """Fetch a single thread with exponential backoff retry for SSL errors""" + for attempt in range(max_retries): + try: + thread = await asyncio.to_thread( + service.users() + .threads() + .get(userId="me", id=tid, format="metadata") + .execute + ) + return tid, thread, None + except ssl.SSLError as ssl_error: + if attempt < max_retries - 1: + # Exponential backoff: 1s, 2s, 4s + delay = 2 ** attempt + logger.warning( + f"[get_gmail_threads_metadata_batch] SSL error for thread {tid} on attempt {attempt + 1}: {ssl_error}. Retrying in {delay}s..." + ) + await asyncio.sleep(delay) + else: + logger.error( + f"[get_gmail_threads_metadata_batch] SSL error for thread {tid} on final attempt: {ssl_error}" + ) + return tid, None, ssl_error + except Exception as e: + return tid, None, e + + # Process threads sequentially with small delays to prevent connection exhaustion + for tid in chunk_ids: + tid_result, thread_data, error = await fetch_thread_with_retry(tid) + results[tid_result] = {"data": thread_data, "error": error} + # Brief delay between requests to allow connection cleanup + await asyncio.sleep(GMAIL_REQUEST_DELAY) + + # Process results for this chunk + for tid in chunk_ids: + entry = results.get(tid, {"data": None, "error": "No result"}) + + if entry["error"]: + output_threads.append(f"⚠️ Thread {tid}: {entry['error']}\n") + else: + thread = entry["data"] + if not thread: + output_threads.append(f"⚠️ Thread {tid}: No data returned\n") + continue + + output_threads.append(_format_thread_content(thread, tid, "metadata")) + + # Combine all threads with separators + header = f"Retrieved {len(thread_ids)} threads (metadata only):" + return header + "\n\n" + "\n---\n\n".join(output_threads) + + @server.tool() @handle_http_errors("list_gmail_labels", is_read_only=True, service_type="gmail") @require_google_service("gmail", "gmail_read") @@ -1218,3 +1422,208 @@ async def batch_modify_gmail_message_labels( actions.append(f"Removed labels: {', '.join(remove_label_ids)}") return f"Labels updated for {len(message_ids)} messages: {'; '.join(actions)}" + + +@server.tool() +@handle_http_errors("modify_gmail_thread_labels", service_type="gmail") +@require_google_service("gmail", GMAIL_MODIFY_SCOPE) +async def modify_gmail_thread_labels( + service, + user_google_email: str, + thread_id: str, + add_label_ids: List[str] = Field(default=[], description="Label IDs to add to all messages in the thread."), + remove_label_ids: List[str] = Field(default=[], description="Label IDs to remove from all messages in the thread."), +) -> str: + """ + Adds or removes labels from all messages in a Gmail thread. + This applies the same label changes to every message within the specified thread. + To archive all messages in a thread, remove the INBOX label. + To delete all messages in a thread, add the TRASH label. + + Args: + user_google_email (str): The user's Google email address. Required. + thread_id (str): The ID of the thread whose messages should be modified. + add_label_ids (Optional[List[str]]): List of label IDs to add to all messages in the thread. + remove_label_ids (Optional[List[str]]): List of label IDs to remove from all messages in the thread. + + Returns: + str: Confirmation message of the label changes applied to all messages in the thread. + """ + logger.info( + f"[modify_gmail_thread_labels] Invoked. Email: '{user_google_email}', Thread ID: '{thread_id}'" + ) + + if not add_label_ids and not remove_label_ids: + raise Exception( + "At least one of add_label_ids or remove_label_ids must be provided." + ) + + # Fetch the thread to get all message IDs + thread_response = await asyncio.to_thread( + service.users().threads().get(userId="me", id=thread_id, format="minimal").execute + ) + + messages = thread_response.get("messages", []) + if not messages: + return f"No messages found in thread '{thread_id}'." + + message_ids = [msg["id"] for msg in messages] + + # Use the batch modify functionality to update all messages in the thread + body = {"ids": message_ids} + if add_label_ids: + body["addLabelIds"] = add_label_ids + if remove_label_ids: + body["removeLabelIds"] = remove_label_ids + + await asyncio.to_thread( + service.users().messages().batchModify(userId="me", body=body).execute + ) + + actions = [] + if add_label_ids: + actions.append(f"Added labels: {', '.join(add_label_ids)}") + if remove_label_ids: + actions.append(f"Removed labels: {', '.join(remove_label_ids)}") + + return f"Labels updated for all {len(message_ids)} messages in thread {thread_id}: {'; '.join(actions)}" + + +@server.tool() +@handle_http_errors("batch_modify_gmail_thread_labels", service_type="gmail") +@require_google_service("gmail", GMAIL_MODIFY_SCOPE) +async def batch_modify_gmail_thread_labels( + service, + user_google_email: str, + thread_ids: List[str], + add_label_ids: List[str] = Field(default=[], description="Label IDs to add to all messages in the threads."), + remove_label_ids: List[str] = Field(default=[], description="Label IDs to remove from all messages in the threads."), +) -> str: + """ + Adds or removes labels from all messages in multiple Gmail threads. + This applies the same label changes to every message within all specified threads. + + Args: + user_google_email (str): The user's Google email address. Required. + thread_ids (List[str]): A list of thread IDs whose messages should be modified. + add_label_ids (Optional[List[str]]): List of label IDs to add to all messages in the threads. + remove_label_ids (Optional[List[str]]): List of label IDs to remove from all messages in the threads. + + Returns: + str: Confirmation message of the label changes applied to all messages in the threads. + """ + logger.info( + f"[batch_modify_gmail_thread_labels] Invoked. Email: '{user_google_email}', Thread IDs: '{thread_ids}'" + ) + + if not add_label_ids and not remove_label_ids: + raise Exception( + "At least one of add_label_ids or remove_label_ids must be provided." + ) + + if not thread_ids: + raise ValueError("No thread IDs provided") + + # Collect all message IDs from all threads + all_message_ids = [] + thread_message_counts = {} + + # Process threads in chunks to prevent SSL connection exhaustion + for chunk_start in range(0, len(thread_ids), GMAIL_BATCH_SIZE): + chunk_ids = thread_ids[chunk_start : chunk_start + GMAIL_BATCH_SIZE] + results: Dict[str, Dict] = {} + + def _batch_callback(request_id, response, exception): + """Callback for batch requests""" + results[request_id] = {"data": response, "error": exception} + + # Try to use batch API to fetch threads + try: + batch = service.new_batch_http_request(callback=_batch_callback) + + for tid in chunk_ids: + req = service.users().threads().get(userId="me", id=tid, format="minimal") + batch.add(req, request_id=tid) + + # Execute batch request + await asyncio.to_thread(batch.execute) + + except Exception as batch_error: + # Fallback to sequential processing + logger.warning( + f"[batch_modify_gmail_thread_labels] Batch API failed, falling back to sequential processing: {batch_error}" + ) + + async def fetch_thread_with_retry(tid: str, max_retries: int = 3): + """Fetch a single thread with exponential backoff retry for SSL errors""" + for attempt in range(max_retries): + try: + thread = await asyncio.to_thread( + service.users() + .threads() + .get(userId="me", id=tid, format="minimal") + .execute + ) + return tid, thread, None + except ssl.SSLError as ssl_error: + if attempt < max_retries - 1: + delay = 2 ** attempt + logger.warning( + f"[batch_modify_gmail_thread_labels] SSL error for thread {tid} on attempt {attempt + 1}: {ssl_error}. Retrying in {delay}s..." + ) + await asyncio.sleep(delay) + else: + logger.error( + f"[batch_modify_gmail_thread_labels] SSL error for thread {tid} on final attempt: {ssl_error}" + ) + return tid, None, ssl_error + except Exception as e: + return tid, None, e + + # Process threads sequentially with small delays to prevent connection exhaustion + for tid in chunk_ids: + tid_result, thread_data, error = await fetch_thread_with_retry(tid) + results[tid_result] = {"data": thread_data, "error": error} + # Brief delay between requests to allow connection cleanup + await asyncio.sleep(GMAIL_REQUEST_DELAY) + + # Process results for this chunk + for tid in chunk_ids: + entry = results.get(tid, {"data": None, "error": "No result"}) + + if entry["error"]: + logger.warning(f"[batch_modify_gmail_thread_labels] Error fetching thread {tid}: {entry['error']}") + continue + + thread = entry["data"] + if not thread: + logger.warning(f"[batch_modify_gmail_thread_labels] No data returned for thread {tid}") + continue + + messages = thread.get("messages", []) + if messages: + message_ids = [msg["id"] for msg in messages] + all_message_ids.extend(message_ids) + thread_message_counts[tid] = len(message_ids) + + if not all_message_ids: + return "No messages found in the specified threads." + + # Use the batch modify functionality to update all messages + body = {"ids": all_message_ids} + if add_label_ids: + body["addLabelIds"] = add_label_ids + if remove_label_ids: + body["removeLabelIds"] = remove_label_ids + + await asyncio.to_thread( + service.users().messages().batchModify(userId="me", body=body).execute + ) + + actions = [] + if add_label_ids: + actions.append(f"Added labels: {', '.join(add_label_ids)}") + if remove_label_ids: + actions.append(f"Removed labels: {', '.join(remove_label_ids)}") + + return f"Labels updated for {len(all_message_ids)} messages across {len(thread_message_counts)} threads: {'; '.join(actions)}"