-
Notifications
You must be signed in to change notification settings - Fork 42
Performance optimized interleaved mode JetStream server #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Optimized TTFT and Optimized output token throughput are conflicted with each. Can we expose some parameter to tuning the two part? |
@@ -316,6 +317,12 @@ def __init__( | |||
queue.Queue(8) | |||
for _ in self._generate_engines | |||
] | |||
self._prefill_detokenize_backlogs = [ | |||
# We don't let detokenization accumulate more than 8 steps to avoid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate more on why there is synchronization issue after 8 steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set it to 8 as the detokenize thread. Too large or too small will cause performance issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, is this PR ready to submit?
@@ -795,6 +792,44 @@ def _detokenize_thread(self, idx: int): | |||
slot, active_request = data | |||
my_live_requests[slot] = active_request | |||
|
|||
def _prefill_detokenize_thread(self, idx: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I though we already had prefill detokenize thread. Do current jetstream (before this pr) always return prefill token (fist token) after first decode step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only had detokenize thread that combined prefill detokenize and decode detokenize. The problem is that we have jax.block_until_ready()
blocking the thread waiting for the prefill token or decode token copy to host async, so putting them in 1 thread would make the TTFT slow. JetStream returns prefill token in prefill thread (after prefill step generating the first token).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, thanks for sharing insights!
Currently, prioritize prefills in interleaved mode, and apply correct JAX blocking for copy to host async to reduce wasted wait time. 1 more optimization to do is to ensure the result returns immediately when the return channel has the result (from orchestrator). |
Uh oh!
There was an error while loading. Please reload this page.