From 3f1c464366c88fb6cf774d5a23730e98bd3d1982 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Wed, 16 Aug 2023 11:40:52 -0400 Subject: [PATCH] add `IntegratedAI*` message models (#155) * add BooleanReplyData * add IntegratedAI* modeling * changelog * more models * cleanup --- CHANGELOG.md | 2 + origami/models/rtu/base.py | 5 ++ origami/models/rtu/channels/files.py | 26 ++--------- origami/models/rtu/channels/kernels.py | 64 +++++++++++++++++++++++++- 4 files changed, 74 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 50e6e24..2f71916 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 For pre-1.0 releases, see [0.0.35 Changelog](https://github.com/noteable-io/origami/blob/0.0.35/CHANGELOG.md) ## [Unreleased] +### Added +- `integrated_ai*` message models for the `kernels` channel ### [1.0.0-alpha.4] - 2023-08-08 ### Added diff --git a/origami/models/rtu/base.py b/origami/models/rtu/base.py index 6e61483..40d608c 100644 --- a/origami/models/rtu/base.py +++ b/origami/models/rtu/base.py @@ -5,6 +5,11 @@ from pydantic import BaseModel, Field, root_validator +class BooleanReplyData(BaseModel): + # Gate will reply to most RTU requests with an RTU reply that's just success=True/False + success: bool + + class BaseRTURequest(BaseModel): transaction_id: uuid.UUID = Field(default_factory=uuid.uuid4) channel: str diff --git a/origami/models/rtu/channels/files.py b/origami/models/rtu/channels/files.py index e24d8d7..904035f 100644 --- a/origami/models/rtu/channels/files.py +++ b/origami/models/rtu/channels/files.py @@ -19,7 +19,7 @@ from origami.models.deltas.discriminators import FileDelta from origami.models.kernels import CellState, KernelStatusUpdate -from origami.models.rtu.base import BaseRTURequest, BaseRTUResponse +from origami.models.rtu.base import BaseRTURequest, BaseRTUResponse, BooleanReplyData class FilesRequest(BaseRTURequest): @@ -78,13 +78,9 @@ class FileUnsubscribeRequest(FilesRequest): event: Literal['unsubscribe_request'] = 'unsubscribe_request' -class FileUnsubscribeReplyData(BaseModel): - success: bool - - class FileUnsubscribeReply(FilesResponse): event: Literal['unsubscribe_reply'] = 'unsubscribe_reply' - data: FileUnsubscribeReplyData + data: BooleanReplyData # Deltas are requests to change a document content or perform cell execution. The API server ensures @@ -101,13 +97,9 @@ class NewDeltaRequest(FilesRequest): data: NewDeltaRequestData -class NewDeltaReplyData(BaseModel): - success: bool - - class NewDeltaReply(FilesResponse): event: Literal['new_delta_reply'] = 'new_delta_reply' - data: NewDeltaReplyData + data: BooleanReplyData class NewDeltaEvent(FilesResponse): @@ -138,13 +130,9 @@ class UpdateUserCellSelectionRequest(FilesRequest): data: UpdateUserCellSelectionRequestData -class UpdateUserCellSelectionReplyData(BaseModel): - success: bool - - class UpdateUserCellSelectionReply(FilesResponse): event: Literal['update_user_cell_selection_reply'] = 'update_user_cell_selection_reply' - data: UpdateUserCellSelectionReplyData + data: BooleanReplyData class UpdateUserFileSubscriptionEventData(BaseModel): @@ -197,13 +185,9 @@ class TransformViewToCodeRequest(FilesRequest): data: TransformViewToCodeRequestData -class TransformViewToCodeReplyData(BaseModel): - success: bool - - class TransformViewToCodeReply(FilesResponse): event: Literal['transform_view_to_code_reply'] = 'transform_view_to_code_reply' - data: TransformViewToCodeReplyData + data: BooleanReplyData # When the API squashes Deltas, it will emit a new file versions changed event diff --git a/origami/models/rtu/channels/kernels.py b/origami/models/rtu/channels/kernels.py index 177961b..bafb6f3 100644 --- a/origami/models/rtu/channels/kernels.py +++ b/origami/models/rtu/channels/kernels.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field from origami.models.kernels import CellState, KernelStatusUpdate -from origami.models.rtu.base import BaseRTURequest, BaseRTUResponse +from origami.models.rtu.base import BaseRTURequest, BaseRTUResponse, BooleanReplyData class KernelsRequest(BaseRTURequest): @@ -66,8 +66,64 @@ class VariableExplorerResponse(KernelsResponse): event: Literal['variable_explorer_event'] = 'variable_explorer_event' +class IntegratedAIRequestData(BaseModel): + prompt: str + # this may not be called on a specific cell, but at a specific point in time at a generic + # "document" level, so we don't require a cell_id + cell_id: Optional[str] + # if a cell_id is provided and this is True, the result will be added to the cell's output + # instead of just sent back as an RTU reply + output_for_response: bool = False + + +class IntegratedAIRequest(KernelsRequest): + event: Literal['integrated_ai_request'] = 'integrated_ai_request' + data: IntegratedAIRequestData + + +class IntegratedAIReply(KernelsResponse): + event: Literal['integrated_ai_reply'] = 'integrated_ai_reply' + data: BooleanReplyData + + +class IntegratedAIEvent(KernelsResponse): + event: Literal['integrated_ai_event'] = 'integrated_ai_event' + # same data as the IntegratedAIRequest, just echoed back out + data: IntegratedAIRequestData + + +class IntegratedAIResultData(BaseModel): + # the full response from OpenAI; in most cases, sidecar will have either created a new cell + # or an output, so this result should really only be used when the RTU client needs it to exist + # outside of the cell/output structure + result: str + + +# this is sidecar to gate as a result of calling the OpenAIHandler method (OpenAI response, +# error, etc); after that, Gate propogates the data out as an IntegratedAIEvent +class IntegratedAIResult(KernelsRequest): + event: Literal['integrated_ai_result'] = 'integrated_ai_result' + data: IntegratedAIResultData + + +class IntegratedAIResultReply(KernelsResponse): + event: Literal['integrated_ai_result_reply'] = 'integrated_ai_result_reply' + data: BooleanReplyData + + +class IntegratedAIResultEvent(KernelsResponse): + event: Literal['integrated_ai_result_event'] = 'integrated_ai_result_event' + data: IntegratedAIResultData + + KernelRequests = Annotated[ - Union[KernelSubscribeRequest, VariableExplorerUpdateRequest], Field(discriminator="event") + Union[ + KernelSubscribeRequest, + VariableExplorerUpdateRequest, + IntegratedAIRequest, + IntegratedAIResult, + ], + Field(discriminator="event"), ] KernelResponses = Annotated[ @@ -76,6 +132,10 @@ class VariableExplorerResponse(KernelsResponse): KernelStatusUpdateResponse, BulkCellStateUpdateResponse, VariableExplorerResponse, + IntegratedAIReply, + IntegratedAIResultReply, + IntegratedAIEvent, + IntegratedAIResultEvent, ], Field(discriminator="event"), ]