diff --git a/README.md b/README.md
index 2adcbba1..d46dd70d 100644
--- a/README.md
+++ b/README.md
@@ -716,14 +716,34 @@ cp .env.oauth21 .env
| `get_gmail_messages_content_batch` | **Core** | Batch retrieve message content |
| `send_gmail_message` | **Core** | Send emails |
| `get_gmail_thread_content` | Extended | Get full thread content |
+| `get_gmail_thread_metadata` | Extended | Get thread metadata (headers, labels, snippet) |
| `modify_gmail_message_labels` | Extended | Modify message labels |
+| `modify_gmail_thread_labels` | Extended | Modify labels for all messages in a thread |
| `list_gmail_labels` | Extended | List available labels |
| `manage_gmail_label` | Extended | Create/update/delete labels |
| `draft_gmail_message` | Extended | Create drafts |
| `get_gmail_threads_content_batch` | Complete | Batch retrieve thread content |
+| `get_gmail_threads_metadata_batch` | Complete | Batch retrieve thread metadata |
| `batch_modify_gmail_message_labels` | Complete | Batch modify labels |
+| `batch_modify_gmail_thread_labels` | Complete | Batch modify labels for all messages in threads |
| `start_google_auth` | Complete | Initialize authentication |
+**Gmail Message Format Options:**
+
+Both `get_gmail_thread_content` and `get_gmail_threads_content_batch` support an optional `format` parameter (defaults to `"full"`):
+
+| Format | Description |
+|--------|-------------|
+| `"minimal"` | Returns only the `id` and `threadId` of each message. |
+| `"metadata"` | Returns message metadata such as headers, labels, and snippet, but *not* the full body. |
+| `"full"` | Returns the full email message data, including headers and body (Base64 encoded). This is the default and most commonly used for reading messages. |
+
+**Convenience Metadata Functions:**
+
+For quick access to metadata only, use the dedicated metadata functions:
+- `get_gmail_thread_metadata` - Equivalent to `get_gmail_thread_content` with `format="metadata"`
+- `get_gmail_threads_metadata_batch` - Equivalent to `get_gmail_threads_content_batch` with `format="metadata"`
+
diff --git a/auth/credential_store.py b/auth/credential_store.py
index 24dda70b..385c77ba 100644
--- a/auth/credential_store.py
+++ b/auth/credential_store.py
@@ -151,9 +151,21 @@ def store_credential(self, user_email: str, credentials: Credentials) -> bool:
"""Store credentials to local JSON file."""
creds_path = self._get_credential_path(user_email)
+ # Preserve existing refresh token if new credentials don't have one
+ # This prevents losing refresh tokens during re-authorization flows
+ refresh_token_to_store = credentials.refresh_token
+ if not refresh_token_to_store:
+ try:
+ existing_creds = self.get_credential(user_email)
+ if existing_creds and existing_creds.refresh_token:
+ refresh_token_to_store = existing_creds.refresh_token
+ logger.info(f"Preserved existing refresh token for {user_email} in credential store")
+ except Exception as e:
+ logger.debug(f"Could not check existing credentials to preserve refresh token: {e}")
+
creds_data = {
"token": credentials.token,
- "refresh_token": credentials.refresh_token,
+ "refresh_token": refresh_token_to_store,
"token_uri": credentials.token_uri,
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
diff --git a/auth/google_auth.py b/auth/google_auth.py
index aef73a4d..22549170 100644
--- a/auth/google_auth.py
+++ b/auth/google_auth.py
@@ -353,7 +353,7 @@ async def start_auth_flow(
state=oauth_state,
)
- auth_url, _ = flow.authorization_url(access_type="offline", prompt="consent")
+ auth_url, _ = flow.authorization_url(access_type="offline")
session_id = None
try:
@@ -486,16 +486,41 @@ def handle_auth_callback(
user_google_email = user_info["email"]
logger.info(f"Identified user_google_email: {user_google_email}")
- # Save the credentials
+ # Preserve existing refresh token if new credentials don't have one
+ # Google often doesn't return refresh_token on re-authorization if one already exists
+ existing_credentials = None
credential_store = get_credential_store()
+ try:
+ existing_credentials = credential_store.get_credential(user_google_email)
+ except Exception as e:
+ logger.debug(f"Could not load existing credentials to preserve refresh token: {e}")
+
+ # Also check OAuth21SessionStore for existing refresh token
+ store = get_oauth21_session_store()
+ existing_session_creds = store.get_credentials(user_google_email)
+
+ # Preserve refresh token from existing credentials if new one is missing
+ preserved_refresh_token = credentials.refresh_token
+ if not preserved_refresh_token:
+ if existing_credentials and existing_credentials.refresh_token:
+ preserved_refresh_token = existing_credentials.refresh_token
+ logger.info(f"Preserved existing refresh token from credential store for {user_google_email}")
+ elif existing_session_creds and existing_session_creds.refresh_token:
+ preserved_refresh_token = existing_session_creds.refresh_token
+ logger.info(f"Preserved existing refresh token from session store for {user_google_email}")
+
+ # Update credentials object with preserved refresh token if we found one
+ if preserved_refresh_token and not credentials.refresh_token:
+ credentials.refresh_token = preserved_refresh_token
+
+ # Save the credentials
credential_store.store_credential(user_google_email, credentials)
# Always save to OAuth21SessionStore for centralized management
- store = get_oauth21_session_store()
store.store_session(
user_email=user_google_email,
access_token=credentials.token,
- refresh_token=credentials.refresh_token,
+ refresh_token=preserved_refresh_token or credentials.refresh_token,
token_uri=credentials.token_uri,
client_id=credentials.client_id,
client_secret=credentials.client_secret,
@@ -505,9 +530,8 @@ def handle_auth_callback(
issuer="https://accounts.google.com" # Add issuer for Google tokens
)
- # If session_id is provided, also save to session cache for compatibility
- if session_id:
- save_credentials_to_session(session_id, credentials)
+ # Note: No need to call save_credentials_to_session() here as we've already
+ # saved to the OAuth21SessionStore above with the correct issuer and mcp_session_id
return user_google_email, credentials
@@ -570,6 +594,9 @@ def get_credentials(
user_email=user_email,
access_token=credentials.token,
refresh_token=credentials.refresh_token,
+ token_uri=credentials.token_uri,
+ client_id=credentials.client_id,
+ client_secret=credentials.client_secret,
scopes=credentials.scopes,
expiry=credentials.expiry,
mcp_session_id=session_id
diff --git a/auth/oauth21_session_store.py b/auth/oauth21_session_store.py
index 15f7d8b3..695cd7d3 100644
--- a/auth/oauth21_session_store.py
+++ b/auth/oauth21_session_store.py
@@ -311,10 +311,19 @@ def store_session(
issuer: Token issuer (e.g., "https://accounts.google.com")
"""
with self._lock:
+ # Preserve existing refresh token if new one is not provided
+ # This prevents losing refresh tokens during re-authorization flows
+ preserved_refresh_token = refresh_token
+ if not preserved_refresh_token:
+ existing_session = self._sessions.get(user_email)
+ if existing_session and existing_session.get("refresh_token"):
+ preserved_refresh_token = existing_session["refresh_token"]
+ logger.info(f"Preserved existing refresh token for {user_email} in session store")
+
normalized_expiry = _normalize_expiry_to_naive_utc(expiry)
session_info = {
"access_token": access_token,
- "refresh_token": refresh_token,
+ "refresh_token": preserved_refresh_token,
"token_uri": token_uri,
"client_id": client_id,
"client_secret": client_secret,
diff --git a/core/tool_tiers.yaml b/core/tool_tiers.yaml
index ab2dcfda..0ec9e6b5 100644
--- a/core/tool_tiers.yaml
+++ b/core/tool_tiers.yaml
@@ -7,14 +7,18 @@ gmail:
extended:
- get_gmail_thread_content
+ - get_gmail_thread_metadata
- modify_gmail_message_labels
+ - modify_gmail_thread_labels
- list_gmail_labels
- manage_gmail_label
- draft_gmail_message
complete:
- get_gmail_threads_content_batch
+ - get_gmail_threads_metadata_batch
- batch_modify_gmail_message_labels
+ - batch_modify_gmail_thread_labels
- start_google_auth
drive:
diff --git a/gmail/gmail_tools.py b/gmail/gmail_tools.py
index a1bad350..f0d116d5 100644
--- a/gmail/gmail_tools.py
+++ b/gmail/gmail_tools.py
@@ -781,23 +781,38 @@ async def draft_gmail_message(
return f"Draft created! Draft ID: {draft_id}"
-def _format_thread_content(thread_data: dict, thread_id: str) -> str:
+def _format_thread_content(thread_data: dict, thread_id: str, format: str = "full") -> str:
"""
Helper function to format thread content from Gmail API response.
Args:
thread_data (dict): Thread data from Gmail API
thread_id (str): Thread ID for display
+ format (str): Message format - "minimal", "metadata", or "full"
Returns:
- str: Formatted thread content
+ str: Formatted thread content based on the specified format
"""
messages = thread_data.get("messages", [])
if not messages:
return f"No messages found in thread '{thread_id}'."
- # Extract thread subject from the first message
+ # Handle minimal format - only return message IDs (threadId is included in the Thread ID header)
+ if format == "minimal":
+ content_lines = [
+ f"Thread ID: {thread_id}",
+ f"Messages: {len(messages)}",
+ "",
+ ]
+ for i, message in enumerate(messages, 1):
+ message_id = message.get("id", "unknown")
+ content_lines.append(f" {i}. Message ID: {message_id}")
+ return "\n".join(content_lines)
+
+ # For other formats, extract thread subject from the first message
first_message = messages[0]
+
+ # For metadata and full formats, extract headers
first_headers = {
h["name"]: h["value"]
for h in first_message.get("payload", {}).get("headers", [])
@@ -814,7 +829,9 @@ def _format_thread_content(thread_data: dict, thread_id: str) -> str:
# Process each message in the thread
for i, message in enumerate(messages, 1):
- # Extract headers
+ message_id = message.get("id", "unknown")
+
+ # Extract headers for metadata and full formats
headers = {
h["name"]: h["value"] for h in message.get("payload", {}).get("headers", [])
}
@@ -823,19 +840,11 @@ def _format_thread_content(thread_data: dict, thread_id: str) -> str:
date = headers.get("Date", "(unknown date)")
subject = headers.get("Subject", "(no subject)")
- # Extract both text and HTML bodies
- payload = message.get("payload", {})
- bodies = _extract_message_bodies(payload)
- text_body = bodies.get("text", "")
- html_body = bodies.get("html", "")
-
- # Format body content with HTML fallback
- body_data = _format_body_content(text_body, html_body)
-
- # Add message to content
+ # Add message header info
content_lines.extend(
[
f"=== Message {i} ===",
+ f"Message ID: {message_id}",
f"From: {sender}",
f"Date: {date}",
]
@@ -845,22 +854,67 @@ def _format_thread_content(thread_data: dict, thread_id: str) -> str:
if subject != thread_subject:
content_lines.append(f"Subject: {subject}")
- content_lines.extend(
- [
- "",
- body_data,
- "",
- ]
- )
+ # Handle metadata format - include labels and snippet
+ if format == "metadata":
+ label_ids = message.get("labelIds", [])
+ snippet = message.get("snippet", "")
+ size_estimate = message.get("sizeEstimate", 0)
+
+ if label_ids:
+ content_lines.append(f"Labels: {', '.join(label_ids)}")
+ if snippet:
+ content_lines.append(f"Snippet: {snippet}")
+ if size_estimate:
+ content_lines.append(f"Size: {size_estimate} bytes")
+ content_lines.append("")
+
+ # Handle full format - include body
+ elif format == "full":
+ # Extract both text and HTML bodies
+ payload = message.get("payload", {})
+ bodies = _extract_message_bodies(payload)
+ text_body = bodies.get("text", "")
+ html_body = bodies.get("html", "")
+
+ # Format body content with HTML fallback
+ body_data = _format_body_content(text_body, html_body)
+
+ content_lines.extend(
+ [
+ "",
+ body_data,
+ "",
+ ]
+ )
return "\n".join(content_lines)
+def _validate_format_parameter(format: str) -> str:
+ """
+ Validate and normalize the format parameter for Gmail API calls.
+
+ Args:
+ format (str): The format parameter to validate
+
+ Returns:
+ str: The normalized format value (lowercase) or "full" if invalid
+ """
+ valid_formats = ["minimal", "metadata", "full"]
+ if not format or not isinstance(format, str):
+ return "full"
+ normalized_format = format.lower()
+ if normalized_format not in valid_formats:
+ logger.warning(f"Invalid format '{format}' provided, defaulting to 'full'")
+ normalized_format = "full"
+ return normalized_format
+
+
@server.tool()
@require_google_service("gmail", "gmail_read")
@handle_http_errors("get_gmail_thread_content", is_read_only=True, service_type="gmail")
async def get_gmail_thread_content(
- service, thread_id: str, user_google_email: str
+ service, thread_id: str, user_google_email: str, format: str = "full"
) -> str:
"""
Retrieves the complete content of a Gmail conversation thread, including all messages.
@@ -868,20 +922,26 @@ async def get_gmail_thread_content(
Args:
thread_id (str): The unique ID of the Gmail thread to retrieve.
user_google_email (str): The user's Google email address. Required.
+ format (str): Message format. Options: "minimal" (returns thread info and message IDs),
+ "metadata" (returns headers, labels, and snippet without body),
+ "full" (returns complete message with body - default).
Returns:
str: The complete thread content with all messages formatted for reading.
"""
logger.info(
- f"[get_gmail_thread_content] Invoked. Thread ID: '{thread_id}', Email: '{user_google_email}'"
+ f"[get_gmail_thread_content] Invoked. Thread ID: '{thread_id}', Email: '{user_google_email}', Format: '{format}'"
)
+ # Validate and normalize format parameter
+ normalized_format = _validate_format_parameter(format)
+
# Fetch the complete thread with all messages
thread_response = await asyncio.to_thread(
- service.users().threads().get(userId="me", id=thread_id, format="full").execute
+ service.users().threads().get(userId="me", id=thread_id, format=normalized_format).execute
)
- return _format_thread_content(thread_response, thread_id)
+ return _format_thread_content(thread_response, thread_id, normalized_format)
@server.tool()
@@ -891,6 +951,7 @@ async def get_gmail_threads_content_batch(
service,
thread_ids: List[str],
user_google_email: str,
+ format: str = "full",
) -> str:
"""
Retrieves the content of multiple Gmail threads in a single batch request.
@@ -899,14 +960,20 @@ async def get_gmail_threads_content_batch(
Args:
thread_ids (List[str]): A list of Gmail thread IDs to retrieve. The function will automatically batch requests in chunks of 25.
user_google_email (str): The user's Google email address. Required.
+ format (str): Message format. Options: "minimal" (returns thread info and message IDs),
+ "metadata" (returns headers, labels, and snippet without body),
+ "full" (returns complete message with body - default).
Returns:
str: A formatted list of thread contents with separators.
"""
logger.info(
- f"[get_gmail_threads_content_batch] Invoked. Thread count: {len(thread_ids)}, Email: '{user_google_email}'"
+ f"[get_gmail_threads_content_batch] Invoked. Thread count: {len(thread_ids)}, Email: '{user_google_email}', Format: '{format}'"
)
+ # Validate and normalize format parameter
+ normalized_format = _validate_format_parameter(format)
+
if not thread_ids:
raise ValueError("No thread IDs provided")
@@ -926,7 +993,7 @@ def _batch_callback(request_id, response, exception):
batch = service.new_batch_http_request(callback=_batch_callback)
for tid in chunk_ids:
- req = service.users().threads().get(userId="me", id=tid, format="full")
+ req = service.users().threads().get(userId="me", id=tid, format=normalized_format)
batch.add(req, request_id=tid)
# Execute batch request
@@ -945,7 +1012,7 @@ async def fetch_thread_with_retry(tid: str, max_retries: int = 3):
thread = await asyncio.to_thread(
service.users()
.threads()
- .get(userId="me", id=tid, format="full")
+ .get(userId="me", id=tid, format=normalized_format)
.execute
)
return tid, thread, None
@@ -984,13 +1051,150 @@ async def fetch_thread_with_retry(tid: str, max_retries: int = 3):
output_threads.append(f"⚠️ Thread {tid}: No data returned\n")
continue
- output_threads.append(_format_thread_content(thread, tid))
+ output_threads.append(_format_thread_content(thread, tid, normalized_format))
# Combine all threads with separators
header = f"Retrieved {len(thread_ids)} threads:"
return header + "\n\n" + "\n---\n\n".join(output_threads)
+@server.tool()
+@require_google_service("gmail", "gmail_read")
+@handle_http_errors("get_gmail_thread_metadata", is_read_only=True, service_type="gmail")
+async def get_gmail_thread_metadata(
+ service, thread_id: str, user_google_email: str
+) -> str:
+ """
+ Retrieves the metadata of a Gmail conversation thread, including headers, labels, and snippet for all messages.
+ This is equivalent to calling get_gmail_thread_content with format='metadata'.
+
+ Args:
+ thread_id (str): The unique ID of the Gmail thread to retrieve.
+ user_google_email (str): The user's Google email address. Required.
+
+ Returns:
+ str: The thread metadata with all messages formatted for reading (headers, labels, snippet - no body).
+ """
+ logger.info(
+ f"[get_gmail_thread_metadata] Invoked. Thread ID: '{thread_id}', Email: '{user_google_email}'"
+ )
+
+ # Fetch the thread with metadata format
+ thread_response = await asyncio.to_thread(
+ service.users().threads().get(userId="me", id=thread_id, format="metadata").execute
+ )
+
+ return _format_thread_content(thread_response, thread_id, "metadata")
+
+
+@server.tool()
+@require_google_service("gmail", "gmail_read")
+@handle_http_errors("get_gmail_threads_metadata_batch", is_read_only=True, service_type="gmail")
+async def get_gmail_threads_metadata_batch(
+ service,
+ thread_ids: List[str],
+ user_google_email: str,
+) -> str:
+ """
+ Retrieves the metadata of multiple Gmail threads in a single batch request.
+ This is equivalent to calling get_gmail_threads_content_batch with format='metadata'.
+ Supports up to 25 threads per batch to prevent SSL connection exhaustion.
+
+ Args:
+ thread_ids (List[str]): A list of Gmail thread IDs to retrieve. The function will automatically batch requests in chunks of 25.
+ user_google_email (str): The user's Google email address. Required.
+
+ Returns:
+ str: A formatted list of thread metadata (headers, labels, snippet - no body) with separators.
+ """
+ logger.info(
+ f"[get_gmail_threads_metadata_batch] Invoked. Thread count: {len(thread_ids)}, Email: '{user_google_email}'"
+ )
+
+ if not thread_ids:
+ raise ValueError("No thread IDs provided")
+
+ output_threads = []
+
+ def _batch_callback(request_id, response, exception):
+ """Callback for batch requests"""
+ results[request_id] = {"data": response, "error": exception}
+
+ # Process in smaller chunks to prevent SSL connection exhaustion
+ for chunk_start in range(0, len(thread_ids), GMAIL_BATCH_SIZE):
+ chunk_ids = thread_ids[chunk_start : chunk_start + GMAIL_BATCH_SIZE]
+ results: Dict[str, Dict] = {}
+
+ # Try to use batch API
+ try:
+ batch = service.new_batch_http_request(callback=_batch_callback)
+
+ for tid in chunk_ids:
+ req = service.users().threads().get(userId="me", id=tid, format="metadata")
+ batch.add(req, request_id=tid)
+
+ # Execute batch request
+ await asyncio.to_thread(batch.execute)
+
+ except Exception as batch_error:
+ # Fallback to sequential processing instead of parallel to prevent SSL exhaustion
+ logger.warning(
+ f"[get_gmail_threads_metadata_batch] Batch API failed, falling back to sequential processing: {batch_error}"
+ )
+
+ async def fetch_thread_with_retry(tid: str, max_retries: int = 3):
+ """Fetch a single thread with exponential backoff retry for SSL errors"""
+ for attempt in range(max_retries):
+ try:
+ thread = await asyncio.to_thread(
+ service.users()
+ .threads()
+ .get(userId="me", id=tid, format="metadata")
+ .execute
+ )
+ return tid, thread, None
+ except ssl.SSLError as ssl_error:
+ if attempt < max_retries - 1:
+ # Exponential backoff: 1s, 2s, 4s
+ delay = 2 ** attempt
+ logger.warning(
+ f"[get_gmail_threads_metadata_batch] SSL error for thread {tid} on attempt {attempt + 1}: {ssl_error}. Retrying in {delay}s..."
+ )
+ await asyncio.sleep(delay)
+ else:
+ logger.error(
+ f"[get_gmail_threads_metadata_batch] SSL error for thread {tid} on final attempt: {ssl_error}"
+ )
+ return tid, None, ssl_error
+ except Exception as e:
+ return tid, None, e
+
+ # Process threads sequentially with small delays to prevent connection exhaustion
+ for tid in chunk_ids:
+ tid_result, thread_data, error = await fetch_thread_with_retry(tid)
+ results[tid_result] = {"data": thread_data, "error": error}
+ # Brief delay between requests to allow connection cleanup
+ await asyncio.sleep(GMAIL_REQUEST_DELAY)
+
+ # Process results for this chunk
+ for tid in chunk_ids:
+ entry = results.get(tid, {"data": None, "error": "No result"})
+
+ if entry["error"]:
+ output_threads.append(f"⚠️ Thread {tid}: {entry['error']}\n")
+ else:
+ thread = entry["data"]
+ if not thread:
+ output_threads.append(f"⚠️ Thread {tid}: No data returned\n")
+ continue
+
+ output_threads.append(_format_thread_content(thread, tid, "metadata"))
+
+ # Combine all threads with separators
+ header = f"Retrieved {len(thread_ids)} threads (metadata only):"
+ return header + "\n\n" + "\n---\n\n".join(output_threads)
+
+
@server.tool()
@handle_http_errors("list_gmail_labels", is_read_only=True, service_type="gmail")
@require_google_service("gmail", "gmail_read")
@@ -1218,3 +1422,208 @@ async def batch_modify_gmail_message_labels(
actions.append(f"Removed labels: {', '.join(remove_label_ids)}")
return f"Labels updated for {len(message_ids)} messages: {'; '.join(actions)}"
+
+
+@server.tool()
+@handle_http_errors("modify_gmail_thread_labels", service_type="gmail")
+@require_google_service("gmail", GMAIL_MODIFY_SCOPE)
+async def modify_gmail_thread_labels(
+ service,
+ user_google_email: str,
+ thread_id: str,
+ add_label_ids: List[str] = Field(default=[], description="Label IDs to add to all messages in the thread."),
+ remove_label_ids: List[str] = Field(default=[], description="Label IDs to remove from all messages in the thread."),
+) -> str:
+ """
+ Adds or removes labels from all messages in a Gmail thread.
+ This applies the same label changes to every message within the specified thread.
+ To archive all messages in a thread, remove the INBOX label.
+ To delete all messages in a thread, add the TRASH label.
+
+ Args:
+ user_google_email (str): The user's Google email address. Required.
+ thread_id (str): The ID of the thread whose messages should be modified.
+ add_label_ids (Optional[List[str]]): List of label IDs to add to all messages in the thread.
+ remove_label_ids (Optional[List[str]]): List of label IDs to remove from all messages in the thread.
+
+ Returns:
+ str: Confirmation message of the label changes applied to all messages in the thread.
+ """
+ logger.info(
+ f"[modify_gmail_thread_labels] Invoked. Email: '{user_google_email}', Thread ID: '{thread_id}'"
+ )
+
+ if not add_label_ids and not remove_label_ids:
+ raise Exception(
+ "At least one of add_label_ids or remove_label_ids must be provided."
+ )
+
+ # Fetch the thread to get all message IDs
+ thread_response = await asyncio.to_thread(
+ service.users().threads().get(userId="me", id=thread_id, format="minimal").execute
+ )
+
+ messages = thread_response.get("messages", [])
+ if not messages:
+ return f"No messages found in thread '{thread_id}'."
+
+ message_ids = [msg["id"] for msg in messages]
+
+ # Use the batch modify functionality to update all messages in the thread
+ body = {"ids": message_ids}
+ if add_label_ids:
+ body["addLabelIds"] = add_label_ids
+ if remove_label_ids:
+ body["removeLabelIds"] = remove_label_ids
+
+ await asyncio.to_thread(
+ service.users().messages().batchModify(userId="me", body=body).execute
+ )
+
+ actions = []
+ if add_label_ids:
+ actions.append(f"Added labels: {', '.join(add_label_ids)}")
+ if remove_label_ids:
+ actions.append(f"Removed labels: {', '.join(remove_label_ids)}")
+
+ return f"Labels updated for all {len(message_ids)} messages in thread {thread_id}: {'; '.join(actions)}"
+
+
+@server.tool()
+@handle_http_errors("batch_modify_gmail_thread_labels", service_type="gmail")
+@require_google_service("gmail", GMAIL_MODIFY_SCOPE)
+async def batch_modify_gmail_thread_labels(
+ service,
+ user_google_email: str,
+ thread_ids: List[str],
+ add_label_ids: List[str] = Field(default=[], description="Label IDs to add to all messages in the threads."),
+ remove_label_ids: List[str] = Field(default=[], description="Label IDs to remove from all messages in the threads."),
+) -> str:
+ """
+ Adds or removes labels from all messages in multiple Gmail threads.
+ This applies the same label changes to every message within all specified threads.
+
+ Args:
+ user_google_email (str): The user's Google email address. Required.
+ thread_ids (List[str]): A list of thread IDs whose messages should be modified.
+ add_label_ids (Optional[List[str]]): List of label IDs to add to all messages in the threads.
+ remove_label_ids (Optional[List[str]]): List of label IDs to remove from all messages in the threads.
+
+ Returns:
+ str: Confirmation message of the label changes applied to all messages in the threads.
+ """
+ logger.info(
+ f"[batch_modify_gmail_thread_labels] Invoked. Email: '{user_google_email}', Thread IDs: '{thread_ids}'"
+ )
+
+ if not add_label_ids and not remove_label_ids:
+ raise Exception(
+ "At least one of add_label_ids or remove_label_ids must be provided."
+ )
+
+ if not thread_ids:
+ raise ValueError("No thread IDs provided")
+
+ # Collect all message IDs from all threads
+ all_message_ids = []
+ thread_message_counts = {}
+
+ # Process threads in chunks to prevent SSL connection exhaustion
+ for chunk_start in range(0, len(thread_ids), GMAIL_BATCH_SIZE):
+ chunk_ids = thread_ids[chunk_start : chunk_start + GMAIL_BATCH_SIZE]
+ results: Dict[str, Dict] = {}
+
+ def _batch_callback(request_id, response, exception):
+ """Callback for batch requests"""
+ results[request_id] = {"data": response, "error": exception}
+
+ # Try to use batch API to fetch threads
+ try:
+ batch = service.new_batch_http_request(callback=_batch_callback)
+
+ for tid in chunk_ids:
+ req = service.users().threads().get(userId="me", id=tid, format="minimal")
+ batch.add(req, request_id=tid)
+
+ # Execute batch request
+ await asyncio.to_thread(batch.execute)
+
+ except Exception as batch_error:
+ # Fallback to sequential processing
+ logger.warning(
+ f"[batch_modify_gmail_thread_labels] Batch API failed, falling back to sequential processing: {batch_error}"
+ )
+
+ async def fetch_thread_with_retry(tid: str, max_retries: int = 3):
+ """Fetch a single thread with exponential backoff retry for SSL errors"""
+ for attempt in range(max_retries):
+ try:
+ thread = await asyncio.to_thread(
+ service.users()
+ .threads()
+ .get(userId="me", id=tid, format="minimal")
+ .execute
+ )
+ return tid, thread, None
+ except ssl.SSLError as ssl_error:
+ if attempt < max_retries - 1:
+ delay = 2 ** attempt
+ logger.warning(
+ f"[batch_modify_gmail_thread_labels] SSL error for thread {tid} on attempt {attempt + 1}: {ssl_error}. Retrying in {delay}s..."
+ )
+ await asyncio.sleep(delay)
+ else:
+ logger.error(
+ f"[batch_modify_gmail_thread_labels] SSL error for thread {tid} on final attempt: {ssl_error}"
+ )
+ return tid, None, ssl_error
+ except Exception as e:
+ return tid, None, e
+
+ # Process threads sequentially with small delays to prevent connection exhaustion
+ for tid in chunk_ids:
+ tid_result, thread_data, error = await fetch_thread_with_retry(tid)
+ results[tid_result] = {"data": thread_data, "error": error}
+ # Brief delay between requests to allow connection cleanup
+ await asyncio.sleep(GMAIL_REQUEST_DELAY)
+
+ # Process results for this chunk
+ for tid in chunk_ids:
+ entry = results.get(tid, {"data": None, "error": "No result"})
+
+ if entry["error"]:
+ logger.warning(f"[batch_modify_gmail_thread_labels] Error fetching thread {tid}: {entry['error']}")
+ continue
+
+ thread = entry["data"]
+ if not thread:
+ logger.warning(f"[batch_modify_gmail_thread_labels] No data returned for thread {tid}")
+ continue
+
+ messages = thread.get("messages", [])
+ if messages:
+ message_ids = [msg["id"] for msg in messages]
+ all_message_ids.extend(message_ids)
+ thread_message_counts[tid] = len(message_ids)
+
+ if not all_message_ids:
+ return "No messages found in the specified threads."
+
+ # Use the batch modify functionality to update all messages
+ body = {"ids": all_message_ids}
+ if add_label_ids:
+ body["addLabelIds"] = add_label_ids
+ if remove_label_ids:
+ body["removeLabelIds"] = remove_label_ids
+
+ await asyncio.to_thread(
+ service.users().messages().batchModify(userId="me", body=body).execute
+ )
+
+ actions = []
+ if add_label_ids:
+ actions.append(f"Added labels: {', '.join(add_label_ids)}")
+ if remove_label_ids:
+ actions.append(f"Removed labels: {', '.join(remove_label_ids)}")
+
+ return f"Labels updated for {len(all_message_ids)} messages across {len(thread_message_counts)} threads: {'; '.join(actions)}"
|