Skip to content

Commit b78c6d9

Browse files
authoredJan 20, 2025··
Polish errors and miscelaneous fixes for workspaces (#649)
* Polish errors in workspace activation This raises specific errors which are then handled excplicitly. Thus giving us more detailed error messages. Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> * Fix unit tests Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> * Don't delete dashboard addition Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> * Disregard output from adding workspace Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> * Fix workspace creation Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com> --------- Signed-off-by: Juan Antonio Osorio <ozz@stacklok.com>
1 parent 9f20ec0 commit b78c6d9

File tree

5 files changed

+55
-41
lines changed

5 files changed

+55
-41
lines changed
 

‎src/codegate/api/v1.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
from fastapi import APIRouter, Response
2-
from fastapi.exceptions import HTTPException
1+
from fastapi import APIRouter, HTTPException, Response
32
from fastapi.routing import APIRoute
43
from pydantic import ValidationError
54

65
from codegate.api import v1_models
7-
from codegate.db.connection import AlreadyExistsError
8-
from codegate.workspaces.crud import WorkspaceCrud
96
from codegate.api.dashboard.dashboard import dashboard_router
7+
from codegate.db.connection import AlreadyExistsError
8+
from codegate.workspaces import crud
109

1110
v1 = APIRouter()
1211
v1.include_router(dashboard_router)
13-
14-
wscrud = WorkspaceCrud()
12+
wscrud = crud.WorkspaceCrud()
1513

1614

1715
def uniq_name(route: APIRoute):
@@ -44,21 +42,24 @@ async def list_active_workspaces() -> v1_models.ListActiveWorkspacesResponse:
4442
@v1.post("/workspaces/active", tags=["Workspaces"], generate_unique_id_function=uniq_name)
4543
async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status_code=204):
4644
"""Activate a workspace by name."""
47-
activated = await wscrud.activate_workspace(request.name)
48-
49-
# TODO: Refactor
50-
if not activated:
45+
try:
46+
await wscrud.activate_workspace(request.name)
47+
except crud.WorkspaceAlreadyActiveError:
5148
return HTTPException(status_code=409, detail="Workspace already active")
49+
except crud.WorkspaceDoesNotExistError:
50+
return HTTPException(status_code=404, detail="Workspace does not exist")
51+
except Exception:
52+
return HTTPException(status_code=500, detail="Internal server error")
5253

5354
return Response(status_code=204)
5455

5556

5657
@v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201)
57-
async def create_workspace(request: v1_models.CreateWorkspaceRequest):
58+
async def create_workspace(request: v1_models.CreateWorkspaceRequest) -> v1_models.Workspace:
5859
"""Create a new workspace."""
5960
# Input validation is done in the model
6061
try:
61-
created = await wscrud.add_workspace(request.name)
62+
_ = await wscrud.add_workspace(request.name)
6263
except AlreadyExistsError:
6364
raise HTTPException(status_code=409, detail="Workspace already exists")
6465
except ValidationError:
@@ -68,8 +69,7 @@ async def create_workspace(request: v1_models.CreateWorkspaceRequest):
6869
except Exception:
6970
raise HTTPException(status_code=500, detail="Internal server error")
7071

71-
if created:
72-
return v1_models.Workspace(name=created.name)
72+
return v1_models.Workspace(name=request.name, is_active=False)
7373

7474

7575
@v1.delete(

‎src/codegate/db/connection.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ async def update_session(self, session: Session) -> Optional[Session]:
284284
"""
285285
)
286286
# We only pass an object to respect the signature of the function
287-
active_session = await self._execute_update_pydantic_model(session, sql)
287+
active_session = await self._execute_update_pydantic_model(session, sql, should_raise=True)
288288
return active_session
289289

290290

@@ -317,14 +317,18 @@ async def _execute_select_pydantic_model(
317317
return None
318318

319319
async def _exec_select_conditions_to_pydantic(
320-
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict
320+
self, model_type: Type[BaseModel], sql_command: TextClause, conditions: dict,
321+
should_raise: bool = False
321322
) -> Optional[List[BaseModel]]:
322323
async with self._async_db_engine.begin() as conn:
323324
try:
324325
result = await conn.execute(sql_command, conditions)
325326
return await self._dump_result_to_pydantic_model(model_type, result)
326327
except Exception as e:
327328
logger.error(f"Failed to select model with conditions: {model_type}.", error=str(e))
329+
# Exposes errors to the caller
330+
if should_raise:
331+
raise e
328332
return None
329333

330334
async def get_prompts_with_output(self) -> List[GetPromptWithOutputsRow]:
@@ -392,7 +396,8 @@ async def get_workspace_by_name(self, name: str) -> List[Workspace]:
392396
"""
393397
)
394398
conditions = {"name": name}
395-
workspaces = await self._exec_select_conditions_to_pydantic(Workspace, sql, conditions)
399+
workspaces = await self._exec_select_conditions_to_pydantic(
400+
Workspace, sql, conditions, should_raise=True)
396401
return workspaces[0] if workspaces else None
397402

