Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ keywords = ["ai", "machine-learning", "openai", "library"]
[dependencies]
serde_json = "1.0.94"
derive_builder = "0.20.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"], optional = true }
serde = { version = "1.0.157", features = ["derive"] }
reqwest = { version = "0.12", default-features = false, features = [
"json",
"stream",
"multipart",
], optional = true }
serde = { version = "^1.0", features = ["derive"] }
reqwest-eventsource = "0.6"
tokio = { version = "1.26.0", features = ["full"] }
anyhow = "1.0.70"
tokio = { version = "1.0", features = ["full"] }
anyhow = "1.0"
futures-util = "0.3.28"
bytes = "1.4.0"
schemars = "0.8"
either = { version = "1.8.1", features = ["serde"] }
serde-double-tag = "0.0.4"
log = "0.4"
strum = { version = "0.26", features = ["derive"] }
strum_macros = "0.26"
once_cell = "^1"

[dev-dependencies]
dotenvy = "0.15.7"
Expand Down
144 changes: 144 additions & 0 deletions src/assistants/assistants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
use std::collections::HashMap;

use schemars::schema::RootSchema;
use serde::{Deserialize, Serialize};

use crate::{
client::{Empty, OpenAiClient},
ApiResponseOrError,
};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Assistant {
pub id: String,
pub object: String,
pub created_at: u32,
/// The name of the assistant. The maximum length is 256 characters.
pub name: Option<String>,
/// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
pub model: String,
/// The system instructions that the assistant uses. The maximum length is 256,000 characters.
pub instructions: Option<String>,
pub tools: Vec<Tool>,
/// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs.
pub tool_resources: Option<ToolResources>,
/// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long.
pub metadata: Option<HashMap<String, String>>,
/// The default model to use for this assistant.
pub response_format: Option<ResponseFormat>,
}

#[derive(Debug, Clone, serde_double_tag::Deserialize, serde_double_tag::Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum Tool {
CodeInterpreter,
Function(Function),
FileSearch(FileSearch),
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: RootSchema,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionParameters {
pub title: String,
pub description: String,
#[serde(rename = "type")]
pub type_: String,
pub required: Vec<String>,
pub properties: HashMap<String, FunctionProperty>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionProperty {
pub description: String,
#[serde(rename = "type")]
pub type_: String,
}

#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct FileSearch {
pub max_num_results: Option<usize>,
}

#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct ToolResources {
pub code_interpreter: Option<CodeInterpreterResources>,
pub file_search: Option<FileSearchResources>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CodeInterpreterResources {
/// A list of file IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.
pub file_ids: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FileSearchResources {
/// The ID of the vector store attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.
pub vector_store_ids: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum ResponseFormat {
Auto,
}

#[derive(Serialize, Default, Debug, Clone)]
pub struct CreateAssistantRequest {
/// ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them.
pub model: String,

/// The name of the assistant. The maximum length is 256 characters.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
/// The description of the assistant. The maximum length is 256 characters.
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
/// The system instructions that the assistant uses. The maximum length is 256,000 characters.
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
/// A set of tools that the assistant can use.
pub tools: Vec<Tool>,
/// A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the code_interpreter tool requires a list of file IDs, while the file_search tool requires a list of vector store IDs.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_resources: Option<ToolResources>,
/// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long.
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
/// The default model to use for this assistant.
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
}

impl OpenAiClient {
pub async fn create_assistant(
&self,
request: CreateAssistantRequest,
) -> ApiResponseOrError<Assistant> {
self.post("assistants", Some(request)).await
}

pub async fn get_assistant(&self, assistant_id: &str) -> ApiResponseOrError<Assistant> {
self.get(format!("assistants/{}", assistant_id)).await
}

pub async fn delete_assistant(&self, assistant_id: &str) -> ApiResponseOrError<Empty> {
self.delete(format!("assistants/{}", assistant_id)).await
}

pub async fn update_assistant(
&self,
assistant_id: &str,
request: CreateAssistantRequest,
) -> ApiResponseOrError<Assistant> {
self.post(format!("assistants/{}", assistant_id), Some(request))
.await
}
}
49 changes: 49 additions & 0 deletions src/assistants/files.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use crate::{client::OpenAiClient, ApiResponseOrError};
use reqwest::{
multipart::{Form, Part},
Body,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct File {
pub id: String,
pub object: String,
pub created_at: u32,
pub bytes: u32,
pub filename: String,
pub purpose: FilePurpose,
}

#[derive(Debug, Serialize, Deserialize, Clone, strum_macros::Display)]
#[strum(serialize_all = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum FilePurpose {
Assistants,
AssistantsOutput,
Batch,
BatchOutput,
FineTune,
FineTuneResults,
Vision,
}

impl OpenAiClient {
pub async fn upload_file<B: Into<Body>>(
&self,
filename: &str,
mime_type: &str,
bytes: B,
purpose: FilePurpose,
) -> ApiResponseOrError<File> {
let file_part = Part::stream(bytes)
.file_name(filename.to_string())
.mime_str(mime_type)?;

let form = Form::new()
.part("file", file_part)
.text("purpose", purpose.to_string());

self.post_multipart("files", form).await
}
}
116 changes: 116 additions & 0 deletions src/assistants/messages.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::{assistants::Tool, client::OpenAiClient, ApiResponseOrError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Message {
pub id: String,
pub object: String,
pub created_at: u32,
/// The thread ID that this message belongs to.
pub thread_id: String,
/// The status of the message, which can be either in_progress, incomplete, or completed.
pub status: Option<String>,
/// On an incomplete message, details about why the message is incomplete.
pub incomplete_details: Option<IncompleteDetails>,
/// The Unix timestamp (in seconds) for when the message was completed.
pub completed_at: Option<u32>,
/// The Unix timestamp (in seconds) for when the message was marked as incomplete.
pub incomplete_at: Option<u32>,
/// The entity that produced the message. One of user or assistant
pub role: Role,
/// The content of the message.
pub content: Vec<Content>,
/// The assistant that produced the message.
pub assistant_id: Option<String>,
/// The ID of the run associated with the creation of this message. Value is null when messages are created manually using the create message or create thread endpoints.
pub run_id: Option<String>,
/// A list of files attached to the message.
pub attachments: Option<Vec<Attachment>>,
/// A set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long.
pub metadata: Option<HashMap<String, String>>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum Status {
InProgress,
Incomplete,
Completed,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct IncompleteDetails {
pub reason: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum Role {
User,
Assistant,
}

#[derive(Debug, serde_double_tag::Serialize, serde_double_tag::Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum Content {
Text(Text),
ImageFile(ImageFile),
ImageUrl(ImageUrl),
Refusal(Refusal),
}

#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct Text {
pub value: String,
pub annotations: Vec<Annotation>,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Annotation {
#[serde(rename = "type")]
pub kind: String,
pub text: String,
pub start_index: u32,
pub end_index: u32,
pub file_citation: FileCitation,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FileCitation {
pub file_id: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ImageFile {
pub file_id: String,
pub detail: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ImageUrl {
pub image_url: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Refusal {
pub refusal: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Attachment {
pub file_id: String,
pub tools: Tool,
}

impl OpenAiClient {
pub async fn list_messages(
&self,
thread_id: &str,
after_id: Option<String>,
) -> ApiResponseOrError<Vec<Message>> {
self.list(format!("threads/{thread_id}/messages"), after_id)
.await
}
}
8 changes: 8 additions & 0 deletions src/assistants/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pub mod assistants;
pub use assistants::*;

pub mod files;
pub mod messages;
pub mod runs;
pub mod threads;
pub mod vector_stores;
Loading
Loading