Skip to content

Commit 89be338

Browse files
committed
set context var on tool call and use to set org id
1 parent 6403597 commit 89be338

File tree

5 files changed

+104
-23
lines changed

5 files changed

+104
-23
lines changed

src/api/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from src.config.config import get_session_request, get_settings
99
from src.logger import get_logger
1010
from src.utils.async_to_sync import async_to_sync
11+
from src.api.context import get_session_settings
1112

1213
# Set up logger for this module
1314
logger = get_logger()
@@ -340,13 +341,17 @@ def get_org_id() -> str | None:
340341
str or None: The organization ID, or None if using API key authentication
341342
"""
342343
settings = get_settings()
344+
session_settings = get_session_settings()
343345

344346
# If using API key authentication, no org_id is needed
345347
if not settings.is_remote and settings.api_key:
346348
logger.debug("Using API key authentication, no organization ID needed")
347349
return None
348350

349-
org_id = settings.org_id
351+
if settings.is_remote and session_settings is not None:
352+
org_id = session_settings.get("org_id", None)
353+
else:
354+
org_id = settings.org_id
350355

351356
if not org_id:
352357
logger.debug(

src/api/context.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Context variable for the API."""
2+
3+
from contextvars import ContextVar
4+
from mcp.server.fastmcp import Context
5+
from functools import wraps
6+
import inspect
7+
from typing import Callable, Dict, Any
8+
9+
# Context variable to hold the current context object of a given tool call
10+
session_context = ContextVar("session_context")
11+
12+
13+
def get_session_context() -> Context | None:
14+
return session_context.get(None)
15+
16+
17+
# Using session's context to store session-specific settings like org_id
18+
def get_session_settings() -> Dict[str, Any] | None:
19+
current_context = get_session_context()
20+
if current_context is not None:
21+
return current_context.request_context.lifespan_context
22+
23+
raise Exception(
24+
"No session context available for this tool call, please try again."
25+
)
26+
27+
28+
def tool_wrapper(func: Callable) -> Callable:
29+
"""
30+
This function wraps tool functions to set the context variable.
31+
"""
32+
33+
# Check if the underlying function accepts a 'ctx' parameter, or generic **kwargs
34+
func_signature = inspect.signature(func)
35+
accepts_ctx = "ctx" in func_signature.parameters or any(
36+
p.kind == inspect.Parameter.VAR_KEYWORD
37+
for p in func_signature.parameters.values()
38+
)
39+
40+
@wraps(func)
41+
async def wrapper(ctx, *args, **kwargs):
42+
# Ensures that ctx is passed to the original function if it accepts it
43+
kwargs_cpy = dict(kwargs)
44+
if accepts_ctx:
45+
kwargs_cpy["ctx"] = ctx
46+
47+
current_context = get_session_context()
48+
if ctx is not None and current_context is None:
49+
session_context.set(ctx.request_context)
50+
51+
if inspect.iscoroutinefunction(func):
52+
return await func(*args, **kwargs_cpy)
53+
else:
54+
return func(*args, **kwargs_cpy)
55+
56+
# Ensures the mcp library knows that the wrapper expects a 'ctx' parameter
57+
if not accepts_ctx:
58+
wrapper.__annotations__["ctx"] = Context
59+
60+
return wrapper

src/api/tools/organization/organization.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import src.api.tools.organization.utils as utils
1010
from src.config import config
1111
from src.api.common import query_graphql_organizations
12+
from src.api.context import get_session_settings
1213
from src.utils.elicitation import try_elicitation, ElicitationError
1314
from src.logger import get_logger
1415

@@ -72,6 +73,8 @@ async def choose_organization(ctx: Context) -> dict:
7273

