Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions sinq/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,25 @@ def optimize_weights_proximal_legacy(

assert axis==1, 'only supports axis 1 right now'
if tiling_mode == '1D':
q, s1, s2, z= tiled_quant_rectangle(W_f.reshape(shape), min_max, tile, method, awq_scale)
try:
q, s1, s2, z= tiled_quant_rectangle(W_f.reshape(shape), min_max, tile, method, awq_scale)
except AssertionError as e:
if 'block must divide W' in str(e):
print(f"Warning: Skipping quantization for layer with incompatible shape (block must divide W). This layer will remain in high precision.")
# Return None to signal that this layer should not be quantized
return None
else:
raise
elif tiling_mode == '2D':
q, s1, s2, z= tiled_quant_square(W_f.reshape(shape), min_max, tile, method, awq_scale)
try:
q, s1, s2, z= tiled_quant_square(W_f.reshape(shape), min_max, tile, method, awq_scale)
except AssertionError as e:
if 'block must divide W' in str(e):
print(f"Warning: Skipping quantization for layer with incompatible shape (block must divide W). This layer will remain in high precision.")
# Return None to signal that this layer should not be quantized
return None
else:
raise

torch.cuda.empty_cache()

Expand All @@ -98,3 +114,5 @@ def optimize_weights_proximal_legacy(

# Default: fast with early stopping
optimize_weights_proximal = optimize_weights_proximal_legacy


Loading