398403
async def get_sessions(self) -> List[Session]:
@@ -453,7 +458,11 @@ def init_session_if_not_exists(db_path: Optional[str] = None):
453458
last_update=datetime.datetime.now(datetime.timezone.utc),
454459
)
455460
db_recorder = DbRecorder(db_path)
456-
asyncio.run(db_recorder.update_session(session))
461+
try:
462+
asyncio.run(db_recorder.update_session(session))
463+
except Exception as e:
464+
logger.error(f"Failed to initialize session in DB: {e}")
465+
return
457466
logger.info("Session in DB initialized successfully.")
458467

459468

‎src/codegate/pipeline/cli/commands.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from codegate import __version__
77
from codegate.db.connection import AlreadyExistsError
8-
from codegate.workspaces.crud import WorkspaceCrud
8+
from codegate.workspaces import crud
99

1010

1111
class CodegateCommand(ABC):
@@ -41,7 +41,7 @@ def help(self) -> str:
4141
class Workspace(CodegateCommand):
4242

4343
def __init__(self):
44-
self.workspace_crud = WorkspaceCrud()
44+
self.workspace_crud = crud.WorkspaceCrud()
4545
self.commands = {
4646
"list": self._list_workspaces,
4747
"add": self._add_workspace,
@@ -94,12 +94,14 @@ async def _activate_workspace(self, args: List[str]) -> str:
9494
if not workspace_name:
9595
return "Please provide a name. Use `codegate workspace activate workspace_name`"
9696

97-
was_activated = await self.workspace_crud.activate_workspace(workspace_name)
98-
if not was_activated:
99-
return (
100-
f"Workspace **{workspace_name}** does not exist or was already active. "
101-
f"Use `codegate workspace add {workspace_name}` to add it"
102-
)
97+
try:
98+
await self.workspace_crud.activate_workspace(workspace_name)
99+
except crud.WorkspaceAlreadyActiveError:
100+
return f"Workspace **{workspace_name}** is already active"
101+
except crud.WorkspaceDoesNotExistError:
102+
return f"Workspace **{workspace_name}** does not exist"
103+
except Exception:
104+
return "An error occurred while activating the workspace"
103105
return f"Workspace **{workspace_name}** has been activated"
104106

105107
async def run(self, args: List[str]) -> str:

‎src/codegate/workspaces/crud.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
class WorkspaceCrudError(Exception):
99
pass
1010

11+
class WorkspaceDoesNotExistError(WorkspaceCrudError):
12+
pass
13+
14+
class WorkspaceAlreadyActiveError(WorkspaceCrudError):
15+
pass
16+
1117
class WorkspaceCrud:
1218

1319
def __init__(self):
@@ -36,44 +42,41 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
3642
"""
3743
return await self._db_reader.get_active_workspace()
3844

39-
async def _is_workspace_active_or_not_exist(
45+
async def _is_workspace_active(
4046
self, workspace_name: str
4147
) -> Tuple[bool, Optional[Session], Optional[Workspace]]:
4248
"""
43-
Check if the workspace is active
44-
45-
Will return:
46-
- True if the workspace was activated
47-
- False if the workspace is already active or does not exist
49+
Check if the workspace is active alongside the session and workspace objects
4850
"""
51+
# TODO: All of this should be done within a transaction.
52+
4953
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
5054
if not selected_workspace:
51-
return True, None, None
55+
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
5256

5357
sessions = await self._db_reader.get_sessions()
5458
# The current implementation expects only one active session
5559
if len(sessions) != 1:
5660
raise RuntimeError("Something went wrong. No active session found.")
5761

5862
session = sessions[0]
59-
if session.active_workspace_id == selected_workspace.id:
60-
return True, None, None
61-
return False, session, selected_workspace
63+
return (session.active_workspace_id == selected_workspace.id,
64+
session, selected_workspace)
6265

63-
async def activate_workspace(self, workspace_name: str) -> bool:
66+
async def activate_workspace(self, workspace_name: str):
6467
"""
6568
Activate a workspace
6669
6770
Will return:
6871
- True if the workspace was activated
6972
- False if the workspace is already active or does not exist
7073
"""
71-
is_active, session, workspace = await self._is_workspace_active_or_not_exist(workspace_name)
74+
is_active, session, workspace = await self._is_workspace_active(workspace_name)
7275
if is_active:
73-
return False
76+
raise WorkspaceAlreadyActiveError(f"Workspace {workspace_name} is already active.")
7477

7578
session.active_workspace_id = workspace.id
7679
session.last_update = datetime.datetime.now(datetime.timezone.utc)
7780
db_recorder = DbRecorder()
7881
await db_recorder.update_session(session)
79-
return True
82+
return

‎tests/pipeline/workspace/test_workspace.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ async def test_add_workspaces(args, existing_workspaces, expected_message):
7878

7979
# We'll also patch DbRecorder to ensure no real DB operations happen
8080
with patch(
81-
"codegate.pipeline.cli.commands.WorkspaceCrud", autospec=True
81+
"codegate.workspaces.crud.WorkspaceCrud", autospec=True
8282
) as mock_recorder_cls:
8383
mock_recorder = mock_recorder_cls.return_value
8484
workspace_commands.workspace_crud = mock_recorder

0 commit comments

Comments
 (0)
Please sign in to comment.