Skip to content

Commit

Permalink
redo_torch_only
Browse files Browse the repository at this point in the history
  • Loading branch information
ujohn33 committed Jan 12, 2025
1 parent 1db916b commit 37cedaf
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions lightgbmlss/distributions/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,14 +502,13 @@ def compute_gradients_and_hessians(self,
if self.quantile_clipping:
# Clip Gradients and Hessians
# Ensure gradients and Hessians are detached before computing quantiles
# Convert list of tensors to NumPy arrays
grad_np = np.concatenate([g.detach().numpy() for g in grad])
hess_np = np.concatenate([h.detach().numpy() for h in hess])

grad_min = np.quantile(grad_np, self.clip_value)
grad_max = np.quantile(grad_np, 1 - self.clip_value)
hess_min = np.quantile(hess_np, self.clip_value)
hess_max = np.quantile(hess_np, 1 - self.clip_value)
grad_tensor = torch.cat([g.detach() for g in grad])
hess_tensor = torch.cat([h.detach() for h in hess])

grad_min = torch.quantile(grad_tensor, self.clip_value)
grad_max = torch.quantile(grad_tensor, 1 - self.clip_value)
hess_min = torch.quantile(hess_tensor, self.clip_value)
hess_max = torch.quantile(hess_tensor, 1 - self.clip_value)

# Clip Gradients and Hessians
grad = [torch.clamp(g, min=grad_min, max=grad_max) for g in grad]
Expand Down

0 comments on commit 37cedaf

Please sign in to comment.