7374
settings = config.get_settings()
7475
user_id = config.get_user_id()
76+
session_settings = get_session_settings()
77+
7578
# Track tool call event
7679
settings.analytics_manager.track_event(
7780
user_id, "tool_calling", {"name": "choose_organization"}
@@ -154,7 +157,10 @@ class OrganizationChoice(BaseModel):
154157

155158
# Set the selected organization in settings
156159
if selected_org:
157-
settings.org_id = selected_org["orgID"]
160+
if settings.is_remote and session_settings:
161+
session_settings["org_id"] = selected_org["orgID"]
162+
else:
163+
settings.org_id = selected_org["orgID"]
158164

159165
return {
160166
"status": "success",
@@ -211,6 +217,7 @@ async def set_organization(ctx: Context, organization_id: str) -> dict:
211217
3. Call set_organization with the chosen ID
212218
"""
213219
settings = config.get_settings()
220+
session_settings = get_session_settings()
214221
user_id = config.get_user_id()
215222
# Track tool call event
216223
settings.analytics_manager.track_event(
@@ -243,10 +250,13 @@ async def set_organization(ctx: Context, organization_id: str) -> dict:
243250
}
244251

245252
# Set the selected organization in settings
246-
if hasattr(settings, "org_id"):
247-
settings.org_id = selected_org["orgID"]
253+
if settings.is_remote and session_settings:
254+
session_settings["org_id"] = selected_org["orgID"]
248255
else:
249-
setattr(settings, "org_id", selected_org["orgID"])
256+
if hasattr(settings, "org_id"):
257+
settings.org_id = selected_org["orgID"]
258+
else:
259+
setattr(settings, "org_id", selected_org["orgID"])
250260

251261
await ctx.info(
252262
f"Organization set to: {selected_org['name']} (ID: {selected_org['orgID']})"

src/api/tools/tools.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,24 @@
2929

3030
# Define the tools with their metadata
3131
tools_definition = [
32-
{"func": get_user_info},
33-
{"func": organization_info},
34-
{"func": choose_organization},
35-
{"func": set_organization},
36-
{"func": workspace_groups_info},
37-
{"func": workspaces_info},
38-
{"func": resume_workspace},
39-
{"func": list_starter_workspaces},
40-
{"func": create_starter_workspace},
41-
{"func": terminate_starter_workspace},
42-
{"func": list_regions},
43-
{"func": list_sharedtier_regions},
44-
{"func": run_sql},
45-
{"func": create_notebook_file},
46-
{"func": upload_notebook_file},
47-
{"func": create_job_from_notebook},
48-
{"func": get_job},
49-
{"func": delete_job},
32+
{"tool": get_user_info},
33+
{"tool": organization_info},
34+
{"tool": choose_organization},
35+
{"tool": set_organization},
36+
{"tool": workspace_groups_info},
37+
{"tool": workspaces_info},
38+
{"tool": resume_workspace},
39+
{"tool": list_starter_workspaces},
40+
{"tool": create_starter_workspace},
41+
{"tool": terminate_starter_workspace},
42+
{"tool": list_regions},
43+
{"tool": list_sharedtier_regions},
44+
{"tool": run_sql},
45+
{"tool": create_notebook_file},
46+
{"tool": upload_notebook_file},
47+
{"tool": create_job_from_notebook},
48+
{"tool": get_job},
49+
{"tool": delete_job},
5050
]
5151

5252
# Export the tools

src/api/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from enum import Flag, auto
33
from typing import Set
44

5+
from src.api.context import tool_wrapper
6+
57

68
# Define all possible flags here - ADD NEW FLAGS TO THIS LIST ONLY!
79
AVAILABLE_FLAGS = ["deprecated", "internal"]
@@ -83,6 +85,10 @@ def create_from_dict(cls, concept_def: dict):
8385
flag_enum = getattr(MCPConceptFlags, flag_name.upper())
8486
flags |= flag_enum
8587

88+
if "tool" in concept_attrs:
89+
concept_attrs["func"] = tool_wrapper(concept_attrs["tool"])
90+
del concept_attrs["tool"]
91+
8692
# Set title if not explicitly provided and we have a function
8793
if "title" not in concept_attrs and "func" in concept_attrs:
8894
concept_attrs["title"] = getattr(concept_attrs["func"], "__name__", "")

0 commit comments

Comments
 (0)