Skip to content

Commit

Permalink
Parallelize tokenization for /classify_batch and remove block allocat…
Browse files Browse the repository at this point in the history
…or for non-causal LMs (#609)
  • Loading branch information
tgaddair authored Sep 18, 2024
1 parent a6b60e9 commit 9919ae1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
20 changes: 14 additions & 6 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ impl Infer {
speculate: u32,
preloaded_adapters: Vec<PreloadedAdapter>,
prefix_caching: bool,
is_causal_lm: bool,
) -> Self {
let adapter_event = Arc::new(AdapterEvent {
batching_task: Notify::new(),
Expand All @@ -178,6 +179,7 @@ impl Infer {
speculate,
max_batch_total_tokens,
prefix_caching,
is_causal_lm,
);

// Initialize with base model adapter (empty) mapping to index 0
Expand Down Expand Up @@ -729,13 +731,19 @@ impl Infer {
.map(|(id, input)| (id as u64, input.clone()))
.collect();

for (id, r_inputs) in request.inputs.iter().enumerate() {
let inputs = r_inputs.to_string().clone();
let (tokenized_inputs, input_length) = self
.validation
.validate_input(r_inputs.to_string(), None, Some(1))
.await?;
// Call validate_input on every input in the request and await the results
let futures: Vec<_> = request
.inputs
.iter()
.map(|input| self.validation.validate_input(input.clone(), None, Some(1)))
.collect();

let all_tokenized_inputs = try_join_all(futures).await?;

for ((id, r_inputs), (tokenized_inputs, input_length)) in
request.inputs.iter().enumerate().zip(all_tokenized_inputs)
{
let inputs = r_inputs.to_string().clone();
let valid_request = ValidClassifyRequest {
inputs,
tokenized_inputs,
Expand Down
8 changes: 7 additions & 1 deletion router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl AdapterScheduler {
speculate: u32,
max_batch_total_tokens: u32,
prefix_caching: bool,
is_causal_lm: bool,
) -> Self {
let (sender, receiver) = flume::unbounded();

Expand All @@ -60,6 +61,7 @@ impl AdapterScheduler {
speculate,
max_batch_total_tokens,
prefix_caching,
is_causal_lm,
));

Self { sender }
Expand Down Expand Up @@ -124,6 +126,7 @@ async fn adapter_scheduler_task(
speculate: u32,
max_batch_total_tokens: u32,
prefix_caching: bool,
is_causal_lm: bool,
) {
let mut state = AdapterSchedulerState::new(
client,
Expand All @@ -135,6 +138,7 @@ async fn adapter_scheduler_task(
speculate,
max_batch_total_tokens,
prefix_caching,
is_causal_lm,
);

while let Ok(cmd) = receiver.recv_async().await {
Expand Down Expand Up @@ -209,14 +213,16 @@ impl AdapterSchedulerState {
speculate: u32,
max_batch_total_tokens: u32,
prefix_caching: bool,
is_causal_lm: bool,
) -> Self {
let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new(
max_active_adapters,
adapter_cycle_time_s,
)));
let loader = AdapterLoader::new(client.clone());

let block_allocator = (!requires_padding).then(|| {
// Only causal LMs require the block allocator, due to paged attention
let block_allocator = (!requires_padding && is_causal_lm).then(|| {
BlockAllocator::new(
max_batch_total_tokens,
block_size,
Expand Down
12 changes: 11 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,12 +1134,21 @@ pub async fn run(
generation_health.clone(),
shard_info.clone(),
);

// For non-causal LMs, the max batch total tokens is equal to the max batch prefill tokens
let is_causal_lm = shard_info.supports_generation;
let effective_max_batch_total_tokens = if is_causal_lm {
max_batch_total_tokens
} else {
max_batch_prefill_tokens
};

let infer = Infer::new(
client.clone(),
validation,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
effective_max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,
max_active_adapters,
Expand All @@ -1154,6 +1163,7 @@ pub async fn run(
shard_info.speculate,
shard_info.preloaded_adapters,
prefix_caching,
is_causal_lm,
);

// Duration buckets
Expand Down

0 comments on commit 9919ae1

Please sign in to comment.