From 6709010a2dc460c1d67aa7dc3485ce8c912f599f Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 29 Aug 2024 16:22:32 +0200 Subject: [PATCH 01/71] New inference crate builds --- Cargo.toml | 1 + sdk/openai/inference/Cargo.toml | 22 +++++++++++++++++++ .../examples/azure_chat_completions.rs | 0 .../examples/non_azure_chat_completions.rs | 0 sdk/openai/inference/src/clients/azure.rs | 0 sdk/openai/inference/src/clients/mod.rs | 2 ++ sdk/openai/inference/src/clients/non_azure.rs | 0 sdk/openai/inference/src/lib.rs | 3 +++ 8 files changed, 28 insertions(+) create mode 100644 sdk/openai/inference/Cargo.toml create mode 100644 sdk/openai/inference/examples/azure_chat_completions.rs create mode 100644 sdk/openai/inference/examples/non_azure_chat_completions.rs create mode 100644 sdk/openai/inference/src/clients/azure.rs create mode 100644 sdk/openai/inference/src/clients/mod.rs create mode 100644 sdk/openai/inference/src/clients/non_azure.rs create mode 100644 sdk/openai/inference/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 694c89ae73..69fb2f6d51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "eng/test/mock_transport", "sdk/storage", "sdk/storage/azure_storage_blob", + "sdk/openai/inference", ] [workspace.package] diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml new file mode 100644 index 0000000000..e19e05521d --- /dev/null +++ b/sdk/openai/inference/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "azure_openai_inference" +version = "0.1.0" +description = "Rust wrappers around Microsoft Azure REST APIs - Azure OpenAI Inference" +authors.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true +readme.workspace = true +keywords = ["sdk", "azure", "rest"] +categories = ["api-bindings"] + +[lints] +workspace = true + +[dependencies] +reqwest = { workspace = true, optional = true } + +[features] +default = [ "reqwest" ] +reqwest = [ "dep:reqwest" ] diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/openai/inference/examples/non_azure_chat_completions.rs b/sdk/openai/inference/examples/non_azure_chat_completions.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs new file mode 100644 index 0000000000..1b3cf68770 --- /dev/null +++ b/sdk/openai/inference/src/clients/mod.rs @@ -0,0 +1,2 @@ +pub mod azure; +pub mod non_azure; diff --git a/sdk/openai/inference/src/clients/non_azure.rs b/sdk/openai/inference/src/clients/non_azure.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs new file mode 100644 index 0000000000..1d92db587b --- /dev/null +++ b/sdk/openai/inference/src/lib.rs @@ -0,0 +1,3 @@ +mod clients; + +pub use clients::*; From ae57b1b5702287971f858eecf4fa51c4b36445f8 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 29 Aug 2024 16:28:10 +0200 Subject: [PATCH 02/71] Added more dependencies --- sdk/openai/inference/Cargo.toml | 2 ++ sdk/openai/inference/examples/azure_chat_completions.rs | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index e19e05521d..679117e9bd 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -15,7 +15,9 @@ categories = ["api-bindings"] workspace = true [dependencies] +azure_core = { workspace = true } reqwest = { workspace = true, optional = true } +tokio = { workspace = true } [features] default = [ "reqwest" ] diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index e69de29bb2..de68857c03 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -0,0 +1,6 @@ +use azure_core::Result; + +#[tokio::main] +pub async fn main() -> Result<()>{ + Ok(()) +} From 897078b6ea2ad4b2d6108e17b76df8910604ce89 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 30 Aug 2024 15:04:14 +0200 Subject: [PATCH 03/71] Internal state of clients is setup --- .../examples/azure_chat_completions.rs | 11 +++++++-- .../examples/non_azure_chat_completions.rs | 1 + sdk/openai/inference/src/clients/azure.rs | 24 +++++++++++++++++++ sdk/openai/inference/src/clients/non_azure.rs | 1 + sdk/openai/inference/src/lib.rs | 3 ++- 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index de68857c03..c593a093b0 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,6 +1,13 @@ -use azure_core::Result; +use azure_core::{auth::TokenCredential, Result}; +use azure_openai_inference::AzureOpenAIClient; #[tokio::main] -pub async fn main() -> Result<()>{ +pub async fn main() -> Result<()> { + let endpoint = + std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); + let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); + + let azure_openai_client = AzureOpenAIClient::new(endpoint, secret)?; + Ok(()) } diff --git a/sdk/openai/inference/examples/non_azure_chat_completions.rs b/sdk/openai/inference/examples/non_azure_chat_completions.rs index e69de29bb2..8b13789179 100644 --- a/sdk/openai/inference/examples/non_azure_chat_completions.rs +++ b/sdk/openai/inference/examples/non_azure_chat_completions.rs @@ -0,0 +1 @@ + diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index e69de29bb2..48fdc1e6c0 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -0,0 +1,24 @@ +use std::sync::Arc; + +use azure_core::auth::{Secret, TokenCredential}; +use azure_core::{self, HttpClient, Result}; +use reqwest::Url; + +pub struct AzureOpenAIClient { + http_client: Arc, + endpoint: Url, + secret: Secret, +} + +impl AzureOpenAIClient { + pub fn new(endpoint: impl AsRef, secret: String) -> Result { + let endpoint = Url::parse(endpoint.as_ref())?; + let secret = Secret::from(secret); + + Ok(AzureOpenAIClient { + http_client: azure_core::new_http_client(), + endpoint, + secret, + }) + } +} diff --git a/sdk/openai/inference/src/clients/non_azure.rs b/sdk/openai/inference/src/clients/non_azure.rs index e69de29bb2..3853afde6d 100644 --- a/sdk/openai/inference/src/clients/non_azure.rs +++ b/sdk/openai/inference/src/clients/non_azure.rs @@ -0,0 +1 @@ +pub struct OpenAIClient {} diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 1d92db587b..228c367b4c 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,3 +1,4 @@ mod clients; -pub use clients::*; +pub use clients::azure::*; +pub use clients::non_azure::*; From 67e5a12851298146be56ba26c547ff43764429b5 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 30 Aug 2024 16:48:00 +0200 Subject: [PATCH 04/71] Somehow hitting the noop client --- sdk/openai/inference/Cargo.toml | 2 + .../examples/azure_chat_completions.rs | 27 +++++++- sdk/openai/inference/src/auth/mod.rs | 41 ++++++++++++ sdk/openai/inference/src/clients/azure.rs | 60 +++++++++++++++-- sdk/openai/inference/src/clients/mod.rs | 40 +++++++++++ sdk/openai/inference/src/lib.rs | 3 + sdk/openai/inference/src/models/mod.rs | 5 ++ .../src/models/request/chat_completions.rs | 66 +++++++++++++++++++ .../inference/src/models/request/mod.rs | 3 + .../src/models/response/chat_completions.rs | 36 ++++++++++ .../inference/src/models/response/mod.rs | 3 + 11 files changed, 278 insertions(+), 8 deletions(-) create mode 100644 sdk/openai/inference/src/auth/mod.rs create mode 100644 sdk/openai/inference/src/models/mod.rs create mode 100644 sdk/openai/inference/src/models/request/chat_completions.rs create mode 100644 sdk/openai/inference/src/models/request/mod.rs create mode 100644 sdk/openai/inference/src/models/response/chat_completions.rs create mode 100644 sdk/openai/inference/src/models/response/mod.rs diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index 679117e9bd..dcd68edec9 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -18,6 +18,8 @@ workspace = true azure_core = { workspace = true } reqwest = { workspace = true, optional = true } tokio = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } [features] default = [ "reqwest" ] diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index c593a093b0..1b2324c848 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,5 +1,7 @@ -use azure_core::{auth::TokenCredential, Result}; -use azure_openai_inference::AzureOpenAIClient; +use azure_core::Result; +use azure_openai_inference::{ + AzureOpenAIClient, AzureServiceVersion, CreateChatCompletionsRequest, +}; #[tokio::main] pub async fn main() -> Result<()> { @@ -9,5 +11,26 @@ pub async fn main() -> Result<()> { let azure_openai_client = AzureOpenAIClient::new(endpoint, secret)?; + let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( + "gpt-4-1106-preview", + "Tell me a joke about pineapples", + ); + + let response = azure_openai_client + .create_chat_completions( + &chat_completions_request.model, + AzureServiceVersion::V2023_12_01Preview, + &chat_completions_request, + ) + .await; + + match response { + Ok(chat_completions) => { + println!("{:#?}", &chat_completions); + } + Err(e) => { + println!("Error: {}", e); + } + }; Ok(()) } diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs new file mode 100644 index 0000000000..e03d485fca --- /dev/null +++ b/sdk/openai/inference/src/auth/mod.rs @@ -0,0 +1,41 @@ +use azure_core::{ + auth::Secret, + headers::{HeaderName, HeaderValue, AUTHORIZATION}, + Header, +}; + +pub struct AzureKeyCredential(Secret); + +pub struct OpenAIKeyCredential(Secret); + +impl OpenAIKeyCredential { + pub fn new(access_token: String) -> Self { + Self(Secret::new(access_token)) + } +} + +impl AzureKeyCredential { + pub fn new(api_key: String) -> Self { + Self(Secret::new(api_key)) + } +} + +impl Header for AzureKeyCredential { + fn name(&self) -> HeaderName { + HeaderName::from_static("api-key") + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("{}", self.0.secret())) + } +} + +impl Header for OpenAIKeyCredential { + fn name(&self) -> HeaderName { + AUTHORIZATION + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) + } +} diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 48fdc1e6c0..ce15841a93 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -1,24 +1,72 @@ use std::sync::Arc; -use azure_core::auth::{Secret, TokenCredential}; -use azure_core::{self, HttpClient, Result}; -use reqwest::Url; +use crate::auth::AzureKeyCredential; +use crate::models::CreateChatCompletionsRequest; +use crate::CreateChatCompletionsResponse; +use azure_core::Url; +use azure_core::{self, HttpClient, Method, Result}; + +// TODO: Implement using this instead +// typespec_client_core::json_model!(CreateChatCompletionsResponse); pub struct AzureOpenAIClient { http_client: Arc, endpoint: Url, - secret: Secret, + key_credential: AzureKeyCredential, } impl AzureOpenAIClient { pub fn new(endpoint: impl AsRef, secret: String) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; - let secret = Secret::from(secret); + let key_credential = AzureKeyCredential::new(secret); Ok(AzureOpenAIClient { http_client: azure_core::new_http_client(), endpoint, - secret, + key_credential, }) } + + pub fn endpoint(&self) -> &Url { + &self.endpoint + } + + pub async fn create_chat_completions( + &self, + deployment_name: &str, + api_version: impl Into, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result { + let url = Url::parse(&format!( + "{}/openai/deployments/{}/chat/completions?api-version={}", + &self.endpoint, + deployment_name, + api_version.into() + ))?; + let request = super::build_request( + &self.key_credential, + url, + Method::Post, + chat_completions_request, + )?; + let response = self.http_client.execute_request(&request).await?; + Ok(response.into_body().json().await?) + } +} + +pub enum AzureServiceVersion { + V2023_09_01Preview, + V2023_12_01Preview, + V2024_07_01Preview, +} + +impl From for String { + fn from(version: AzureServiceVersion) -> String { + let as_str = match version { + AzureServiceVersion::V2023_09_01Preview => "2023-09-01-preview", + AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview", + AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview", + }; + return String::from(as_str); + } } diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 1b3cf68770..8740f354ae 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -1,2 +1,42 @@ +use azure_core::{ + headers::{ACCEPT, CONTENT_TYPE}, + Header, Method, Request, Result, Url, +}; +use serde::Serialize; + pub mod azure; pub mod non_azure; + +pub(crate) fn build_request( + key_credential: &impl Header, + url: Url, + method: Method, + data: &T, +) -> Result +where + T: ?Sized + Serialize, +{ + let mut request = Request::new(url, method); + request.add_mandatory_header(key_credential); + request.insert_header(CONTENT_TYPE, "application/json"); + request.insert_header(ACCEPT, "application/json"); + request.set_json(data)?; + Ok(request) +} + +// pub(crate) fn build_multipart_request( +// key_credential: &impl Header, +// url: Url, +// form_generator: F, +// ) -> Result +// where +// F: FnOnce() -> Result, +// { +// let mut request = Request::new(url, Method::Post); +// request.add_mandatory_header(key_credential); +// // handled insternally by reqwest +// // request.insert_header(CONTENT_TYPE, "multipart/form-data"); +// // request.insert_header(ACCEPT, "application/json"); +// request.multipart(form_generator()?); +// Ok(request) +// } diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 228c367b4c..a84c4f69b4 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,4 +1,7 @@ +pub mod auth; mod clients; +mod models; pub use clients::azure::*; pub use clients::non_azure::*; +pub use models::*; diff --git a/sdk/openai/inference/src/models/mod.rs b/sdk/openai/inference/src/models/mod.rs new file mode 100644 index 0000000000..b8be6322b5 --- /dev/null +++ b/sdk/openai/inference/src/models/mod.rs @@ -0,0 +1,5 @@ +mod request; +mod response; + +pub use request::*; +pub use response::*; diff --git a/sdk/openai/inference/src/models/request/chat_completions.rs b/sdk/openai/inference/src/models/request/chat_completions.rs new file mode 100644 index 0000000000..3af246183a --- /dev/null +++ b/sdk/openai/inference/src/models/request/chat_completions.rs @@ -0,0 +1,66 @@ +use serde::Serialize; + +#[derive(Serialize, Debug, Clone, Default)] +pub struct CreateChatCompletionsRequest { + pub messages: Vec, + pub model: String, + pub stream: Option, + // pub frequency_penalty: f64, + // pub logit_bias: Option>, + // pub logprobs: Option, + // pub top_logprobs: Option, + // pub max_tokens: Option, +} + +#[derive(Serialize, Debug, Clone, Default)] +pub struct ChatCompletionRequestMessageBase { + #[serde(skip)] + pub name: Option, + pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type) +} + +#[derive(Serialize, Debug, Clone)] +#[serde(tag = "role")] +pub enum ChatCompletionRequestMessage { + #[serde(rename = "system")] + System(ChatCompletionRequestMessageBase), + #[serde(rename = "user")] + User(ChatCompletionRequestMessageBase), +} + +impl ChatCompletionRequestMessage { + pub fn new_user(content: impl Into) -> Self { + Self::User(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } + + pub fn new_system(content: impl Into) -> Self { + Self::System(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } +} +impl CreateChatCompletionsRequest { + pub fn new_with_user_message(model: &str, prompt: &str) -> Self { + Self { + model: model.to_string(), + messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + ..Default::default() + } + } + + pub fn new_stream_with_user_message( + model: impl Into, + prompt: impl Into, + ) -> Self { + Self { + model: model.into(), + messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + stream: Some(true), + ..Default::default() + } + } +} diff --git a/sdk/openai/inference/src/models/request/mod.rs b/sdk/openai/inference/src/models/request/mod.rs new file mode 100644 index 0000000000..8ccec0e32c --- /dev/null +++ b/sdk/openai/inference/src/models/request/mod.rs @@ -0,0 +1,3 @@ +mod chat_completions; + +pub use chat_completions::*; diff --git a/sdk/openai/inference/src/models/response/chat_completions.rs b/sdk/openai/inference/src/models/response/chat_completions.rs new file mode 100644 index 0000000000..687fa7dca4 --- /dev/null +++ b/sdk/openai/inference/src/models/response/chat_completions.rs @@ -0,0 +1,36 @@ +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct CreateChatCompletionsResponse { + pub choices: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionChoice { + pub message: ChatCompletionResponseMessage, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionResponseMessage { + pub content: Option, + pub role: String, +} + +// region: --- Streaming +#[derive(Debug, Clone, Deserialize)] +pub struct CreateChatCompletionsStreamResponse { + pub choices: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionStreamChoice { + pub delta: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ChatCompletionStreamResponseMessage { + pub content: Option, + pub role: Option, +} + +// endregion: Streaming diff --git a/sdk/openai/inference/src/models/response/mod.rs b/sdk/openai/inference/src/models/response/mod.rs new file mode 100644 index 0000000000..8ccec0e32c --- /dev/null +++ b/sdk/openai/inference/src/models/response/mod.rs @@ -0,0 +1,3 @@ +mod chat_completions; + +pub use chat_completions::*; From 5b5e3a9d233f3a19ba2e0a33097835e6c071f43c Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 30 Aug 2024 21:30:51 +0200 Subject: [PATCH 05/71] getting 200 but no content --- sdk/openai/inference/Cargo.toml | 12 +++++++++--- sdk/openai/inference/src/clients/azure.rs | 7 ++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index dcd68edec9..cc4ca823e7 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -16,11 +16,17 @@ workspace = true [dependencies] azure_core = { workspace = true } -reqwest = { workspace = true, optional = true } +# reqwest = { workspace = true, optional = true } tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +[dev-dependencies] +reqwest = { workspace = true } +tokio = { workspace = true } + [features] -default = [ "reqwest" ] -reqwest = [ "dep:reqwest" ] +# default = [ "reqwest" ] +# reqwest = [ "dep:reqwest" ] +default = ["enable_reqwest"] +enable_reqwest = ["azure_core/enable_reqwest"] diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index ce15841a93..751137ada5 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -16,6 +16,7 @@ pub struct AzureOpenAIClient { } impl AzureOpenAIClient { + // TODO: not sure if this should be named `with_key_credential` instead pub fn new(endpoint: impl AsRef, secret: String) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; let key_credential = AzureKeyCredential::new(secret); @@ -50,7 +51,11 @@ impl AzureOpenAIClient { chat_completions_request, )?; let response = self.http_client.execute_request(&request).await?; - Ok(response.into_body().json().await?) + let (status_code , headers, body) = response.deconstruct(); + + println!("Status code: {:?}", status_code); + println!("Headers: {:?}", headers); + Ok(body.json().await?) } } From 80e2853c39d77683f7a4e1e406fd34f26bd76275 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Tue, 3 Sep 2024 11:23:05 +0200 Subject: [PATCH 06/71] format --- sdk/openai/inference/src/clients/azure.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 751137ada5..9795f2b79c 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -51,7 +51,7 @@ impl AzureOpenAIClient { chat_completions_request, )?; let response = self.http_client.execute_request(&request).await?; - let (status_code , headers, body) = response.deconstruct(); + let (status_code, headers, body) = response.deconstruct(); println!("Status code: {:?}", status_code); println!("Headers: {:?}", headers); From 5157d81494d9932db3a6de85398f76b338cf2d87 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 6 Sep 2024 10:58:07 +0200 Subject: [PATCH 07/71] Attempt at using the pipeline --- .../examples/azure_chat_completions.rs | 2 +- sdk/openai/inference/src/auth/mod.rs | 1 + sdk/openai/inference/src/clients/azure.rs | 61 ++++++++++++++++++- 3 files changed, 60 insertions(+), 4 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 1b2324c848..b6faeb1763 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -17,7 +17,7 @@ pub async fn main() -> Result<()> { ); let response = azure_openai_client - .create_chat_completions( + .create_chat_completions_through_pipeline( &chat_completions_request.model, AzureServiceVersion::V2023_12_01Preview, &chat_completions_request, diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index e03d485fca..88ad9f3fe8 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -4,6 +4,7 @@ use azure_core::{ Header, }; +#[derive(Debug, Clone)] pub struct AzureKeyCredential(Secret); pub struct OpenAIKeyCredential(Secret); diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 9795f2b79c..ecfa78aa03 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -3,35 +3,90 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::models::CreateChatCompletionsRequest; use crate::CreateChatCompletionsResponse; -use azure_core::Url; use azure_core::{self, HttpClient, Method, Result}; +use azure_core::{Context, Url}; // TODO: Implement using this instead // typespec_client_core::json_model!(CreateChatCompletionsResponse); -pub struct AzureOpenAIClient { +pub struct AzureOpenAIClient<'a> { http_client: Arc, endpoint: Url, key_credential: AzureKeyCredential, + context: Context<'a>, + pipeline: azure_core::Pipeline, } -impl AzureOpenAIClient { +impl AzureOpenAIClient<'_> { // TODO: not sure if this should be named `with_key_credential` instead pub fn new(endpoint: impl AsRef, secret: String) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; let key_credential = AzureKeyCredential::new(secret); + // let auth_header_policy = CustomHeadersPolicy(key_credential.into()); + + let mut context = Context::new(); + context.insert(key_credential.clone()); + + let pipeline = Self::new_pipeline(); Ok(AzureOpenAIClient { http_client: azure_core::new_http_client(), endpoint, key_credential, + context, + pipeline, }) } + fn new_pipeline() -> azure_core::Pipeline { + let crate_name = option_env!("CARGO_PKG_NAME"); + let crate_version = option_env!("CARGO_PKG_VERSION"); + let options = azure_core::ClientOptions::default(); + let per_call_policies = Vec::new(); + let per_retry_policies = Vec::new(); + + azure_core::Pipeline::new( + crate_name, + crate_version, + options, + per_call_policies, + per_retry_policies, + ) + } + pub fn endpoint(&self) -> &Url { &self.endpoint } + pub async fn create_chat_completions_through_pipeline( + &self, + deployment_name: &str, + api_version: impl Into, + chat_completions_request: &CreateChatCompletionsRequest, + // Should I be using RequestContent ? All the new methods have signatures that would force me to mutate + // the request object into &static str, Vec, etc. + // chat_completions_request: RequestContent, + ) -> Result { + let url = Url::parse(&format!( + "{}/openai/deployments/{}/chat/completions?api-version={}", + &self.endpoint, + deployment_name, + api_version.into() + ))?; + + let mut request = azure_core::Request::new(url, Method::Post); + // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) + request.add_mandatory_header(&self.key_credential); + + request.set_json(chat_completions_request)?; + + let response = self + .pipeline + .send::(&self.context, &mut request) + .await?; + response.into_body().json().await + } + pub async fn create_chat_completions( &self, deployment_name: &str, From 4967330461f3f9cb651714881fc729dba941fca1 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 6 Sep 2024 11:57:59 +0200 Subject: [PATCH 08/71] tried implementing custom type as policy and pass in the pipeline --- sdk/openai/inference/Cargo.toml | 1 + .../examples/azure_chat_completions.rs | 2 +- sdk/openai/inference/src/auth/mod.rs | 29 +++++++++++++++++-- sdk/openai/inference/src/clients/azure.rs | 22 +++++++++----- 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index cc4ca823e7..f28284a620 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -20,6 +20,7 @@ azure_core = { workspace = true } tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +async-trait = { workspace = true } [dev-dependencies] reqwest = { workspace = true } diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index b6faeb1763..32f1e609a6 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -9,7 +9,7 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let azure_openai_client = AzureOpenAIClient::new(endpoint, secret)?; + let azure_openai_client = AzureOpenAIClient::new(endpoint, secret, None)?; let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( "gpt-4-1106-preview", diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 88ad9f3fe8..fcdd926922 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,7 +1,8 @@ +use std::sync::Arc; +use async_trait::async_trait; + use azure_core::{ - auth::Secret, - headers::{HeaderName, HeaderValue, AUTHORIZATION}, - Header, + auth::Secret, headers::{HeaderName, HeaderValue, AUTHORIZATION}, Context, Header, Policy, PolicyResult, Request }; #[derive(Debug, Clone)] @@ -31,6 +32,28 @@ impl Header for AzureKeyCredential { } } +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for AzureKeyCredential { + + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.insert_header(Header::name(self), Header::value(self)); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl Into>> for AzureKeyCredential { + fn into(self) -> Vec> { + vec![Arc::new(self)] + } +} + impl Header for OpenAIKeyCredential { fn name(&self) -> HeaderName { AUTHORIZATION diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index ecfa78aa03..65921d6c7c 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -3,31 +3,38 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::models::CreateChatCompletionsRequest; use crate::CreateChatCompletionsResponse; -use azure_core::{self, HttpClient, Method, Result}; +use azure_core::{self, ClientOptions, HttpClient, Method, Policy, Result}; use azure_core::{Context, Url}; // TODO: Implement using this instead // typespec_client_core::json_model!(CreateChatCompletionsResponse); -pub struct AzureOpenAIClient<'a> { +#[derive(Clone, Debug, Default)] +pub struct AzureOpenAIClientOptions { + client_options: ClientOptions, +} + +pub struct AzureOpenAIClient <'a> { http_client: Arc, endpoint: Url, key_credential: AzureKeyCredential, context: Context<'a>, pipeline: azure_core::Pipeline, + azure_openai_client_options: AzureOpenAIClientOptions } -impl AzureOpenAIClient<'_> { +impl AzureOpenAIClient <'_> { // TODO: not sure if this should be named `with_key_credential` instead - pub fn new(endpoint: impl AsRef, secret: String) -> Result { + pub fn new(endpoint: impl AsRef, secret: String, client_options: Option) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; let key_credential = AzureKeyCredential::new(secret); - // let auth_header_policy = CustomHeadersPolicy(key_credential.into()); - let mut context = Context::new(); - context.insert(key_credential.clone()); + let context = Context::new(); let pipeline = Self::new_pipeline(); + let mut azure_openai_client_options = client_options.unwrap_or_default(); + let per_call_policies: Vec> = key_credential.clone().into(); + azure_openai_client_options.client_options.set_per_call_policies(per_call_policies); Ok(AzureOpenAIClient { http_client: azure_core::new_http_client(), @@ -35,6 +42,7 @@ impl AzureOpenAIClient<'_> { key_credential, context, pipeline, + azure_openai_client_options }) } From 817cd3b5aceb41c8a855e57f606cd175a13e5bd5 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 6 Sep 2024 12:13:51 +0200 Subject: [PATCH 09/71] Works with policy --- sdk/openai/inference/src/clients/azure.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 65921d6c7c..ec70c2633f 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -31,10 +31,11 @@ impl AzureOpenAIClient <'_> { let context = Context::new(); - let pipeline = Self::new_pipeline(); let mut azure_openai_client_options = client_options.unwrap_or_default(); let per_call_policies: Vec> = key_credential.clone().into(); - azure_openai_client_options.client_options.set_per_call_policies(per_call_policies); + + let pipeline = Self::new_pipeline(per_call_policies); + // azure_openai_client_options.client_options.set_per_call_policies(per_call_policies); Ok(AzureOpenAIClient { http_client: azure_core::new_http_client(), @@ -46,11 +47,11 @@ impl AzureOpenAIClient <'_> { }) } - fn new_pipeline() -> azure_core::Pipeline { + fn new_pipeline(per_call_policies: Vec>) -> azure_core::Pipeline { let crate_name = option_env!("CARGO_PKG_NAME"); let crate_version = option_env!("CARGO_PKG_VERSION"); let options = azure_core::ClientOptions::default(); - let per_call_policies = Vec::new(); + // let per_call_policies = Vec::new(); let per_retry_policies = Vec::new(); azure_core::Pipeline::new( @@ -84,7 +85,7 @@ impl AzureOpenAIClient <'_> { let mut request = azure_core::Request::new(url, Method::Post); // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) - request.add_mandatory_header(&self.key_credential); + // request.add_mandatory_header(&self.key_credential); request.set_json(chat_completions_request)?; From 9056e38315b274503877f55b14a1380a32e7d4b6 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Tue, 10 Sep 2024 11:22:56 +0200 Subject: [PATCH 10/71] WIP --- .../examples/azure_chat_completions.rs | 14 ++- sdk/openai/inference/src/auth/mod.rs | 7 +- sdk/openai/inference/src/clients/azure.rs | 104 +++++++++++------- sdk/openai/inference/src/clients/mod.rs | 42 +++---- 4 files changed, 99 insertions(+), 68 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 32f1e609a6..4c14987bbf 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,6 +1,7 @@ use azure_core::Result; use azure_openai_inference::{ - AzureOpenAIClient, AzureServiceVersion, CreateChatCompletionsRequest, + builders::AzureOpenAIClientOptionsBuilder, AzureOpenAIClient, AzureOpenAIClientOptions, + AzureServiceVersion, CreateChatCompletionsRequest, }; #[tokio::main] @@ -9,7 +10,13 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let azure_openai_client = AzureOpenAIClient::new(endpoint, secret, None)?; + let azure_openai_client = AzureOpenAIClient::new( + endpoint, + secret, + Some(AzureOpenAIClientOptions::builder() + .with_api_version(AzureServiceVersion::V2023_12_01Preview) + .build()), + )?; let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( "gpt-4-1106-preview", @@ -17,9 +24,8 @@ pub async fn main() -> Result<()> { ); let response = azure_openai_client - .create_chat_completions_through_pipeline( + .create_chat_completions( &chat_completions_request.model, - AzureServiceVersion::V2023_12_01Preview, &chat_completions_request, ) .await; diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index fcdd926922..20ef984aa5 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,8 +1,10 @@ -use std::sync::Arc; use async_trait::async_trait; +use std::sync::Arc; use azure_core::{ - auth::Secret, headers::{HeaderName, HeaderValue, AUTHORIZATION}, Context, Header, Policy, PolicyResult, Request + auth::Secret, + headers::{HeaderName, HeaderValue, AUTHORIZATION}, + Context, Header, Policy, PolicyResult, Request, }; #[derive(Debug, Clone)] @@ -36,7 +38,6 @@ impl Header for AzureKeyCredential { #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl Policy for AzureKeyCredential { - async fn send( &self, ctx: &Context, diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index ec70c2633f..5d4dd0db6b 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -3,47 +3,51 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::models::CreateChatCompletionsRequest; use crate::CreateChatCompletionsResponse; -use azure_core::{self, ClientOptions, HttpClient, Method, Policy, Result}; +use azure_core::{ + self, builders::ClientOptionsBuilder, AppendToUrlQuery, ClientOptions, HttpClient, Method, + Policy, Result, +}; use azure_core::{Context, Url}; // TODO: Implement using this instead // typespec_client_core::json_model!(CreateChatCompletionsResponse); +// TODO: I was not able to find ClientOptions as a derive macros #[derive(Clone, Debug, Default)] pub struct AzureOpenAIClientOptions { client_options: ClientOptions, + api_service_version: AzureServiceVersion, } -pub struct AzureOpenAIClient <'a> { - http_client: Arc, +pub struct AzureOpenAIClient<'a> { endpoint: Url, - key_credential: AzureKeyCredential, context: Context<'a>, pipeline: azure_core::Pipeline, - azure_openai_client_options: AzureOpenAIClientOptions + options: AzureOpenAIClientOptions, } -impl AzureOpenAIClient <'_> { +impl AzureOpenAIClient<'_> { // TODO: not sure if this should be named `with_key_credential` instead - pub fn new(endpoint: impl AsRef, secret: String, client_options: Option) -> Result { + pub fn new( + endpoint: impl AsRef, + secret: String, + client_options: Option, + ) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; let key_credential = AzureKeyCredential::new(secret); let context = Context::new(); - let mut azure_openai_client_options = client_options.unwrap_or_default(); + let options = client_options.unwrap_or_default(); let per_call_policies: Vec> = key_credential.clone().into(); let pipeline = Self::new_pipeline(per_call_policies); - // azure_openai_client_options.client_options.set_per_call_policies(per_call_policies); Ok(AzureOpenAIClient { - http_client: azure_core::new_http_client(), endpoint, - key_credential, context, pipeline, - azure_openai_client_options + options, }) } @@ -51,7 +55,6 @@ impl AzureOpenAIClient <'_> { let crate_name = option_env!("CARGO_PKG_NAME"); let crate_version = option_env!("CARGO_PKG_VERSION"); let options = azure_core::ClientOptions::default(); - // let per_call_policies = Vec::new(); let per_retry_policies = Vec::new(); azure_core::Pipeline::new( @@ -67,10 +70,9 @@ impl AzureOpenAIClient <'_> { &self.endpoint } - pub async fn create_chat_completions_through_pipeline( + pub async fn create_chat_completions( &self, deployment_name: &str, - api_version: impl Into, chat_completions_request: &CreateChatCompletionsRequest, // Should I be using RequestContent ? All the new methods have signatures that would force me to mutate // the request object into &static str, Vec, etc. @@ -80,7 +82,7 @@ impl AzureOpenAIClient <'_> { "{}/openai/deployments/{}/chat/completions?api-version={}", &self.endpoint, deployment_name, - api_version.into() + &self.options.api_service_version.to_string(), ))?; let mut request = azure_core::Request::new(url, Method::Post); @@ -95,40 +97,56 @@ impl AzureOpenAIClient <'_> { .await?; response.into_body().json().await } +} - pub async fn create_chat_completions( - &self, - deployment_name: &str, - api_version: impl Into, - chat_completions_request: &CreateChatCompletionsRequest, - ) -> Result { - let url = Url::parse(&format!( - "{}/openai/deployments/{}/chat/completions?api-version={}", - &self.endpoint, - deployment_name, - api_version.into() - ))?; - let request = super::build_request( - &self.key_credential, - url, - Method::Post, - chat_completions_request, - )?; - let response = self.http_client.execute_request(&request).await?; - let (status_code, headers, body) = response.deconstruct(); - - println!("Status code: {:?}", status_code); - println!("Headers: {:?}", headers); - Ok(body.json().await?) +impl AzureOpenAIClientOptions { + pub fn builder() -> builders::AzureOpenAIClientOptionsBuilder { + builders::AzureOpenAIClientOptionsBuilder::new() + } +} + +pub mod builders { + use super::*; + + #[derive(Clone, Debug, Default)] + pub struct AzureOpenAIClientOptionsBuilder { + options: AzureOpenAIClientOptions, + } + + impl AzureOpenAIClientOptionsBuilder { + pub(super) fn new() -> Self { + Self::default() + } + pub fn with_api_version(mut self, api_service_version: AzureServiceVersion) -> Self { + self.options.api_service_version = api_service_version; + self + } + + pub fn build(&self) -> AzureOpenAIClientOptions { + self.options.clone() + } } } +#[derive(Debug, Clone)] pub enum AzureServiceVersion { V2023_09_01Preview, V2023_12_01Preview, V2024_07_01Preview, } +impl Default for AzureServiceVersion { + fn default() -> AzureServiceVersion { + AzureServiceVersion::get_latest() + } +} + +impl AzureServiceVersion { + pub fn get_latest() -> AzureServiceVersion { + AzureServiceVersion::V2024_07_01Preview + } +} + impl From for String { fn from(version: AzureServiceVersion) -> String { let as_str = match version { @@ -139,3 +157,9 @@ impl From for String { return String::from(as_str); } } + +impl ToString for AzureServiceVersion { + fn to_string(&self) -> String { + String::from(self.clone()) + } +} diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 8740f354ae..2e8d14cefc 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -1,28 +1,28 @@ -use azure_core::{ - headers::{ACCEPT, CONTENT_TYPE}, - Header, Method, Request, Result, Url, -}; -use serde::Serialize; +// use azure_core::{ +// headers::{ACCEPT, CONTENT_TYPE}, +// Header, Method, Request, Result, Url, +// }; +// use serde::Serialize; pub mod azure; pub mod non_azure; -pub(crate) fn build_request( - key_credential: &impl Header, - url: Url, - method: Method, - data: &T, -) -> Result -where - T: ?Sized + Serialize, -{ - let mut request = Request::new(url, method); - request.add_mandatory_header(key_credential); - request.insert_header(CONTENT_TYPE, "application/json"); - request.insert_header(ACCEPT, "application/json"); - request.set_json(data)?; - Ok(request) -} +// pub(crate) fn build_request( +// key_credential: &impl Header, +// url: Url, +// method: Method, +// data: &T, +// ) -> Result +// where +// T: ?Sized + Serialize, +// { +// let mut request = Request::new(url, method); +// request.add_mandatory_header(key_credential); +// request.insert_header(CONTENT_TYPE, "application/json"); +// request.insert_header(ACCEPT, "application/json"); +// request.set_json(data)?; +// Ok(request) +// } // pub(crate) fn build_multipart_request( // key_credential: &impl Header, From d65edb43b895fdfd268a84efefef117847b872c9 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 17:49:14 +0200 Subject: [PATCH 11/71] Extracted client_options to a separate module --- .../examples/azure_chat_completions.rs | 16 ++-- sdk/openai/inference/src/clients/azure.rs | 89 ++----------------- sdk/openai/inference/src/lib.rs | 2 + sdk/openai/inference/src/options/mod.rs | 40 +++++++++ .../inference/src/options/service_version.rs | 35 ++++++++ 5 files changed, 91 insertions(+), 91 deletions(-) create mode 100644 sdk/openai/inference/src/options/mod.rs create mode 100644 sdk/openai/inference/src/options/service_version.rs diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 4c14987bbf..5527aec4f1 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,7 +1,6 @@ use azure_core::Result; use azure_openai_inference::{ - builders::AzureOpenAIClientOptionsBuilder, AzureOpenAIClient, AzureOpenAIClientOptions, - AzureServiceVersion, CreateChatCompletionsRequest, + AzureOpenAIClient, AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, }; #[tokio::main] @@ -13,9 +12,11 @@ pub async fn main() -> Result<()> { let azure_openai_client = AzureOpenAIClient::new( endpoint, secret, - Some(AzureOpenAIClientOptions::builder() - .with_api_version(AzureServiceVersion::V2023_12_01Preview) - .build()), + Some( + AzureOpenAIClientOptions::builder() + .with_api_version(AzureServiceVersion::V2023_12_01Preview) + .build(), + ), )?; let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( @@ -24,10 +25,7 @@ pub async fn main() -> Result<()> { ); let response = azure_openai_client - .create_chat_completions( - &chat_completions_request.model, - &chat_completions_request, - ) + .create_chat_completions(&chat_completions_request.model, &chat_completions_request) .await; match response { diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 5d4dd0db6b..52b6b6d83e 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -2,23 +2,11 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::models::CreateChatCompletionsRequest; +use crate::options::AzureOpenAIClientOptions; use crate::CreateChatCompletionsResponse; -use azure_core::{ - self, builders::ClientOptionsBuilder, AppendToUrlQuery, ClientOptions, HttpClient, Method, - Policy, Result, -}; +use azure_core::{self, Method, Policy, Result}; use azure_core::{Context, Url}; -// TODO: Implement using this instead -// typespec_client_core::json_model!(CreateChatCompletionsResponse); - -// TODO: I was not able to find ClientOptions as a derive macros -#[derive(Clone, Debug, Default)] -pub struct AzureOpenAIClientOptions { - client_options: ClientOptions, - api_service_version: AzureServiceVersion, -} - pub struct AzureOpenAIClient<'a> { endpoint: Url, context: Context<'a>, @@ -41,7 +29,7 @@ impl AzureOpenAIClient<'_> { let options = client_options.unwrap_or_default(); let per_call_policies: Vec> = key_credential.clone().into(); - let pipeline = Self::new_pipeline(per_call_policies); + let pipeline = Self::new_pipeline(per_call_policies, options.client_options.clone()); Ok(AzureOpenAIClient { endpoint, @@ -51,10 +39,12 @@ impl AzureOpenAIClient<'_> { }) } - fn new_pipeline(per_call_policies: Vec>) -> azure_core::Pipeline { + fn new_pipeline( + per_call_policies: Vec>, + options: azure_core::ClientOptions, + ) -> azure_core::Pipeline { let crate_name = option_env!("CARGO_PKG_NAME"); let crate_version = option_env!("CARGO_PKG_VERSION"); - let options = azure_core::ClientOptions::default(); let per_retry_policies = Vec::new(); azure_core::Pipeline::new( @@ -98,68 +88,3 @@ impl AzureOpenAIClient<'_> { response.into_body().json().await } } - -impl AzureOpenAIClientOptions { - pub fn builder() -> builders::AzureOpenAIClientOptionsBuilder { - builders::AzureOpenAIClientOptionsBuilder::new() - } -} - -pub mod builders { - use super::*; - - #[derive(Clone, Debug, Default)] - pub struct AzureOpenAIClientOptionsBuilder { - options: AzureOpenAIClientOptions, - } - - impl AzureOpenAIClientOptionsBuilder { - pub(super) fn new() -> Self { - Self::default() - } - pub fn with_api_version(mut self, api_service_version: AzureServiceVersion) -> Self { - self.options.api_service_version = api_service_version; - self - } - - pub fn build(&self) -> AzureOpenAIClientOptions { - self.options.clone() - } - } -} - -#[derive(Debug, Clone)] -pub enum AzureServiceVersion { - V2023_09_01Preview, - V2023_12_01Preview, - V2024_07_01Preview, -} - -impl Default for AzureServiceVersion { - fn default() -> AzureServiceVersion { - AzureServiceVersion::get_latest() - } -} - -impl AzureServiceVersion { - pub fn get_latest() -> AzureServiceVersion { - AzureServiceVersion::V2024_07_01Preview - } -} - -impl From for String { - fn from(version: AzureServiceVersion) -> String { - let as_str = match version { - AzureServiceVersion::V2023_09_01Preview => "2023-09-01-preview", - AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview", - AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview", - }; - return String::from(as_str); - } -} - -impl ToString for AzureServiceVersion { - fn to_string(&self) -> String { - String::from(self.clone()) - } -} diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index a84c4f69b4..311b0f2297 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,7 +1,9 @@ pub mod auth; mod clients; mod models; +mod options; pub use clients::azure::*; pub use clients::non_azure::*; pub use models::*; +pub use options::*; diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs new file mode 100644 index 0000000000..b3e2d9cbe0 --- /dev/null +++ b/sdk/openai/inference/src/options/mod.rs @@ -0,0 +1,40 @@ +mod service_version; + +use azure_core::ClientOptions; + +pub use service_version::AzureServiceVersion; + +// TODO: I was not able to find ClientOptions as a derive macros +#[derive(Clone, Debug, Default)] +pub struct AzureOpenAIClientOptions { + pub(crate) client_options: ClientOptions, + pub(crate) api_service_version: AzureServiceVersion, +} +impl AzureOpenAIClientOptions { + pub fn builder() -> builders::AzureOpenAIClientOptionsBuilder { + builders::AzureOpenAIClientOptionsBuilder::new() + } +} + +pub mod builders { + use super::*; + + #[derive(Clone, Debug, Default)] + pub struct AzureOpenAIClientOptionsBuilder { + options: AzureOpenAIClientOptions, + } + + impl AzureOpenAIClientOptionsBuilder { + pub(super) fn new() -> Self { + Self::default() + } + pub fn with_api_version(mut self, api_service_version: AzureServiceVersion) -> Self { + self.options.api_service_version = api_service_version; + self + } + + pub fn build(&self) -> AzureOpenAIClientOptions { + self.options.clone() + } + } +} diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs new file mode 100644 index 0000000000..7a7c3bc52d --- /dev/null +++ b/sdk/openai/inference/src/options/service_version.rs @@ -0,0 +1,35 @@ +#[derive(Debug, Clone)] +pub enum AzureServiceVersion { + V2023_09_01Preview, + V2023_12_01Preview, + V2024_07_01Preview, +} + +impl Default for AzureServiceVersion { + fn default() -> AzureServiceVersion { + AzureServiceVersion::get_latest() + } +} + +impl AzureServiceVersion { + pub fn get_latest() -> AzureServiceVersion { + AzureServiceVersion::V2024_07_01Preview + } +} + +impl From for String { + fn from(version: AzureServiceVersion) -> String { + let as_str = match version { + AzureServiceVersion::V2023_09_01Preview => "2023-09-01-preview", + AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview", + AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview", + }; + return String::from(as_str); + } +} + +impl ToString for AzureServiceVersion { + fn to_string(&self) -> String { + String::from(self.clone()) + } +} From 0e9f8e28c5a07cdd47161790ecc3381a6c9d84dd Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 17:57:16 +0200 Subject: [PATCH 12/71] Re-organized models --- .../examples/azure_chat_completions.rs | 3 +- sdk/openai/inference/src/clients/azure.rs | 5 +- .../inference/src/models/chat_completions.rs | 109 ++++++++++++++++++ sdk/openai/inference/src/models/mod.rs | 6 +- .../src/models/request/chat_completions.rs | 66 ----------- .../inference/src/models/request/mod.rs | 3 - .../src/models/response/chat_completions.rs | 36 ------ .../inference/src/models/response/mod.rs | 3 - 8 files changed, 116 insertions(+), 115 deletions(-) create mode 100644 sdk/openai/inference/src/models/chat_completions.rs delete mode 100644 sdk/openai/inference/src/models/request/chat_completions.rs delete mode 100644 sdk/openai/inference/src/models/request/mod.rs delete mode 100644 sdk/openai/inference/src/models/response/chat_completions.rs delete mode 100644 sdk/openai/inference/src/models/response/mod.rs diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 5527aec4f1..78b06f65e3 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,6 +1,7 @@ use azure_core::Result; use azure_openai_inference::{ - AzureOpenAIClient, AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, + request::CreateChatCompletionsRequest, AzureOpenAIClient, AzureOpenAIClientOptions, + AzureServiceVersion, }; #[tokio::main] diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 52b6b6d83e..01eb19f853 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -1,9 +1,10 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; -use crate::models::CreateChatCompletionsRequest; + use crate::options::AzureOpenAIClientOptions; -use crate::CreateChatCompletionsResponse; +use crate::request::CreateChatCompletionsRequest; +use crate::response::CreateChatCompletionsResponse; use azure_core::{self, Method, Policy, Result}; use azure_core::{Context, Url}; diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs new file mode 100644 index 0000000000..bf250e131a --- /dev/null +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -0,0 +1,109 @@ +pub mod request { + + use serde::Serialize; + + #[derive(Serialize, Debug, Clone, Default)] + pub struct CreateChatCompletionsRequest { + pub messages: Vec, + pub model: String, + pub stream: Option, + // pub frequency_penalty: f64, + // pub logit_bias: Option>, + // pub logprobs: Option, + // pub top_logprobs: Option, + // pub max_tokens: Option, + } + + #[derive(Serialize, Debug, Clone, Default)] + pub struct ChatCompletionRequestMessageBase { + #[serde(skip)] + pub name: Option, + pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type) + } + + #[derive(Serialize, Debug, Clone)] + #[serde(tag = "role")] + pub enum ChatCompletionRequestMessage { + #[serde(rename = "system")] + System(ChatCompletionRequestMessageBase), + #[serde(rename = "user")] + User(ChatCompletionRequestMessageBase), + } + + impl ChatCompletionRequestMessage { + pub fn new_user(content: impl Into) -> Self { + Self::User(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } + + pub fn new_system(content: impl Into) -> Self { + Self::System(ChatCompletionRequestMessageBase { + content: content.into(), + name: None, + }) + } + } + impl CreateChatCompletionsRequest { + pub fn new_with_user_message(model: &str, prompt: &str) -> Self { + Self { + model: model.to_string(), + messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + ..Default::default() + } + } + + pub fn new_stream_with_user_message( + model: impl Into, + prompt: impl Into, + ) -> Self { + Self { + model: model.into(), + messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + stream: Some(true), + ..Default::default() + } + } + } +} + +pub mod response { + + use serde::Deserialize; + + #[derive(Debug, Clone, Deserialize)] + pub struct CreateChatCompletionsResponse { + pub choices: Vec, + } + + #[derive(Debug, Clone, Deserialize)] + pub struct ChatCompletionChoice { + pub message: ChatCompletionResponseMessage, + } + + #[derive(Debug, Clone, Deserialize)] + pub struct ChatCompletionResponseMessage { + pub content: Option, + pub role: String, + } + + // region: --- Streaming + #[derive(Debug, Clone, Deserialize)] + pub struct CreateChatCompletionsStreamResponse { + pub choices: Vec, + } + + #[derive(Debug, Clone, Deserialize)] + pub struct ChatCompletionStreamChoice { + pub delta: Option, + } + + #[derive(Debug, Clone, Deserialize)] + pub struct ChatCompletionStreamResponseMessage { + pub content: Option, + pub role: Option, + } + + // endregion: Streaming +} diff --git a/sdk/openai/inference/src/models/mod.rs b/sdk/openai/inference/src/models/mod.rs index b8be6322b5..8ccec0e32c 100644 --- a/sdk/openai/inference/src/models/mod.rs +++ b/sdk/openai/inference/src/models/mod.rs @@ -1,5 +1,3 @@ -mod request; -mod response; +mod chat_completions; -pub use request::*; -pub use response::*; +pub use chat_completions::*; diff --git a/sdk/openai/inference/src/models/request/chat_completions.rs b/sdk/openai/inference/src/models/request/chat_completions.rs deleted file mode 100644 index 3af246183a..0000000000 --- a/sdk/openai/inference/src/models/request/chat_completions.rs +++ /dev/null @@ -1,66 +0,0 @@ -use serde::Serialize; - -#[derive(Serialize, Debug, Clone, Default)] -pub struct CreateChatCompletionsRequest { - pub messages: Vec, - pub model: String, - pub stream: Option, - // pub frequency_penalty: f64, - // pub logit_bias: Option>, - // pub logprobs: Option, - // pub top_logprobs: Option, - // pub max_tokens: Option, -} - -#[derive(Serialize, Debug, Clone, Default)] -pub struct ChatCompletionRequestMessageBase { - #[serde(skip)] - pub name: Option, - pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type) -} - -#[derive(Serialize, Debug, Clone)] -#[serde(tag = "role")] -pub enum ChatCompletionRequestMessage { - #[serde(rename = "system")] - System(ChatCompletionRequestMessageBase), - #[serde(rename = "user")] - User(ChatCompletionRequestMessageBase), -} - -impl ChatCompletionRequestMessage { - pub fn new_user(content: impl Into) -> Self { - Self::User(ChatCompletionRequestMessageBase { - content: content.into(), - name: None, - }) - } - - pub fn new_system(content: impl Into) -> Self { - Self::System(ChatCompletionRequestMessageBase { - content: content.into(), - name: None, - }) - } -} -impl CreateChatCompletionsRequest { - pub fn new_with_user_message(model: &str, prompt: &str) -> Self { - Self { - model: model.to_string(), - messages: vec![ChatCompletionRequestMessage::new_user(prompt)], - ..Default::default() - } - } - - pub fn new_stream_with_user_message( - model: impl Into, - prompt: impl Into, - ) -> Self { - Self { - model: model.into(), - messages: vec![ChatCompletionRequestMessage::new_user(prompt)], - stream: Some(true), - ..Default::default() - } - } -} diff --git a/sdk/openai/inference/src/models/request/mod.rs b/sdk/openai/inference/src/models/request/mod.rs deleted file mode 100644 index 8ccec0e32c..0000000000 --- a/sdk/openai/inference/src/models/request/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod chat_completions; - -pub use chat_completions::*; diff --git a/sdk/openai/inference/src/models/response/chat_completions.rs b/sdk/openai/inference/src/models/response/chat_completions.rs deleted file mode 100644 index 687fa7dca4..0000000000 --- a/sdk/openai/inference/src/models/response/chat_completions.rs +++ /dev/null @@ -1,36 +0,0 @@ -use serde::Deserialize; - -#[derive(Debug, Clone, Deserialize)] -pub struct CreateChatCompletionsResponse { - pub choices: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ChatCompletionChoice { - pub message: ChatCompletionResponseMessage, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ChatCompletionResponseMessage { - pub content: Option, - pub role: String, -} - -// region: --- Streaming -#[derive(Debug, Clone, Deserialize)] -pub struct CreateChatCompletionsStreamResponse { - pub choices: Vec, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ChatCompletionStreamChoice { - pub delta: Option, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct ChatCompletionStreamResponseMessage { - pub content: Option, - pub role: Option, -} - -// endregion: Streaming diff --git a/sdk/openai/inference/src/models/response/mod.rs b/sdk/openai/inference/src/models/response/mod.rs deleted file mode 100644 index 8ccec0e32c..0000000000 --- a/sdk/openai/inference/src/models/response/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod chat_completions; - -pub use chat_completions::*; From 060a944e9b268aec6d7974db5aa43e14d763cc57 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 18:42:48 +0200 Subject: [PATCH 13/71] Renamed new to with_key and provided context on request method --- .../examples/azure_chat_completions.rs | 2 +- sdk/openai/inference/src/clients/azure.rs | 19 ++++++++----------- .../src/clients/chat_completions_client.rs | 1 + sdk/openai/inference/src/clients/mod.rs | 1 + sdk/openai/inference/src/options/mod.rs | 2 +- 5 files changed, 12 insertions(+), 13 deletions(-) create mode 100644 sdk/openai/inference/src/clients/chat_completions_client.rs diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 78b06f65e3..4e1595d65e 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -10,7 +10,7 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let azure_openai_client = AzureOpenAIClient::new( + let azure_openai_client = AzureOpenAIClient::with_key( endpoint, secret, Some( diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 01eb19f853..13857c5391 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -8,24 +8,20 @@ use crate::response::CreateChatCompletionsResponse; use azure_core::{self, Method, Policy, Result}; use azure_core::{Context, Url}; -pub struct AzureOpenAIClient<'a> { +pub struct AzureOpenAIClient { endpoint: Url, - context: Context<'a>, pipeline: azure_core::Pipeline, options: AzureOpenAIClientOptions, } -impl AzureOpenAIClient<'_> { - // TODO: not sure if this should be named `with_key_credential` instead - pub fn new( +impl AzureOpenAIClient { + pub fn with_key( endpoint: impl AsRef, - secret: String, + secret: impl Into, client_options: Option, ) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; - let key_credential = AzureKeyCredential::new(secret); - - let context = Context::new(); + let key_credential = AzureKeyCredential::new(secret.into()); let options = client_options.unwrap_or_default(); let per_call_policies: Vec> = key_credential.clone().into(); @@ -34,7 +30,6 @@ impl AzureOpenAIClient<'_> { Ok(AzureOpenAIClient { endpoint, - context, pipeline, options, }) @@ -76,6 +71,8 @@ impl AzureOpenAIClient<'_> { &self.options.api_service_version.to_string(), ))?; + let context = Context::new(); + let mut request = azure_core::Request::new(url, Method::Post); // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) // request.add_mandatory_header(&self.key_credential); @@ -84,7 +81,7 @@ impl AzureOpenAIClient<'_> { let response = self .pipeline - .send::(&self.context, &mut request) + .send::(&context, &mut request) .await?; response.into_body().json().await } diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -0,0 +1 @@ + diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 2e8d14cefc..c6dba4c0d3 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -5,6 +5,7 @@ // use serde::Serialize; pub mod azure; +pub mod chat_completions_client; pub mod non_azure; // pub(crate) fn build_request( diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs index b3e2d9cbe0..c0d2c929d2 100644 --- a/sdk/openai/inference/src/options/mod.rs +++ b/sdk/openai/inference/src/options/mod.rs @@ -4,7 +4,7 @@ use azure_core::ClientOptions; pub use service_version::AzureServiceVersion; -// TODO: I was not able to find ClientOptions as a derive macros +// TODO: I was not able to find ClientOptions as a derive macros #[derive(Clone, Debug, Default)] pub struct AzureOpenAIClientOptions { pub(crate) client_options: ClientOptions, From 3417181ee73a4628f480f48c55b43106fc2649c4 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 19:11:21 +0200 Subject: [PATCH 14/71] Added api service version as a policy --- sdk/openai/inference/src/auth/mod.rs | 6 ++--- sdk/openai/inference/src/clients/azure.rs | 9 ++++--- .../inference/src/options/service_version.rs | 27 +++++++++++++++++++ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 20ef984aa5..06463ed675 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -49,9 +49,9 @@ impl Policy for AzureKeyCredential { } } -impl Into>> for AzureKeyCredential { - fn into(self) -> Vec> { - vec![Arc::new(self)] +impl Into> for AzureKeyCredential { + fn into(self) -> Arc { + Arc::new(self) } } diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure.rs index 13857c5391..d2476946d9 100644 --- a/sdk/openai/inference/src/clients/azure.rs +++ b/sdk/openai/inference/src/clients/azure.rs @@ -24,7 +24,9 @@ impl AzureOpenAIClient { let key_credential = AzureKeyCredential::new(secret.into()); let options = client_options.unwrap_or_default(); - let per_call_policies: Vec> = key_credential.clone().into(); + let auth_policy: Arc = key_credential.clone().into(); + let version_policy: Arc = options.api_service_version.clone().into(); + let per_call_policies: Vec> = vec![auth_policy, version_policy]; let pipeline = Self::new_pipeline(per_call_policies, options.client_options.clone()); @@ -65,10 +67,9 @@ impl AzureOpenAIClient { // chat_completions_request: RequestContent, ) -> Result { let url = Url::parse(&format!( - "{}/openai/deployments/{}/chat/completions?api-version={}", + "{}/openai/deployments/{}/chat/completions", &self.endpoint, - deployment_name, - &self.options.api_service_version.to_string(), + deployment_name ))?; let context = Context::new(); diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs index 7a7c3bc52d..7f29611136 100644 --- a/sdk/openai/inference/src/options/service_version.rs +++ b/sdk/openai/inference/src/options/service_version.rs @@ -1,3 +1,8 @@ +use std::sync::Arc; +use async_trait::async_trait; + +use azure_core::{Context, Policy, PolicyResult, Request}; + #[derive(Debug, Clone)] pub enum AzureServiceVersion { V2023_09_01Preview, @@ -33,3 +38,25 @@ impl ToString for AzureServiceVersion { String::from(self.clone()) } } + +// Not entirely sure this is a good idea +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for AzureServiceVersion { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.url_mut().query_pairs_mut().append_pair("api-version", &self.to_string()); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl Into> for AzureServiceVersion { + fn into(self) -> Arc { + Arc::new(self) + } +} From d6be9f33e701da8b036c6e681f7917403f4a8a98 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 19:14:54 +0200 Subject: [PATCH 15/71] Refactored clients --- .../src/clients/{azure.rs => azure_openai_client.rs} | 0 sdk/openai/inference/src/clients/mod.rs | 4 ++-- .../inference/src/clients/{non_azure.rs => openai_client.rs} | 0 sdk/openai/inference/src/lib.rs | 4 ++-- 4 files changed, 4 insertions(+), 4 deletions(-) rename sdk/openai/inference/src/clients/{azure.rs => azure_openai_client.rs} (100%) rename sdk/openai/inference/src/clients/{non_azure.rs => openai_client.rs} (100%) diff --git a/sdk/openai/inference/src/clients/azure.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs similarity index 100% rename from sdk/openai/inference/src/clients/azure.rs rename to sdk/openai/inference/src/clients/azure_openai_client.rs diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index c6dba4c0d3..c0608197a6 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -4,9 +4,9 @@ // }; // use serde::Serialize; -pub mod azure; +pub mod azure_openai_client; pub mod chat_completions_client; -pub mod non_azure; +pub mod openai_client; // pub(crate) fn build_request( // key_credential: &impl Header, diff --git a/sdk/openai/inference/src/clients/non_azure.rs b/sdk/openai/inference/src/clients/openai_client.rs similarity index 100% rename from sdk/openai/inference/src/clients/non_azure.rs rename to sdk/openai/inference/src/clients/openai_client.rs diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 311b0f2297..5b685503bd 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -3,7 +3,7 @@ mod clients; mod models; mod options; -pub use clients::azure::*; -pub use clients::non_azure::*; +pub use clients::azure_openai_client::*; +pub use clients::openai_client::*; pub use models::*; pub use options::*; From f8b21a8a1164a5370eb4c86a008693dc1ec36be5 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 19:15:45 +0200 Subject: [PATCH 16/71] Added comment for clarity --- sdk/openai/inference/src/clients/azure_openai_client.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index d2476946d9..3b25aff073 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -75,6 +75,7 @@ impl AzureOpenAIClient { let context = Context::new(); let mut request = azure_core::Request::new(url, Method::Post); + // this was replaced by the AzureServiceVersion policy, not sure what is the right approach // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) // request.add_mandatory_header(&self.key_credential); From 96eec97d9c28b974013043388493fe92cc71ab7b Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 19:19:39 +0200 Subject: [PATCH 17/71] More clarity in comments --- sdk/openai/inference/src/clients/azure_openai_client.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 3b25aff073..6a6bb6f359 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -43,6 +43,7 @@ impl AzureOpenAIClient { ) -> azure_core::Pipeline { let crate_name = option_env!("CARGO_PKG_NAME"); let crate_version = option_env!("CARGO_PKG_VERSION"); + // should I be using per_call_policies here too or are they used by default on retries too? let per_retry_policies = Vec::new(); azure_core::Pipeline::new( From ba1490b76568aaa80f4471bd89923b271ff2c58a Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 19:27:11 +0200 Subject: [PATCH 18/71] Added client option modules to differenciate between nonAzure and Azure --- .../options/azure_openai_client_options.rs | 38 ++++++++++++++++++ sdk/openai/inference/src/options/mod.rs | 40 ++----------------- .../src/options/openai_client_options.rs | 0 3 files changed, 41 insertions(+), 37 deletions(-) create mode 100644 sdk/openai/inference/src/options/azure_openai_client_options.rs create mode 100644 sdk/openai/inference/src/options/openai_client_options.rs diff --git a/sdk/openai/inference/src/options/azure_openai_client_options.rs b/sdk/openai/inference/src/options/azure_openai_client_options.rs new file mode 100644 index 0000000000..a6e5a9b50a --- /dev/null +++ b/sdk/openai/inference/src/options/azure_openai_client_options.rs @@ -0,0 +1,38 @@ +use azure_core::ClientOptions; + +use crate::AzureServiceVersion; + +// TODO: I was not able to find ClientOptions as a derive macros +#[derive(Clone, Debug, Default)] +pub struct AzureOpenAIClientOptions { + pub(crate) client_options: ClientOptions, + pub(crate) api_service_version: AzureServiceVersion, +} +impl AzureOpenAIClientOptions { + pub fn builder() -> builders::AzureOpenAIClientOptionsBuilder { + builders::AzureOpenAIClientOptionsBuilder::new() + } +} + +pub mod builders { + use super::*; + + #[derive(Clone, Debug, Default)] + pub struct AzureOpenAIClientOptionsBuilder { + options: AzureOpenAIClientOptions, + } + + impl AzureOpenAIClientOptionsBuilder { + pub(super) fn new() -> Self { + Self::default() + } + pub fn with_api_version(mut self, api_service_version: AzureServiceVersion) -> Self { + self.options.api_service_version = api_service_version; + self + } + + pub fn build(&self) -> AzureOpenAIClientOptions { + self.options.clone() + } + } +} diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs index c0d2c929d2..602817ff96 100644 --- a/sdk/openai/inference/src/options/mod.rs +++ b/sdk/openai/inference/src/options/mod.rs @@ -1,40 +1,6 @@ mod service_version; - -use azure_core::ClientOptions; +mod azure_openai_client_options; +mod openai_client_options; pub use service_version::AzureServiceVersion; - -// TODO: I was not able to find ClientOptions as a derive macros -#[derive(Clone, Debug, Default)] -pub struct AzureOpenAIClientOptions { - pub(crate) client_options: ClientOptions, - pub(crate) api_service_version: AzureServiceVersion, -} -impl AzureOpenAIClientOptions { - pub fn builder() -> builders::AzureOpenAIClientOptionsBuilder { - builders::AzureOpenAIClientOptionsBuilder::new() - } -} - -pub mod builders { - use super::*; - - #[derive(Clone, Debug, Default)] - pub struct AzureOpenAIClientOptionsBuilder { - options: AzureOpenAIClientOptions, - } - - impl AzureOpenAIClientOptionsBuilder { - pub(super) fn new() -> Self { - Self::default() - } - pub fn with_api_version(mut self, api_service_version: AzureServiceVersion) -> Self { - self.options.api_service_version = api_service_version; - self - } - - pub fn build(&self) -> AzureOpenAIClientOptions { - self.options.clone() - } - } -} +pub use azure_openai_client_options::*; diff --git a/sdk/openai/inference/src/options/openai_client_options.rs b/sdk/openai/inference/src/options/openai_client_options.rs new file mode 100644 index 0000000000..e69de29bb2 From c92cbe5fe8d78f869fd28aa13d84db55a6e9b7f1 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 12 Sep 2024 19:28:42 +0200 Subject: [PATCH 19/71] added module visibility --- sdk/openai/inference/src/options/mod.rs | 1 + sdk/openai/inference/src/options/openai_client_options.rs | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs index 602817ff96..be871b38b7 100644 --- a/sdk/openai/inference/src/options/mod.rs +++ b/sdk/openai/inference/src/options/mod.rs @@ -4,3 +4,4 @@ mod openai_client_options; pub use service_version::AzureServiceVersion; pub use azure_openai_client_options::*; +pub use openai_client_options::*; diff --git a/sdk/openai/inference/src/options/openai_client_options.rs b/sdk/openai/inference/src/options/openai_client_options.rs index e69de29bb2..9632aa73d2 100644 --- a/sdk/openai/inference/src/options/openai_client_options.rs +++ b/sdk/openai/inference/src/options/openai_client_options.rs @@ -0,0 +1,6 @@ +use azure_core::ClientOptions; + +#[derive(Clone, Debug, Default)] +pub struct OpenAIClientOptions { + pub(crate) client_options: ClientOptions, +} From 12e61e43bd60aacbe4321f0b67eaa704bd797161 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 10:17:48 +0200 Subject: [PATCH 20/71] Made auth classes private --- .../src/clients/azure_openai_client.rs | 45 +++++++++---------- .../src/clients/chat_completions_client.rs | 2 +- sdk/openai/inference/src/lib.rs | 2 +- sdk/openai/inference/src/options/mod.rs | 4 +- .../inference/src/options/service_version.rs | 7 ++- 5 files changed, 31 insertions(+), 29 deletions(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 6a6bb6f359..77c8467719 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -21,14 +21,14 @@ impl AzureOpenAIClient { client_options: Option, ) -> Result { let endpoint = Url::parse(endpoint.as_ref())?; - let key_credential = AzureKeyCredential::new(secret.into()); let options = client_options.unwrap_or_default(); - let auth_policy: Arc = key_credential.clone().into(); + + let auth_policy: Arc = AzureKeyCredential::new(secret.into()).into(); let version_policy: Arc = options.api_service_version.clone().into(); let per_call_policies: Vec> = vec![auth_policy, version_policy]; - let pipeline = Self::new_pipeline(per_call_policies, options.client_options.clone()); + let pipeline = new_pipeline(per_call_policies, options.client_options.clone()); Ok(AzureOpenAIClient { endpoint, @@ -37,24 +37,6 @@ impl AzureOpenAIClient { }) } - fn new_pipeline( - per_call_policies: Vec>, - options: azure_core::ClientOptions, - ) -> azure_core::Pipeline { - let crate_name = option_env!("CARGO_PKG_NAME"); - let crate_version = option_env!("CARGO_PKG_VERSION"); - // should I be using per_call_policies here too or are they used by default on retries too? - let per_retry_policies = Vec::new(); - - azure_core::Pipeline::new( - crate_name, - crate_version, - options, - per_call_policies, - per_retry_policies, - ) - } - pub fn endpoint(&self) -> &Url { &self.endpoint } @@ -69,8 +51,7 @@ impl AzureOpenAIClient { ) -> Result { let url = Url::parse(&format!( "{}/openai/deployments/{}/chat/completions", - &self.endpoint, - deployment_name + &self.endpoint, deployment_name ))?; let context = Context::new(); @@ -89,3 +70,21 @@ impl AzureOpenAIClient { response.into_body().json().await } } + +fn new_pipeline( + per_call_policies: Vec>, + options: azure_core::ClientOptions, +) -> azure_core::Pipeline { + let crate_name = option_env!("CARGO_PKG_NAME"); + let crate_version = option_env!("CARGO_PKG_VERSION"); + // should I be using per_call_policies here too or are they used by default on retries too? + let per_retry_policies = Vec::new(); + + azure_core::Pipeline::new( + crate_name, + crate_version, + options, + per_call_policies, + per_retry_policies, + ) +} diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 8b13789179..955e32de51 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1 +1 @@ - +pub struct ChatCompletionsClient; diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 5b685503bd..0b5dcb8c2e 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,4 +1,4 @@ -pub mod auth; +mod auth; mod clients; mod models; mod options; diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs index be871b38b7..bc1cf1eb13 100644 --- a/sdk/openai/inference/src/options/mod.rs +++ b/sdk/openai/inference/src/options/mod.rs @@ -1,7 +1,7 @@ -mod service_version; mod azure_openai_client_options; mod openai_client_options; +mod service_version; -pub use service_version::AzureServiceVersion; pub use azure_openai_client_options::*; pub use openai_client_options::*; +pub use service_version::AzureServiceVersion; diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs index 7f29611136..7637709de1 100644 --- a/sdk/openai/inference/src/options/service_version.rs +++ b/sdk/openai/inference/src/options/service_version.rs @@ -1,5 +1,5 @@ -use std::sync::Arc; use async_trait::async_trait; +use std::sync::Arc; use azure_core::{Context, Policy, PolicyResult, Request}; @@ -50,7 +50,10 @@ impl Policy for AzureServiceVersion { request: &mut Request, next: &[Arc], ) -> PolicyResult { - request.url_mut().query_pairs_mut().append_pair("api-version", &self.to_string()); + request + .url_mut() + .query_pairs_mut() + .append_pair("api-version", &self.to_string()); next[0].send(ctx, request, &next[1..]).await } } From eff2f89b2842a016158a94be0e5914c49e3e03e2 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 10:24:05 +0200 Subject: [PATCH 21/71] wip --- .../src/auth/azure_key_credential.rs | 48 +++++++++++++ sdk/openai/inference/src/auth/mod.rs | 69 ++----------------- .../src/auth/openai_key_credential.rs | 23 +++++++ .../src/clients/azure_openai_client.rs | 2 +- 4 files changed, 76 insertions(+), 66 deletions(-) create mode 100644 sdk/openai/inference/src/auth/azure_key_credential.rs create mode 100644 sdk/openai/inference/src/auth/openai_key_credential.rs diff --git a/sdk/openai/inference/src/auth/azure_key_credential.rs b/sdk/openai/inference/src/auth/azure_key_credential.rs new file mode 100644 index 0000000000..d0db604fe4 --- /dev/null +++ b/sdk/openai/inference/src/auth/azure_key_credential.rs @@ -0,0 +1,48 @@ +use async_trait::async_trait; +use std::sync::Arc; + +use azure_core::{ + auth::Secret, + headers::{HeaderName, HeaderValue}, + Context, Header, Policy, PolicyResult, Request, +}; + +#[derive(Debug, Clone)] +pub struct AzureKeyCredential(Secret); + +impl AzureKeyCredential { + pub fn new(api_key: impl Into) -> Self { + Self(Secret::new(api_key.into())) + } +} + +impl Header for AzureKeyCredential { + fn name(&self) -> HeaderName { + HeaderName::from_static("api-key") + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("{}", self.0.secret())) + } +} + +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for AzureKeyCredential { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.insert_header(Header::name(self), Header::value(self)); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl Into> for AzureKeyCredential { + fn into(self) -> Arc { + Arc::new(self) + } +} diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 06463ed675..63db568285 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,66 +1,5 @@ -use async_trait::async_trait; -use std::sync::Arc; +mod azure_key_credential; +mod openai_key_credential; -use azure_core::{ - auth::Secret, - headers::{HeaderName, HeaderValue, AUTHORIZATION}, - Context, Header, Policy, PolicyResult, Request, -}; - -#[derive(Debug, Clone)] -pub struct AzureKeyCredential(Secret); - -pub struct OpenAIKeyCredential(Secret); - -impl OpenAIKeyCredential { - pub fn new(access_token: String) -> Self { - Self(Secret::new(access_token)) - } -} - -impl AzureKeyCredential { - pub fn new(api_key: String) -> Self { - Self(Secret::new(api_key)) - } -} - -impl Header for AzureKeyCredential { - fn name(&self) -> HeaderName { - HeaderName::from_static("api-key") - } - - fn value(&self) -> HeaderValue { - HeaderValue::from_cow(format!("{}", self.0.secret())) - } -} - -// code lifted from BearerTokenCredentialPolicy -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl Policy for AzureKeyCredential { - async fn send( - &self, - ctx: &Context, - request: &mut Request, - next: &[Arc], - ) -> PolicyResult { - request.insert_header(Header::name(self), Header::value(self)); - next[0].send(ctx, request, &next[1..]).await - } -} - -impl Into> for AzureKeyCredential { - fn into(self) -> Arc { - Arc::new(self) - } -} - -impl Header for OpenAIKeyCredential { - fn name(&self) -> HeaderName { - AUTHORIZATION - } - - fn value(&self) -> HeaderValue { - HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) - } -} +pub(crate) use azure_key_credential::*; +pub(crate) use openai_key_credential::*; diff --git a/sdk/openai/inference/src/auth/openai_key_credential.rs b/sdk/openai/inference/src/auth/openai_key_credential.rs new file mode 100644 index 0000000000..0b170e208c --- /dev/null +++ b/sdk/openai/inference/src/auth/openai_key_credential.rs @@ -0,0 +1,23 @@ +use azure_core::{ + auth::Secret, + headers::{HeaderName, HeaderValue, AUTHORIZATION}, + Header, +}; + +pub struct OpenAIKeyCredential(Secret); + +impl OpenAIKeyCredential { + pub fn new(access_token: String) -> Self { + Self(Secret::new(access_token)) + } +} + +impl Header for OpenAIKeyCredential { + fn name(&self) -> HeaderName { + AUTHORIZATION + } + + fn value(&self) -> HeaderValue { + HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) + } +} diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 77c8467719..7abcff36c6 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -24,7 +24,7 @@ impl AzureOpenAIClient { let options = client_options.unwrap_or_default(); - let auth_policy: Arc = AzureKeyCredential::new(secret.into()).into(); + let auth_policy: Arc = AzureKeyCredential::new(secret).into(); let version_policy: Arc = options.api_service_version.clone().into(); let per_call_policies: Vec> = vec![auth_policy, version_policy]; From b07cecee8e999c0dd3dc45e0c52ec2e3cdfb6e9b Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 10:24:55 +0200 Subject: [PATCH 22/71] cleanup --- sdk/openai/inference/src/clients/mod.rs | 40 ------------------------- 1 file changed, 40 deletions(-) diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index c0608197a6..d40414f7f0 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -1,43 +1,3 @@ -// use azure_core::{ -// headers::{ACCEPT, CONTENT_TYPE}, -// Header, Method, Request, Result, Url, -// }; -// use serde::Serialize; - pub mod azure_openai_client; pub mod chat_completions_client; pub mod openai_client; - -// pub(crate) fn build_request( -// key_credential: &impl Header, -// url: Url, -// method: Method, -// data: &T, -// ) -> Result -// where -// T: ?Sized + Serialize, -// { -// let mut request = Request::new(url, method); -// request.add_mandatory_header(key_credential); -// request.insert_header(CONTENT_TYPE, "application/json"); -// request.insert_header(ACCEPT, "application/json"); -// request.set_json(data)?; -// Ok(request) -// } - -// pub(crate) fn build_multipart_request( -// key_credential: &impl Header, -// url: Url, -// form_generator: F, -// ) -> Result -// where -// F: FnOnce() -> Result, -// { -// let mut request = Request::new(url, Method::Post); -// request.add_mandatory_header(key_credential); -// // handled insternally by reqwest -// // request.insert_header(CONTENT_TYPE, "multipart/form-data"); -// // request.insert_header(ACCEPT, "application/json"); -// request.multipart(form_generator()?); -// Ok(request) -// } From 5cd092527cfff75e8b4dac40276fc48edb80af77 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 12:02:26 +0200 Subject: [PATCH 23/71] Project compiles and runs, but request errors --- .../examples/azure_chat_completions.rs | 12 ++-- sdk/openai/inference/src/auth/mod.rs | 4 +- .../src/clients/azure_openai_client.rs | 72 ++++++++++--------- .../src/clients/chat_completions_client.rs | 52 +++++++++++++- sdk/openai/inference/src/clients/mod.rs | 15 +++- sdk/openai/inference/src/lib.rs | 4 +- 6 files changed, 111 insertions(+), 48 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 4e1595d65e..81f68e1b0f 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,7 +1,8 @@ use azure_core::Result; use azure_openai_inference::{ - request::CreateChatCompletionsRequest, AzureOpenAIClient, AzureOpenAIClientOptions, - AzureServiceVersion, + clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, + request::CreateChatCompletionsRequest, + AzureOpenAIClientOptions, AzureServiceVersion, }; #[tokio::main] @@ -10,7 +11,7 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let azure_openai_client = AzureOpenAIClient::with_key( + let chat_completions_client = AzureOpenAIClient::with_key( endpoint, secret, Some( @@ -18,14 +19,15 @@ pub async fn main() -> Result<()> { .with_api_version(AzureServiceVersion::V2023_12_01Preview) .build(), ), - )?; + )? + .chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( "gpt-4-1106-preview", "Tell me a joke about pineapples", ); - let response = azure_openai_client + let response = chat_completions_client .create_chat_completions(&chat_completions_request.model, &chat_completions_request) .await; diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 63db568285..02dcf59f6f 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,5 +1,5 @@ mod azure_key_credential; -mod openai_key_credential; +// mod openai_key_credential; pub(crate) use azure_key_credential::*; -pub(crate) use openai_key_credential::*; +// pub(crate) use openai_key_credential::*; diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 7abcff36c6..8e7156d1d1 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -3,19 +3,35 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::options::AzureOpenAIClientOptions; -use crate::request::CreateChatCompletionsRequest; -use crate::response::CreateChatCompletionsResponse; -use azure_core::{self, Method, Policy, Result}; -use azure_core::{Context, Url}; +use azure_core::Url; +use azure_core::{self, Policy, Result}; +use super::chat_completions_client::ChatCompletionsClient; +use super::BaseOpenAIClientMethods; + +pub trait AzureOpenAIClientMethods: BaseOpenAIClientMethods { + fn with_key( + endpoint: impl AsRef, + secret: impl Into, + client_options: Option, + ) -> Result + where + Self: Sized; + + fn endpoint(&self) -> &Url; + + fn chat_completions_client(&self) -> ChatCompletionsClient; +} + +#[derive(Debug, Clone)] pub struct AzureOpenAIClient { endpoint: Url, pipeline: azure_core::Pipeline, options: AzureOpenAIClientOptions, } -impl AzureOpenAIClient { - pub fn with_key( +impl AzureOpenAIClientMethods for AzureOpenAIClient { + fn with_key( endpoint: impl AsRef, secret: impl Into, client_options: Option, @@ -37,37 +53,25 @@ impl AzureOpenAIClient { }) } - pub fn endpoint(&self) -> &Url { + fn endpoint(&self) -> &Url { &self.endpoint } - pub async fn create_chat_completions( - &self, - deployment_name: &str, - chat_completions_request: &CreateChatCompletionsRequest, - // Should I be using RequestContent ? All the new methods have signatures that would force me to mutate - // the request object into &static str, Vec, etc. - // chat_completions_request: RequestContent, - ) -> Result { - let url = Url::parse(&format!( - "{}/openai/deployments/{}/chat/completions", - &self.endpoint, deployment_name - ))?; - - let context = Context::new(); - - let mut request = azure_core::Request::new(url, Method::Post); - // this was replaced by the AzureServiceVersion policy, not sure what is the right approach - // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) - // request.add_mandatory_header(&self.key_credential); - - request.set_json(chat_completions_request)?; - - let response = self - .pipeline - .send::(&context, &mut request) - .await?; - response.into_body().json().await + fn chat_completions_client(&self) -> ChatCompletionsClient { + ChatCompletionsClient::new(Box::new(self.clone())) + } +} + +impl BaseOpenAIClientMethods for AzureOpenAIClient { + fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result { + // TODO gracefully handle this + Ok(self + .endpoint() + .join(deployment_name.expect("deployment_name should be provided"))?) + } + + fn pipeline(&self) -> &azure_core::Pipeline { + &self.pipeline } } diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 955e32de51..54a4f85e4b 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1 +1,51 @@ -pub struct ChatCompletionsClient; +use super::BaseOpenAIClientMethods; +use crate::{request::CreateChatCompletionsRequest, response::CreateChatCompletionsResponse}; +use azure_core::{Context, Method, Result}; + +pub trait ChatCompletionsClientMethods { + #[allow(async_fn_in_trait)] + async fn create_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result; +} + +pub struct ChatCompletionsClient { + base_client: Box, +} + +impl ChatCompletionsClient { + pub fn new(base_client: Box) -> Self { + Self { base_client } + } +} + +impl ChatCompletionsClientMethods for ChatCompletionsClient { + async fn create_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result { + let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; + let request_url = base_url.join("chat/completions")?; + + let context = Context::new(); + + let mut request = azure_core::Request::new(request_url, Method::Post); + // this was replaced by the AzureServiceVersion policy, not sure what is the right approach + // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) + // request.add_mandatory_header(&self.key_credential); + + request.set_json(chat_completions_request)?; + + let response = self + .base_client + .pipeline() + .send::(&context, &mut request) + .await?; + response.into_body().json().await + + // todo!() + } +} diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index d40414f7f0..2c7e718436 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -1,3 +1,12 @@ -pub mod azure_openai_client; -pub mod chat_completions_client; -pub mod openai_client; +mod azure_openai_client; +mod chat_completions_client; +mod openai_client; + +pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; +pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; + +pub trait BaseOpenAIClientMethods { + fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result; + + fn pipeline(&self) -> &azure_core::Pipeline; +} diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 0b5dcb8c2e..8ade855521 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,9 +1,7 @@ mod auth; -mod clients; +pub mod clients; mod models; mod options; -pub use clients::azure_openai_client::*; -pub use clients::openai_client::*; pub use models::*; pub use options::*; From ff6cba09c296b6d5f0c56e3ee26bd3ec5dab332e Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 12:11:30 +0200 Subject: [PATCH 24/71] Functionality restored --- sdk/openai/inference/src/clients/azure_openai_client.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 8e7156d1d1..beee712108 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -67,7 +67,9 @@ impl BaseOpenAIClientMethods for AzureOpenAIClient { // TODO gracefully handle this Ok(self .endpoint() - .join(deployment_name.expect("deployment_name should be provided"))?) + .join("openai/")? + .join("deployments/")? + .join(&format!("{}/", deployment_name.expect("Deployment name is required")))?) } fn pipeline(&self) -> &azure_core::Pipeline { From bcba5b92f50b124c91eef0d5b0d4097f1e082b89 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 14:47:59 +0200 Subject: [PATCH 25/71] Running state after rebase --- sdk/openai/inference/Cargo.toml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index f28284a620..ddad21098c 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -16,18 +16,12 @@ workspace = true [dependencies] azure_core = { workspace = true } -# reqwest = { workspace = true, optional = true } tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } async-trait = { workspace = true } [dev-dependencies] +azure_core = { workspace = true, features = ["reqwest"] } reqwest = { workspace = true } tokio = { workspace = true } - -[features] -# default = [ "reqwest" ] -# reqwest = [ "dep:reqwest" ] -default = ["enable_reqwest"] -enable_reqwest = ["azure_core/enable_reqwest"] From 053b4f7bb027fc788012a39904ee033eea49b934 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 14:48:12 +0200 Subject: [PATCH 26/71] Running state after rebase --- sdk/openai/inference/Cargo.toml | 1 + .../inference/examples/azure_chat_completions.rs | 6 +++++- .../inference/src/clients/azure_openai_client.rs | 5 ++++- .../src/clients/chat_completions_client.rs | 14 +++++--------- .../inference/src/models/chat_completions.rs | 3 ++- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index ddad21098c..80f0cd5994 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -20,6 +20,7 @@ tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } async-trait = { workspace = true } +typespec_client_core = { workspace = true, features = ["derive"] } [dev-dependencies] azure_core = { workspace = true, features = ["reqwest"] } diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 81f68e1b0f..f308bd750b 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -2,6 +2,7 @@ use azure_core::Result; use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, request::CreateChatCompletionsRequest, + response::CreateChatCompletionsResponse, AzureOpenAIClientOptions, AzureServiceVersion, }; @@ -32,7 +33,10 @@ pub async fn main() -> Result<()> { .await; match response { - Ok(chat_completions) => { + Ok(chat_completions_response) => { + let chat_completions = chat_completions_response + .deserialize_body_into::() + .await?; println!("{:#?}", &chat_completions); } Err(e) => { diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index beee712108..682ce4044e 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -69,7 +69,10 @@ impl BaseOpenAIClientMethods for AzureOpenAIClient { .endpoint() .join("openai/")? .join("deployments/")? - .join(&format!("{}/", deployment_name.expect("Deployment name is required")))?) + .join(&format!( + "{}/", + deployment_name.expect("Deployment name is required") + ))?) } fn pipeline(&self) -> &azure_core::Pipeline { diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 54a4f85e4b..53b5dfd6b2 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,6 +1,6 @@ use super::BaseOpenAIClientMethods; use crate::{request::CreateChatCompletionsRequest, response::CreateChatCompletionsResponse}; -use azure_core::{Context, Method, Result}; +use azure_core::{Context, Method, Response, Result}; pub trait ChatCompletionsClientMethods { #[allow(async_fn_in_trait)] @@ -8,7 +8,7 @@ pub trait ChatCompletionsClientMethods { &self, deployment_name: impl AsRef, chat_completions_request: &CreateChatCompletionsRequest, - ) -> Result; + ) -> Result>; } pub struct ChatCompletionsClient { @@ -26,7 +26,7 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { &self, deployment_name: impl AsRef, chat_completions_request: &CreateChatCompletionsRequest, - ) -> Result { + ) -> Result> { let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; let request_url = base_url.join("chat/completions")?; @@ -39,13 +39,9 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { request.set_json(chat_completions_request)?; - let response = self - .base_client + self.base_client .pipeline() .send::(&context, &mut request) - .await?; - response.into_body().json().await - - // todo!() + .await } } diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs index bf250e131a..0586a51819 100644 --- a/sdk/openai/inference/src/models/chat_completions.rs +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -70,9 +70,10 @@ pub mod request { pub mod response { + use azure_core::Model; use serde::Deserialize; - #[derive(Debug, Clone, Deserialize)] + #[derive(Debug, Clone, Deserialize, Model)] pub struct CreateChatCompletionsResponse { pub choices: Vec, } From a2c17529683ef1d6272eefb43e0509979a616a6b Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 14:52:54 +0200 Subject: [PATCH 27/71] Adding comment for clarity --- sdk/openai/inference/examples/azure_chat_completions.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index f308bd750b..fb5283e797 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -33,6 +33,7 @@ pub async fn main() -> Result<()> { .await; match response { + // TODO: I don't understand why the Response generic type gets erased when calling `deserialize_body_into` Ok(chat_completions_response) => { let chat_completions = chat_completions_response .deserialize_body_into::() From 35fd48c569c06335af4f863e39da19080356d7a5 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 15:24:32 +0200 Subject: [PATCH 28/71] OpenAIClient code builds --- sdk/openai/inference/src/auth/mod.rs | 4 +- .../src/auth/openai_key_credential.rs | 31 +++++++++- .../src/clients/azure_openai_client.rs | 25 ++------ sdk/openai/inference/src/clients/mod.rs | 21 +++++++ .../inference/src/clients/openai_client.rs | 60 ++++++++++++++++++- 5 files changed, 114 insertions(+), 27 deletions(-) diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 02dcf59f6f..63db568285 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -1,5 +1,5 @@ mod azure_key_credential; -// mod openai_key_credential; +mod openai_key_credential; pub(crate) use azure_key_credential::*; -// pub(crate) use openai_key_credential::*; +pub(crate) use openai_key_credential::*; diff --git a/sdk/openai/inference/src/auth/openai_key_credential.rs b/sdk/openai/inference/src/auth/openai_key_credential.rs index 0b170e208c..2562cb288d 100644 --- a/sdk/openai/inference/src/auth/openai_key_credential.rs +++ b/sdk/openai/inference/src/auth/openai_key_credential.rs @@ -1,14 +1,18 @@ +use async_trait::async_trait; +use std::sync::Arc; + use azure_core::{ auth::Secret, headers::{HeaderName, HeaderValue, AUTHORIZATION}, - Header, + Context, Header, Policy, PolicyResult, Request, }; +#[derive(Debug, Clone)] pub struct OpenAIKeyCredential(Secret); impl OpenAIKeyCredential { - pub fn new(access_token: String) -> Self { - Self(Secret::new(access_token)) + pub fn new(access_token: impl Into) -> Self { + Self(Secret::new(access_token.into())) } } @@ -21,3 +25,24 @@ impl Header for OpenAIKeyCredential { HeaderValue::from_cow(format!("Bearer {}", &self.0.secret())) } } + +// code lifted from BearerTokenCredentialPolicy +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl Policy for OpenAIKeyCredential { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + request.insert_header(Header::name(self), Header::value(self)); + next[0].send(ctx, request, &next[1..]).await + } +} + +impl Into> for OpenAIKeyCredential { + fn into(self) -> Arc { + Arc::new(self) + } +} diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 682ce4044e..28f0714858 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -27,6 +27,7 @@ pub trait AzureOpenAIClientMethods: BaseOpenAIClientMethods { pub struct AzureOpenAIClient { endpoint: Url, pipeline: azure_core::Pipeline, + #[allow(dead_code)] options: AzureOpenAIClientOptions, } @@ -44,7 +45,7 @@ impl AzureOpenAIClientMethods for AzureOpenAIClient { let version_policy: Arc = options.api_service_version.clone().into(); let per_call_policies: Vec> = vec![auth_policy, version_policy]; - let pipeline = new_pipeline(per_call_policies, options.client_options.clone()); + let pipeline = super::new_pipeline(per_call_policies, options.client_options.clone()); Ok(AzureOpenAIClient { endpoint, @@ -64,14 +65,14 @@ impl AzureOpenAIClientMethods for AzureOpenAIClient { impl BaseOpenAIClientMethods for AzureOpenAIClient { fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result { - // TODO gracefully handle this + // TODO gracefully handle this, if it makes sense. A panic seems appropriate IMO. Ok(self .endpoint() .join("openai/")? .join("deployments/")? .join(&format!( "{}/", - deployment_name.expect("Deployment name is required") + deployment_name.expect("Deployment name is required.") ))?) } @@ -79,21 +80,3 @@ impl BaseOpenAIClientMethods for AzureOpenAIClient { &self.pipeline } } - -fn new_pipeline( - per_call_policies: Vec>, - options: azure_core::ClientOptions, -) -> azure_core::Pipeline { - let crate_name = option_env!("CARGO_PKG_NAME"); - let crate_version = option_env!("CARGO_PKG_VERSION"); - // should I be using per_call_policies here too or are they used by default on retries too? - let per_retry_policies = Vec::new(); - - azure_core::Pipeline::new( - crate_name, - crate_version, - options, - per_call_policies, - per_retry_policies, - ) -} diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 2c7e718436..512928ef4f 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -2,11 +2,32 @@ mod azure_openai_client; mod chat_completions_client; mod openai_client; +use std::sync::Arc; + pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; +pub use openai_client::{OpenAIClient, OpenAIClientMethods}; pub trait BaseOpenAIClientMethods { fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result; fn pipeline(&self) -> &azure_core::Pipeline; } + +fn new_pipeline( + per_call_policies: Vec>, + options: azure_core::ClientOptions, +) -> azure_core::Pipeline { + let crate_name = option_env!("CARGO_PKG_NAME"); + let crate_version = option_env!("CARGO_PKG_VERSION"); + // should I be using per_call_policies here too or are they used by default on retries too? + let per_retry_policies = Vec::new(); + + azure_core::Pipeline::new( + crate_name, + crate_version, + options, + per_call_policies, + per_retry_policies, + ) +} diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index 3853afde6d..4321d16337 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -1 +1,59 @@ -pub struct OpenAIClient {} +use std::sync::Arc; + +use azure_core::{Policy, Result, Url}; + +use crate::{auth::OpenAIKeyCredential, OpenAIClientOptions}; + +use super::{BaseOpenAIClientMethods, ChatCompletionsClient}; + +pub trait OpenAIClientMethods { + fn with_key( + secret: impl Into, + client_options: Option, + ) -> Result + where + Self: Sized; + + fn chat_completions_client(&self) -> ChatCompletionsClient; +} + +#[derive(Debug, Clone)] +pub struct OpenAIClient { + base_url: Url, + pipeline: azure_core::Pipeline, + #[allow(dead_code)] + options: OpenAIClientOptions, +} + +impl OpenAIClientMethods for OpenAIClient { + fn with_key( + secret: impl Into, + client_options: Option, + ) -> Result { + let base_url = Url::parse("https://api.openai.com/v1")?; + let options = client_options.unwrap_or_default(); + let auth_policy: Arc = OpenAIKeyCredential::new(secret).into(); + + let pipeline = super::new_pipeline(vec![auth_policy], options.client_options.clone()); + + Ok(OpenAIClient { + base_url, + pipeline, + options, + }) + } + + fn chat_completions_client(&self) -> ChatCompletionsClient { + ChatCompletionsClient::new(Box::new(self.clone())) + } +} + +impl BaseOpenAIClientMethods for OpenAIClient { + fn pipeline(&self) -> &azure_core::Pipeline { + &self.pipeline + } + + fn base_url(&self, _deployment_name: Option<&str>) -> Result { + Ok(self.base_url.clone()) + } +} From b2fdf3107d8b3121fe521fab7147daa5b4535c59 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 15:54:59 +0200 Subject: [PATCH 29/71] non-Azure OpenAI works again --- .../examples/non_azure_chat_completions.rs | 34 +++++++++++++++++++ .../src/clients/chat_completions_client.rs | 8 ++++- .../inference/src/clients/openai_client.rs | 2 +- .../options/azure_openai_client_options.rs | 1 + .../src/options/openai_client_options.rs | 25 ++++++++++++++ 5 files changed, 68 insertions(+), 2 deletions(-) diff --git a/sdk/openai/inference/examples/non_azure_chat_completions.rs b/sdk/openai/inference/examples/non_azure_chat_completions.rs index 8b13789179..082a87548e 100644 --- a/sdk/openai/inference/examples/non_azure_chat_completions.rs +++ b/sdk/openai/inference/examples/non_azure_chat_completions.rs @@ -1 +1,35 @@ +use azure_openai_inference::{ + clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, + request::CreateChatCompletionsRequest, + response::CreateChatCompletionsResponse, +}; +#[tokio::main] +pub async fn main() -> azure_core::Result<()> { + let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); + + let chat_completions_client = OpenAIClient::with_key(secret, None)?.chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( + "gpt-3.5-turbo-1106", + "Tell me a joke about pineapples", + ); + + let response = chat_completions_client + .create_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await; + + match response { + // TODO: I don't understand why the Response generic type gets erased when calling `deserialize_body_into` + Ok(chat_completions_response) => { + let chat_completions = chat_completions_response + .deserialize_body_into::() + .await?; + println!("{:#?}", &chat_completions); + } + Err(e) => { + println!("Error: {}", e); + } + }; + Ok(()) +} diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 53b5dfd6b2..6f0a834c59 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,6 +1,9 @@ use super::BaseOpenAIClientMethods; use crate::{request::CreateChatCompletionsRequest, response::CreateChatCompletionsResponse}; -use azure_core::{Context, Method, Response, Result}; +use azure_core::{ + headers::{ACCEPT, CONTENT_TYPE}, + Context, Method, Response, Result, +}; pub trait ChatCompletionsClientMethods { #[allow(async_fn_in_trait)] @@ -37,6 +40,9 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) // request.add_mandatory_header(&self.key_credential); + // For some reason non-Azure OpenAI's API is strict about these headers being present + request.insert_header(CONTENT_TYPE, "application/json"); + request.insert_header(ACCEPT, "application/json"); request.set_json(chat_completions_request)?; self.base_client diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index 4321d16337..01c3f56c5a 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -30,7 +30,7 @@ impl OpenAIClientMethods for OpenAIClient { secret: impl Into, client_options: Option, ) -> Result { - let base_url = Url::parse("https://api.openai.com/v1")?; + let base_url = Url::parse("https://api.openai.com/v1/")?; let options = client_options.unwrap_or_default(); let auth_policy: Arc = OpenAIKeyCredential::new(secret).into(); diff --git a/sdk/openai/inference/src/options/azure_openai_client_options.rs b/sdk/openai/inference/src/options/azure_openai_client_options.rs index a6e5a9b50a..410e8c4fe7 100644 --- a/sdk/openai/inference/src/options/azure_openai_client_options.rs +++ b/sdk/openai/inference/src/options/azure_openai_client_options.rs @@ -8,6 +8,7 @@ pub struct AzureOpenAIClientOptions { pub(crate) client_options: ClientOptions, pub(crate) api_service_version: AzureServiceVersion, } + impl AzureOpenAIClientOptions { pub fn builder() -> builders::AzureOpenAIClientOptionsBuilder { builders::AzureOpenAIClientOptionsBuilder::new() diff --git a/sdk/openai/inference/src/options/openai_client_options.rs b/sdk/openai/inference/src/options/openai_client_options.rs index 9632aa73d2..65eb66e97e 100644 --- a/sdk/openai/inference/src/options/openai_client_options.rs +++ b/sdk/openai/inference/src/options/openai_client_options.rs @@ -4,3 +4,28 @@ use azure_core::ClientOptions; pub struct OpenAIClientOptions { pub(crate) client_options: ClientOptions, } + +impl OpenAIClientOptions { + pub fn builder() -> builders::OpenAIClientOptionsBuilder { + builders::OpenAIClientOptionsBuilder::new() + } +} + +pub mod builders { + use super::*; + + #[derive(Clone, Debug, Default)] + pub struct OpenAIClientOptionsBuilder { + options: OpenAIClientOptions, + } + + impl OpenAIClientOptionsBuilder { + pub(super) fn new() -> Self { + Self::default() + } + + pub fn build(&self) -> OpenAIClientOptions { + self.options.clone() + } + } +} From 8836f2bfdab1f0b92f4b72ca463bec6abbff6186 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 20:57:28 +0200 Subject: [PATCH 30/71] Old approach to streaming response not working yet --- sdk/openai/inference/Cargo.toml | 2 + .../src/clients/chat_completions_client.rs | 87 ++++++++++++++- sdk/openai/inference/src/helpers/mod.rs | 1 + sdk/openai/inference/src/helpers/streaming.rs | 101 ++++++++++++++++++ sdk/openai/inference/src/lib.rs | 1 + 5 files changed, 191 insertions(+), 1 deletion(-) create mode 100644 sdk/openai/inference/src/helpers/mod.rs create mode 100644 sdk/openai/inference/src/helpers/streaming.rs diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index 80f0cd5994..0b1ce0cc31 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -20,6 +20,8 @@ tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } async-trait = { workspace = true } +futures = { workspace = true } +bytes = { workspace = true } typespec_client_core = { workspace = true, features = ["derive"] } [dev-dependencies] diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 6f0a834c59..a801c4e86c 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,9 +1,16 @@ +use std::pin::Pin; + use super::BaseOpenAIClientMethods; -use crate::{request::CreateChatCompletionsRequest, response::CreateChatCompletionsResponse}; +use crate::{ + helpers::streaming::{string_chunks, EventStreamer}, + request::CreateChatCompletionsRequest, + response::{CreateChatCompletionsResponse, CreateChatCompletionsStreamResponse}, +}; use azure_core::{ headers::{ACCEPT, CONTENT_TYPE}, Context, Method, Response, Result, }; +use futures::{Stream, TryStreamExt}; pub trait ChatCompletionsClientMethods { #[allow(async_fn_in_trait)] @@ -12,6 +19,13 @@ pub trait ChatCompletionsClientMethods { deployment_name: impl AsRef, chat_completions_request: &CreateChatCompletionsRequest, ) -> Result>; + + #[allow(async_fn_in_trait)] + async fn stream_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result>>; } pub struct ChatCompletionsClient { @@ -50,4 +64,75 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { .send::(&context, &mut request) .await } + + async fn stream_chat_completions( + &self, + deployment_name: impl AsRef, + chat_completions_request: &CreateChatCompletionsRequest, + ) -> Result>> { + let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; + let request_url = base_url.join("chat/completions")?; + + let context = Context::new(); + + let mut chat_completions_request = chat_completions_request; + chat_completions_request.stream = Some(true); + + let mut request = azure_core::Request::new(request_url, Method::Post); + // this was replaced by the AzureServiceVersion policy, not sure what is the right approach + // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) + // request.add_mandatory_header(&self.key_credential); + + // For some reason non-Azure OpenAI's API is strict about these headers being present + request.insert_header(CONTENT_TYPE, "application/json"); + request.insert_header(ACCEPT, "application/json"); + request.set_json(chat_completions_request)?; + + let response = self.base_client + .pipeline() + .send(&context, &mut request) + .await? + .into_body(); + + let stream_handler = ChatCompletionsStreamHandler::new("\n\n"); + + let stream = stream_handler.event_stream(response).await; + Ok(stream) + } + +} + +struct ChatCompletionsStreamHandler { + stream_event_delimiter: String, +} + +impl ChatCompletionsStreamHandler { + pub fn new(stream_event_delimiter: impl Into) -> Self { + let stream_event_delimiter = stream_event_delimiter.into(); + Self { + stream_event_delimiter, + } + } +} + +impl EventStreamer> for ChatCompletionsStreamHandler { + fn delimiter(&self) -> impl AsRef { + self.stream_event_delimiter.as_str() + } + + async fn event_stream( + &self, + response_body: azure_core::ResponseBody, + ) -> Pin>>> { + let response_body = response_body; + + let stream = + string_chunks(response_body, self.stream_event_delimiter.as_str()).map_ok(|event| { + // println!("EVENT AS A STRING: {:?}", &event); + serde_json::from_str::(&event) + .expect("Deserialization failed") + // CreateChatCompletionsStreamResponse { choices: vec![] } + }); + Box::pin(stream) + } } diff --git a/sdk/openai/inference/src/helpers/mod.rs b/sdk/openai/inference/src/helpers/mod.rs new file mode 100644 index 0000000000..c65f3f1305 --- /dev/null +++ b/sdk/openai/inference/src/helpers/mod.rs @@ -0,0 +1 @@ +pub(crate) mod streaming; diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs new file mode 100644 index 0000000000..93e9e7084b --- /dev/null +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -0,0 +1,101 @@ +use std::pin::Pin; + +use azure_core::{Error, Result}; +use futures::{Stream, StreamExt}; + +pub trait EventStreamer { + fn delimiter(&self) -> impl AsRef; + + #[allow(async_fn_in_trait)] + async fn event_stream( + &self, + response_body: azure_core::ResponseBody, + ) -> Pin>> + where + T: serde::de::DeserializeOwned; +} + +pub(crate) fn string_chunks( + response_body: (impl Stream> + Unpin), + _stream_event_delimiter: &str, // figure out how to use it in the move +) -> impl Stream> { + let chunk_buffer = Vec::new(); + let stream = futures::stream::unfold( + (response_body, chunk_buffer), + |(mut response_body, mut chunk_buffer)| async move { + // Need to figure out a way how I can move the _stream_event_delimiter into this closure + let delimiter = b"\n\n"; + let delimiter_len = delimiter.len(); + + if let Some(Ok(bytes)) = response_body.next().await { + chunk_buffer.extend_from_slice(&bytes); + // Looking for the next occurence of the event delimiter + // it's + 4 because the \n\n are escaped and represented as [92, 110, 92, 110] + if let Some(pos) = chunk_buffer + .windows(delimiter_len) + .position(|window| window == delimiter) + { + // the range must include the delimiter bytes + let mut bytes = chunk_buffer + .drain(..pos + delimiter_len) + .collect::>(); + bytes.truncate(bytes.len() - delimiter_len); + + return if let Ok(yielded_value) = std::str::from_utf8(&bytes) { + // We strip the "data: " portion of the event. The rest is always JSON and will be deserialized + // by a subsquent mapping function for this stream + let yielded_value = yielded_value.trim_start_matches("data:").trim(); + if yielded_value == "[DONE]" { + return None; + } else { + Some((Ok(yielded_value.to_string()), (response_body, chunk_buffer))) + } + } else { + None + }; + } + if chunk_buffer.len() > 0 { + return Some(( + Err(Error::with_message( + azure_core::error::ErrorKind::DataConversion, + || "Incomplete chunk", + )), + (response_body, chunk_buffer), + )); + } + // We drain the buffer of any messages that may be left over. + // The block above will be skipped, since response_body.next() will be None every time + } else if !chunk_buffer.is_empty() { + // we need to verify if there are any event left in the buffer and emit them individually + // it's + 4 because the \n\n are escaped and represented as [92, 110, 92, 110] + if let Some(pos) = chunk_buffer + .windows(delimiter_len) + .position(|window| window == delimiter) + { + // the range must include the delimiter bytes + let mut bytes = chunk_buffer + .drain(..pos + delimiter_len) + .collect::>(); + bytes.truncate(bytes.len() - delimiter_len); + + return if let Ok(yielded_value) = std::str::from_utf8(&bytes) { + let yielded_value = yielded_value.trim_start_matches("data:").trim(); + if yielded_value == "[DONE]" { + return None; + } else { + Some((Ok(yielded_value.to_string()), (response_body, chunk_buffer))) + } + } else { + None + }; + } + // if we get to this point, it means we have drained the buffer of all events, meaning that we haven't been able to find the next delimiter + } + None + }, + ); + + // We filter errors, we should specifically target the error type yielded when we are not able to find an event in a chunk + // Specifically the Error::with_messagge(ErrorKind::DataConversion, || "Incomplete chunk") + return stream.filter(|it| std::future::ready(it.is_ok())); +} diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 8ade855521..6e746a2cca 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,5 +1,6 @@ mod auth; pub mod clients; +mod helpers; mod models; mod options; From e8885822c514bada3c3e6e6b8f3fcea6d4b9a77f Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 21:52:23 +0200 Subject: [PATCH 31/71] Resolved one issue, but now having ownership problems --- .../src/clients/chat_completions_client.rs | 25 ++++++++----------- sdk/openai/inference/src/helpers/streaming.rs | 9 +++---- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index a801c4e86c..174e14d2a3 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -7,8 +7,7 @@ use crate::{ response::{CreateChatCompletionsResponse, CreateChatCompletionsStreamResponse}, }; use azure_core::{ - headers::{ACCEPT, CONTENT_TYPE}, - Context, Method, Response, Result, + headers::{ACCEPT, CONTENT_TYPE}, Context, Error, Method, Response, Result }; use futures::{Stream, TryStreamExt}; @@ -75,9 +74,6 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { let context = Context::new(); - let mut chat_completions_request = chat_completions_request; - chat_completions_request.stream = Some(true); - let mut request = azure_core::Request::new(request_url, Method::Post); // this was replaced by the AzureServiceVersion policy, not sure what is the right approach // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) @@ -88,16 +84,16 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { request.insert_header(ACCEPT, "application/json"); request.set_json(chat_completions_request)?; - let response = self.base_client + let response_body = self.base_client .pipeline() - .send(&context, &mut request) + .send::<()>(&context, &mut request) .await? .into_body(); let stream_handler = ChatCompletionsStreamHandler::new("\n\n"); - let stream = stream_handler.event_stream(response).await; - Ok(stream) + let stream = stream_handler.event_stream(response_body); + return Ok(stream); } } @@ -115,17 +111,16 @@ impl ChatCompletionsStreamHandler { } } -impl EventStreamer> for ChatCompletionsStreamHandler { +impl EventStreamer for ChatCompletionsStreamHandler { fn delimiter(&self) -> impl AsRef { self.stream_event_delimiter.as_str() } - async fn event_stream( + fn event_stream( &self, response_body: azure_core::ResponseBody, - ) -> Pin>>> { - let response_body = response_body; - + ) -> impl Stream> { + // TODO: is there something like try_map_ok? let stream = string_chunks(response_body, self.stream_event_delimiter.as_str()).map_ok(|event| { // println!("EVENT AS A STRING: {:?}", &event); @@ -133,6 +128,6 @@ impl EventStreamer> for ChatCompleti .expect("Deserialization failed") // CreateChatCompletionsStreamResponse { choices: vec![] } }); - Box::pin(stream) + stream } } diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 93e9e7084b..e0dcf32d19 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -3,16 +3,13 @@ use std::pin::Pin; use azure_core::{Error, Result}; use futures::{Stream, StreamExt}; -pub trait EventStreamer { +pub trait EventStreamer where T: serde::de::DeserializeOwned { fn delimiter(&self) -> impl AsRef; - #[allow(async_fn_in_trait)] - async fn event_stream( + fn event_stream( &self, response_body: azure_core::ResponseBody, - ) -> Pin>> - where - T: serde::de::DeserializeOwned; + ) -> impl Stream>; } pub(crate) fn string_chunks( From 74ac1806083b4f2fd5537d6035b1127f3e221166 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 22:06:28 +0200 Subject: [PATCH 32/71] Streaming verified working with Azure --- .../examples/azure_stream_chat_completions.rs | 56 ++++++++++++ .../src/clients/chat_completions_client.rs | 87 ++++++++++--------- sdk/openai/inference/src/helpers/streaming.rs | 7 +- 3 files changed, 106 insertions(+), 44 deletions(-) create mode 100644 sdk/openai/inference/examples/azure_stream_chat_completions.rs diff --git a/sdk/openai/inference/examples/azure_stream_chat_completions.rs b/sdk/openai/inference/examples/azure_stream_chat_completions.rs new file mode 100644 index 0000000000..16bb277a37 --- /dev/null +++ b/sdk/openai/inference/examples/azure_stream_chat_completions.rs @@ -0,0 +1,56 @@ +use azure_core::Result; +use azure_openai_inference::{ + clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, + request::CreateChatCompletionsRequest, + AzureOpenAIClientOptions, AzureServiceVersion, +}; +use futures::stream::StreamExt; +use std::io::{self, Write}; + +#[tokio::main] +async fn main() -> Result<()> { + let endpoint = + std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); + let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); + + let chat_completions_client = AzureOpenAIClient::with_key( + endpoint, + secret, + Some( + AzureOpenAIClientOptions::builder() + .with_api_version(AzureServiceVersion::V2023_12_01Preview) + .build(), + ), + )? + .chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::new_stream_with_user_message( + "gpt-4-1106-preview", + "Write me an essay that is at least 200 words long on the nutritional values (or lack thereof) of fast food. + Start the essay by stating 'this essay will be x many words long' where x is the number of words in the essay.",); + + let response = chat_completions_client + .stream_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await?; + + // this pins the stream to the stack so it is safe to poll it (namely, it won't be dealloacted or moved) + futures::pin_mut!(response); + + while let Some(result) = response.next().await { + match result { + Ok(delta) => { + if let Some(choice) = delta.choices.get(0) { + choice.delta.as_ref().map(|d| { + d.content.as_ref().map(|c| { + print!("{}", c); + let _ = io::stdout().flush(); + }); + }); + } + } + Err(e) => println!("Error: {:?}", e), + } + } + + Ok(()) +} diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 174e14d2a3..229f2a8f74 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,13 +1,12 @@ -use std::pin::Pin; - use super::BaseOpenAIClientMethods; use crate::{ - helpers::streaming::{string_chunks, EventStreamer}, + helpers::streaming::string_chunks, request::CreateChatCompletionsRequest, response::{CreateChatCompletionsResponse, CreateChatCompletionsStreamResponse}, }; use azure_core::{ - headers::{ACCEPT, CONTENT_TYPE}, Context, Error, Method, Response, Result + headers::{ACCEPT, CONTENT_TYPE}, + Context, Method, Response, Result, }; use futures::{Stream, TryStreamExt}; @@ -84,50 +83,56 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { request.insert_header(ACCEPT, "application/json"); request.set_json(chat_completions_request)?; - let response_body = self.base_client + let response_body = self + .base_client .pipeline() .send::<()>(&context, &mut request) .await? .into_body(); - let stream_handler = ChatCompletionsStreamHandler::new("\n\n"); - - let stream = stream_handler.event_stream(response_body); - return Ok(stream); - } + // let stream_handler = ChatCompletionsStreamHandler::new("\n\n"); -} - -struct ChatCompletionsStreamHandler { - stream_event_delimiter: String, -} + // let stream = stream_handler.event_stream(response_body); -impl ChatCompletionsStreamHandler { - pub fn new(stream_event_delimiter: impl Into) -> Self { - let stream_event_delimiter = stream_event_delimiter.into(); - Self { - stream_event_delimiter, - } + let stream = string_chunks(response_body, "\n\n").map_ok(|event| { + // println!("EVENT AS A STRING: {:?}", &event); + serde_json::from_str::(&event) + .expect("Deserialization failed") + // CreateChatCompletionsStreamResponse { choices: vec![] } + }); + return Ok(stream); } } -impl EventStreamer for ChatCompletionsStreamHandler { - fn delimiter(&self) -> impl AsRef { - self.stream_event_delimiter.as_str() - } - - fn event_stream( - &self, - response_body: azure_core::ResponseBody, - ) -> impl Stream> { - // TODO: is there something like try_map_ok? - let stream = - string_chunks(response_body, self.stream_event_delimiter.as_str()).map_ok(|event| { - // println!("EVENT AS A STRING: {:?}", &event); - serde_json::from_str::(&event) - .expect("Deserialization failed") - // CreateChatCompletionsStreamResponse { choices: vec![] } - }); - stream - } -} +// struct ChatCompletionsStreamHandler { +// stream_event_delimiter: String, +// } +// impl ChatCompletionsStreamHandler { +// pub fn new(stream_event_delimiter: impl Into) -> Self { +// let stream_event_delimiter = stream_event_delimiter.into(); +// Self { +// stream_event_delimiter, +// } +// } +// } + +// impl EventStreamer for ChatCompletionsStreamHandler { +// fn delimiter(&self) -> impl AsRef { +// self.stream_event_delimiter.as_str() +// } + +// fn event_stream( +// &self, +// response_body: azure_core::ResponseBody, +// ) -> impl Stream> { +// // TODO: is there something like try_map_ok? +// let stream = +// string_chunks(response_body, self.stream_event_delimiter.as_str()).map_ok(|event| { +// // println!("EVENT AS A STRING: {:?}", &event); +// serde_json::from_str::(&event) +// .expect("Deserialization failed") +// // CreateChatCompletionsStreamResponse { choices: vec![] } +// }); +// stream +// } +// } diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index e0dcf32d19..70a56cfd64 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -1,9 +1,10 @@ -use std::pin::Pin; - use azure_core::{Error, Result}; use futures::{Stream, StreamExt}; -pub trait EventStreamer where T: serde::de::DeserializeOwned { +pub trait EventStreamer +where + T: serde::de::DeserializeOwned, +{ fn delimiter(&self) -> impl AsRef; fn event_stream( From 69fcf6c9579bd6d1ebc25c8628828d53dfbf6564 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 22:11:29 +0200 Subject: [PATCH 33/71] Added example for non Azure usage for streaming --- ...ns.rs => azure_chat_completions_stream.rs} | 0 ...hat_completions.rs => chat_completions.rs} | 0 .../examples/chat_completions_stream.rs | 44 +++++++++++++++++++ 3 files changed, 44 insertions(+) rename sdk/openai/inference/examples/{azure_stream_chat_completions.rs => azure_chat_completions_stream.rs} (100%) rename sdk/openai/inference/examples/{non_azure_chat_completions.rs => chat_completions.rs} (100%) create mode 100644 sdk/openai/inference/examples/chat_completions_stream.rs diff --git a/sdk/openai/inference/examples/azure_stream_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs similarity index 100% rename from sdk/openai/inference/examples/azure_stream_chat_completions.rs rename to sdk/openai/inference/examples/azure_chat_completions_stream.rs diff --git a/sdk/openai/inference/examples/non_azure_chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs similarity index 100% rename from sdk/openai/inference/examples/non_azure_chat_completions.rs rename to sdk/openai/inference/examples/chat_completions.rs diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs new file mode 100644 index 0000000000..dc3905efa4 --- /dev/null +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -0,0 +1,44 @@ +use azure_core::Result; +use azure_openai_inference::{ + clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, + request::CreateChatCompletionsRequest, +}; +use futures::stream::StreamExt; +use std::io::{self, Write}; + +#[tokio::main] +async fn main() -> Result<()> { + let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); + + let chat_completions_client = OpenAIClient::with_key(secret, None)?.chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::new_stream_with_user_message( + "gpt-3.5-turbo-1106", + "Write me an essay that is at least 200 words long on the nutritional values (or lack thereof) of fast food. + Start the essay by stating 'this essay will be x many words long' where x is the number of words in the essay.",); + + let response = chat_completions_client + .stream_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await?; + + // this pins the stream to the stack so it is safe to poll it (namely, it won't be dealloacted or moved) + futures::pin_mut!(response); + + while let Some(result) = response.next().await { + match result { + Ok(delta) => { + if let Some(choice) = delta.choices.get(0) { + choice.delta.as_ref().map(|d| { + d.content.as_ref().map(|c| { + print!("{}", c); + let _ = io::stdout().flush(); + }); + }); + } + } + Err(e) => println!("Error: {:?}", e), + } + } + + Ok(()) +} From acbd3f392150ec183a6cd9a9374b8ee15e183047 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Fri, 13 Sep 2024 22:16:05 +0200 Subject: [PATCH 34/71] Using correct method in examples for deserialization --- sdk/openai/inference/examples/azure_chat_completions.rs | 4 +--- sdk/openai/inference/examples/chat_completions.rs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index fb5283e797..09629ce89a 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -2,7 +2,6 @@ use azure_core::Result; use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, request::CreateChatCompletionsRequest, - response::CreateChatCompletionsResponse, AzureOpenAIClientOptions, AzureServiceVersion, }; @@ -33,10 +32,9 @@ pub async fn main() -> Result<()> { .await; match response { - // TODO: I don't understand why the Response generic type gets erased when calling `deserialize_body_into` Ok(chat_completions_response) => { let chat_completions = chat_completions_response - .deserialize_body_into::() + .deserialize_body() .await?; println!("{:#?}", &chat_completions); } diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index 082a87548e..a67f864318 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -1,7 +1,6 @@ use azure_openai_inference::{ clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, request::CreateChatCompletionsRequest, - response::CreateChatCompletionsResponse, }; #[tokio::main] @@ -20,10 +19,9 @@ pub async fn main() -> azure_core::Result<()> { .await; match response { - // TODO: I don't understand why the Response generic type gets erased when calling `deserialize_body_into` Ok(chat_completions_response) => { let chat_completions = chat_completions_response - .deserialize_body_into::() + .deserialize_body() .await?; println!("{:#?}", &chat_completions); } From 1e288efbd5f0377ca14b06ed6007c570bc93c6bc Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Tue, 17 Sep 2024 17:20:42 +0200 Subject: [PATCH 35/71] Added lifetime parameter for string slice to be able to move to lambda --- sdk/openai/inference/src/helpers/streaming.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 70a56cfd64..321efb4cfd 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -13,16 +13,15 @@ where ) -> impl Stream>; } -pub(crate) fn string_chunks( - response_body: (impl Stream> + Unpin), - _stream_event_delimiter: &str, // figure out how to use it in the move -) -> impl Stream> { +pub(crate) fn string_chunks<'a>( + response_body: (impl Stream> + Unpin + 'a), + stream_event_delimiter: &'a str, // figure out how to use it in the move +) -> impl Stream> + 'a { let chunk_buffer = Vec::new(); let stream = futures::stream::unfold( (response_body, chunk_buffer), - |(mut response_body, mut chunk_buffer)| async move { - // Need to figure out a way how I can move the _stream_event_delimiter into this closure - let delimiter = b"\n\n"; + move |(mut response_body, mut chunk_buffer)| async move { + let delimiter = stream_event_delimiter.as_bytes(); let delimiter_len = delimiter.len(); if let Some(Ok(bytes)) = response_body.next().await { From f9869ff94f8a9690e1cb1266b43f3f73dc3ae4e6 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Wed, 18 Sep 2024 14:33:14 +0200 Subject: [PATCH 36/71] Added ctx for AAD and related sample --- sdk/openai/inference/Cargo.toml | 1 + .../examples/azure_chat_completions.rs | 4 +- .../examples/azure_chat_completions_aad.rs | 47 +++++++++++++++++++ .../inference/examples/chat_completions.rs | 4 +- sdk/openai/inference/src/auth/mod.rs | 3 ++ .../src/clients/azure_openai_client.rs | 36 +++++++++++++- sdk/openai/inference/src/helpers/streaming.rs | 2 +- .../inference/src/models/chat_completions.rs | 9 ++++ 8 files changed, 98 insertions(+), 8 deletions(-) create mode 100644 sdk/openai/inference/examples/azure_chat_completions_aad.rs diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index 0b1ce0cc31..6dc22764d4 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -26,5 +26,6 @@ typespec_client_core = { workspace = true, features = ["derive"] } [dev-dependencies] azure_core = { workspace = true, features = ["reqwest"] } +azure_identity = { workspace = true } reqwest = { workspace = true } tokio = { workspace = true } diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 09629ce89a..7ff2f0067a 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -33,9 +33,7 @@ pub async fn main() -> Result<()> { match response { Ok(chat_completions_response) => { - let chat_completions = chat_completions_response - .deserialize_body() - .await?; + let chat_completions = chat_completions_response.deserialize_body().await?; println!("{:#?}", &chat_completions); } Err(e) => { diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs new file mode 100644 index 0000000000..a9dfff48e3 --- /dev/null +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use azure_core::Result; +use azure_identity::DefaultAzureCredentialBuilder; +use azure_openai_inference::{ + clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, + request::CreateChatCompletionsRequest, + AzureOpenAIClientOptions, AzureServiceVersion, +}; + +#[tokio::main] +async fn main() -> Result<()> { + let endpoint = + std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); + + let chat_completions_client = AzureOpenAIClient::new( + endpoint, + Arc::new(DefaultAzureCredentialBuilder::new().build()?), + Some( + AzureOpenAIClientOptions::builder() + .with_api_version(AzureServiceVersion::V2023_12_01Preview) + .build(), + ), + )? + .chat_completions_client(); + + let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( + "gpt-4-1106-preview", + "Tell me a joke about pineapples", + ); + + let response = chat_completions_client + .create_chat_completions(&chat_completions_request.model, &chat_completions_request) + .await; + + match response { + Ok(chat_completions_response) => { + let chat_completions = chat_completions_response.deserialize_body().await?; + println!("{:#?}", &chat_completions); + } + Err(e) => { + println!("Error: {}", e); + } + }; + + Ok(()) +} diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index a67f864318..5a4742094c 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -20,9 +20,7 @@ pub async fn main() -> azure_core::Result<()> { match response { Ok(chat_completions_response) => { - let chat_completions = chat_completions_response - .deserialize_body() - .await?; + let chat_completions = chat_completions_response.deserialize_body().await?; println!("{:#?}", &chat_completions); } Err(e) => { diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/auth/mod.rs index 63db568285..11b63877da 100644 --- a/sdk/openai/inference/src/auth/mod.rs +++ b/sdk/openai/inference/src/auth/mod.rs @@ -3,3 +3,6 @@ mod openai_key_credential; pub(crate) use azure_key_credential::*; pub(crate) use openai_key_credential::*; + +pub(crate) const DEFAULT_SCOPE: [&'static str; 1] = + ["https://cognitiveservices.azure.com/.default"]; diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 28f0714858..e7a8fecd58 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -3,13 +3,22 @@ use std::sync::Arc; use crate::auth::AzureKeyCredential; use crate::options::AzureOpenAIClientOptions; -use azure_core::Url; +use azure_core::auth::{self, TokenCredential, DEFAULT_SCOPE_SUFFIX}; use azure_core::{self, Policy, Result}; +use azure_core::{BearerTokenCredentialPolicy, Url}; use super::chat_completions_client::ChatCompletionsClient; use super::BaseOpenAIClientMethods; pub trait AzureOpenAIClientMethods: BaseOpenAIClientMethods { + fn new( + endpoint: impl AsRef, + credentials: Arc, + client_options: Option, + ) -> Result + where + Self: Sized; + fn with_key( endpoint: impl AsRef, secret: impl Into, @@ -32,6 +41,31 @@ pub struct AzureOpenAIClient { } impl AzureOpenAIClientMethods for AzureOpenAIClient { + fn new( + endpoint: impl AsRef, + credential: Arc, + client_options: Option, + ) -> Result { + let endpoint = Url::parse(endpoint.as_ref())?; + + let options = client_options.unwrap_or_default(); + + let auth_policy = Arc::new(BearerTokenCredentialPolicy::new( + credential, + crate::auth::DEFAULT_SCOPE, + )); + let version_policy: Arc = options.api_service_version.clone().into(); + let per_call_policies: Vec> = vec![auth_policy, version_policy]; + + let pipeline = super::new_pipeline(per_call_policies, options.client_options.clone()); + + Ok(AzureOpenAIClient { + endpoint, + pipeline, + options, + }) + } + fn with_key( endpoint: impl AsRef, secret: impl Into, diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 321efb4cfd..a12de9493d 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -15,7 +15,7 @@ where pub(crate) fn string_chunks<'a>( response_body: (impl Stream> + Unpin + 'a), - stream_event_delimiter: &'a str, // figure out how to use it in the move + stream_event_delimiter: &'a str, ) -> impl Stream> + 'a { let chunk_buffer = Vec::new(); let stream = futures::stream::unfold( diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs index 0586a51819..bc292f7bfc 100644 --- a/sdk/openai/inference/src/models/chat_completions.rs +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -3,6 +3,7 @@ pub mod request { use serde::Serialize; #[derive(Serialize, Debug, Clone, Default)] + #[non_exhaustive] pub struct CreateChatCompletionsRequest { pub messages: Vec, pub model: String, @@ -15,6 +16,7 @@ pub mod request { } #[derive(Serialize, Debug, Clone, Default)] + #[non_exhaustive] pub struct ChatCompletionRequestMessageBase { #[serde(skip)] pub name: Option, @@ -22,6 +24,7 @@ pub mod request { } #[derive(Serialize, Debug, Clone)] + #[non_exhaustive] #[serde(tag = "role")] pub enum ChatCompletionRequestMessage { #[serde(rename = "system")] @@ -74,16 +77,19 @@ pub mod response { use serde::Deserialize; #[derive(Debug, Clone, Deserialize, Model)] + #[non_exhaustive] pub struct CreateChatCompletionsResponse { pub choices: Vec, } #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] pub struct ChatCompletionChoice { pub message: ChatCompletionResponseMessage, } #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] pub struct ChatCompletionResponseMessage { pub content: Option, pub role: String, @@ -91,16 +97,19 @@ pub mod response { // region: --- Streaming #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] pub struct CreateChatCompletionsStreamResponse { pub choices: Vec, } #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] pub struct ChatCompletionStreamChoice { pub delta: Option, } #[derive(Debug, Clone, Deserialize)] + #[non_exhaustive] pub struct ChatCompletionStreamResponseMessage { pub content: Option, pub role: Option, From 4757c9fe48d2e020eb673a0f3f67c362b7bcaba7 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Wed, 18 Sep 2024 14:38:10 +0200 Subject: [PATCH 37/71] Cleaned up imports --- sdk/openai/inference/src/clients/azure_openai_client.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index e7a8fecd58..eb023ef5a0 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use crate::auth::AzureKeyCredential; +use crate::auth::{AzureKeyCredential, DEFAULT_SCOPE}; use crate::options::AzureOpenAIClientOptions; -use azure_core::auth::{self, TokenCredential, DEFAULT_SCOPE_SUFFIX}; +use azure_core::auth::TokenCredential; use azure_core::{self, Policy, Result}; use azure_core::{BearerTokenCredentialPolicy, Url}; @@ -50,10 +50,7 @@ impl AzureOpenAIClientMethods for AzureOpenAIClient { let options = client_options.unwrap_or_default(); - let auth_policy = Arc::new(BearerTokenCredentialPolicy::new( - credential, - crate::auth::DEFAULT_SCOPE, - )); + let auth_policy = Arc::new(BearerTokenCredentialPolicy::new(credential, DEFAULT_SCOPE)); let version_policy: Arc = options.api_service_version.clone().into(); let per_call_policies: Vec> = vec![auth_policy, version_policy]; From af6ab1421d83a4f1edd275cb644f91bc0039dc49 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Wed, 18 Sep 2024 14:40:49 +0200 Subject: [PATCH 38/71] Renamed methods to be guideline compliant --- sdk/openai/inference/examples/azure_chat_completions.rs | 2 +- .../inference/examples/azure_chat_completions_stream.rs | 2 +- sdk/openai/inference/examples/chat_completions.rs | 3 ++- sdk/openai/inference/examples/chat_completions_stream.rs | 3 ++- sdk/openai/inference/src/clients/azure_openai_client.rs | 4 ++-- sdk/openai/inference/src/clients/openai_client.rs | 4 ++-- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 7ff2f0067a..8199a6b60c 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -11,7 +11,7 @@ pub async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let chat_completions_client = AzureOpenAIClient::with_key( + let chat_completions_client = AzureOpenAIClient::with_key_credential( endpoint, secret, Some( diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index 16bb277a37..5b676c478d 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -13,7 +13,7 @@ async fn main() -> Result<()> { std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); - let chat_completions_client = AzureOpenAIClient::with_key( + let chat_completions_client = AzureOpenAIClient::with_key_credential( endpoint, secret, Some( diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index 5a4742094c..e0dd04ef33 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -7,7 +7,8 @@ use azure_openai_inference::{ pub async fn main() -> azure_core::Result<()> { let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); - let chat_completions_client = OpenAIClient::with_key(secret, None)?.chat_completions_client(); + let chat_completions_client = + OpenAIClient::with_key_credential(secret, None)?.chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( "gpt-3.5-turbo-1106", diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index dc3905efa4..95c0c6245d 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -10,7 +10,8 @@ use std::io::{self, Write}; async fn main() -> Result<()> { let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); - let chat_completions_client = OpenAIClient::with_key(secret, None)?.chat_completions_client(); + let chat_completions_client = + OpenAIClient::with_key_credential(secret, None)?.chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::new_stream_with_user_message( "gpt-3.5-turbo-1106", diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index eb023ef5a0..562790d6cb 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -19,7 +19,7 @@ pub trait AzureOpenAIClientMethods: BaseOpenAIClientMethods { where Self: Sized; - fn with_key( + fn with_key_credential( endpoint: impl AsRef, secret: impl Into, client_options: Option, @@ -63,7 +63,7 @@ impl AzureOpenAIClientMethods for AzureOpenAIClient { }) } - fn with_key( + fn with_key_credential( endpoint: impl AsRef, secret: impl Into, client_options: Option, diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index 01c3f56c5a..93a4a75689 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -7,7 +7,7 @@ use crate::{auth::OpenAIKeyCredential, OpenAIClientOptions}; use super::{BaseOpenAIClientMethods, ChatCompletionsClient}; pub trait OpenAIClientMethods { - fn with_key( + fn with_key_credential( secret: impl Into, client_options: Option, ) -> Result @@ -26,7 +26,7 @@ pub struct OpenAIClient { } impl OpenAIClientMethods for OpenAIClient { - fn with_key( + fn with_key_credential( secret: impl Into, client_options: Option, ) -> Result { From 53d23b0046a31ae5eb67b8719221447328aada24 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Wed, 18 Sep 2024 15:01:44 +0200 Subject: [PATCH 39/71] Restored ChatCompletionStreamHandler struct --- .../src/clients/chat_completions_client.rs | 53 +++++-------------- sdk/openai/inference/src/helpers/streaming.rs | 11 +--- 2 files changed, 14 insertions(+), 50 deletions(-) diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 229f2a8f74..f94e52cf79 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,6 +1,6 @@ use super::BaseOpenAIClientMethods; use crate::{ - helpers::streaming::string_chunks, + helpers::streaming::{string_chunks, EventStreamer}, request::CreateChatCompletionsRequest, response::{CreateChatCompletionsResponse, CreateChatCompletionsStreamResponse}, }; @@ -90,49 +90,22 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { .await? .into_body(); - // let stream_handler = ChatCompletionsStreamHandler::new("\n\n"); + Ok(ChatCompletionsStreamHandler::event_stream(response_body)) + } +} - // let stream = stream_handler.event_stream(response_body); +struct ChatCompletionsStreamHandler; - let stream = string_chunks(response_body, "\n\n").map_ok(|event| { - // println!("EVENT AS A STRING: {:?}", &event); +impl EventStreamer for ChatCompletionsStreamHandler { + fn event_stream( + response_body: azure_core::ResponseBody, + ) -> impl Stream> { + let stream_event_delimiter = "\n\n"; + // TODO: is there something like try_map_ok? + let stream = string_chunks(response_body, stream_event_delimiter).map_ok(|event| { serde_json::from_str::(&event) .expect("Deserialization failed") - // CreateChatCompletionsStreamResponse { choices: vec![] } }); - return Ok(stream); + stream } } - -// struct ChatCompletionsStreamHandler { -// stream_event_delimiter: String, -// } -// impl ChatCompletionsStreamHandler { -// pub fn new(stream_event_delimiter: impl Into) -> Self { -// let stream_event_delimiter = stream_event_delimiter.into(); -// Self { -// stream_event_delimiter, -// } -// } -// } - -// impl EventStreamer for ChatCompletionsStreamHandler { -// fn delimiter(&self) -> impl AsRef { -// self.stream_event_delimiter.as_str() -// } - -// fn event_stream( -// &self, -// response_body: azure_core::ResponseBody, -// ) -> impl Stream> { -// // TODO: is there something like try_map_ok? -// let stream = -// string_chunks(response_body, self.stream_event_delimiter.as_str()).map_ok(|event| { -// // println!("EVENT AS A STRING: {:?}", &event); -// serde_json::from_str::(&event) -// .expect("Deserialization failed") -// // CreateChatCompletionsStreamResponse { choices: vec![] } -// }); -// stream -// } -// } diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index a12de9493d..0cdfe9117b 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -5,12 +5,7 @@ pub trait EventStreamer where T: serde::de::DeserializeOwned, { - fn delimiter(&self) -> impl AsRef; - - fn event_stream( - &self, - response_body: azure_core::ResponseBody, - ) -> impl Stream>; + fn event_stream(response_body: azure_core::ResponseBody) -> impl Stream>; } pub(crate) fn string_chunks<'a>( @@ -26,8 +21,6 @@ pub(crate) fn string_chunks<'a>( if let Some(Ok(bytes)) = response_body.next().await { chunk_buffer.extend_from_slice(&bytes); - // Looking for the next occurence of the event delimiter - // it's + 4 because the \n\n are escaped and represented as [92, 110, 92, 110] if let Some(pos) = chunk_buffer .windows(delimiter_len) .position(|window| window == delimiter) @@ -63,8 +56,6 @@ pub(crate) fn string_chunks<'a>( // We drain the buffer of any messages that may be left over. // The block above will be skipped, since response_body.next() will be None every time } else if !chunk_buffer.is_empty() { - // we need to verify if there are any event left in the buffer and emit them individually - // it's + 4 because the \n\n are escaped and represented as [92, 110, 92, 110] if let Some(pos) = chunk_buffer .windows(delimiter_len) .position(|window| window == delimiter) From 1465a7b69077792f0326cf8855a73fa7e42f02c9 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Wed, 18 Sep 2024 15:27:18 +0200 Subject: [PATCH 40/71] stream mapping function handles errors better --- .../src/clients/chat_completions_client.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index f94e52cf79..6ffefba71c 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -8,7 +8,7 @@ use azure_core::{ headers::{ACCEPT, CONTENT_TYPE}, Context, Method, Response, Result, }; -use futures::{Stream, TryStreamExt}; +use futures::{Stream, StreamExt}; pub trait ChatCompletionsClientMethods { #[allow(async_fn_in_trait)] @@ -101,11 +101,11 @@ impl EventStreamer for ChatCompletionsStrea response_body: azure_core::ResponseBody, ) -> impl Stream> { let stream_event_delimiter = "\n\n"; - // TODO: is there something like try_map_ok? - let stream = string_chunks(response_body, stream_event_delimiter).map_ok(|event| { - serde_json::from_str::(&event) - .expect("Deserialization failed") - }); - stream + + string_chunks(response_body, stream_event_delimiter).map(|event| match event { + Ok(event) => serde_json::from_str::(&event) + .map_err(|e| e.into()), + Err(e) => Err(e), + }) } } From 6349265dcb5c11fdcc575c8723e674f75d20d9ff Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Wed, 18 Sep 2024 17:23:11 +0200 Subject: [PATCH 41/71] Ported tests for string_chunks --- sdk/openai/inference/src/helpers/streaming.rs | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 0cdfe9117b..266b61a147 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -87,3 +87,127 @@ pub(crate) fn string_chunks<'a>( // Specifically the Error::with_messagge(ErrorKind::DataConversion, || "Incomplete chunk") return stream.filter(|it| std::future::ready(it.is_ok())); } + +#[cfg(test)] +mod tests { + use super::*; + use futures::pin_mut; + + #[tokio::test] + async fn clean_chunks() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1\n\n")), + Ok(bytes::Bytes::from_static(b"data: piece 2\n\n")), + Ok(bytes::Bytes::from_static(b"data: [DONE]\n\n")), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = + vec![Ok("piece 1".to_string()), Ok("piece 2".to_string())]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn multiple_message_in_one_chunk() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static( + b"data: piece 1\n\ndata: piece 2\n\n", + )), + Ok(bytes::Bytes::from_static( + b"data: piece 3\n\ndata: [DONE]\n\n", + )), + ]); + + let mut actual = Vec::new(); + + let actual_stream = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual_stream); + + while let Some(event) = actual_stream.next().await { + actual.push(event); + } + + let expected: Vec> = vec![ + Ok("piece 1".to_string()), + Ok("piece 2".to_string()), + Ok("piece 3".to_string()), + ]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn data_marker_in_previous_chunk() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static( + b"data: piece 1\n\ndata: piece 2\n\ndata:", + )), + Ok(bytes::Bytes::from_static(b" piece 3\n\ndata: [DONE]\n\n")), + ]); + + let mut actual = Vec::new(); + + let actual_stream = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual_stream); + + while let Some(event) = actual_stream.next().await { + actual.push(event); + } + + let expected: Vec> = vec![ + Ok("piece 1".to_string()), + Ok("piece 2".to_string()), + Ok("piece 3".to_string()), + ]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn event_delimeter_split_across_chunks() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1\n")), + Ok(bytes::Bytes::from_static(b"\ndata: [DONE]")), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = vec![Ok("piece 1".to_string())]; + assert_result_vectors(expected, actual); + } + + #[tokio::test] + async fn event_delimiter_at_start_of_next_chunk() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1")), + Ok(bytes::Bytes::from_static(b"\n\ndata: [DONE]")), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = vec![Ok("piece 1".to_string())]; + assert_result_vectors(expected, actual); + } + + fn assert_result_vectors(expected: Vec>, actual: Vec>) + where + T: std::fmt::Debug + PartialEq, + { + assert_eq!(expected.len(), actual.len()); + for (expected, actual) in expected.iter().zip(actual.iter()) { + if let Ok(actual) = actual { + assert_eq!(actual, expected.as_ref().unwrap()); + } else { + let actual_err = actual.as_ref().unwrap_err(); + let expected_err = expected.as_ref().unwrap_err(); + assert_eq!(actual_err.kind(), expected_err.kind()); + assert_eq!(actual_err.to_string(), expected_err.to_string()); + } + } + } +} From 624e988d8167bb4a9c5307a0b9406b502edca5e1 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Wed, 18 Sep 2024 17:37:48 +0200 Subject: [PATCH 42/71] Made crate level visible the BaseOpenAIClientMethods trait --- sdk/openai/inference/src/clients/azure_openai_client.rs | 2 +- sdk/openai/inference/src/clients/chat_completions_client.rs | 2 +- sdk/openai/inference/src/clients/mod.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 562790d6cb..dd7bb041b6 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -10,7 +10,7 @@ use azure_core::{BearerTokenCredentialPolicy, Url}; use super::chat_completions_client::ChatCompletionsClient; use super::BaseOpenAIClientMethods; -pub trait AzureOpenAIClientMethods: BaseOpenAIClientMethods { +pub trait AzureOpenAIClientMethods { fn new( endpoint: impl AsRef, credentials: Arc, diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 6ffefba71c..7cb0692cdb 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -31,7 +31,7 @@ pub struct ChatCompletionsClient { } impl ChatCompletionsClient { - pub fn new(base_client: Box) -> Self { + pub(crate) fn new(base_client: Box) -> Self { Self { base_client } } } diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 512928ef4f..8209f7b534 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -8,7 +8,7 @@ pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; pub use openai_client::{OpenAIClient, OpenAIClientMethods}; -pub trait BaseOpenAIClientMethods { +pub(crate) trait BaseOpenAIClientMethods { fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result; fn pipeline(&self) -> &azure_core::Pipeline; From 49a19cf9a1cc69be8bba0aae180556bbc6955712 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 10:00:41 +0200 Subject: [PATCH 43/71] Extracted json request builder method --- .../src/clients/chat_completions_client.rs | 35 ++++--------------- sdk/openai/inference/src/clients/mod.rs | 18 ++++++++++ 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 7cb0692cdb..e249aeed40 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,13 +1,10 @@ -use super::BaseOpenAIClientMethods; +use super::{new_json_request, BaseOpenAIClientMethods}; use crate::{ helpers::streaming::{string_chunks, EventStreamer}, request::CreateChatCompletionsRequest, response::{CreateChatCompletionsResponse, CreateChatCompletionsStreamResponse}, }; -use azure_core::{ - headers::{ACCEPT, CONTENT_TYPE}, - Context, Method, Response, Result, -}; +use azure_core::{Context, Method, Response, Result}; use futures::{Stream, StreamExt}; pub trait ChatCompletionsClientMethods { @@ -45,21 +42,11 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; let request_url = base_url.join("chat/completions")?; - let context = Context::new(); - - let mut request = azure_core::Request::new(request_url, Method::Post); - // this was replaced by the AzureServiceVersion policy, not sure what is the right approach - // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) - // request.add_mandatory_header(&self.key_credential); - - // For some reason non-Azure OpenAI's API is strict about these headers being present - request.insert_header(CONTENT_TYPE, "application/json"); - request.insert_header(ACCEPT, "application/json"); - request.set_json(chat_completions_request)?; + let mut request = new_json_request(request_url, Method::Post, &chat_completions_request); self.base_client .pipeline() - .send::(&context, &mut request) + .send::(&Context::new(), &mut request) .await } @@ -71,22 +58,12 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { let base_url = self.base_client.base_url(Some(deployment_name.as_ref()))?; let request_url = base_url.join("chat/completions")?; - let context = Context::new(); - - let mut request = azure_core::Request::new(request_url, Method::Post); - // this was replaced by the AzureServiceVersion policy, not sure what is the right approach - // adding the mandatory header shouldn't be necessary if the pipeline was setup correctly (?) - // request.add_mandatory_header(&self.key_credential); - - // For some reason non-Azure OpenAI's API is strict about these headers being present - request.insert_header(CONTENT_TYPE, "application/json"); - request.insert_header(ACCEPT, "application/json"); - request.set_json(chat_completions_request)?; + let mut request = new_json_request(request_url, Method::Post, &chat_completions_request); let response_body = self .base_client .pipeline() - .send::<()>(&context, &mut request) + .send::<()>(&Context::new(), &mut request) .await? .into_body(); diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 8209f7b534..92c49ff879 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -31,3 +31,21 @@ fn new_pipeline( per_retry_policies, ) } + +fn new_json_request( + url: azure_core::Url, + method: azure_core::Method, + json_body: &T, +) -> azure_core::Request +where + T: serde::Serialize, +{ + let mut request = azure_core::Request::new(url, method); + + // For some reason non-Azure OpenAI's API is strict about these headers being present + request.insert_header(azure_core::headers::CONTENT_TYPE, "application/json"); + request.insert_header(azure_core::headers::ACCEPT, "application/json"); + + request.set_json(json_body).unwrap(); + request +} From becd225b4c3a75dbd51a57dd35c17a127cc2f348 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 11:28:18 +0200 Subject: [PATCH 44/71] Added unhappy path test --- sdk/openai/inference/src/helpers/streaming.rs | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 266b61a147..7242aefec7 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -85,7 +85,16 @@ pub(crate) fn string_chunks<'a>( // We filter errors, we should specifically target the error type yielded when we are not able to find an event in a chunk // Specifically the Error::with_messagge(ErrorKind::DataConversion, || "Incomplete chunk") - return stream.filter(|it| std::future::ready(it.is_ok())); + return stream.filter(|it| { + std::future::ready( + it.is_ok() + || it.as_ref().unwrap_err().to_string() + != Error::with_message(azure_core::error::ErrorKind::DataConversion, || { + "Incomplete chunk" + }) + .to_string(), + ) + }); } #[cfg(test)] @@ -194,6 +203,27 @@ mod tests { assert_result_vectors(expected, actual); } + // This is an over simplification, reasonable for an MVP. We should: + // 1. propagate error upwards + // 2. handle an unexpected "data:" marker (this will simply send the string as is, which will fail deserialization in an upper mapping layer) + #[tokio::test] + async fn error_in_response_ends_stream() { + let mut source_stream = futures::stream::iter(vec![ + Ok(bytes::Bytes::from_static(b"data: piece 1\n\n")), + Err(Error::with_message( + azure_core::error::ErrorKind::Other, + || "Incomplete chunk", + )), + ]); + + let actual = string_chunks(&mut source_stream, "\n\n"); + pin_mut!(actual); + let actual: Vec> = actual.collect().await; + + let expected: Vec> = vec![Ok("piece 1".to_string())]; + assert_result_vectors(expected, actual); + } + fn assert_result_vectors(expected: Vec>, actual: Vec>) where T: std::fmt::Debug + PartialEq, From 3f8f4cd1faa1ebdc5d94bedf4eb638eab528db52 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 11:32:31 +0200 Subject: [PATCH 45/71] Updated comment --- sdk/openai/inference/src/helpers/streaming.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 7242aefec7..00c8f3f624 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -83,8 +83,8 @@ pub(crate) fn string_chunks<'a>( }, ); - // We filter errors, we should specifically target the error type yielded when we are not able to find an event in a chunk - // Specifically the Error::with_messagge(ErrorKind::DataConversion, || "Incomplete chunk") + // We specifically allow the Error::with_messagge(ErrorKind::DataConversion, || "Incomplete chunk") + // So that we are able to continue pushing bytes to the buffer until we find the next delimiter return stream.filter(|it| { std::future::ready( it.is_ok() From f471f29fe577ded1c5bab6187a4d512c9cede0f4 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 13:17:17 +0200 Subject: [PATCH 46/71] Added crate readme --- sdk/openai/inference/Cargo.toml | 4 ++-- sdk/openai/inference/README.md | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 sdk/openai/inference/README.md diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index 6dc22764d4..ef5b34a876 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "azure_openai_inference" version = "0.1.0" -description = "Rust wrappers around Microsoft Azure REST APIs - Azure OpenAI Inference" +description = "Rust client SDK for Azure OpenAI Inference" +readme = "README.md" authors.workspace = true edition.workspace = true license.workspace = true repository.workspace = true rust-version.workspace = true -readme.workspace = true keywords = ["sdk", "azure", "rest"] categories = ["api-bindings"] diff --git a/sdk/openai/inference/README.md b/sdk/openai/inference/README.md new file mode 100644 index 0000000000..4da3423fd4 --- /dev/null +++ b/sdk/openai/inference/README.md @@ -0,0 +1,37 @@ +# Azure OpenAI Inference SDK for Rust + +## Introduction + +This SDK provides Rust types to interact with both OpenAI and Azure OpenAI services. + +### Features + +All features are showcased in the `example` folder of this crate. The following is a list of what is currently supported: + +- Supporting both usage with OpenAI and Azure OpenAI services by using `OpenAIClient` or `AzureOpenAIClient`, respectively. +- Key credential authentication is supported. +- [Azure Only] Azure Active Directory (AAD) authentication is supported. +- `ChatCompletions` operation supported (limited fields). +- Streaming for `ChatCompletions` is supported + +## Authentication methods + +### Azure Active Directory + +This authentication method is only supported for Azure OpenAI services. + +```rust +AzureOpenAIClient::new( + endpoint, + Arc::new(DefaultAzureCredentialBuilder::new().build()?), + None, +)? +``` + +### Key Credentials + +This method of authentication is supported both for Azure and non-Azure OpenAI services. + +```rust +OpenAIClient::with_key_credential(secret, None)? +``` From 711aa976bd722f6c449173ab790fda6f7ccf53a1 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 13:21:26 +0200 Subject: [PATCH 47/71] Example docs and comments --- sdk/openai/inference/examples/azure_chat_completions.rs | 1 + sdk/openai/inference/examples/azure_chat_completions_aad.rs | 1 + sdk/openai/inference/examples/azure_chat_completions_stream.rs | 1 + sdk/openai/inference/examples/chat_completions.rs | 1 + sdk/openai/inference/examples/chat_completions_stream.rs | 1 + sdk/openai/inference/src/options/service_version.rs | 1 - 6 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 8199a6b60c..385059132a 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -5,6 +5,7 @@ use azure_openai_inference::{ AzureOpenAIClientOptions, AzureServiceVersion, }; +// This example illustrates how to use Azure OpenAI with key credential authentication to generate a chat completion. #[tokio::main] pub async fn main() -> Result<()> { let endpoint = diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs index a9dfff48e3..1056723df5 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_aad.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -8,6 +8,7 @@ use azure_openai_inference::{ AzureOpenAIClientOptions, AzureServiceVersion, }; +/// This example illustrates how to use Azure OpenAI Chat Completions with Azure Active Directory authentication. #[tokio::main] async fn main() -> Result<()> { let endpoint = diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index 5b676c478d..e1f5cfd449 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -7,6 +7,7 @@ use azure_openai_inference::{ use futures::stream::StreamExt; use std::io::{self, Write}; +/// This example illustrates how to use Azure OpenAI with key credential authentication to stream chat completions. #[tokio::main] async fn main() -> Result<()> { let endpoint = diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index e0dd04ef33..ec4eac967d 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -3,6 +3,7 @@ use azure_openai_inference::{ request::CreateChatCompletionsRequest, }; +/// This example illustrates how to use OpenAI to generate a chat completion. #[tokio::main] pub async fn main() -> azure_core::Result<()> { let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index 95c0c6245d..f828ee67f6 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -6,6 +6,7 @@ use azure_openai_inference::{ use futures::stream::StreamExt; use std::io::{self, Write}; +/// This example illustrates how to use OpenAI to stream chat completions. #[tokio::main] async fn main() -> Result<()> { let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs index 7637709de1..71f36ae464 100644 --- a/sdk/openai/inference/src/options/service_version.rs +++ b/sdk/openai/inference/src/options/service_version.rs @@ -39,7 +39,6 @@ impl ToString for AzureServiceVersion { } } -// Not entirely sure this is a good idea // code lifted from BearerTokenCredentialPolicy #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait)] From 48d7dfab1c801f8923f2780e03bdb5e0042935aa Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 14:25:46 +0200 Subject: [PATCH 48/71] Added comments to crate::options module --- sdk/openai/inference/src/clients/openai_client.rs | 1 + .../src/options/azure_openai_client_options.rs | 11 +++++++++++ .../inference/src/options/openai_client_options.rs | 14 ++++++++++++++ .../inference/src/options/service_version.rs | 5 +++++ 4 files changed, 31 insertions(+) diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index 93a4a75689..56cccb053e 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -17,6 +17,7 @@ pub trait OpenAIClientMethods { fn chat_completions_client(&self) -> ChatCompletionsClient; } +/// A client that can be used to interact with the OpenAI API. #[derive(Debug, Clone)] pub struct OpenAIClient { base_url: Url, diff --git a/sdk/openai/inference/src/options/azure_openai_client_options.rs b/sdk/openai/inference/src/options/azure_openai_client_options.rs index 410e8c4fe7..538d60c171 100644 --- a/sdk/openai/inference/src/options/azure_openai_client_options.rs +++ b/sdk/openai/inference/src/options/azure_openai_client_options.rs @@ -2,6 +2,7 @@ use azure_core::ClientOptions; use crate::AzureServiceVersion; +/// Options to be passed to [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient). // TODO: I was not able to find ClientOptions as a derive macros #[derive(Clone, Debug, Default)] pub struct AzureOpenAIClientOptions { @@ -27,11 +28,21 @@ pub mod builders { pub(super) fn new() -> Self { Self::default() } + + /// Configures the [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient) to use the specified API version. + /// If no value is supplied, the latest version will be used as default. See [`AzureServiceVersion::get_latest()`](AzureServiceVersion::get_latest). pub fn with_api_version(mut self, api_service_version: AzureServiceVersion) -> Self { self.options.api_service_version = api_service_version; self } + /// Builds the [`AzureOpenAIClientOptions`]. + /// + /// # Examples + /// + /// ```rust + /// let options = azure_openai_inference::OpenAIClientOptions::builder().build(); + /// ``` pub fn build(&self) -> AzureOpenAIClientOptions { self.options.clone() } diff --git a/sdk/openai/inference/src/options/openai_client_options.rs b/sdk/openai/inference/src/options/openai_client_options.rs index 65eb66e97e..96f99e9355 100644 --- a/sdk/openai/inference/src/options/openai_client_options.rs +++ b/sdk/openai/inference/src/options/openai_client_options.rs @@ -1,16 +1,23 @@ use azure_core::ClientOptions; +/// Options to be passed to [`OpenAIClient`](crate::clients::OpenAIClient). +/// +/// Note: There are currently no options to be set. +/// This struct is a placeholder for future options. +// TODO: I was not able to find ClientOptions as a derive macros #[derive(Clone, Debug, Default)] pub struct OpenAIClientOptions { pub(crate) client_options: ClientOptions, } impl OpenAIClientOptions { + /// Creates a new [`builders::OpenAIClientOptionsBuilder`]. pub fn builder() -> builders::OpenAIClientOptionsBuilder { builders::OpenAIClientOptionsBuilder::new() } } +/// Builder to construct a [`OpenAIClientOptions`]. pub mod builders { use super::*; @@ -24,6 +31,13 @@ pub mod builders { Self::default() } + /// Builds the [`OpenAIClientOptions`]. + /// + /// # Examples + /// + /// ```rust + /// let options = azure_openai_inference::OpenAIClientOptions::builder().build(); + /// ``` pub fn build(&self) -> OpenAIClientOptions { self.options.clone() } diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs index 71f36ae464..c68d7b8e7f 100644 --- a/sdk/openai/inference/src/options/service_version.rs +++ b/sdk/openai/inference/src/options/service_version.rs @@ -3,6 +3,10 @@ use std::sync::Arc; use azure_core::{Context, Policy, PolicyResult, Request}; +/// The version of the Azure service to use. +/// This enum is passed to the [`AzureOpenAIClientOptionsBuilder`](crate::builders::AzureOpenAIClientOptionsBuilder) to configure an [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient) to specify the version of the service to use. +/// +/// If no version is specified, the latest version will be used. See [`AzureServiceVersion::get_latest()`](AzureServiceVersion::get_latest). #[derive(Debug, Clone)] pub enum AzureServiceVersion { V2023_09_01Preview, @@ -17,6 +21,7 @@ impl Default for AzureServiceVersion { } impl AzureServiceVersion { + /// Returns the latest supported version of the Azure OpenAI service. pub fn get_latest() -> AzureServiceVersion { AzureServiceVersion::V2024_07_01Preview } From cf98e0118b943ca2ac42bed4e3847cfed6b36d24 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 15:55:35 +0200 Subject: [PATCH 49/71] Added docs for models and renamed methods --- sdk/openai/inference/Cargo.toml | 3 +- sdk/openai/inference/README.md | 2 + .../examples/azure_chat_completions.rs | 2 +- .../examples/azure_chat_completions_aad.rs | 2 +- .../examples/azure_chat_completions_stream.rs | 2 +- .../inference/examples/chat_completions.rs | 2 +- .../examples/chat_completions_stream.rs | 2 +- .../inference/src/models/chat_completions.rs | 89 ++++++++++++++++--- 8 files changed, 85 insertions(+), 19 deletions(-) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index ef5b34a876..979983ba91 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "azure_openai_inference" -version = "0.1.0" +version = "1.0.0-beta.1" description = "Rust client SDK for Azure OpenAI Inference" readme = "README.md" authors.workspace = true @@ -16,7 +16,6 @@ workspace = true [dependencies] azure_core = { workspace = true } -tokio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } async-trait = { workspace = true } diff --git a/sdk/openai/inference/README.md b/sdk/openai/inference/README.md index 4da3423fd4..eb9e01c840 100644 --- a/sdk/openai/inference/README.md +++ b/sdk/openai/inference/README.md @@ -4,6 +4,8 @@ This SDK provides Rust types to interact with both OpenAI and Azure OpenAI services. +Note: Currently request and response models have as few fields as possible, leveraging the server side defaults wherever it can. + ### Features All features are showcased in the `example` folder of this crate. The following is a list of what is currently supported: diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 385059132a..b262fb08b4 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -23,7 +23,7 @@ pub async fn main() -> Result<()> { )? .chat_completions_client(); - let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( + let chat_completions_request = CreateChatCompletionsRequest::with_user_message( "gpt-4-1106-preview", "Tell me a joke about pineapples", ); diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs index 1056723df5..d2140cc519 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_aad.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -25,7 +25,7 @@ async fn main() -> Result<()> { )? .chat_completions_client(); - let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( + let chat_completions_request = CreateChatCompletionsRequest::with_user_message( "gpt-4-1106-preview", "Tell me a joke about pineapples", ); diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index e1f5cfd449..9352b0e677 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -25,7 +25,7 @@ async fn main() -> Result<()> { )? .chat_completions_client(); - let chat_completions_request = CreateChatCompletionsRequest::new_stream_with_user_message( + let chat_completions_request = CreateChatCompletionsRequest::with_user_message_and_stream( "gpt-4-1106-preview", "Write me an essay that is at least 200 words long on the nutritional values (or lack thereof) of fast food. Start the essay by stating 'this essay will be x many words long' where x is the number of words in the essay.",); diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index ec4eac967d..069b12c800 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -11,7 +11,7 @@ pub async fn main() -> azure_core::Result<()> { let chat_completions_client = OpenAIClient::with_key_credential(secret, None)?.chat_completions_client(); - let chat_completions_request = CreateChatCompletionsRequest::new_with_user_message( + let chat_completions_request = CreateChatCompletionsRequest::with_user_message( "gpt-3.5-turbo-1106", "Tell me a joke about pineapples", ); diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index f828ee67f6..cf85577a36 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -14,7 +14,7 @@ async fn main() -> Result<()> { let chat_completions_client = OpenAIClient::with_key_credential(secret, None)?.chat_completions_client(); - let chat_completions_request = CreateChatCompletionsRequest::new_stream_with_user_message( + let chat_completions_request = CreateChatCompletionsRequest::with_user_message_and_stream( "gpt-3.5-turbo-1106", "Write me an essay that is at least 200 words long on the nutritional values (or lack thereof) of fast food. Start the essay by stating 'this essay will be x many words long' where x is the number of words in the essay.",); diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs index bc292f7bfc..3dcc2a816b 100644 --- a/sdk/openai/inference/src/models/chat_completions.rs +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -2,72 +2,115 @@ pub mod request { use serde::Serialize; + /// The configuration information for a chat completions request. + /// Completions support a wide variety of tasks and generate text that continues from or "completes" + /// provided prompt data. #[derive(Serialize, Debug, Clone, Default)] #[non_exhaustive] pub struct CreateChatCompletionsRequest { pub messages: Vec, pub model: String, pub stream: Option, - // pub frequency_penalty: f64, - // pub logit_bias: Option>, - // pub logprobs: Option, - // pub top_logprobs: Option, - // pub max_tokens: Option, } + /// An abstract representation of a chat message as provided in a request. #[derive(Serialize, Debug, Clone, Default)] #[non_exhaustive] pub struct ChatCompletionRequestMessageBase { + /// An optional name for the participant. #[serde(skip)] pub name: Option, + /// The contents of the message. pub content: String, // TODO this should be either a string or ChatCompletionRequestMessageContentPart (a polymorphic type) } + /// A description of the intended purpose of a message within a chat completions interaction. #[derive(Serialize, Debug, Clone)] #[non_exhaustive] #[serde(tag = "role")] pub enum ChatCompletionRequestMessage { + /// The role that instructs or sets the behavior of the assistant." #[serde(rename = "system")] System(ChatCompletionRequestMessageBase), + + /// The role that provides input for chat completions. #[serde(rename = "user")] User(ChatCompletionRequestMessageBase), } impl ChatCompletionRequestMessage { - pub fn new_user(content: impl Into) -> Self { + /// Creates a new [`ChatCompletionRequestMessage`] with a single `user` message. + pub fn with_user_role(content: impl Into) -> Self { Self::User(ChatCompletionRequestMessageBase { content: content.into(), name: None, }) } - pub fn new_system(content: impl Into) -> Self { + /// Creates a new [`ChatCompletionRequestMessage`] with a single `system` message. + pub fn with_system_role(content: impl Into) -> Self { Self::System(ChatCompletionRequestMessageBase { content: content.into(), name: None, }) } } + impl CreateChatCompletionsRequest { - pub fn new_with_user_message(model: &str, prompt: &str) -> Self { + /// Creates a new [`CreateChatCompletionsRequest`] with a single `user` message. + /// + /// # Example + /// + /// ```rust + /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message("gpt-3.5-turbo-1106", "Why couldn't the eagles take Frodo directly to mount doom?"); + /// ``` + pub fn with_user_message(model: &str, prompt: &str) -> Self { Self { model: model.to_string(), - messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + messages: vec![ChatCompletionRequestMessage::with_user_role(prompt)], ..Default::default() } } - pub fn new_stream_with_user_message( + /// Creates a new [`CreateChatCompletionsRequest`] with a single `system` message whose response will be streamed. + /// + /// # Example + /// + /// ```rust + /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message_and_stream("gpt-3.5-turbo-1106", "Why couldn't the eagles take Frodo directly to mount doom?"); + /// ``` + pub fn with_user_message_and_stream( model: impl Into, prompt: impl Into, ) -> Self { Self { model: model.into(), - messages: vec![ChatCompletionRequestMessage::new_user(prompt)], + messages: vec![ChatCompletionRequestMessage::with_user_role(prompt)], stream: Some(true), ..Default::default() } } + + /// Creates a new [`CreateChatCompletionsRequest`] with a list of messages. + /// + /// # Example + /// ```rust + /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_messages( + /// "gpt-3.5-turbo-1106", + /// vec![ + /// azure_openai_inference::request::ChatCompletionRequestMessage::with_system_role("You are a good math tutor who explains things briefly."), + /// azure_openai_inference::request::ChatCompletionRequestMessage::with_user_role("What is the value of 'x' in the equation: '2x + 3 = 11'?"), + /// ]); + pub fn with_messages( + model: impl Into, + messages: Vec, + ) -> Self { + Self { + model: model.into(), + messages, + ..Default::default() + } + } } } @@ -76,44 +119,66 @@ pub mod response { use azure_core::Model; use serde::Deserialize; + /// Representation of the response data from a chat completions request. + /// Completions support a wide variety of tasks and generate text that continues from or "completes" + /// provided prompt data. #[derive(Debug, Clone, Deserialize, Model)] #[non_exhaustive] pub struct CreateChatCompletionsResponse { + /// The collection of completions choices associated with this completions response. + /// Generally, `n` choices are generated per provided prompt with a default value of 1. + /// Token limits and other settings may limit the number of choices generated. pub choices: Vec, } + /// The representation of a single prompt completion as part of an overall chat completions request. + /// Generally, `n` choices are generated per provided prompt with a default value of 1. + /// Token limits and other settings may limit the number of choices generated. #[derive(Debug, Clone, Deserialize)] #[non_exhaustive] pub struct ChatCompletionChoice { + /// The chat message for a given chat completions prompt. pub message: ChatCompletionResponseMessage, } #[derive(Debug, Clone, Deserialize)] #[non_exhaustive] pub struct ChatCompletionResponseMessage { + /// The content of the message. pub content: Option, + + /// The chat role associated with the message. pub role: String, } // region: --- Streaming + /// Represents a streamed chunk of a chat completion response returned by model, based on the provided input. #[derive(Debug, Clone, Deserialize)] #[non_exhaustive] pub struct CreateChatCompletionsStreamResponse { + /// A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. pub choices: Vec, } + /// A chat completion delta generated by streamed model responses. #[derive(Debug, Clone, Deserialize)] #[non_exhaustive] pub struct ChatCompletionStreamChoice { + /// The delta message content for a streaming response. pub delta: Option, } + /// A chat completion delta generated by streamed model responses. + /// + /// Note: all fields are optional because in a streaming scenario there is no guarantee of what is present in the model. #[derive(Debug, Clone, Deserialize)] #[non_exhaustive] pub struct ChatCompletionStreamResponseMessage { + /// The content of the chunk message. pub content: Option, + + /// The chat role associated with the message. pub role: Option, } - // endregion: Streaming } From 32f59eff1059e33d1901bb12ff1ecb039d1e26b2 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 15:56:51 +0200 Subject: [PATCH 50/71] restricted visibility of EventHandler trait --- sdk/openai/inference/src/helpers/streaming.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 00c8f3f624..d0e3c5e8c1 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -1,7 +1,7 @@ use azure_core::{Error, Result}; use futures::{Stream, StreamExt}; -pub trait EventStreamer +pub(crate) trait EventStreamer where T: serde::de::DeserializeOwned, { From de711b674f685e07ac65f21757538415606c337c Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 16:12:53 +0200 Subject: [PATCH 51/71] Various visibility restrictions --- .../src/clients/chat_completions_client.rs | 2 +- sdk/openai/inference/src/clients/mod.rs | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index e249aeed40..2a77151053 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -28,7 +28,7 @@ pub struct ChatCompletionsClient { } impl ChatCompletionsClient { - pub(crate) fn new(base_client: Box) -> Self { + pub(super) fn new(base_client: Box) -> Self { Self { base_client } } } diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 92c49ff879..a36450bdbb 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -8,7 +8,7 @@ pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; pub use openai_client::{OpenAIClient, OpenAIClientMethods}; -pub(crate) trait BaseOpenAIClientMethods { +trait BaseOpenAIClientMethods { fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result; fn pipeline(&self) -> &azure_core::Pipeline; @@ -18,17 +18,12 @@ fn new_pipeline( per_call_policies: Vec>, options: azure_core::ClientOptions, ) -> azure_core::Pipeline { - let crate_name = option_env!("CARGO_PKG_NAME"); - let crate_version = option_env!("CARGO_PKG_VERSION"); - // should I be using per_call_policies here too or are they used by default on retries too? - let per_retry_policies = Vec::new(); - azure_core::Pipeline::new( - crate_name, - crate_version, + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), options, per_call_policies, - per_retry_policies, + Vec::new(), ) } From 2c1a3c0c4713d24992545885a5c9bda31cc28e62 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 16:23:09 +0200 Subject: [PATCH 52/71] Added documentation for auth module --- .../inference/src/auth/azure_key_credential.rs | 14 ++++++++++++++ .../inference/src/auth/openai_key_credential.rs | 12 ++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sdk/openai/inference/src/auth/azure_key_credential.rs b/sdk/openai/inference/src/auth/azure_key_credential.rs index d0db604fe4..b7c2d18e76 100644 --- a/sdk/openai/inference/src/auth/azure_key_credential.rs +++ b/sdk/openai/inference/src/auth/azure_key_credential.rs @@ -7,10 +7,24 @@ use azure_core::{ Context, Header, Policy, PolicyResult, Request, }; +/// A key credential for the [AzureOpenAIClient](crate::clients::AzureOpenAIClient). +/// +/// # Example +/// ```no_run +/// use azure_openai_inference::clients::{AzureOpenAIClient, AzureOpenAIClientMethods}; +/// +/// let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); +/// let azure_open_ai_client = AzureOpenAIClient::with_key_credential( +/// "https://my.endpoint/", +/// secret, +/// None, +/// ).unwrap(); +/// ``` #[derive(Debug, Clone)] pub struct AzureKeyCredential(Secret); impl AzureKeyCredential { + /// Create a new [`AzureKeyCredential`]. pub fn new(api_key: impl Into) -> Self { Self(Secret::new(api_key.into())) } diff --git a/sdk/openai/inference/src/auth/openai_key_credential.rs b/sdk/openai/inference/src/auth/openai_key_credential.rs index 2562cb288d..a543214d91 100644 --- a/sdk/openai/inference/src/auth/openai_key_credential.rs +++ b/sdk/openai/inference/src/auth/openai_key_credential.rs @@ -7,6 +7,18 @@ use azure_core::{ Context, Header, Policy, PolicyResult, Request, }; +/// A key credential for the [OpenAIClient](crate::clients::OpenAIClient). +/// +/// # Example +/// ```no_run +/// use azure_openai_inference::clients::{OpenAIClient, OpenAIClientMethods}; +/// +/// let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); +/// let open_ai_client = OpenAIClient::with_key_credential( +/// secret, +/// None, +/// ).unwrap(); +/// ``` #[derive(Debug, Clone)] pub struct OpenAIKeyCredential(Secret); From a071db5d5e58e57c7336460c0b1140de2caa252b Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 16:34:44 +0200 Subject: [PATCH 53/71] Documented helpers module --- sdk/openai/inference/src/helpers/streaming.rs | 10 ++++++++++ sdk/openai/inference/src/models/chat_completions.rs | 8 ++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index d0e3c5e8c1..95ef9b4a3d 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -1,6 +1,7 @@ use azure_core::{Error, Result}; use futures::{Stream, StreamExt}; +/// A trait used to designate a type into which the streams will be deserialized. pub(crate) trait EventStreamer where T: serde::de::DeserializeOwned, @@ -8,6 +9,15 @@ where fn event_stream(response_body: azure_core::ResponseBody) -> impl Stream>; } +/// A helper function to be used in streaming scenarios. The `response_body`, the input stream +/// is buffered until a `stream_event_delimiter` is found. This constitutes a single event. +/// These series of events are then returned as a stream. +/// +/// # Arguments +/// * `response_body` - The response body stream of an HTTP request. +/// * `stream_event_delimiter` - The delimiter that separates events in the stream. In some cases `\n\n`, in other cases can be `\n\r\n\n`. +/// # Returns +/// The `response_body` stream segmented and streamed into String events demarcated by `stream_event_delimiter`. pub(crate) fn string_chunks<'a>( response_body: (impl Stream> + Unpin + 'a), stream_event_delimiter: &'a str, diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs index 3dcc2a816b..d5a2c49c29 100644 --- a/sdk/openai/inference/src/models/chat_completions.rs +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -62,7 +62,9 @@ pub mod request { /// # Example /// /// ```rust - /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message("gpt-3.5-turbo-1106", "Why couldn't the eagles take Frodo directly to mount doom?"); + /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message( + /// "gpt-3.5-turbo-1106", + /// "Why couldn't the eagles take Frodo directly to mount doom?"); /// ``` pub fn with_user_message(model: &str, prompt: &str) -> Self { Self { @@ -77,7 +79,9 @@ pub mod request { /// # Example /// /// ```rust - /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message_and_stream("gpt-3.5-turbo-1106", "Why couldn't the eagles take Frodo directly to mount doom?"); + /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message_and_stream( + /// "gpt-3.5-turbo-1106", + /// "Why couldn't the eagles take Frodo directly to Mount Doom?"); /// ``` pub fn with_user_message_and_stream( model: impl Into, From cbfa4669a2ada9ca6c3393948a1cfbc138116dd2 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 16:38:11 +0200 Subject: [PATCH 54/71] wip --- sdk/openai/inference/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/inference/README.md b/sdk/openai/inference/README.md index eb9e01c840..92306126f5 100644 --- a/sdk/openai/inference/README.md +++ b/sdk/openai/inference/README.md @@ -4,7 +4,7 @@ This SDK provides Rust types to interact with both OpenAI and Azure OpenAI services. -Note: Currently request and response models have as few fields as possible, leveraging the server side defaults wherever it can. +Note: Currently request and response models have as few fields as possible, leveraging the server side defaults wherever they can. ### Features From ba93ea41bc3457ea04d1de127eaecafed2367941 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 17:11:04 +0200 Subject: [PATCH 55/71] Added docs for azure openai client --- .../src/clients/azure_openai_client.rs | 77 ++++++++++++++----- .../options/azure_openai_client_options.rs | 1 + 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index dd7bb041b6..fc20684e66 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -10,38 +10,53 @@ use azure_core::{BearerTokenCredentialPolicy, Url}; use super::chat_completions_client::ChatCompletionsClient; use super::BaseOpenAIClientMethods; +/// Defines the methods provided by a [`AzureOpenAIClient`] and can be used for mocking. pub trait AzureOpenAIClientMethods { - fn new( - endpoint: impl AsRef, - credentials: Arc, - client_options: Option, - ) -> Result - where - Self: Sized; - - fn with_key_credential( - endpoint: impl AsRef, - secret: impl Into, - client_options: Option, - ) -> Result - where - Self: Sized; - + /// Returns the endpoint [`Url`] of the client. fn endpoint(&self) -> &Url; + /// Returns a new instance of the [`ChatCompletionsClient`]. fn chat_completions_client(&self) -> ChatCompletionsClient; } +/// An Azure OpenAI client. #[derive(Debug, Clone)] pub struct AzureOpenAIClient { + /// The Azure resource endpoint endpoint: Url, + + /// The pipeline for sending requests to the service. pipeline: azure_core::Pipeline, + + /// The options for the client. #[allow(dead_code)] options: AzureOpenAIClientOptions, } -impl AzureOpenAIClientMethods for AzureOpenAIClient { - fn new( +impl AzureOpenAIClient { + /// Creates a new [`AzureOpenAIClient`] using a [`TokenCredential`]. + /// See the following example for Azure Active Directory authentication: + /// + /// # Parameters + /// * `endpoint` - The full URL of the Azure OpenAI resource endpoint. + /// * `credential` - An implementation of [`TokenCredential`] used for authentication. + /// * `client_options` - Optional configuration for the client. The [`AzureServiceVersion`](crate::options::AzureServiceVersion) can be provided here. + /// + /// # Example + /// + /// ```no_run + /// use azure_openai_inference::clients::{AzureOpenAIClient, AzureOpenAIClientMethods}; + /// use azure_identity::DefaultAzureCredentialBuilder; + /// use std::sync::Arc; + /// + /// let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT environment variable"); + /// let client = AzureOpenAIClient::new( + /// endpoint, + /// Arc::new(DefaultAzureCredentialBuilder::new().build().unwrap()), + /// None, + /// ).unwrap(); + /// ``` + pub fn new( endpoint: impl AsRef, credential: Arc, client_options: Option, @@ -63,7 +78,26 @@ impl AzureOpenAIClientMethods for AzureOpenAIClient { }) } - fn with_key_credential( + /// Creates a new [`AzureOpenAIClient`] using a key credential + /// + /// # Parameters + /// * `endpoint` - The full URL of the Azure OpenAI resource endpoint. + /// * `secret` - The key creadential used for authentication. Passed as header parameter in the request. + /// * `client_options` - Optional configuration for the client. The [`AzureServiceVersion`](crate::options::AzureServiceVersion) can be provided here. + /// + /// # Example + /// ```no_run + /// use azure_openai_inference::clients::{AzureOpenAIClient, AzureOpenAIClientMethods}; + /// + /// let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT environment variable"); + /// let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY environment variable"); + /// let client = AzureOpenAIClient::with_key_credential( + /// endpoint, + /// secret, + /// None, + /// ).unwrap(); + /// ``` + pub fn with_key_credential( endpoint: impl AsRef, secret: impl Into, client_options: Option, @@ -84,11 +118,16 @@ impl AzureOpenAIClientMethods for AzureOpenAIClient { options, }) } +} + +impl AzureOpenAIClientMethods for AzureOpenAIClient { + /// Returns the endpoint [`Url`] of the client. fn endpoint(&self) -> &Url { &self.endpoint } + /// Returns a new instance of the [`ChatCompletionsClient`] using an [`AzureOpenAIClient`] underneath. fn chat_completions_client(&self) -> ChatCompletionsClient { ChatCompletionsClient::new(Box::new(self.clone())) } diff --git a/sdk/openai/inference/src/options/azure_openai_client_options.rs b/sdk/openai/inference/src/options/azure_openai_client_options.rs index 538d60c171..438f2d0543 100644 --- a/sdk/openai/inference/src/options/azure_openai_client_options.rs +++ b/sdk/openai/inference/src/options/azure_openai_client_options.rs @@ -6,6 +6,7 @@ use crate::AzureServiceVersion; // TODO: I was not able to find ClientOptions as a derive macros #[derive(Clone, Debug, Default)] pub struct AzureOpenAIClientOptions { + #[allow(dead_code)] pub(crate) client_options: ClientOptions, pub(crate) api_service_version: AzureServiceVersion, } From 3eb7cac7534e544f221a553fbacd0e509d5872d4 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 17:15:52 +0200 Subject: [PATCH 56/71] Added docs for ChatCompletionsClient --- .../src/clients/azure_openai_client.rs | 1 - .../src/clients/chat_completions_client.rs | 24 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index fc20684e66..93310cd12b 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -121,7 +121,6 @@ impl AzureOpenAIClient { } impl AzureOpenAIClientMethods for AzureOpenAIClient { - /// Returns the endpoint [`Url`] of the client. fn endpoint(&self) -> &Url { &self.endpoint diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 2a77151053..02acbbd71a 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -7,7 +7,13 @@ use crate::{ use azure_core::{Context, Method, Response, Result}; use futures::{Stream, StreamExt}; +/// A [`ChatCompletionsClient`]'s methods. This trait can be used for mocking. pub trait ChatCompletionsClientMethods { + /// Creates a new chat completion. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. #[allow(async_fn_in_trait)] async fn create_chat_completions( &self, @@ -15,6 +21,11 @@ pub trait ChatCompletionsClientMethods { chat_completions_request: &CreateChatCompletionsRequest, ) -> Result>; + /// Creates a new chat completion and returns a streamed response. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. #[allow(async_fn_in_trait)] async fn stream_chat_completions( &self, @@ -23,7 +34,9 @@ pub trait ChatCompletionsClientMethods { ) -> Result>>; } +/// A client for Chat Completions related operations. pub struct ChatCompletionsClient { + /// The underlying HTTP client with an associated pipeline. base_client: Box, } @@ -34,6 +47,11 @@ impl ChatCompletionsClient { } impl ChatCompletionsClientMethods for ChatCompletionsClient { + /// Creates a new chat completion. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. async fn create_chat_completions( &self, deployment_name: impl AsRef, @@ -50,6 +68,11 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { .await } + /// Creates a new chat completion and returns a streamed response. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In OpenAI it is the model name to be used. + /// * `chat_completions_request` - The request specifying the chat completion to be created. async fn stream_chat_completions( &self, deployment_name: impl AsRef, @@ -71,6 +94,7 @@ impl ChatCompletionsClientMethods for ChatCompletionsClient { } } +/// A placeholder type to provide an implementation for the [`EventStreamer`] trait specifically for chat completions. struct ChatCompletionsStreamHandler; impl EventStreamer for ChatCompletionsStreamHandler { From f22d74e179d521d26a7784d87f859b053f54af2b Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Thu, 19 Sep 2024 17:24:41 +0200 Subject: [PATCH 57/71] clients module documented --- sdk/openai/inference/src/clients/mod.rs | 7 +++++ .../inference/src/clients/openai_client.rs | 29 ++++++++++++------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index a36450bdbb..fbb8bd1331 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -8,9 +8,16 @@ pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; pub use openai_client::{OpenAIClient, OpenAIClientMethods}; +/// A trait that defines the common behaviour expected from an [`OpenAIClient`] and an [`AzureOpenAIClient`]. +/// This trait will be used as a boxed types for any clients such as [`ChatCompletionsClient`] so they issue HTTP requests. trait BaseOpenAIClientMethods { + /// Returns the base [`Url`] of the underlying client. + /// + /// # Arguments + /// * `deployment_name` - The name of the deployment in Azure. In an [`OpenAIClient`] this parameter is ignored. fn base_url(&self, deployment_name: Option<&str>) -> azure_core::Result; + /// Returns the [`azure_core::Pipeline`] of the underlying client. fn pipeline(&self) -> &azure_core::Pipeline; } diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index 56cccb053e..6e0e34ddde 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -6,18 +6,12 @@ use crate::{auth::OpenAIKeyCredential, OpenAIClientOptions}; use super::{BaseOpenAIClientMethods, ChatCompletionsClient}; +/// Defines the methods provided by a [`OpenAIClient`] and can be used for mocking. pub trait OpenAIClientMethods { - fn with_key_credential( - secret: impl Into, - client_options: Option, - ) -> Result - where - Self: Sized; - fn chat_completions_client(&self) -> ChatCompletionsClient; } -/// A client that can be used to interact with the OpenAI API. +/// An OpenAI client. #[derive(Debug, Clone)] pub struct OpenAIClient { base_url: Url, @@ -26,8 +20,21 @@ pub struct OpenAIClient { options: OpenAIClientOptions, } -impl OpenAIClientMethods for OpenAIClient { - fn with_key_credential( +impl OpenAIClient { + /// Creates a new [`OpenAIClient`] using a secret key. + /// + /// # Parameters + /// * `secret` - The key credential used for authentication. + /// * `client_options` - Optional configuration for the client. Reserved for future used, currently can always be `None`. + /// + /// # Example + /// ```no_run + /// use azure_openai_inference::clients::{OpenAIClient, OpenAIClientMethods}; + /// + /// let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); + /// let client = OpenAIClient::with_key_credential(secret, None).unwrap(); + /// ``` + pub fn with_key_credential( secret: impl Into, client_options: Option, ) -> Result { @@ -43,7 +50,9 @@ impl OpenAIClientMethods for OpenAIClient { options, }) } +} +impl OpenAIClientMethods for OpenAIClient { fn chat_completions_client(&self) -> ChatCompletionsClient { ChatCompletionsClient::new(Box::new(self.clone())) } From 811279b4e2d2037f96fd733be150847e05145f90 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 11:32:28 +0200 Subject: [PATCH 58/71] renamed auth module to credentials --- sdk/openai/inference/src/clients/azure_openai_client.rs | 2 +- sdk/openai/inference/src/clients/openai_client.rs | 2 +- .../inference/src/{auth => credentials}/azure_key_credential.rs | 0 sdk/openai/inference/src/{auth => credentials}/mod.rs | 0 .../src/{auth => credentials}/openai_key_credential.rs | 0 sdk/openai/inference/src/lib.rs | 2 +- 6 files changed, 3 insertions(+), 3 deletions(-) rename sdk/openai/inference/src/{auth => credentials}/azure_key_credential.rs (100%) rename sdk/openai/inference/src/{auth => credentials}/mod.rs (100%) rename sdk/openai/inference/src/{auth => credentials}/openai_key_credential.rs (100%) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 93310cd12b..3e6e1568a0 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::auth::{AzureKeyCredential, DEFAULT_SCOPE}; +use crate::credentials::{AzureKeyCredential, DEFAULT_SCOPE}; use crate::options::AzureOpenAIClientOptions; use azure_core::auth::TokenCredential; diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index 6e0e34ddde..cf3a51d02f 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use azure_core::{Policy, Result, Url}; -use crate::{auth::OpenAIKeyCredential, OpenAIClientOptions}; +use crate::{credentials::OpenAIKeyCredential, OpenAIClientOptions}; use super::{BaseOpenAIClientMethods, ChatCompletionsClient}; diff --git a/sdk/openai/inference/src/auth/azure_key_credential.rs b/sdk/openai/inference/src/credentials/azure_key_credential.rs similarity index 100% rename from sdk/openai/inference/src/auth/azure_key_credential.rs rename to sdk/openai/inference/src/credentials/azure_key_credential.rs diff --git a/sdk/openai/inference/src/auth/mod.rs b/sdk/openai/inference/src/credentials/mod.rs similarity index 100% rename from sdk/openai/inference/src/auth/mod.rs rename to sdk/openai/inference/src/credentials/mod.rs diff --git a/sdk/openai/inference/src/auth/openai_key_credential.rs b/sdk/openai/inference/src/credentials/openai_key_credential.rs similarity index 100% rename from sdk/openai/inference/src/auth/openai_key_credential.rs rename to sdk/openai/inference/src/credentials/openai_key_credential.rs diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 6e746a2cca..28387bd322 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,4 +1,4 @@ -mod auth; +mod credentials; pub mod clients; mod helpers; mod models; From b2b8f4502e2be0f3d8c139cc495a41dcb6ec681b Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 11:40:59 +0200 Subject: [PATCH 59/71] Removed usage of azure_core::Result in examples --- .../inference/examples/azure_chat_completions.rs | 12 +++++++----- .../examples/azure_chat_completions_aad.rs | 15 ++++++++------- .../examples/azure_chat_completions_stream.rs | 11 +++++------ sdk/openai/inference/examples/chat_completions.rs | 13 ++++++++----- .../inference/examples/chat_completions_stream.rs | 13 ++++++------- sdk/openai/inference/src/lib.rs | 2 +- 6 files changed, 35 insertions(+), 31 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index b262fb08b4..736f852a1a 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,4 +1,3 @@ -use azure_core::Result; use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, request::CreateChatCompletionsRequest, @@ -7,7 +6,7 @@ use azure_openai_inference::{ // This example illustrates how to use Azure OpenAI with key credential authentication to generate a chat completion. #[tokio::main] -pub async fn main() -> Result<()> { +pub async fn main() { let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); @@ -20,7 +19,8 @@ pub async fn main() -> Result<()> { .with_api_version(AzureServiceVersion::V2023_12_01Preview) .build(), ), - )? + ) + .unwrap() .chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::with_user_message( @@ -34,12 +34,14 @@ pub async fn main() -> Result<()> { match response { Ok(chat_completions_response) => { - let chat_completions = chat_completions_response.deserialize_body().await?; + let chat_completions = chat_completions_response + .deserialize_body() + .await + .expect("Failed to deserialize response"); println!("{:#?}", &chat_completions); } Err(e) => { println!("Error: {}", e); } }; - Ok(()) } diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs index d2140cc519..09310b03ab 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_aad.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use azure_core::Result; use azure_identity::DefaultAzureCredentialBuilder; use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, @@ -10,19 +9,20 @@ use azure_openai_inference::{ /// This example illustrates how to use Azure OpenAI Chat Completions with Azure Active Directory authentication. #[tokio::main] -async fn main() -> Result<()> { +async fn main() { let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let chat_completions_client = AzureOpenAIClient::new( endpoint, - Arc::new(DefaultAzureCredentialBuilder::new().build()?), + Arc::new(DefaultAzureCredentialBuilder::new().build().unwrap()), Some( AzureOpenAIClientOptions::builder() .with_api_version(AzureServiceVersion::V2023_12_01Preview) .build(), ), - )? + ) + .unwrap() .chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::with_user_message( @@ -36,13 +36,14 @@ async fn main() -> Result<()> { match response { Ok(chat_completions_response) => { - let chat_completions = chat_completions_response.deserialize_body().await?; + let chat_completions = chat_completions_response + .deserialize_body() + .await + .expect("Failed to deserialize response"); println!("{:#?}", &chat_completions); } Err(e) => { println!("Error: {}", e); } }; - - Ok(()) } diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index 9352b0e677..4fdf66547b 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -1,4 +1,3 @@ -use azure_core::Result; use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, request::CreateChatCompletionsRequest, @@ -9,7 +8,7 @@ use std::io::{self, Write}; /// This example illustrates how to use Azure OpenAI with key credential authentication to stream chat completions. #[tokio::main] -async fn main() -> Result<()> { +async fn main() { let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT env variable"); let secret = std::env::var("AZURE_OPENAI_KEY").expect("Set AZURE_OPENAI_KEY env variable"); @@ -22,7 +21,8 @@ async fn main() -> Result<()> { .with_api_version(AzureServiceVersion::V2023_12_01Preview) .build(), ), - )? + ) + .unwrap() .chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::with_user_message_and_stream( @@ -32,7 +32,8 @@ async fn main() -> Result<()> { let response = chat_completions_client .stream_chat_completions(&chat_completions_request.model, &chat_completions_request) - .await?; + .await + .unwrap(); // this pins the stream to the stack so it is safe to poll it (namely, it won't be dealloacted or moved) futures::pin_mut!(response); @@ -52,6 +53,4 @@ async fn main() -> Result<()> { Err(e) => println!("Error: {:?}", e), } } - - Ok(()) } diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index 069b12c800..25fbc08477 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -5,11 +5,12 @@ use azure_openai_inference::{ /// This example illustrates how to use OpenAI to generate a chat completion. #[tokio::main] -pub async fn main() -> azure_core::Result<()> { +pub async fn main() { let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); - let chat_completions_client = - OpenAIClient::with_key_credential(secret, None)?.chat_completions_client(); + let chat_completions_client = OpenAIClient::with_key_credential(secret, None) + .unwrap() + .chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::with_user_message( "gpt-3.5-turbo-1106", @@ -22,12 +23,14 @@ pub async fn main() -> azure_core::Result<()> { match response { Ok(chat_completions_response) => { - let chat_completions = chat_completions_response.deserialize_body().await?; + let chat_completions = chat_completions_response + .deserialize_body() + .await + .expect("Failed to deserialize response"); println!("{:#?}", &chat_completions); } Err(e) => { println!("Error: {}", e); } }; - Ok(()) } diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index cf85577a36..8f2492e574 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -1,4 +1,3 @@ -use azure_core::Result; use azure_openai_inference::{ clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, request::CreateChatCompletionsRequest, @@ -8,11 +7,12 @@ use std::io::{self, Write}; /// This example illustrates how to use OpenAI to stream chat completions. #[tokio::main] -async fn main() -> Result<()> { +async fn main() { let secret = std::env::var("OPENAI_KEY").expect("Set OPENAI_KEY env variable"); - let chat_completions_client = - OpenAIClient::with_key_credential(secret, None)?.chat_completions_client(); + let chat_completions_client = OpenAIClient::with_key_credential(secret, None) + .unwrap() + .chat_completions_client(); let chat_completions_request = CreateChatCompletionsRequest::with_user_message_and_stream( "gpt-3.5-turbo-1106", @@ -21,7 +21,8 @@ async fn main() -> Result<()> { let response = chat_completions_client .stream_chat_completions(&chat_completions_request.model, &chat_completions_request) - .await?; + .await + .unwrap(); // this pins the stream to the stack so it is safe to poll it (namely, it won't be dealloacted or moved) futures::pin_mut!(response); @@ -41,6 +42,4 @@ async fn main() -> Result<()> { Err(e) => println!("Error: {:?}", e), } } - - Ok(()) } diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 28387bd322..8251ecf5f0 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,5 +1,5 @@ -mod credentials; pub mod clients; +mod credentials; mod helpers; mod models; mod options; From e427611de48fff563a76cc0a01cd6b1ed2d2a6fa Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 11:42:02 +0200 Subject: [PATCH 60/71] crate description correction --- sdk/openai/inference/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/inference/Cargo.toml b/sdk/openai/inference/Cargo.toml index 979983ba91..0f1700da99 100644 --- a/sdk/openai/inference/Cargo.toml +++ b/sdk/openai/inference/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "azure_openai_inference" version = "1.0.0-beta.1" -description = "Rust client SDK for Azure OpenAI Inference" +description = "Rust client library for Azure OpenAI Inference" readme = "README.md" authors.workspace = true edition.workspace = true From 3802d5567a89b2f2e116f36067a523cf2cc4434c Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 02:42:41 -0700 Subject: [PATCH 61/71] Update sdk/openai/inference/src/lib.rs Co-authored-by: Heath Stewart --- sdk/openai/inference/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 8251ecf5f0..0587fac590 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -4,5 +4,5 @@ mod helpers; mod models; mod options; -pub use models::*; +pub use models::{request::*, response}; pub use options::*; From 984645be4c2c880931928934be7a86117084ac22 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 11:48:13 +0200 Subject: [PATCH 62/71] request module flattening fixes --- sdk/openai/inference/examples/azure_chat_completions.rs | 3 +-- sdk/openai/inference/examples/azure_chat_completions_aad.rs | 3 +-- sdk/openai/inference/examples/azure_chat_completions_stream.rs | 3 +-- sdk/openai/inference/examples/chat_completions.rs | 2 +- sdk/openai/inference/examples/chat_completions_stream.rs | 2 +- sdk/openai/inference/src/clients/chat_completions_client.rs | 2 +- 6 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index 736f852a1a..f4561e1370 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,7 +1,6 @@ use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, - request::CreateChatCompletionsRequest, - AzureOpenAIClientOptions, AzureServiceVersion, + AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, }; // This example illustrates how to use Azure OpenAI with key credential authentication to generate a chat completion. diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs index 09310b03ab..d9558e6196 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_aad.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -3,8 +3,7 @@ use std::sync::Arc; use azure_identity::DefaultAzureCredentialBuilder; use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, - request::CreateChatCompletionsRequest, - AzureOpenAIClientOptions, AzureServiceVersion, + AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, }; /// This example illustrates how to use Azure OpenAI Chat Completions with Azure Active Directory authentication. diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index 4fdf66547b..ca5bb5f229 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -1,7 +1,6 @@ use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, - request::CreateChatCompletionsRequest, - AzureOpenAIClientOptions, AzureServiceVersion, + AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, }; use futures::stream::StreamExt; use std::io::{self, Write}; diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index 25fbc08477..589e97c296 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -1,6 +1,6 @@ use azure_openai_inference::{ clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, - request::CreateChatCompletionsRequest, + CreateChatCompletionsRequest, }; /// This example illustrates how to use OpenAI to generate a chat completion. diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index 8f2492e574..b925049be5 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -1,6 +1,6 @@ use azure_openai_inference::{ clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, - request::CreateChatCompletionsRequest, + CreateChatCompletionsRequest, }; use futures::stream::StreamExt; use std::io::{self, Write}; diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index 02acbbd71a..d3ed88e177 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,8 +1,8 @@ use super::{new_json_request, BaseOpenAIClientMethods}; use crate::{ helpers::streaming::{string_chunks, EventStreamer}, - request::CreateChatCompletionsRequest, response::{CreateChatCompletionsResponse, CreateChatCompletionsStreamResponse}, + CreateChatCompletionsRequest, }; use azure_core::{Context, Method, Response, Result}; use futures::{Stream, StreamExt}; From a82e4725c18d5044d25c4be733d4c4fdb29cbe8b Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 11:52:23 +0200 Subject: [PATCH 63/71] Added license header to source files --- sdk/openai/inference/examples/azure_chat_completions.rs | 2 ++ sdk/openai/inference/examples/azure_chat_completions_aad.rs | 2 ++ sdk/openai/inference/examples/azure_chat_completions_stream.rs | 2 ++ sdk/openai/inference/examples/chat_completions.rs | 2 ++ sdk/openai/inference/examples/chat_completions_stream.rs | 2 ++ sdk/openai/inference/src/clients/azure_openai_client.rs | 2 ++ sdk/openai/inference/src/clients/chat_completions_client.rs | 2 ++ sdk/openai/inference/src/clients/mod.rs | 2 ++ sdk/openai/inference/src/clients/openai_client.rs | 2 ++ sdk/openai/inference/src/credentials/azure_key_credential.rs | 2 ++ sdk/openai/inference/src/credentials/mod.rs | 2 ++ sdk/openai/inference/src/credentials/openai_key_credential.rs | 2 ++ sdk/openai/inference/src/helpers/mod.rs | 2 ++ sdk/openai/inference/src/helpers/streaming.rs | 2 ++ sdk/openai/inference/src/lib.rs | 2 ++ sdk/openai/inference/src/models/chat_completions.rs | 2 ++ sdk/openai/inference/src/models/mod.rs | 2 ++ sdk/openai/inference/src/options/azure_openai_client_options.rs | 2 ++ sdk/openai/inference/src/options/mod.rs | 2 ++ sdk/openai/inference/src/options/openai_client_options.rs | 2 ++ sdk/openai/inference/src/options/service_version.rs | 2 ++ 21 files changed, 42 insertions(+) diff --git a/sdk/openai/inference/examples/azure_chat_completions.rs b/sdk/openai/inference/examples/azure_chat_completions.rs index f4561e1370..1e0da9b73a 100644 --- a/sdk/openai/inference/examples/azure_chat_completions.rs +++ b/sdk/openai/inference/examples/azure_chat_completions.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs index d9558e6196..9afab7c0af 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_aad.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use std::sync::Arc; use azure_identity::DefaultAzureCredentialBuilder; diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index ca5bb5f229..aaa2f24797 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, AzureOpenAIClientOptions, AzureServiceVersion, CreateChatCompletionsRequest, diff --git a/sdk/openai/inference/examples/chat_completions.rs b/sdk/openai/inference/examples/chat_completions.rs index 589e97c296..a5e6fe8261 100644 --- a/sdk/openai/inference/examples/chat_completions.rs +++ b/sdk/openai/inference/examples/chat_completions.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use azure_openai_inference::{ clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, CreateChatCompletionsRequest, diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index b925049be5..d7ffc6488b 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use azure_openai_inference::{ clients::{ChatCompletionsClientMethods, OpenAIClient, OpenAIClientMethods}, CreateChatCompletionsRequest, diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 3e6e1568a0..6655677e99 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use std::sync::Arc; use crate::credentials::{AzureKeyCredential, DEFAULT_SCOPE}; diff --git a/sdk/openai/inference/src/clients/chat_completions_client.rs b/sdk/openai/inference/src/clients/chat_completions_client.rs index d3ed88e177..9d2016c267 100644 --- a/sdk/openai/inference/src/clients/chat_completions_client.rs +++ b/sdk/openai/inference/src/clients/chat_completions_client.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use super::{new_json_request, BaseOpenAIClientMethods}; use crate::{ helpers::streaming::{string_chunks, EventStreamer}, diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index fbb8bd1331..85699ada28 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. mod azure_openai_client; mod chat_completions_client; mod openai_client; diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index cf3a51d02f..a8cad42bb1 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use std::sync::Arc; use azure_core::{Policy, Result, Url}; diff --git a/sdk/openai/inference/src/credentials/azure_key_credential.rs b/sdk/openai/inference/src/credentials/azure_key_credential.rs index b7c2d18e76..0593c9cbbe 100644 --- a/sdk/openai/inference/src/credentials/azure_key_credential.rs +++ b/sdk/openai/inference/src/credentials/azure_key_credential.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use async_trait::async_trait; use std::sync::Arc; diff --git a/sdk/openai/inference/src/credentials/mod.rs b/sdk/openai/inference/src/credentials/mod.rs index 11b63877da..5b81a8355c 100644 --- a/sdk/openai/inference/src/credentials/mod.rs +++ b/sdk/openai/inference/src/credentials/mod.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. mod azure_key_credential; mod openai_key_credential; diff --git a/sdk/openai/inference/src/credentials/openai_key_credential.rs b/sdk/openai/inference/src/credentials/openai_key_credential.rs index a543214d91..a95cb90eba 100644 --- a/sdk/openai/inference/src/credentials/openai_key_credential.rs +++ b/sdk/openai/inference/src/credentials/openai_key_credential.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use async_trait::async_trait; use std::sync::Arc; diff --git a/sdk/openai/inference/src/helpers/mod.rs b/sdk/openai/inference/src/helpers/mod.rs index c65f3f1305..2d0184ea05 100644 --- a/sdk/openai/inference/src/helpers/mod.rs +++ b/sdk/openai/inference/src/helpers/mod.rs @@ -1 +1,3 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. pub(crate) mod streaming; diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 95ef9b4a3d..286aced5ab 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use azure_core::{Error, Result}; use futures::{Stream, StreamExt}; diff --git a/sdk/openai/inference/src/lib.rs b/sdk/openai/inference/src/lib.rs index 0587fac590..dceeb85b4d 100644 --- a/sdk/openai/inference/src/lib.rs +++ b/sdk/openai/inference/src/lib.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. pub mod clients; mod credentials; mod helpers; diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs index d5a2c49c29..3f7e439ba9 100644 --- a/sdk/openai/inference/src/models/chat_completions.rs +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. pub mod request { use serde::Serialize; diff --git a/sdk/openai/inference/src/models/mod.rs b/sdk/openai/inference/src/models/mod.rs index 8ccec0e32c..67a1b7a918 100644 --- a/sdk/openai/inference/src/models/mod.rs +++ b/sdk/openai/inference/src/models/mod.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. mod chat_completions; pub use chat_completions::*; diff --git a/sdk/openai/inference/src/options/azure_openai_client_options.rs b/sdk/openai/inference/src/options/azure_openai_client_options.rs index 438f2d0543..bcdbef9047 100644 --- a/sdk/openai/inference/src/options/azure_openai_client_options.rs +++ b/sdk/openai/inference/src/options/azure_openai_client_options.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use azure_core::ClientOptions; use crate::AzureServiceVersion; diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs index bc1cf1eb13..a66a5734dd 100644 --- a/sdk/openai/inference/src/options/mod.rs +++ b/sdk/openai/inference/src/options/mod.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. mod azure_openai_client_options; mod openai_client_options; mod service_version; diff --git a/sdk/openai/inference/src/options/openai_client_options.rs b/sdk/openai/inference/src/options/openai_client_options.rs index 96f99e9355..b428727dae 100644 --- a/sdk/openai/inference/src/options/openai_client_options.rs +++ b/sdk/openai/inference/src/options/openai_client_options.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use azure_core::ClientOptions; /// Options to be passed to [`OpenAIClient`](crate::clients::OpenAIClient). diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs index c68d7b8e7f..8cfb335b80 100644 --- a/sdk/openai/inference/src/options/service_version.rs +++ b/sdk/openai/inference/src/options/service_version.rs @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. use async_trait::async_trait; use std::sync::Arc; From 687858a512d1c277e7555e92b01cdf670d7a0a96 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 15:13:39 +0200 Subject: [PATCH 64/71] rebase new changes to feature/track2 and changed from auth->credentials mod --- sdk/openai/inference/examples/azure_chat_completions_aad.rs | 6 +++--- sdk/openai/inference/src/clients/azure_openai_client.rs | 2 +- .../inference/src/credentials/azure_key_credential.rs | 2 +- .../inference/src/credentials/openai_key_credential.rs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions_aad.rs b/sdk/openai/inference/examples/azure_chat_completions_aad.rs index 9afab7c0af..4d187d4bfb 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_aad.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_aad.rs @@ -1,7 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use std::sync::Arc; - use azure_identity::DefaultAzureCredentialBuilder; use azure_openai_inference::{ clients::{AzureOpenAIClient, AzureOpenAIClientMethods, ChatCompletionsClientMethods}, @@ -16,7 +14,9 @@ async fn main() { let chat_completions_client = AzureOpenAIClient::new( endpoint, - Arc::new(DefaultAzureCredentialBuilder::new().build().unwrap()), + DefaultAzureCredentialBuilder::new() + .build() + .expect("Failed to create Azure credential"), Some( AzureOpenAIClientOptions::builder() .with_api_version(AzureServiceVersion::V2023_12_01Preview) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 6655677e99..60bd4ac7a4 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use crate::credentials::{AzureKeyCredential, DEFAULT_SCOPE}; use crate::options::AzureOpenAIClientOptions; -use azure_core::auth::TokenCredential; +use azure_core::credentials::TokenCredential; use azure_core::{self, Policy, Result}; use azure_core::{BearerTokenCredentialPolicy, Url}; diff --git a/sdk/openai/inference/src/credentials/azure_key_credential.rs b/sdk/openai/inference/src/credentials/azure_key_credential.rs index 0593c9cbbe..e137e197b7 100644 --- a/sdk/openai/inference/src/credentials/azure_key_credential.rs +++ b/sdk/openai/inference/src/credentials/azure_key_credential.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use std::sync::Arc; use azure_core::{ - auth::Secret, + credentials::Secret, headers::{HeaderName, HeaderValue}, Context, Header, Policy, PolicyResult, Request, }; diff --git a/sdk/openai/inference/src/credentials/openai_key_credential.rs b/sdk/openai/inference/src/credentials/openai_key_credential.rs index a95cb90eba..dce03c83f0 100644 --- a/sdk/openai/inference/src/credentials/openai_key_credential.rs +++ b/sdk/openai/inference/src/credentials/openai_key_credential.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use std::sync::Arc; use azure_core::{ - auth::Secret, + credentials::Secret, headers::{HeaderName, HeaderValue, AUTHORIZATION}, Context, Header, Policy, PolicyResult, Request, }; From 611f716ae603b559d1c845ed8d096413dc4706d2 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 15:44:49 +0200 Subject: [PATCH 65/71] Addressed clippy warnings and errors --- .../src/credentials/azure_key_credential.rs | 8 ++++---- sdk/openai/inference/src/credentials/mod.rs | 3 +-- .../src/credentials/openai_key_credential.rs | 6 +++--- sdk/openai/inference/src/helpers/streaming.rs | 2 +- sdk/openai/inference/src/options/mod.rs | 4 ++-- .../inference/src/options/service_version.rs | 16 ++++++++-------- 6 files changed, 19 insertions(+), 20 deletions(-) diff --git a/sdk/openai/inference/src/credentials/azure_key_credential.rs b/sdk/openai/inference/src/credentials/azure_key_credential.rs index e137e197b7..f39f560d2c 100644 --- a/sdk/openai/inference/src/credentials/azure_key_credential.rs +++ b/sdk/openai/inference/src/credentials/azure_key_credential.rs @@ -38,7 +38,7 @@ impl Header for AzureKeyCredential { } fn value(&self) -> HeaderValue { - HeaderValue::from_cow(format!("{}", self.0.secret())) + HeaderValue::from_cow(self.0.secret().to_string()) } } @@ -57,8 +57,8 @@ impl Policy for AzureKeyCredential { } } -impl Into> for AzureKeyCredential { - fn into(self) -> Arc { - Arc::new(self) +impl From for Arc { + fn from(credential: AzureKeyCredential) -> Arc { + Arc::new(credential) } } diff --git a/sdk/openai/inference/src/credentials/mod.rs b/sdk/openai/inference/src/credentials/mod.rs index 5b81a8355c..5ace38ea0e 100644 --- a/sdk/openai/inference/src/credentials/mod.rs +++ b/sdk/openai/inference/src/credentials/mod.rs @@ -6,5 +6,4 @@ mod openai_key_credential; pub(crate) use azure_key_credential::*; pub(crate) use openai_key_credential::*; -pub(crate) const DEFAULT_SCOPE: [&'static str; 1] = - ["https://cognitiveservices.azure.com/.default"]; +pub(crate) const DEFAULT_SCOPE: [&str; 1] = ["https://cognitiveservices.azure.com/.default"]; diff --git a/sdk/openai/inference/src/credentials/openai_key_credential.rs b/sdk/openai/inference/src/credentials/openai_key_credential.rs index dce03c83f0..7bf09d0c1f 100644 --- a/sdk/openai/inference/src/credentials/openai_key_credential.rs +++ b/sdk/openai/inference/src/credentials/openai_key_credential.rs @@ -55,8 +55,8 @@ impl Policy for OpenAIKeyCredential { } } -impl Into> for OpenAIKeyCredential { - fn into(self) -> Arc { - Arc::new(self) +impl From for Arc { + fn from(credential: OpenAIKeyCredential) -> Arc { + Arc::new(credential) } } diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index 286aced5ab..c245530a41 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -56,7 +56,7 @@ pub(crate) fn string_chunks<'a>( None }; } - if chunk_buffer.len() > 0 { + if !chunk_buffer.is_empty() { return Some(( Err(Error::with_message( azure_core::error::ErrorKind::DataConversion, diff --git a/sdk/openai/inference/src/options/mod.rs b/sdk/openai/inference/src/options/mod.rs index a66a5734dd..ed2a303a01 100644 --- a/sdk/openai/inference/src/options/mod.rs +++ b/sdk/openai/inference/src/options/mod.rs @@ -4,6 +4,6 @@ mod azure_openai_client_options; mod openai_client_options; mod service_version; -pub use azure_openai_client_options::*; -pub use openai_client_options::*; +pub use azure_openai_client_options::{builders::*, AzureOpenAIClientOptions}; +pub use openai_client_options::{builders::*, OpenAIClientOptions}; pub use service_version::AzureServiceVersion; diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs index 8cfb335b80..724e236876 100644 --- a/sdk/openai/inference/src/options/service_version.rs +++ b/sdk/openai/inference/src/options/service_version.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. use async_trait::async_trait; -use std::sync::Arc; +use std::{fmt::Display, sync::Arc}; use azure_core::{Context, Policy, PolicyResult, Request}; @@ -36,13 +36,13 @@ impl From for String { AzureServiceVersion::V2023_12_01Preview => "2023-12-01-preview", AzureServiceVersion::V2024_07_01Preview => "2024-07-01-preview", }; - return String::from(as_str); + String::from(as_str) } } -impl ToString for AzureServiceVersion { - fn to_string(&self) -> String { - String::from(self.clone()) +impl Display for AzureServiceVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&String::from(self.clone())) } } @@ -64,8 +64,8 @@ impl Policy for AzureServiceVersion { } } -impl Into> for AzureServiceVersion { - fn into(self) -> Arc { - Arc::new(self) +impl From for Arc { + fn from(version: AzureServiceVersion) -> Arc { + Arc::new(version) } } From 8f82cd16fd61f7ff041282119e2cc5796cc793d4 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 16:15:52 +0200 Subject: [PATCH 66/71] Fixed tests in comments --- .../inference/src/clients/azure_openai_client.rs | 3 +-- sdk/openai/inference/src/models/chat_completions.rs | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index 60bd4ac7a4..e0765ffb5f 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -49,12 +49,11 @@ impl AzureOpenAIClient { /// ```no_run /// use azure_openai_inference::clients::{AzureOpenAIClient, AzureOpenAIClientMethods}; /// use azure_identity::DefaultAzureCredentialBuilder; - /// use std::sync::Arc; /// /// let endpoint = std::env::var("AZURE_OPENAI_ENDPOINT").expect("Set AZURE_OPENAI_ENDPOINT environment variable"); /// let client = AzureOpenAIClient::new( /// endpoint, - /// Arc::new(DefaultAzureCredentialBuilder::new().build().unwrap()), + /// DefaultAzureCredentialBuilder::new().build().unwrap(), /// None, /// ).unwrap(); /// ``` diff --git a/sdk/openai/inference/src/models/chat_completions.rs b/sdk/openai/inference/src/models/chat_completions.rs index 3f7e439ba9..9d11d55223 100644 --- a/sdk/openai/inference/src/models/chat_completions.rs +++ b/sdk/openai/inference/src/models/chat_completions.rs @@ -64,7 +64,7 @@ pub mod request { /// # Example /// /// ```rust - /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message( + /// let request = azure_openai_inference::CreateChatCompletionsRequest::with_user_message( /// "gpt-3.5-turbo-1106", /// "Why couldn't the eagles take Frodo directly to mount doom?"); /// ``` @@ -81,7 +81,7 @@ pub mod request { /// # Example /// /// ```rust - /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_user_message_and_stream( + /// let request = azure_openai_inference::CreateChatCompletionsRequest::with_user_message_and_stream( /// "gpt-3.5-turbo-1106", /// "Why couldn't the eagles take Frodo directly to Mount Doom?"); /// ``` @@ -101,11 +101,11 @@ pub mod request { /// /// # Example /// ```rust - /// let request = azure_openai_inference::request::CreateChatCompletionsRequest::with_messages( + /// let request = azure_openai_inference::CreateChatCompletionsRequest::with_messages( /// "gpt-3.5-turbo-1106", /// vec![ - /// azure_openai_inference::request::ChatCompletionRequestMessage::with_system_role("You are a good math tutor who explains things briefly."), - /// azure_openai_inference::request::ChatCompletionRequestMessage::with_user_role("What is the value of 'x' in the equation: '2x + 3 = 11'?"), + /// azure_openai_inference::ChatCompletionRequestMessage::with_system_role("You are a good math tutor who explains things briefly."), + /// azure_openai_inference::ChatCompletionRequestMessage::with_user_role("What is the value of 'x' in the equation: '2x + 3 = 11'?"), /// ]); pub fn with_messages( model: impl Into, From bc04acd91f2144491b770cf1e34dca88cf1d7acb Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 16:19:06 +0200 Subject: [PATCH 67/71] No longer holding reference to options in the base clients --- .../inference/src/clients/azure_openai_client.rs | 16 ++-------------- .../inference/src/clients/openai_client.rs | 8 +------- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index e0765ffb5f..ebef3cf4b5 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -29,10 +29,6 @@ pub struct AzureOpenAIClient { /// The pipeline for sending requests to the service. pipeline: azure_core::Pipeline, - - /// The options for the client. - #[allow(dead_code)] - options: AzureOpenAIClientOptions, } impl AzureOpenAIClient { @@ -72,11 +68,7 @@ impl AzureOpenAIClient { let pipeline = super::new_pipeline(per_call_policies, options.client_options.clone()); - Ok(AzureOpenAIClient { - endpoint, - pipeline, - options, - }) + Ok(AzureOpenAIClient { endpoint, pipeline }) } /// Creates a new [`AzureOpenAIClient`] using a key credential @@ -113,11 +105,7 @@ impl AzureOpenAIClient { let pipeline = super::new_pipeline(per_call_policies, options.client_options.clone()); - Ok(AzureOpenAIClient { - endpoint, - pipeline, - options, - }) + Ok(AzureOpenAIClient { endpoint, pipeline }) } } diff --git a/sdk/openai/inference/src/clients/openai_client.rs b/sdk/openai/inference/src/clients/openai_client.rs index a8cad42bb1..7d0eeed3ce 100644 --- a/sdk/openai/inference/src/clients/openai_client.rs +++ b/sdk/openai/inference/src/clients/openai_client.rs @@ -18,8 +18,6 @@ pub trait OpenAIClientMethods { pub struct OpenAIClient { base_url: Url, pipeline: azure_core::Pipeline, - #[allow(dead_code)] - options: OpenAIClientOptions, } impl OpenAIClient { @@ -46,11 +44,7 @@ impl OpenAIClient { let pipeline = super::new_pipeline(vec![auth_policy], options.client_options.clone()); - Ok(OpenAIClient { - base_url, - pipeline, - options, - }) + Ok(OpenAIClient { base_url, pipeline }) } } From 4240214d7e719266da6c8f31db571241c44ccf4e Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 16:35:32 +0200 Subject: [PATCH 68/71] Broken struct link fixed --- sdk/openai/inference/src/options/service_version.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/openai/inference/src/options/service_version.rs b/sdk/openai/inference/src/options/service_version.rs index 724e236876..02df7b5b25 100644 --- a/sdk/openai/inference/src/options/service_version.rs +++ b/sdk/openai/inference/src/options/service_version.rs @@ -6,7 +6,7 @@ use std::{fmt::Display, sync::Arc}; use azure_core::{Context, Policy, PolicyResult, Request}; /// The version of the Azure service to use. -/// This enum is passed to the [`AzureOpenAIClientOptionsBuilder`](crate::builders::AzureOpenAIClientOptionsBuilder) to configure an [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient) to specify the version of the service to use. +/// This enum is passed to the [`AzureOpenAIClientOptionsBuilder`](crate::AzureOpenAIClientOptionsBuilder) to configure an [`AzureOpenAIClient`](crate::clients::AzureOpenAIClient) to specify the version of the service to use. /// /// If no version is specified, the latest version will be used. See [`AzureServiceVersion::get_latest()`](AzureServiceVersion::get_latest). #[derive(Debug, Clone)] From fe004b777d401aa07ef78e957c26677a357cd560 Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 16:43:04 +0200 Subject: [PATCH 69/71] Added openai to word list --- .vscode/cspell.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index a22575878d..76ff9c6890 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -39,6 +39,7 @@ "iothub", "keyvault", "msrc", + "openai", "pageable", "pkce", "pkcs", From 314dc75676ba3ab58db89f27729c6bebd307774a Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 16:57:30 +0200 Subject: [PATCH 70/71] More spell checks --- .vscode/cspell.json | 1 + .../inference/examples/azure_chat_completions_stream.rs | 2 +- sdk/openai/inference/examples/chat_completions_stream.rs | 2 +- sdk/openai/inference/src/clients/azure_openai_client.rs | 2 +- sdk/openai/inference/src/clients/mod.rs | 2 +- sdk/openai/inference/src/helpers/streaming.rs | 6 +++--- 6 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 76ff9c6890..991ae2753a 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -32,6 +32,7 @@ "downcasted", "downcasting", "entra", + "endregion", "etag", "eventhub", "eventhubs", diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index aaa2f24797..04ffd96fad 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -36,7 +36,7 @@ async fn main() { .await .unwrap(); - // this pins the stream to the stack so it is safe to poll it (namely, it won't be dealloacted or moved) + // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-alloacted or moved) futures::pin_mut!(response); while let Some(result) = response.next().await { diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index d7ffc6488b..f5de952120 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -26,7 +26,7 @@ async fn main() { .await .unwrap(); - // this pins the stream to the stack so it is safe to poll it (namely, it won't be dealloacted or moved) + // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-alloacted or moved) futures::pin_mut!(response); while let Some(result) = response.next().await { diff --git a/sdk/openai/inference/src/clients/azure_openai_client.rs b/sdk/openai/inference/src/clients/azure_openai_client.rs index ebef3cf4b5..e5e1004b84 100644 --- a/sdk/openai/inference/src/clients/azure_openai_client.rs +++ b/sdk/openai/inference/src/clients/azure_openai_client.rs @@ -75,7 +75,7 @@ impl AzureOpenAIClient { /// /// # Parameters /// * `endpoint` - The full URL of the Azure OpenAI resource endpoint. - /// * `secret` - The key creadential used for authentication. Passed as header parameter in the request. + /// * `secret` - The key credential used for authentication. Passed as header parameter in the request. /// * `client_options` - Optional configuration for the client. The [`AzureServiceVersion`](crate::options::AzureServiceVersion) can be provided here. /// /// # Example diff --git a/sdk/openai/inference/src/clients/mod.rs b/sdk/openai/inference/src/clients/mod.rs index 85699ada28..f892407cda 100644 --- a/sdk/openai/inference/src/clients/mod.rs +++ b/sdk/openai/inference/src/clients/mod.rs @@ -10,7 +10,7 @@ pub use azure_openai_client::{AzureOpenAIClient, AzureOpenAIClientMethods}; pub use chat_completions_client::{ChatCompletionsClient, ChatCompletionsClientMethods}; pub use openai_client::{OpenAIClient, OpenAIClientMethods}; -/// A trait that defines the common behaviour expected from an [`OpenAIClient`] and an [`AzureOpenAIClient`]. +/// A trait that defines the common behavior expected from an [`OpenAIClient`] and an [`AzureOpenAIClient`]. /// This trait will be used as a boxed types for any clients such as [`ChatCompletionsClient`] so they issue HTTP requests. trait BaseOpenAIClientMethods { /// Returns the base [`Url`] of the underlying client. diff --git a/sdk/openai/inference/src/helpers/streaming.rs b/sdk/openai/inference/src/helpers/streaming.rs index c245530a41..9f234acb22 100644 --- a/sdk/openai/inference/src/helpers/streaming.rs +++ b/sdk/openai/inference/src/helpers/streaming.rs @@ -45,7 +45,7 @@ pub(crate) fn string_chunks<'a>( return if let Ok(yielded_value) = std::str::from_utf8(&bytes) { // We strip the "data: " portion of the event. The rest is always JSON and will be deserialized - // by a subsquent mapping function for this stream + // by a subsequent mapping function for this stream let yielded_value = yielded_value.trim_start_matches("data:").trim(); if yielded_value == "[DONE]" { return None; @@ -95,7 +95,7 @@ pub(crate) fn string_chunks<'a>( }, ); - // We specifically allow the Error::with_messagge(ErrorKind::DataConversion, || "Incomplete chunk") + // We specifically allow the Error::with_message(ErrorKind::DataConversion, || "Incomplete chunk") // So that we are able to continue pushing bytes to the buffer until we find the next delimiter return stream.filter(|it| { std::future::ready( @@ -186,7 +186,7 @@ mod tests { } #[tokio::test] - async fn event_delimeter_split_across_chunks() { + async fn event_delimiter_split_across_chunks() { let mut source_stream = futures::stream::iter(vec![ Ok(bytes::Bytes::from_static(b"data: piece 1\n")), Ok(bytes::Bytes::from_static(b"\ndata: [DONE]")), From d148db8c536296058f9c8d439bf1faab7cee2bda Mon Sep 17 00:00:00 2001 From: Jose Alvarez Date: Mon, 30 Sep 2024 17:27:45 +0200 Subject: [PATCH 71/71] more spell checks --- sdk/openai/inference/examples/azure_chat_completions_stream.rs | 2 +- sdk/openai/inference/examples/chat_completions_stream.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/openai/inference/examples/azure_chat_completions_stream.rs b/sdk/openai/inference/examples/azure_chat_completions_stream.rs index 04ffd96fad..233326b1f3 100644 --- a/sdk/openai/inference/examples/azure_chat_completions_stream.rs +++ b/sdk/openai/inference/examples/azure_chat_completions_stream.rs @@ -36,7 +36,7 @@ async fn main() { .await .unwrap(); - // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-alloacted or moved) + // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-allocated or moved) futures::pin_mut!(response); while let Some(result) = response.next().await { diff --git a/sdk/openai/inference/examples/chat_completions_stream.rs b/sdk/openai/inference/examples/chat_completions_stream.rs index f5de952120..3058d3e0bb 100644 --- a/sdk/openai/inference/examples/chat_completions_stream.rs +++ b/sdk/openai/inference/examples/chat_completions_stream.rs @@ -26,7 +26,7 @@ async fn main() { .await .unwrap(); - // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-alloacted or moved) + // this pins the stream to the stack so it is safe to poll it (namely, it won't be de-allocated or moved) futures::pin_mut!(response); while let Some(result) = response.next().await {