diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d4cf249fe..1e18dbf07 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1033,7 +1033,7 @@ def postprocess(module, name): original_device = weight.device original_dtype = weight.dtype weight_f64 = weight.to(dtype=torch.float64, device=original_device) - u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False) + u, s, vt = torch.linalg.svd(weight_f64, driver="gesvd", full_matrices=False) if u.shape[1] < lowrank or vt.shape[0] < lowrank: warnings.warn( "The low-rank dimensions do not match the layer dimensions. "