Skip to content

Commit caebdb1

Browse files
Yasha Singhfacebook-github-bot
authored andcommitted
Fix long CPU->GPU synchronization during Gradient clipping (#3318)
Summary: Pull Request resolved: #3318 1. Recent changes from D79128843 introduced sync point `clipping.py` which was seen in trace 2. It was creating CPU tensors which were being moved **synchronously** to cuda devices consequently causing long wait times in training with `CudaStreamSychronization` exhibiting in trace. 3. This caused QPS degradation in CTX FM model which I was actively working on optimizing and also it cause QPS degradation in most models including OmniFM that are enabling Optimizer Gradient clipping in their yaml config. 4. This fix helps bump qps by around 5% while keep NE unimpacted. Reviewed By: wz337 Differential Revision: D80959986 fbshipit-source-id: 55b0ae4165cabe4d5ce66ad442814868d408a1ac
1 parent 31cf49f commit caebdb1

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torchrec/optim/clipping.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _compute_total_norm(
201201
"""
202202

203203
## compute the norm |W|^p corresponding to all sharded params W
204-
sharded_grad_norm: torch.Tensor = torch.tensor(0.0)
204+
sharded_grad_norm: torch.Tensor = torch.tensor(0.0, pin_memory=True)
205205
combine_norm_operator = torch.maximum if norm_type == torch.inf else torch.add
206206

207207
# We need to move sharded_grad_norm to the same device as the first shard so that we can do addition (or take max)
@@ -216,7 +216,8 @@ def _compute_total_norm(
216216
process_groups=pgs,
217217
)
218218
sharded_grad_norm = combine_norm_operator(
219-
sharded_grad_norm.to(current_shard_norm.device), current_shard_norm
219+
sharded_grad_norm.to(current_shard_norm.device, non_blocking=True),
220+
current_shard_norm,
220221
)
221222
# compute |W|^p corresponding to all replicate params W
222223
# Similar to the case above, we move replicate_grad_norm to the same device as sharded_grad_norm so that we can do addition.
@@ -226,7 +227,7 @@ def _compute_total_norm(
226227
)
227228
if replicate_grads
228229
else torch.tensor(0.0)
229-
).to(sharded_grad_norm.device)
230+
).to(sharded_grad_norm.device, non_blocking=True)
230231

231232
# In the p-norm case, we are given norms |W_sharded|^p and |W_replicate|^p. To compute the total norm, we need to
232233
# sum them and take the p-th root. In the inf-norm case, we are given max(|W_sharded|) and max(|W_replicate|).

0 commit comments

Comments
 (0)