Skip to content

Commit

Permalink
add creator_client type for file/project creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kafonek committed Oct 3, 2023
1 parent 27f7a8e commit f544617
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
24 changes: 17 additions & 7 deletions origami/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
headers: Optional[dict] = None,
transport: Optional[httpx.AsyncHTTPTransport] = None,
timeout: httpx.Timeout = httpx.Timeout(5.0),
rtu_client_type: str = "origami",
creator_client_type: str = "origami",
):
# jwt and api_base_url saved as attributes because they're re-used when creating rtu client
self.jwt = authorization_token or os.environ.get("NOTEABLE_TOKEN")
Expand All @@ -47,10 +47,13 @@ def __init__(
transport=transport,
timeout=timeout,
)
# Hack until Gate changes out rtu_client_type from enum to str
if rtu_client_type not in ["origami", "origamist", "planar_ally", "geas"]:
rtu_client_type = "unknown"
self.rtu_client_type = rtu_client_type # Only used when generating an RTUClient
# creator_client_type helps log what kind of client created Resources like Files/Projects
# or is interacting with Notebooks through RTU / Deltas. If you're not sure what to use
# yourself, go with the default 'origami'
if creator_client_type not in ["origami", "origamist", "planar_ally", "geas"]:
# this list of valid creator client types is sourced from Gate's FrontendType enum
creator_client_type = "unknown"
self.creator_client_type = creator_client_type # Only used when generating an RTUClient

def add_tags_and_contextvars(self, **tags):
"""Hook for Apps to override so they can set structlog contextvars or ddtrace tags etc"""
Expand Down Expand Up @@ -108,7 +111,13 @@ async def create_project(
self.add_tags_and_contextvars(space_id=str(space_id))
endpoint = "/projects"
resp = await self.client.post(
endpoint, json={"space_id": str(space_id), "name": name, "description": description}
endpoint,
json={
"space_id": str(space_id),
"name": name,
"description": description,
"creator_client_type": self.creator_client_type,
},
)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
Expand Down Expand Up @@ -159,6 +168,7 @@ async def _multi_step_file_create(
"path": path,
"type": file_type,
"file_size_bytes": len(content),
"creator_client_type": self.creator_client_type,
}
resp = await self.client.post("/v1/files", json=body)
resp.raise_for_status()
Expand Down Expand Up @@ -346,7 +356,7 @@ async def connect_realtime(self, file: Union[File, uuid.UUID, str]) -> RTUClient
file_id=file.id,
file_version_id=file.current_version_id,
builder=nb_builder,
rtu_client_type=self.rtu_client_type,
rtu_client_type=self.creator_client_type,
)
await rtu_client.initialize()
await rtu_client.deltas_to_apply_event.wait()
Expand Down
34 changes: 32 additions & 2 deletions tests/e2e/api/test_files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
import uuid

import httpx

from origami.clients.api import APIClient
from origami.models.api.files import File


async def test_file_crud(api_client: APIClient, test_project_id):
name = uuid.uuid4().hex + ".txt"
content = b"foo\nbar\nbaz"
new_file = await api_client.create_file(project_id=test_project_id, path=name, content=content)
assert new_file.filename == name
assert new_file.project_id == test_project_id

# on create, Gate does not give back a presigned url, need to request it.
f: File = await api_client.get_file(new_file.id)
assert f.presigned_download_url is not None

async with httpx.AsyncClient() as plain_client:
resp = await plain_client.get(f.presigned_download_url)
assert resp.status_code == 200
assert resp.content == content

# Delete file
deleted_file = await api_client.delete_file(new_file.id)
assert deleted_file.id == new_file.id
assert deleted_file.deleted_at is not None


async def test_get_file_version(api_client: APIClient, notebook_maker):
f: File = await notebook_maker()
versions = await api_client.get_file_versions(f.id)
Expand All @@ -13,11 +39,15 @@ async def test_get_file_version(api_client: APIClient, notebook_maker):
assert versions[0].content_presigned_url is not None

# Trigger a version save -- something needs to change (i.e. make a delta) or save as named
endpoint = f'/v1/files/{f.id}/versions'
resp = await api_client.client.post(endpoint, json={'name': 'foo'})
endpoint = f"/v1/files/{f.id}/versions"
resp = await api_client.client.post(endpoint, json={"name": "foo"})
assert resp.status_code == 201

new_versions = await api_client.get_file_versions(f.id)
assert new_versions[0].number == 1

assert len(new_versions) == 2

assert len(new_versions) == 2

assert len(new_versions) == 2

0 comments on commit f544617

Please sign in to comment.