Skip to content

Commit

Permalink
Rough sketch of Workers AI binding. Relates to #417
Browse files Browse the repository at this point in the history
  • Loading branch information
kflansburg committed Mar 22, 2024
1 parent 3db61a4 commit 0af3e61
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ async-trait.workspace = true
bytes = "1.5"
chrono.workspace = true
chrono-tz.workspace = true
flate2 = "1.0"
futures-channel.workspace = true
futures-util.workspace = true
wasm-bindgen.workspace = true
Expand Down
98 changes: 98 additions & 0 deletions worker/src/ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use crate::{Headers, RequestInit, Result};
use flate2::read::GzDecoder;
#[cfg(feature = "http")]
use std::convert::TryInto;

#[derive(Serialize)]
pub struct TextEmbeddingInput {
pub text: Vec<String>,
}

#[derive(Deserialize, Debug)]
pub struct TextEmbeddingOutput {
pub shape: Vec<usize>,
pub data: Vec<Vec<f64>>,
}

pub struct TextEmbeddingModel;

impl ModelKind for TextEmbeddingModel {
type Input = TextEmbeddingInput;
type Output = TextEmbeddingOutput;
}

pub trait ModelKind {
type Input: Serialize;
type Output: DeserializeOwned;
}

/// Object to interact with Workers AI binding.
///
/// ```rust
/// let ai = env.ai("AI")?;
///
/// let input = TextEmbeddingInput {
/// text: "This is a test embedding".to_owned(),
/// };
///
/// let result = ai
/// .run::<TextEmbeddingModel>("@cf/baai/bge-base-en-v1.5", input)
/// .await?;
/// ```
pub struct Ai {
inner: crate::Fetcher,
}

#[derive(Serialize)]
struct AiRequestOptions {
debug: bool,
}

#[derive(Serialize)]
struct AiRequest<Input: serde::Serialize> {
inputs: Input,
options: AiRequestOptions,
}

impl Ai {
pub(crate) fn new(inner: crate::Fetcher) -> Self {
Self { inner }
}

pub async fn run<M>(&self, model: &str, inputs: M::Input) -> Result<M::Output>
where
M: ModelKind,
{
use std::io::prelude::*;
let request = AiRequest::<M::Input> {
inputs,
options: AiRequestOptions { debug: false },
};
let payload = serde_json::to_string(&request)?;
let mut init = RequestInit::new();
init.with_body(Some(payload.into()));
let mut headers = Headers::new();
headers.append("content-encoding", "application/json")?;
headers.append("cf-consn-model-id", model)?;
init.with_headers(headers);
init.with_method(crate::Method::Post);
let response = self
.inner
.fetch("http://workers-binding.ai/run?version=2", Some(init))
.await?;

#[cfg(feature = "http")]
let mut resp: crate::Response = response.try_into()?;
#[cfg(not(feature = "http"))]
let mut resp = response;

let data = resp.bytes().await?.to_owned();
let mut d = GzDecoder::new(&data[..]);
let mut text = String::new();
d.read_to_string(&mut text).unwrap();
let body: M::Output = serde_json::from_str(&text)?;
Ok(body)
}
}
4 changes: 4 additions & 0 deletions worker/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ impl Env {
}
}

pub fn ai(&self, binding: &str) -> Result<crate::Ai> {
Ok(crate::Ai::new(self.get_binding::<Fetcher>(binding)?))
}

/// Access Secret value bindings added to your Worker via the UI or `wrangler`:
/// <https://developers.cloudflare.com/workers/cli-wrangler/commands#secret>
pub fn secret(&self, binding: &str) -> Result<Secret> {
Expand Down
2 changes: 2 additions & 0 deletions worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub use worker_sys;
pub use worker_sys::{console_debug, console_error, console_log, console_warn};

pub use crate::abort::*;
pub use crate::ai::*;
pub use crate::cache::{Cache, CacheDeletionOutcome, CacheKey};
pub use crate::context::Context;
pub use crate::cors::Cors;
Expand Down Expand Up @@ -104,6 +105,7 @@ pub use crate::streams::*;
pub use crate::websocket::*;

mod abort;
mod ai;
mod cache;
mod cf;
mod context;
Expand Down

0 comments on commit 0af3e61

Please sign in to comment.