Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online weight update [WIP] #2119

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a3a57c2
update parameters online in model runner
zhaochenyang20 Nov 18, 2024
d9c8ca9
fix the rank config in init_process_group
zhaochenyang20 Nov 21, 2024
9944fdb
add get_parameter_by_name api for unit test of weight updates online
zhaochenyang20 Nov 21, 2024
eceed0d
test customed process group function
zhaochenyang20 Nov 22, 2024
032d909
revert update weights from disk
zhaochenyang20 Nov 22, 2024
ea6572a
[WIP] deal lock detected in init group for parameters
zhaochenyang20 Nov 22, 2024
3c155f8
revert srt/utils
zhaochenyang20 Nov 22, 2024
d566b98
fix init function name
zhaochenyang20 Nov 22, 2024
0045f98
fix deak lock of init process group
zhaochenyang20 Nov 22, 2024
7874f39
Merge branch 'main' into online-weight-update
zhaochenyang20 Nov 22, 2024
05cf5a7
Successfully initialized parameter update group between huggingface a…
zhaochenyang20 Nov 22, 2024
3596b01
Merge branch 'main' into online-weight-update
zhaochenyang20 Nov 22, 2024
779c2a0
[WIP] broadcast error
zhaochenyang20 Nov 23, 2024
8c8710e
[WIP] device and wolrd size error
zhaochenyang20 Nov 23, 2024
88814d4
[WIP] merge router into weight update
zhaochenyang20 Nov 24, 2024
f3f87a4
[WIP] merge fix docs
zhaochenyang20 Nov 24, 2024
70755d8
[WIP] group conflicts
zhaochenyang20 Nov 24, 2024
13864e6
Merge branch 'main' into online-weight-update
zhaochenyang20 Nov 25, 2024
fd8ae0d
[WIP] get_max_total_num_tokens
zhaochenyang20 Nov 25, 2024
4e92954
Merge branch 'main' into online-weight-update
zhaochenyang20 Nov 25, 2024
bd7178b
[WIP] tp group failed to communicate
zhaochenyang20 Nov 26, 2024
59eab7e
finish get weights by parameter name
zhaochenyang20 Nov 26, 2024
fcdb8ca
Merge branch 'main' into online-weight-update
zhaochenyang20 Nov 26, 2024
a13f76a
failed to clean cache
zhaochenyang20 Nov 27, 2024
ea7b941
[WIP] failed to load parameters
zhaochenyang20 Nov 27, 2024
b6ffa9e
success to update weights in engine
zhaochenyang20 Nov 27, 2024
e2d1324
success in broadcast and read weights
zhaochenyang20 Nov 27, 2024
84b32f7
init weight-update-test-2-gpu
zhaochenyang20 Nov 27, 2024
918c47b
remove print
zhaochenyang20 Nov 27, 2024
c293a6a
fix tp in all paramter read
zhaochenyang20 Nov 27, 2024
f438931
Merge branch 'main' into online-weight-update
zhaochenyang20 Nov 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,24 @@ jobs:
cd test/srt
python3 test_moe_eval_accuracy_large.py

weight-update-test-2-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 2-gpu-runner
steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Install dependencies
run: |
bash scripts/ci_install_dependency.sh

- name: Test weight update (TP, DP = 2,1 or 1,2)
timeout-minutes: 20
run: |
cd test/srt
python3 test_update_parameter_from_distributed.py
python3 test_get_parameter_by_name.py

