|
1 | 1 | /// HTTP Server logic
|
2 | 2 | use crate::http::types::{
|
3 | 3 | DecodeRequest, DecodeResponse, EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse,
|
4 |
| - EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, OpenAICompatEmbedding, |
| 4 | + EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, InputType, OpenAICompatEmbedding, |
5 | 5 | OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
|
6 | 6 | PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
|
7 |
| - Sequence, SimpleToken, SparseValue, TokenizeRequest, TokenizeResponse, VertexPrediction, |
8 |
| - VertexRequest, VertexResponse, |
| 7 | + Sequence, SimpleToken, SparseValue, TokenizeInput, TokenizeRequest, TokenizeResponse, |
| 8 | + VertexPrediction, VertexRequest, VertexResponse, |
9 | 9 | };
|
10 | 10 | use crate::{
|
11 | 11 | shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
|
@@ -474,7 +474,7 @@ async fn embed(
|
474 | 474 | Input::Single(input) => {
|
475 | 475 | metrics::increment_counter!("te_request_count", "method" => "single");
|
476 | 476 |
|
477 |
| - let compute_chars = input.chars().count(); |
| 477 | + let compute_chars = input.count_chars(); |
478 | 478 |
|
479 | 479 | let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
|
480 | 480 | let response = infer
|
@@ -529,7 +529,7 @@ async fn embed(
|
529 | 529 | let mut compute_chars = 0;
|
530 | 530 |
|
531 | 531 | for input in inputs {
|
532 |
| - compute_chars += input.chars().count(); |
| 532 | + compute_chars += input.count_chars(); |
533 | 533 |
|
534 | 534 | let local_infer = infer.clone();
|
535 | 535 | futures.push(async move {
|
@@ -630,7 +630,7 @@ async fn embed_sparse(
|
630 | 630 | Input::Single(input) => {
|
631 | 631 | metrics::increment_counter!("te_request_count", "method" => "single");
|
632 | 632 |
|
633 |
| - let compute_chars = input.chars().count(); |
| 633 | + let compute_chars = input.count_chars(); |
634 | 634 |
|
635 | 635 | let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
|
636 | 636 | let response = infer
|
@@ -685,7 +685,7 @@ async fn embed_sparse(
|
685 | 685 | let mut compute_chars = 0;
|
686 | 686 |
|
687 | 687 | for input in inputs {
|
688 |
| - compute_chars += input.chars().count(); |
| 688 | + compute_chars += input.count_chars(); |
689 | 689 |
|
690 | 690 | let local_infer = infer.clone();
|
691 | 691 | futures.push(async move {
|
@@ -778,7 +778,7 @@ async fn embed_all(
|
778 | 778 | Input::Single(input) => {
|
779 | 779 | metrics::increment_counter!("te_request_count", "method" => "single");
|
780 | 780 |
|
781 |
| - let compute_chars = input.chars().count(); |
| 781 | + let compute_chars = input.count_chars(); |
782 | 782 |
|
783 | 783 | let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
|
784 | 784 | let response = infer
|
@@ -833,7 +833,7 @@ async fn embed_all(
|
833 | 833 | let mut compute_chars = 0;
|
834 | 834 |
|
835 | 835 | for input in inputs {
|
836 |
| - compute_chars += input.chars().count(); |
| 836 | + compute_chars += input.count_chars(); |
837 | 837 |
|
838 | 838 | let local_infer = infer.clone();
|
839 | 839 | futures.push(async move {
|
@@ -892,7 +892,7 @@ async fn embed_all(
|
892 | 892 | #[utoipa::path(
|
893 | 893 | post,
|
894 | 894 | tag = "Text Embeddings Inference",
|
895 |
| -path = "/embeddings", |
| 895 | +path = "/v1/embeddings", |
896 | 896 | request_body = OpenAICompatRequest,
|
897 | 897 | responses(
|
898 | 898 | (status = 200, description = "Embeddings", body = OpenAICompatResponse),
|
@@ -923,7 +923,7 @@ async fn openai_embed(
|
923 | 923 | Input::Single(input) => {
|
924 | 924 | metrics::increment_counter!("te_request_count", "method" => "single");
|
925 | 925 |
|
926 |
| - let compute_chars = input.chars().count(); |
| 926 | + let compute_chars = input.count_chars(); |
927 | 927 |
|
928 | 928 | let permit = infer.try_acquire_permit().map_err(ErrorResponse::from)?;
|
929 | 929 | let response = infer
|
@@ -982,7 +982,7 @@ async fn openai_embed(
|
982 | 982 | let mut compute_chars = 0;
|
983 | 983 |
|
984 | 984 | for input in inputs {
|
985 |
| - compute_chars += input.chars().count(); |
| 985 | + compute_chars += input.count_chars(); |
986 | 986 |
|
987 | 987 | let local_infer = infer.clone();
|
988 | 988 | futures.push(async move {
|
@@ -1107,8 +1107,10 @@ async fn tokenize(
|
1107 | 1107 | };
|
1108 | 1108 |
|
1109 | 1109 | let tokens = match req.inputs {
|
1110 |
| - Input::Single(input) => vec![tokenize_inner(input, req.add_special_tokens, infer.0).await?], |
1111 |
| - Input::Batch(inputs) => { |
| 1110 | + TokenizeInput::Single(input) => { |
| 1111 | + vec![tokenize_inner(input, req.add_special_tokens, infer.0).await?] |
| 1112 | + } |
| 1113 | + TokenizeInput::Batch(inputs) => { |
1112 | 1114 | if inputs.is_empty() {
|
1113 | 1115 | let message = "`inputs` cannot be empty".to_string();
|
1114 | 1116 | tracing::error!("{message}");
|
@@ -1369,9 +1371,11 @@ pub async fn run(
|
1369 | 1371 | EmbedResponse,
|
1370 | 1372 | ErrorResponse,
|
1371 | 1373 | OpenAICompatErrorResponse,
|
| 1374 | + TokenizeInput, |
1372 | 1375 | TokenizeRequest,
|
1373 | 1376 | TokenizeResponse,
|
1374 | 1377 | SimpleToken,
|
| 1378 | + InputType, |
1375 | 1379 | InputIds,
|
1376 | 1380 | DecodeRequest,
|
1377 | 1381 | DecodeResponse,
|
@@ -1448,6 +1452,7 @@ pub async fn run(
|
1448 | 1452 | .route("/decode", post(decode))
|
1449 | 1453 | // OpenAI compat route
|
1450 | 1454 | .route("/embeddings", post(openai_embed))
|
| 1455 | + .route("/v1/embeddings", post(openai_embed)) |
1451 | 1456 | // Vertex compat route
|
1452 | 1457 | .route("/vertex", post(vertex_compatibility))
|
1453 | 1458 | // Base Health route
|
|
0 commit comments