Skip to content

Commit

Permalink
Refactor tokenizer manager (#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Oct 31, 2024
1 parent f7102fb commit 438526a
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,22 +549,18 @@ async def get_memory_pool_size(self):
self.create_handle_loop()

req = GetMemPoolSizeReq()
ret = None

self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()

if self.server_args.dp_size == 1:
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
res = await self.mem_pool_size
ret = res.size

return res.size
else: # self.server_args.dp_size > 1
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]

return ret
return ret

async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
Expand All @@ -578,29 +574,21 @@ async def update_weights(

if not self.model_update_lock.locked():

if self.server_args.dp_size == 1:
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()

if self.server_args.dp_size == 1:
result = await self.model_update_result
if result.success:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message

else: # self.server_args.dp_size > 1

# There will be dp_size number of response from the detokenizer
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
return result.success, result.message
else: # self.server_args.dp_size > 1
self.model_update_tmp = []
result = await self.model_update_result

Expand All @@ -611,8 +599,7 @@ async def update_weights(
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)

return all_success, all_message
return all_success, all_message

else:
return False, "Another update is in progress. Please try again later."
Expand Down

0 comments on commit 438526a

Please sign in to comment.