Skip to content

Commit

Permalink
feat(server): auto max_batch_total_tokens for flash att models (huggi…
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Jul 19, 2023
1 parent 5e6ddfd commit fe80f53
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 94 deletions.
54 changes: 28 additions & 26 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ struct Args {
/// depends on other parameters like if you're using quantization, flash attention
/// or the model implementation, text-generation-inference cannot infer this number
/// automatically.
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,

/// This setting defines how many tokens can be passed before forcing the waiting
/// queries to be put on the batch (if the size of the batch allows for it).
Expand Down Expand Up @@ -369,12 +369,6 @@ fn shard_manager(
// Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();

// Use cuda allocator. It leads to less memory fragmentation
envs.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));

// Torch Distributed Env vars
envs.push(("RANK".into(), rank.to_string().into()));
envs.push(("WORLD_SIZE".into(), world_size.to_string().into()));
Expand Down Expand Up @@ -428,7 +422,7 @@ fn shard_manager(
}

// Start process
tracing::info!("Starting shard {rank}");
tracing::info!("Starting shard");
let mut p = match Command::new("text-generation-server")
.args(shard_args)
.envs(envs)
Expand Down Expand Up @@ -493,17 +487,17 @@ fn shard_manager(
if shutdown.load(Ordering::SeqCst) {
p.kill().unwrap();
let _ = p.wait();
tracing::info!("Shard {rank} terminated");
tracing::info!("Shard terminated");
return;
}

// Shard is ready
if uds.exists() && !ready {
tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed());
tracing::info!("Shard ready in {:?}", start_time.elapsed());
status_sender.send(ShardStatus::Ready).unwrap();
ready = true;
} else if !ready && wait_time.elapsed() > Duration::from_secs(10) {
tracing::info!("Waiting for shard {rank} to be ready...");
tracing::info!("Waiting for shard to be ready...");
wait_time = Instant::now();
}
sleep(Duration::from_millis(100));
Expand Down Expand Up @@ -860,8 +854,6 @@ fn spawn_webserver(
args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
Expand All @@ -878,6 +870,12 @@ fn spawn_webserver(
args.model_id,
];

// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string());
router_args.push(max_batch_total_tokens.to_string());
}

// Model optional revision
if let Some(ref revision) = args.revision {
router_args.push("--revision".to_string());
Expand Down Expand Up @@ -1036,18 +1034,7 @@ fn main() -> Result<(), LauncherError> {
args.max_batch_prefill_tokens, args.max_input_length
)));
}
if args.max_batch_prefill_tokens > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > args.max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, args.max_batch_total_tokens
)));
}

if args.validation_workers == 0 {
return Err(LauncherError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
Expand All @@ -1065,6 +1052,21 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("Sharding model on {num_shard} processes");
}

if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if args.max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens
)));
}
}

// Signal handler
let running = Arc::new(AtomicBool::new(true));
let r = running.clone();
Expand Down
7 changes: 4 additions & 3 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ message DecodeResponse {
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
/// Maximum number of tokens that the client will send
uint32 max_total_tokens = 2;
}

/// Empty response
message WarmupResponse {}
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}
13 changes: 4 additions & 9 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ impl Client {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
) -> Result<Option<u32>> {
let mut n_tokens = 0;
let mut requests = Vec::new();

Expand Down Expand Up @@ -143,13 +142,9 @@ impl Client {
max_tokens: 0,
};

let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_total_tokens,
})
.inject_context();
self.stub.warmup(request).await?.into_inner();
Ok(())
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
}

/// Generate one token for each request in the given batch
Expand Down
7 changes: 2 additions & 5 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,11 @@ impl ShardedClient {
&mut self,
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
) -> Result<()> {
) -> Result<Option<u32>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| {
Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens))
})
.map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens)))
.collect();
// all shards return the same message
join_all(futures).await.pop().unwrap()
Expand Down
2 changes: 1 addition & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl Infer {
generation_health: Arc<AtomicBool>,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding);
let queue = Queue::new(requires_padding, 16);
let shared = Arc::new(Shared {
batching_task: Notify::new(),
});
Expand Down
57 changes: 41 additions & 16 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ struct Args {
waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "16000", long, env)]
max_batch_total_tokens: u32,
#[clap(long, env)]
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)]
Expand Down Expand Up @@ -110,18 +110,22 @@ fn main() -> Result<(), RouterError> {
if max_input_length as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
}
if max_batch_prefill_tokens > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}

if validation_workers == 0 {
return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(),
));
}

if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}")));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}")));
}
}

// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
Expand Down Expand Up @@ -210,14 +214,35 @@ fn main() -> Result<(), RouterError> {

// Warmup model
tracing::info!("Warming up model");
sharded_client
.warmup(
max_input_length as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
)
let max_supported_batch_total_tokens = match sharded_client
.warmup(max_input_length as u32, max_batch_prefill_tokens)
.await
.map_err(RouterError::Warmup)?;
.map_err(RouterError::Warmup)?
{
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)),
);
tracing::warn!("Model does not support automatic max batch total tokens");
max_batch_total_tokens
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
// Warn if user added his own max-batch-total-tokens as we will ignore it
if max_batch_total_tokens.is_some() {
tracing::warn!(
"`--max-batch-total-tokens` is deprecated for Flash \
Attention models."
);
tracing::warn!(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
max_supported_batch_total_tokens
}
};
tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}");
tracing::info!("Connected");

let addr = match hostname.parse() {
Expand All @@ -240,7 +265,7 @@ fn main() -> Result<(), RouterError> {
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_supported_batch_total_tokens,
max_waiting_tokens,
sharded_client,
tokenizer,
Expand Down
Loading

0 comments on commit fe80f53

Please sign in to comment.