finish:
needs: [
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, unit-test-backend-part-4,
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/amd/profiling/PROFILING.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,5 +421,5 @@ index 62d1ff9..6ecd78c 100644
3. Modify the included server.sh by removing "loadTracer.sh" before python command and launch script ./server.sh in one terminal inside the docker container.

4. Similar to step 6 in RPD profiling section, but remove the last 2 lines in client.sh, which converted rpd file into csv and json files. Run modified client.sh for PyTorch profiling.
=======
-------
- [Torch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
32 changes: 21 additions & 11 deletions docs/backend/native_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"- `/health`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
"- `/update_weights`\n",
"- `/update_weights_from_disk`\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
"\n",
Expand Down Expand Up @@ -98,7 +98,7 @@
"print_highlight(response_json)\n",
"assert response_json[\"model_path\"] == \"meta-llama/Llama-3.2-1B-Instruct\"\n",
"assert response_json[\"is_generation\"] is True\n",
"assert response_json.keys() == {\"model_path\", \"is_generation\"}"
"assert list(response_json.keys()) == [\"model_path\", \"tokenizer_path\", \"is_generation\"]"
]
},
{
Expand Down Expand Up @@ -144,8 +144,7 @@
"source": [
"url = \"http://localhost:30010/health_generate\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
"response = requests.get(url)"
]
},
{
Expand All @@ -156,8 +155,7 @@
"source": [
"url = \"http://localhost:30010/health\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
"response = requests.get(url)"
]
},
{
Expand Down Expand Up @@ -187,9 +185,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Update Weights\n",
"## Update Weights From Disk\n",
"\n",
"Update model weights without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size."
"Update model weights from disk without restarting the server. Use for continuous evaluation during training. Only applicable for models with the same architecture and parameter size."
]
},
{
Expand All @@ -200,7 +198,7 @@
"source": [
"# successful update with same architecture and size\n",
"\n",
"url = \"http://localhost:30010/update_weights\"\n",
"url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-1B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
Expand All @@ -218,7 +216,7 @@
"source": [
"# failed update with different parameter size\n",
"\n",
"url = \"http://localhost:30010/update_weights\"\n",
"url = \"http://localhost:30010/update_weights_from_disk\"\n",
"data = {\"model_path\": \"meta-llama/Llama-3.2-3B\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
Expand Down Expand Up @@ -340,7 +338,19 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"terminate_process(reward_process)"
]
Expand Down
51 changes: 49 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,19 +352,66 @@ class FlushCacheReq:


@dataclass
class UpdateWeightReqInput:
class UpdateWeightFromDistReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None


@dataclass
class UpdateWeightReqOutput:
class UpdateWeightFromDistReqOutput:
success: bool
message: str


@dataclass
class UpdateParameterFromDistributedReqInput:
name: str
dtype: str
shape: List[int]
empty_cache: bool


@dataclass
class UpdateParameterFromDistributedReqOutput:
success: bool
message: str


@dataclass
class InitParameterUpdateGroupReqInput:
# The master address
master_address: str
# The master port
master_port: int
# The rank offset
rank_offset: int
# The world size
world_size: int
# The group name
group_name: str
# The backend
backend: str = "nccl"


@dataclass
class InitParameterUpdateGroupReqOutput:
success: bool
message: str


@dataclass
class GetParameterByNameReqInput:
name: str
truncate_size: int = 100


@dataclass
class GetParameterByNameReqOutput:
parameter: list


@dataclass
class AbortReq:
# The request id
Expand Down
56 changes: 49 additions & 7 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,19 @@
FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
GetParameterByNameReqInput,
GetParameterByNameReqOutput,
InitParameterUpdateGroupReqInput,
InitParameterUpdateGroupReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
UpdateParameterFromDistributedReqInput,
UpdateParameterFromDistributedReqOutput,
UpdateWeightFromDistReqInput,
UpdateWeightFromDistReqOutput,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
Expand Down Expand Up @@ -500,10 +506,25 @@ def process_input_requests(self, recv_reqs: List):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
elif isinstance(recv_req, UpdateWeightFromDistReqInput):
success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightReqOutput(success, message)
UpdateWeightFromDistReqOutput(success, message)
)
elif isinstance(recv_req, GetParameterByNameReqInput):
parameter = self.get_weights_by_parameter_name(recv_req)
self.send_to_tokenizer.send_pyobj(
GetParameterByNameReqOutput(parameter)
)
elif isinstance(recv_req, InitParameterUpdateGroupReqInput):
success, message = self.init_parameter_update_group(recv_req)
self.send_to_tokenizer.send_pyobj(
InitParameterUpdateGroupReqOutput(success, message)
)
elif isinstance(recv_req, UpdateParameterFromDistributedReqInput):
success, message = self.update_parameter_from_distributed(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateParameterFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
Expand Down Expand Up @@ -1353,16 +1374,37 @@ def abort_request(self, recv_req: AbortReq):
self.tree_cache.cache_finished_req(req)
break

def update_weights(self, recv_req: UpdateWeightReqInput):
def update_weights_from_disk(self, recv_req: UpdateWeightFromDistReqInput):
"""In-place update of the weights."""
success, message = self.tp_worker.update_weights(recv_req)
success, message = self.tp_worker.update_weights_from_disk(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message

def init_parameter_update_group(self, recv_req: InitParameterUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_parameter_update_group(recv_req)
return success, message

def update_parameter_from_distributed(
self, recv_req: UpdateParameterFromDistributedReqInput
):
"""Update the online model parameter."""
success, message = self.tp_worker.update_parameter_from_distributed(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message

def get_weights_by_parameter_name(self, recv_req: GetParameterByNameReqInput):
parameter = self.tp_worker.get_weights_by_parameter_name(recv_req)
return parameter

def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
Expand Down
Loading
Loading