Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions src/claude_web_state/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ impl ClaudeWebState {
warn!("Failed to decode image: {}", e);
})
.ok()?;
// choose the file name based on the media type
let file_name = match img.media_type.to_lowercase().as_str() {
// choose the file name based on the media type (extract main type before any params)
let main_type = img.media_type.split(';').next().unwrap_or(&img.media_type);
let file_name = match main_type.to_lowercase().as_str() {
"image/png" => "image.png",
"image/jpeg" => "image.jpg",
"image/jpg" => "image.jpg",
Expand Down Expand Up @@ -170,7 +171,7 @@ fn merge_messages(msgs: Vec<Message>, system: String) -> Option<Merged> {
}
ContentBlock::ImageUrl { image_url } => {
// oai image
if let Some(source) = extract_image_from_url(&image_url.url) {
if let Some(source) = ImageSource::from_data_url(&image_url.url) {
imgs.push(source);
}
None
Expand Down Expand Up @@ -253,17 +254,3 @@ fn merge_system(sys: Value) -> String {
}
}

fn extract_image_from_url(url: &str) -> Option<ImageSource> {
if !url.starts_with("data:") {
return None; // only support data URI
}
let (metadata, base64_data) = url.split_once(',')?;

let (media_type, type_) = metadata.strip_prefix("data:")?.split_once(';')?;

Some(ImageSource {
type_: type_.to_string(),
media_type: media_type.to_string(),
data: base64_data.to_owned(),
})
}
32 changes: 32 additions & 0 deletions src/types/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,38 @@ pub struct ImageSource {
pub data: String,
}

impl ImageSource {
/// Parse a data URI into an ImageSource
/// Supports format: data:<media_type>[;params];base64,<data>
/// e.g., data:image/png;base64,iVBORw0KGgo...
/// e.g., data:image/png;name=foo;base64,iVBORw0KGgo...
pub fn from_data_url(url: &str) -> Option<Self> {
if !url.starts_with("data:") {
return None;
}
let (metadata, base64_data) = url.split_once(',')?;
// reject empty data
if base64_data.is_empty() {
return None;
}
let after_data = metadata.strip_prefix("data:")?;
// find the last ";base64" marker (case-insensitive)
let lower = after_data.to_lowercase();
let base64_pos = lower.rfind(";base64")?;
let media_type = &after_data[..base64_pos];
// reject empty media type
if media_type.is_empty() {
return None;
}

Some(Self {
type_: "base64".to_string(),
media_type: media_type.to_string(),
data: base64_data.to_owned(),
})
}
}

// oai image
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
pub struct ImageUrl {
Expand Down
40 changes: 38 additions & 2 deletions src/types/oai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,39 @@ use serde_json::{Value, json};
use tiktoken_rs::o200k_base;

use super::claude::{CreateMessageParams as ClaudeCreateMessageParams, *};
use crate::types::claude::Message;
use crate::types::claude::{ImageSource, Message};

/// Convert OAI ImageUrl to Claude Image format
fn normalize_block(block: ContentBlock) -> Option<ContentBlock> {
match block {
ContentBlock::Text { .. } => Some(block),
ContentBlock::Image { .. } => Some(block),
ContentBlock::ImageUrl { image_url } => {
ImageSource::from_data_url(&image_url.url).map(|source| ContentBlock::Image { source })
}
_ => Some(block),
}
}

/// Normalize all blocks in a message content
/// Returns None if the message becomes empty after filtering
fn normalize_message(msg: Message) -> Option<Message> {
let content = match msg.content {
MessageContent::Blocks { content } => {
let blocks: Vec<_> = content.into_iter().filter_map(normalize_block).collect();
// skip empty messages
if blocks.is_empty() {
return None;
}
MessageContent::Blocks { content: blocks }
}
other => other,
};
Some(Message {
role: msg.role,
content,
})
}

#[derive(Debug, Serialize, Deserialize, Default, Clone)]
#[serde(rename_all = "snake_case")]
Expand All @@ -20,17 +52,21 @@ impl From<CreateMessageParams> for ClaudeCreateMessageParams {
.messages
.into_iter()
.partition(|m| m.role == Role::System);
// normalize system blocks (convert ImageUrl to Image)
let systems = systems
.into_iter()
.map(|m| m.content)
.flat_map(|c| match c {
MessageContent::Text { content } => vec![ContentBlock::Text { text: content }],
MessageContent::Blocks { content } => content,
})
.filter(|b| matches!(b, ContentBlock::Text { .. }))
.filter_map(normalize_block)
.filter(|b| matches!(b, ContentBlock::Text { .. } | ContentBlock::Image { .. }))
.map(|b| json!(b))
.collect::<Vec<_>>();
let system = (!systems.is_empty()).then(|| json!(systems));
// normalize messages (convert ImageUrl to Image, skip empty messages)
let messages = messages.into_iter().filter_map(normalize_message).collect();
Self {
max_tokens: (params.max_tokens.or(params.max_completion_tokens))
.unwrap_or_else(default_max_tokens),
Expand Down
Loading