Skip to content

Commit

Permalink
feat(router): Add max_waiting_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Oct 21, 2022
1 parent 895a341 commit c837893
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 79 deletions.
12 changes: 6 additions & 6 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server", long, env)]
Expand All @@ -41,7 +41,7 @@ struct Args {
}

fn main() -> ExitCode {
tracing_subscriber::fmt::init();
tracing_subscriber::fmt().compact().with_ansi(false).init();

// Pattern match configuration
let Args {
Expand All @@ -51,7 +51,7 @@ fn main() -> ExitCode {
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
max_waiting_tokens,
port,
shard_uds_path,
master_addr,
Expand Down Expand Up @@ -148,8 +148,8 @@ fn main() -> ExitCode {
&max_input_length.to_string(),
"--max-batch-size",
&max_batch_size.to_string(),
"--max-waiting-time",
&max_waiting_time.to_string(),
"--max-waiting-tokens",
&max_waiting_tokens.to_string(),
"--port",
&port.to_string(),
"--master-shard-uds-path",
Expand Down
59 changes: 36 additions & 23 deletions router/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use axum::http::StatusCode;
use bloom_inference_client::{Batch, ClientError, GeneratedText, ShardedClient};
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::{oneshot, Notify};
use tokio::time::Instant;
Expand All @@ -30,7 +29,7 @@ impl Batcher {
pub(crate) fn new(
client: ShardedClient,
max_batch_size: usize,
max_waiting_time: Duration,
max_waiting_tokens: usize,
) -> Self {
// Batcher shared state
let db = Db::new();
Expand All @@ -41,7 +40,7 @@ impl Batcher {
// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
max_batch_size,
max_waiting_time,
max_waiting_tokens,
client,
db.clone(),
shared.clone(),
Expand All @@ -55,7 +54,7 @@ impl Batcher {
&self,
input_length: usize,
request: GenerateRequest,
) -> Result<String, InferError> {
) -> Result<InferResponse, InferError> {
// One shot channel to communicate with the background batching task
let (response_tx, response_rx) = oneshot::channel();

Expand All @@ -65,6 +64,7 @@ impl Batcher {
response_tx,
input_length,
time: Instant::now(),
batch_time: None,
});

// Notify the background task that we have a new entry in the database that needs
Expand All @@ -87,7 +87,7 @@ impl Batcher {
#[instrument(skip(client, db, shared))]
async fn batching_task(
max_batch_size: usize,
max_waiting_time: Duration,
max_waiting_tokens: usize,
client: ShardedClient,
db: Db,
shared: Arc<Shared>,
Expand All @@ -103,8 +103,10 @@ async fn batching_task(
// Get the next batch from the DB
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the DB
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size, None) {
let mut waiting_tokens = 0;
if let Some((request_ids, batch)) = db.next_batch(None, max_batch_size) {
let mut cached_batch = wrap_future(client.generate(batch), request_ids, &db).await;
waiting_tokens += 1;

// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
Expand All @@ -116,10 +118,20 @@ async fn batching_task(

// If the current batch is too small, we try to add more requests to it
if batch_size <= limit_min_batch_size {
// Get the next batch from the DB that meet our minimum size criteria
let min_size = match waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
_ if waiting_tokens >= max_waiting_tokens => None,
// Minimum size criteria
_ => Some(limit_min_batch_size as usize),
};

// Try to get a new batch
if let Some((new_request_ids, new_batch)) =
db.next_batch(Some(limit_min_batch_size as usize), max_batch_size, None)
db.next_batch(min_size, max_batch_size)
{
// Reset waiting counter
waiting_tokens = 0;
// Generate one token for this new batch to have the attention past in cache
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
Expand All @@ -129,24 +141,11 @@ async fn batching_task(
batches.push(new_cached_batch);
}
}
// If we don't have enough requests to meet the minimum size criteria, we
// try to get the next batch from the DB that have been waiting over
// the max_waiting_time
else if let Some((new_request_ids, new_batch)) =
db.next_batch(None, max_batch_size, Some(max_waiting_time))
{
let new_cached_batch =
wrap_future(client.generate(new_batch), new_request_ids, &db).await;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
request_ids.extend(new_cached_batch.requests.iter().map(|req| req.id));
batches.push(new_cached_batch);
}
}
}

cached_batch =
wrap_future(client.generate_with_cache(batches), request_ids, &db).await;
waiting_tokens += 1;
}
}
}
Expand Down Expand Up @@ -188,11 +187,25 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let entry = db
.remove(&output.request.unwrap().id)
.expect("ID not found in db. This is a bug.");
let response = InferResponse {
output: output.output,
queued: entry.time,
start: entry.batch_time.unwrap(), // unwrap is always valid
end: Instant::now(),
};
// unwrap_or is valid here as we don't care if the receiver is gone.
entry.response_tx.send(Ok(output.output)).unwrap_or(());
entry.response_tx.send(Ok(response)).unwrap_or(());
});
}

#[derive(Debug)]
pub(crate) struct InferResponse {
pub(crate) output: String,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) end: Instant,
}

#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
Expand Down
28 changes: 10 additions & 18 deletions router/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::InferResponse;
/// This code is massively inspired by Tokio mini-redis
use crate::{GenerateParameters, GenerateRequest};
use bloom_inference_client::{Batch, ClientError, LogitsWarperParameters, Request};
use parking_lot::Mutex;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot::Sender;
use tokio::time::Instant;

Expand All @@ -14,11 +14,13 @@ pub(crate) struct Entry {
/// Request
pub request: GenerateRequest,
/// Response sender to communicate between the Batcher and the batching_task
pub response_tx: Sender<Result<String, ClientError>>,
pub response_tx: Sender<Result<InferResponse, ClientError>>,
/// Number of tokens in the input
pub input_length: usize,
/// Instant when this entry was created
pub time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
}

/// Request Database
Expand Down Expand Up @@ -51,11 +53,7 @@ struct State {

impl State {
/// Get the next requests
fn next_requests(
&self,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Vec<Request>)> {
fn next_requests(&self, max_size: usize) -> Option<(Vec<u64>, Vec<Request>)> {
// Iterates for max_size over the BTreemap starting from next_batch_start_id
let mut requests = Vec::new();
let mut ids = Vec::new();
Expand All @@ -67,15 +65,6 @@ impl State {
// Take max_size
.take(max_size)
{
if let Some(min_waiting_time) = min_waiting_time {
// Only take entries that waited for at least min_waiting_time
if entry.time.elapsed() < min_waiting_time {
// Since entries are ordered, we already know that all following entries won't
// satisfy the condition
break;
}
}

requests.push(Request {
id: *id,
inputs: entry.request.inputs.clone(),
Expand Down Expand Up @@ -134,19 +123,22 @@ impl Db {
&self,
min_size: Option<usize>,
max_size: usize,
min_waiting_time: Option<Duration>,
) -> Option<(Vec<u64>, Batch)> {
// Acquire lock
let mut state = self.shared.state.lock();

// Get requests from the database
if let Some((ids, requests)) = state.next_requests(max_size, min_waiting_time) {
if let Some((ids, requests)) = state.next_requests(max_size) {
if let Some(min_size) = min_size {
// If min_size is set, only return a batch if there are enough requests
if requests.len() < min_size {
return None;
}
}
ids.iter().for_each(|id| {
// Set batch_time for each request
state.entries.get_mut(id).unwrap().batch_time = Some(Instant::now());
});

// Batch size
let size = requests.len();
Expand Down
4 changes: 1 addition & 3 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod db;
pub mod server;
mod validation;

use batcher::Batcher;
use batcher::{Batcher, InferResponse};
use db::{Db, Entry};
use serde::{Deserialize, Serialize};
use validation::Validation;
Expand Down Expand Up @@ -64,5 +64,3 @@ pub(crate) struct GenerateRequest {
pub(crate) struct GeneratedText {
pub generated_text: String,
}

pub(crate) type GenerateResponse = Vec<GeneratedText>;
15 changes: 6 additions & 9 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
use bloom_inference_client::ShardedClient;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use text_generation_router::server;
use tokenizers::Tokenizer;

Expand All @@ -16,8 +15,8 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "32", long, env)]
max_batch_size: usize,
#[clap(default_value = "5", long, env)]
max_waiting_time: u64,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/bloom-inference-0", long, env)]
Expand All @@ -36,19 +35,19 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
max_waiting_tokens,
port,
master_shard_uds_path,
tokenizer_name,
validation_workers,
} = args;

tracing_subscriber::fmt().compact().with_ansi(false).init();

if validation_workers == 1 {
panic!("validation_workers must be > 0");
}

let max_waiting_time = Duration::from_secs(max_waiting_time);

// Download and instantiate tokenizer
// This will only be used to validate payloads
//
Expand All @@ -61,8 +60,6 @@ fn main() -> Result<(), std::io::Error> {
.build()
.unwrap()
.block_on(async {
tracing_subscriber::fmt::init();

// Instantiate sharded client from the master unix socket
let sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
Expand All @@ -82,7 +79,7 @@ fn main() -> Result<(), std::io::Error> {
max_concurrent_requests,
max_input_length,
max_batch_size,
max_waiting_time,
max_waiting_tokens,
sharded_client,
tokenizer,
validation_workers,
Expand Down
Loading

0 comments on commit c837893

Please sign in to comment.