Skip to content

Commit

Permalink
Port to pydantic 2 (#192)
Browse files Browse the repository at this point in the history
* Progress porting to pydantic 2

* isort

* Tests over BaseRTU

* Test over File.construct_url

* Test over User.construct_auth_type.

* Test over Space.construct_url()

* Test project.construct_url

* Notebook model unit tests.

* isort

* Remove unused import

* Another round of dependency updates

* Test suite passing without warnings

* lint

* Remove bump-pydantic, no longer needed

* Better changelog

* Merge changelog

* Move to new changelog section

* clean up DeltaCallback class (not Pydantic anymore)

---------

Co-authored-by: Kafonek <[email protected]>
  • Loading branch information
James Robinson and Kafonek authored Nov 6, 2023
1 parent 48e0591 commit 583d25c
Show file tree
Hide file tree
Showing 30 changed files with 750 additions and 493 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
For pre-1.0 releases, see [0.0.35 Changelog](https://github.com/noteable-io/origami/blob/0.0.35/CHANGELOG.md)

## [Unreleased]
### Changed
- Upgraded pydantic to 2.4.2 up from 1.X.

### [1.1.5] - 2023-11-06
### Fixed
Expand Down
30 changes: 15 additions & 15 deletions origami/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def user_info(self) -> User:
endpoint = "/users/me"
resp = await self.client.get(endpoint)
resp.raise_for_status()
user = User.parse_obj(resp.json())
user = User.model_validate(resp.json())
self.add_tags_and_contextvars(user_id=str(user.id))
return user

Expand Down Expand Up @@ -191,7 +191,7 @@ async def create_space(self, name: str, description: Optional[str] = None) -> Sp
endpoint = "/spaces"
resp = await self.client.post(endpoint, json={"name": name, "description": description})
resp.raise_for_status()
space = Space.parse_obj(resp.json())
space = Space.model_validate(resp.json())
self.add_tags_and_contextvars(space_id=str(space.id))
return space

Expand All @@ -200,7 +200,7 @@ async def get_space(self, space_id: uuid.UUID) -> Space:
endpoint = f"/spaces/{space_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
space = Space.parse_obj(resp.json())
space = Space.model_validate(resp.json())
return space

async def delete_space(self, space_id: uuid.UUID) -> None:
Expand All @@ -216,7 +216,7 @@ async def list_space_projects(self, space_id: uuid.UUID) -> List[Project]:
endpoint = f"/spaces/{space_id}/projects"
resp = await self.client.get(endpoint)
resp.raise_for_status()
projects = [Project.parse_obj(project) for project in resp.json()]
projects = [Project.model_validate(project) for project in resp.json()]
return projects

async def share_space(
Expand Down Expand Up @@ -267,7 +267,7 @@ async def create_project(
},
)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
self.add_tags_and_contextvars(project_id=str(project.id))
return project

Expand All @@ -276,15 +276,15 @@ async def get_project(self, project_id: uuid.UUID) -> Project:
endpoint = f"/projects/{project_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
return project

async def delete_project(self, project_id: uuid.UUID) -> Project:
self.add_tags_and_contextvars(project_id=str(project_id))
endpoint = f"/projects/{project_id}"
resp = await self.client.delete(endpoint)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
return project

async def share_project(
Expand Down Expand Up @@ -323,7 +323,7 @@ async def list_project_files(self, project_id: uuid.UUID) -> List[File]:
endpoint = f"/projects/{project_id}/files"
resp = await self.client.get(endpoint)
resp.raise_for_status()
files = [File.parse_obj(file) for file in resp.json()]
files = [File.model_validate(file) for file in resp.json()]
return files

# Files are flat files (like text, csv, etc) or Notebooks.
Expand Down Expand Up @@ -355,7 +355,7 @@ async def _multi_step_file_create(
upload_url = js["presigned_upload_url_info"]["parts"][0]["upload_url"]
upload_id = js["presigned_upload_url_info"]["upload_id"]
upload_key = js["presigned_upload_url_info"]["key"]
file = File.parse_obj(js)
file = File.model_validate(js)

# (2) Upload to pre-signed url
# TODO: remove this hack if/when we get containers in Skaffold to be able to translate
Expand Down Expand Up @@ -393,7 +393,7 @@ async def create_notebook(
self.add_tags_and_contextvars(project_id=str(project_id))
if notebook is None:
notebook = Notebook()
content = notebook.json().encode()
content = notebook.model_dump_json().encode()
file = await self._multi_step_file_create(project_id, path, "notebook", content)
self.add_tags_and_contextvars(file_id=str(file.id))
logger.info("Created new notebook", extra={"file_id": str(file.id)})
Expand All @@ -405,7 +405,7 @@ async def get_file(self, file_id: uuid.UUID) -> File:
endpoint = f"/v1/files/{file_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
file = File.parse_obj(resp.json())
file = File.model_validate(resp.json())
return file

async def get_file_content(self, file_id: uuid.UUID) -> bytes:
Expand Down Expand Up @@ -433,15 +433,15 @@ async def get_file_versions(self, file_id: uuid.UUID) -> List[FileVersion]:
endpoint = f"/files/{file_id}/versions"
resp = await self.client.get(endpoint)
resp.raise_for_status()
versions = [FileVersion.parse_obj(version) for version in resp.json()]
versions = [FileVersion.model_validate(version) for version in resp.json()]
return versions

async def delete_file(self, file_id: uuid.UUID) -> File:
self.add_tags_and_contextvars(file_id=str(file_id))
endpoint = f"/v1/files/{file_id}"
resp = await self.client.delete(endpoint)
resp.raise_for_status()
file = File.parse_obj(resp.json())
file = File.model_validate(resp.json())
return file

async def share_file(
Expand Down Expand Up @@ -497,7 +497,7 @@ async def launch_kernel(
}
resp = await self.client.post(endpoint, json=data)
resp.raise_for_status()
kernel_session = KernelSession.parse_obj(resp.json())
kernel_session = KernelSession.model_validate(resp.json())
self.add_tags_and_contextvars(kernel_session_id=str(kernel_session.id))
logger.info(
"Launched new kernel",
Expand All @@ -517,7 +517,7 @@ async def get_output_collection(
endpoint = f"/outputs/collection/{output_collection_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
return KernelOutputCollection.parse_obj(resp.json())
return KernelOutputCollection.model_validate(resp.json())

async def connect_realtime(self, file: Union[File, uuid.UUID, str]) -> "RTUClient": # noqa
"""
Expand Down
28 changes: 16 additions & 12 deletions origami/clients/rtu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import httpx
import orjson
from pydantic import BaseModel, parse_obj_as
from sending.backends.websocket import WebsocketManager
from websockets.client import WebSocketClientProtocol

Expand Down Expand Up @@ -51,7 +50,7 @@
KernelStatusUpdateResponse,
)
from origami.models.rtu.channels.system import AuthenticateReply, AuthenticateRequest
from origami.models.rtu.discriminators import RTURequest, RTUResponse
from origami.models.rtu.discriminators import RTURequest, RTUResponse, RTUResponseParser
from origami.models.rtu.errors import InconsistentStateEvent
from origami.notebook.builder import CellNotFound, NotebookBuilder

Expand Down Expand Up @@ -87,7 +86,8 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
# to error or BaseRTUResponse)
data: dict = orjson.loads(contents)
data["channel_prefix"] = data.get("channel", "").split("/")[0]
rtu_event = parse_obj_as(RTUResponse, data)

rtu_event = RTUResponseParser.validate_python(data)

# Debug Logging
extra_dict = {
Expand All @@ -98,15 +98,18 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
if isinstance(rtu_event, NewDeltaEvent):
extra_dict["delta_type"] = rtu_event.data.delta_type
extra_dict["delta_action"] = rtu_event.data.delta_action
logger.debug(f"Received: {data}\nParsed: {rtu_event.dict()}", extra=extra_dict)

if logging.DEBUG >= logging.root.level:
logger.debug(f"Received: {data}\nParsed: {rtu_event.model_dump()}", extra=extra_dict)

return rtu_event

async def outbound_message_hook(self, contents: RTURequest) -> str:
"""
Hook applied to every message we send out over the websocket.
- Anything calling .send() should pass in an RTU Request pydantic model
"""
return contents.json()
return contents.model_dump_json()

def send(self, message: RTURequest) -> None:
"""Override WebsocketManager-defined method for type hinting and logging."""
Expand All @@ -118,7 +121,9 @@ def send(self, message: RTURequest) -> None:
if message.event == "new_delta_request":
extra_dict["delta_type"] = message.data.delta.delta_type
extra_dict["delta_action"] = message.data.delta.delta_action

logger.debug("Sending: RTU request", extra=extra_dict)

super().send(message) # the .outbound_message_hook handles serializing this to json

async def on_exception(self, exc: Exception):
Expand All @@ -143,11 +148,10 @@ class DeltaRejected(Exception):


# Used in registering callback functions that get called right after squashing a Delta
class DeltaCallback(BaseModel):
# callback function should be async and expect one argument: a FileDelta
# Doesn't matter what it returns. Pydantic doesn't validate Callable args/return.
delta_class: Type[FileDelta]
fn: Callable[[FileDelta], Awaitable[None]]
class DeltaCallback:
def __init__(self, delta_class: Type[FileDelta], fn: Callable[[FileDelta], Awaitable[None]]):
self.delta_class = delta_class
self.fn = fn


class DeltaRequestCallbackManager:
Expand Down Expand Up @@ -455,7 +459,7 @@ async def load_seed_notebook(self):
resp = await plain_http_client.get(file.presigned_download_url)
resp.raise_for_status()

seed_notebook = Notebook.parse_obj(resp.json())
seed_notebook = Notebook.model_validate(resp.json())
self.builder = NotebookBuilder(seed_notebook=seed_notebook)

# See Sending backends.websocket for details but a quick refresher on hook timing:
Expand Down Expand Up @@ -494,7 +498,7 @@ async def auth_hook(self, *args, **kwargs):
# we observe the auth reply. Instead use the unauth_ws directly and manually serialize
ws: WebSocketClientProtocol = await self.manager.unauth_ws
logger.info(f"Sending auth request with jwt {jwt[:5]}...{jwt[-5:]}")
await ws.send(auth_request.json())
await ws.send(auth_request.model_dump_json())

async def on_auth(self, msg: AuthenticateReply):
# hook for Application code to override, consider catastrophic failure on auth failure
Expand Down
2 changes: 1 addition & 1 deletion origami/models/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class ResourceBase(BaseModel):
id: uuid.UUID
created_at: datetime
updated_at: datetime
deleted_at: Optional[datetime]
deleted_at: Optional[datetime] = None
6 changes: 3 additions & 3 deletions origami/models/api/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class DataSource(BaseModel):
type_id: str # e.g. duckdb, postgresql
sql_cell_handle: str # this goes in cell metadata for SQL cells
# One of these three will be not None, and that tells you the scope of the datasource
space_id: Optional[uuid.UUID]
project_id: Optional[uuid.UUID]
user_id: Optional[uuid.UUID]
space_id: Optional[uuid.UUID] = None
project_id: Optional[uuid.UUID] = None
user_id: Optional[uuid.UUID] = None
created_by_id: uuid.UUID
created_at: datetime
updated_at: datetime
Expand Down
10 changes: 6 additions & 4 deletions origami/models/api/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from typing import Literal, Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase

Expand All @@ -22,10 +22,12 @@ class File(ResourceBase):
presigned_download_url: Optional[str] = None
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/f/{values['id']}/{values['path']}"
self.url = f"{noteable_url}/f/{self.id}/{self.path}"

return self


class FileVersion(ResourceBase):
Expand Down
6 changes: 3 additions & 3 deletions origami/models/api/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class KernelOutputContent(BaseModel):

class KernelOutput(ResourceBase):
type: str
display_id: Optional[str]
display_id: Optional[str] = None
available_mimetypes: List[str]
content_metadata: KernelOutputContent
content: Optional[KernelOutputContent]
content_for_llm: Optional[KernelOutputContent]
content: Optional[KernelOutputContent] = None
content_for_llm: Optional[KernelOutputContent] = None
parent_collection_id: uuid.UUID


Expand Down
12 changes: 7 additions & 5 deletions origami/models/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
import uuid
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase


class Project(ResourceBase):
name: str
description: Optional[str]
description: Optional[str] = None
space_id: uuid.UUID
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/p/{values['id']}"
self.url = f"{noteable_url}/p/{self.id}"

return self
12 changes: 7 additions & 5 deletions origami/models/api/spaces.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase


class Space(ResourceBase):
name: str
description: Optional[str]
description: Optional[str] = None
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/s/{values['id']}"
self.url = f"{noteable_url}/s/{self.id}"

return self
20 changes: 11 additions & 9 deletions origami/models/api/users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase

Expand All @@ -10,14 +10,16 @@ class User(ResourceBase):
"""The user fields sent to/from the server"""

handle: str
email: Optional[str] # not returned if looking up user other than yourself
email: Optional[str] = None # not returned if looking up user other than yourself
first_name: str
last_name: str
origamist_default_project_id: Optional[uuid.UUID]
principal_sub: Optional[str] # from /users/me only, represents auth type
auth_type: Optional[str]
origamist_default_project_id: Optional[uuid.UUID] = None
principal_sub: Optional[str] = None # from /users/me only, represents auth type
auth_type: Optional[str] = None

@validator("auth_type", always=True)
def construct_auth_type(cls, v, values):
if values.get("principal_sub"):
return values["principal_sub"].split("|")[0]
@model_validator(mode="after")
def construct_auth_type(self):
if self.principal_sub:
self.auth_type = self.principal_sub.split("|")[0]

return self
6 changes: 3 additions & 3 deletions origami/models/deltas/delta_types/cell_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CellMetadataDelta(FileDeltaBase):
# A lot of state is stored in cell metadata, including DEX and execute time
class CellMetadataUpdateProperties(BaseModel):
path: list
value: Any
value: Any = None
prior_value: Any = NULL_PRIOR_VALUE_SENTINEL


Expand All @@ -26,8 +26,8 @@ class CellMetadataUpdate(CellMetadataDelta):

# Cell metadata replace is used for changing cell type and language (Python/R/etc)
class CellMetadataReplaceProperties(BaseModel):
type: Optional[str]
language: Optional[str]
type: Optional[str] = None
language: Optional[str] = None


class CellMetadataReplace(CellMetadataDelta):
Expand Down
Loading

0 comments on commit 583d25c

Please sign in to comment.