diff --git a/CHANGELOG.md b/CHANGELOG.md index 7668223..681fd35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ For pre-1.0 releases, see [0.0.35 Changelog](https://github.com/noteable-io/orig ### Changed - `rtu_client_type` renamed to `client_creator_type`, now used both in RTU auth subscribe and when creating Files/Projects + - `RTUClient` crashes if trying to instantiate with no file version id, which can happen after a Notebook has been changed from non-RTU mechanism ### [1.1.0] - 2023-09-28 ### Added diff --git a/origami/clients/rtu.py b/origami/clients/rtu.py index 9309aba..5f55932 100644 --- a/origami/clients/rtu.py +++ b/origami/clients/rtu.py @@ -36,6 +36,7 @@ from origami.models.rtu.channels.files import ( FileSubscribeReply, FileSubscribeRequest, + FileSubscribeRequestData, NewDeltaEvent, NewDeltaRequest, NewDeltaRequestData, @@ -255,6 +256,10 @@ def __init__( self.jwt = jwt self.file_id = file_id self.file_version_id = file_version_id + if not self.file_version_id: + raise ValueError( + "File version id cannot be None. This can happen if a Notebook has had its file version wiped due to inconsistent state (changes made through non RTU mechanism)." # noqa: E501 + ) self.builder = builder self.rtu_client_type = rtu_client_type self.user_id = None # set during authenticate_reply handling, used in new_delta_request @@ -458,7 +463,11 @@ async def _on_auth(self, msg: AuthenticateReply): self.manager.authed_ws = asyncio.Future() self.manager.authed_ws.set_result(self.manager.unauth_ws.result()) - await self.send_file_subscribe() + try: + await self.send_file_subscribe() + except Exception: + logger.exception("Error sending file subscribe request") + await self.on_auth(msg) async def send_file_subscribe(self): @@ -477,23 +486,27 @@ async def send_file_subscribe(self): # # Second note, subscribing by delta id all-0's throws an error in Gate. if self.builder.last_applied_delta_id and self.builder.last_applied_delta_id != uuid.UUID(int=0): # type: ignore # noqa: E501 - req = FileSubscribeRequest( - channel=f"files/{self.file_id}", - data={"from_delta_id": self.builder.last_applied_delta_id}, - ) logger.info( "Sending File subscribe request by last applied delta id", - extra={"from_delta_id": str(req.data.from_delta_id)}, + extra={"from_delta_id": str(self.builder.last_applied_delta_id)}, ) - else: + req_data = FileSubscribeRequestData(from_delta_id=self.builder.last_applied_delta_id) req = FileSubscribeRequest( channel=f"files/{self.file_id}", - data={"from_version_id": self.file_version_id}, + data=req_data, ) + + else: logger.info( "Sending File subscribe request by version id", - extra={"from_version_id": str(req.data.from_version_id)}, + extra={"from_version_id": str(self.file_version_id)}, + ) + req_data = FileSubscribeRequestData(from_version_id=self.file_version_id) + req = FileSubscribeRequest( + channel=f"files/{self.file_id}", + data=req_data, ) + self.file_subscribe_timeout_task = asyncio.create_task(self.on_file_subscribe_timeout()) self.manager.send(req) diff --git a/origami/models/rtu/channels/files.py b/origami/models/rtu/channels/files.py index a6c9e73..3a2ed69 100644 --- a/origami/models/rtu/channels/files.py +++ b/origami/models/rtu/channels/files.py @@ -15,7 +15,7 @@ from datetime import datetime from typing import Annotated, Any, List, Literal, Optional, Union -from pydantic import BaseModel, Field, ValidationError, root_validator +from pydantic import BaseModel, Field, root_validator from origami.models.api.outputs import KernelOutput from origami.models.deltas.discriminators import FileDelta @@ -47,11 +47,8 @@ def exactly_one_field(cls, values): num_set_fields = sum(value is not None for value in values.values()) # If exactly one field is set, return the values as they are - if num_set_fields == 1: - return values - - # If not, raise a validation error - raise ValidationError("Exactly one field must be set") + assert num_set_fields == 1, "Exactly one field must be set" + return values class FileSubscribeRequest(FilesRequest): diff --git a/tests/e2e/rtu/test_notebook.py b/tests/e2e/rtu/test_notebook.py index fab2bd8..a955212 100644 --- a/tests/e2e/rtu/test_notebook.py +++ b/tests/e2e/rtu/test_notebook.py @@ -17,7 +17,7 @@ async def test_add_and_remove_cell(api_client: APIClient, notebook_maker): assert rtu_client.builder.nb.cells == [] cell = await rtu_client.add_cell(source='print("hello world")') - assert cell.cell_type == 'code' + assert cell.cell_type == "code" assert cell.id in rtu_client.cell_ids await rtu_client.delete_cell(cell.id) @@ -35,17 +35,17 @@ async def test_change_cell_type(api_client: APIClient, notebook_maker): try: assert rtu_client.builder.nb.cells == [] - source_cell = await rtu_client.add_cell(source='1 + 1') + source_cell = await rtu_client.add_cell(source="1 + 1") _, cell = rtu_client.builder.get_cell(source_cell.id) - assert cell.cell_type == 'code' + assert cell.cell_type == "code" - await rtu_client.change_cell_type(cell.id, 'markdown') + await rtu_client.change_cell_type(cell.id, "markdown") _, cell = rtu_client.builder.get_cell(source_cell.id) - assert cell.cell_type == 'markdown' + assert cell.cell_type == "markdown" - await rtu_client.change_cell_type(cell.id, 'sql') + await rtu_client.change_cell_type(cell.id, "sql") _, cell = rtu_client.builder.get_cell(source_cell.id) - assert cell.cell_type == 'code' + assert cell.cell_type == "code" assert cell.is_sql_cell finally: await rtu_client.shutdown() @@ -59,11 +59,11 @@ async def test_update_cell_content(api_client: APIClient, notebook_maker): try: assert rtu_client.builder.nb.cells == [] - source_cell = await rtu_client.add_cell(source='1 + 1') + source_cell = await rtu_client.add_cell(source="1 + 1") _, cell = rtu_client.builder.get_cell(source_cell.id) - cell = await rtu_client.update_cell_content(cell.id, '@@ -1,5 +1,5 @@\n-1 + 1\n+2 + 2\n') - assert cell.source == '2 + 2' + cell = await rtu_client.update_cell_content(cell.id, "@@ -1,5 +1,5 @@\n-1 + 1\n+2 + 2\n") + assert cell.source == "2 + 2" finally: await rtu_client.shutdown() @@ -76,10 +76,31 @@ async def test_replace_cell_content(api_client: APIClient, notebook_maker): try: assert rtu_client.builder.nb.cells == [] - source_cell = await rtu_client.add_cell(source='1 + 1') + source_cell = await rtu_client.add_cell(source="1 + 1") _, cell = rtu_client.builder.get_cell(source_cell.id) - cell = await rtu_client.replace_cell_content(cell.id, '2 + 2') - assert cell.source == '2 + 2' + cell = await rtu_client.replace_cell_content(cell.id, "2 + 2") + assert cell.source == "2 + 2" + finally: + await rtu_client.shutdown() + + +async def test_notebook_has_no_current_file_version(api_client: APIClient, notebook_maker): + file: File = await notebook_maker() + # TODO: remove sleep when Gate stops permission denied on newly created files (db time-travel) + await asyncio.sleep(2) + rtu_client = await api_client.connect_realtime(file) + # now that we have a working rtu_client, we'll try and make a second one that is missing the + # file_version_id, which should raise an error + try: + with pytest.raises(ValueError) as e: + RTUClient( + rtu_url=rtu_client.manager.ws_url, + jwt=rtu_client.jwt, + file_id=rtu_client.file_id, + file_version_id=None, + builder=rtu_client.builder, + ) + assert e.value.args[0].startswith("File version id cannot be None.") finally: await rtu_client.shutdown()