diff --git a/sinq/optimize.py b/sinq/optimize.py index 732e22d..8c9ecb6 100644 --- a/sinq/optimize.py +++ b/sinq/optimize.py @@ -86,9 +86,25 @@ def optimize_weights_proximal_legacy( assert axis==1, 'only supports axis 1 right now' if tiling_mode == '1D': - q, s1, s2, z= tiled_quant_rectangle(W_f.reshape(shape), min_max, tile, method, awq_scale) + try: + q, s1, s2, z= tiled_quant_rectangle(W_f.reshape(shape), min_max, tile, method, awq_scale) + except AssertionError as e: + if 'block must divide W' in str(e): + print(f"Warning: Skipping quantization for layer with incompatible shape (block must divide W). This layer will remain in high precision.") + # Return None to signal that this layer should not be quantized + return None + else: + raise elif tiling_mode == '2D': - q, s1, s2, z= tiled_quant_square(W_f.reshape(shape), min_max, tile, method, awq_scale) + try: + q, s1, s2, z= tiled_quant_square(W_f.reshape(shape), min_max, tile, method, awq_scale) + except AssertionError as e: + if 'block must divide W' in str(e): + print(f"Warning: Skipping quantization for layer with incompatible shape (block must divide W). This layer will remain in high precision.") + # Return None to signal that this layer should not be quantized + return None + else: + raise torch.cuda.empty_cache() @@ -98,3 +114,5 @@ def optimize_weights_proximal_legacy( # Default: fast with early stopping optimize_weights_proximal = optimize_weights_proximal_legacy + + diff --git a/sinq/patch_model.py b/sinq/patch_model.py index d553ec4..76626fe 100644 --- a/sinq/patch_model.py +++ b/sinq/patch_model.py @@ -64,6 +64,7 @@ def is_leaf_module(module) -> bool: # Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj +# Now more universal to include components like acoustic_connector, semantic_tokenizer, etc. def name_to_linear_tag(name: str) -> str: return ".".join( [ @@ -418,6 +419,9 @@ def quantize_model( compute_dtype: torch.dtype = float16, device: Union[str, list, dict] = "cuda", use_unpack_kernel: bool = True, + skip_tensors: list = None, # Allow tensors to be skipped + custom_model_class=None, # Allow custom model class for saving/loading compatibility + verbose: bool = False, # Control verbose output ): # Check if the model was already quantized if getattr(model, "sinq_quantized", False): @@ -453,21 +457,48 @@ def quantize_model( # Same quant_config for all layers patch_params = {k: quant_config for k in model.linear_tags} + # Apply skip_tensors list - set quantization config to None for specified tensor patterns + if skip_tensors is not None: + print(f"Skipping quantization for tensors matching patterns: {skip_tensors}") + for linear_tag in model.linear_tags: + for skip_pattern in skip_tensors: + if skip_pattern in linear_tag: + patch_params[linear_tag] = None + print(f" Skipping: {linear_tag}") + break + # Get list of all nodes in order all_nodes = get_all_children_from_model(model, []) # ordered nodes try: - # Extract block names: This is following Hugging Face models. - num_blocks = ( - len(model.model.layers) - if hasattr(model, "model") - else len(model.layers) - ) - all_blocks = ["model.layers." + str(i) for i in range(num_blocks)] - except Exception: + # Extract block names: Handle different model architectures + if hasattr(model, "model") and hasattr(model.model, "layers"): + num_blocks = len(model.model.layers) + all_blocks = ["model.layers." + str(i) for i in range(num_blocks)] + elif hasattr(model, "transformer") and hasattr(model.transformer, "h"): + # GPT-2 style + num_blocks = len(model.transformer.h) + all_blocks = ["transformer.h." + str(i) for i in range(num_blocks)] + elif hasattr(model, "layers"): + # Direct layers attribute (BLOOM, GPT-NeoX, etc.) + num_blocks = len(model.layers) + all_blocks = ["layers." + str(i) for i in range(num_blocks)] + else: + # For universal models with various components (like VibeVoice), + # create blocks from all major components + all_blocks = [] + for name in all_nodes: + # Get the top-level component (e.g., "model.language_model", "model.acoustic_tokenizer") + parts = name.split(".") + if len(parts) >= 3: # model.component.subcomponent + component = ".".join(parts[:2]) # model.component + if component not in all_blocks: + all_blocks.append(component) + + if not all_blocks: + raise AttributeError("Could not find model layers or components") + except Exception as e: all_blocks = None - print( - "Default model structure not supported. Make sure you feed device as dictionary as {name_block: device}" - ) + print(f"Default model structure not supported: {e}. Make sure you feed device as dictionary.") if isinstance( device, dict @@ -520,20 +551,61 @@ def _patch_linear(linear_layer, quant_config): # print(linear_layer.name) # the layer's name if quant_config is not None: - if 'awq' in quant_config['weight_quant_params']['method']: - layer_activations = activations.get(linear_layer.name, None) + # Check if this layer should be skipped based on user-provided skip_tensors + if skip_tensors is not None: + if any(skip_pattern in linear_layer.name for skip_pattern in skip_tensors): + if verbose: + print(f"[SKIP QUANT] Layer {linear_layer.name}: Explicitly skipped via skip_tensors list") + out_module = linear_layer.to(device=current_device, dtype=compute_dtype) + else: + try: + if 'awq' in quant_config['weight_quant_params']['method']: + layer_activations = activations.get(linear_layer.name, None) + else: + layer_activations = None + out_module = SINQLinear( + linear_layer, + quant_config, + compute_dtype=compute_dtype, + device=current_device, + use_unpack_kernel = use_unpack_kernel, + layer_activations = layer_activations, + custom_model_class = current_custom_model_class + )3 + except ValueError as e: + # Check if this is our special skip signal + if "QUANT_SKIP_LAYERS:" in str(e): + if verbose: + print(f"[SKIP QUANT] Layer {linear_layer.name}: {str(e).split('QUANT_SKIP_LAYERS: ')[1]}") + out_module = linear_layer.to(device=current_device, dtype=compute_dtype) + else: + # Re-raise other ValueError exceptions + raise else: - layer_activations = None - out_module = SINQLinear( - linear_layer, - quant_config, - compute_dtype=compute_dtype, - device=current_device, - use_unpack_kernel = use_unpack_kernel, - layer_activations = layer_activations - ) - else: - out_module = linear_layer.to(device=current_device, dtype=compute_dtype) + # No skip list provided, quantize all layers + try: + if 'awq' in quant_config['weight_quant_params']['method']: + layer_activations = activations.get(linear_layer.name, None) + else: + layer_activations = None + out_module = SINQLinear( + linear_layer, + quant_config, + compute_dtype=compute_dtype, + device=current_device, + use_unpack_kernel = use_unpack_kernel, + layer_activations = layer_activations, + custom_model_class = current_custom_model_class + ) + except ValueError as e: + # Check if this is our special skip signal + if "QUANT_SKIP_LAYERS:" in str(e): + if verbose: + print(f"[SKIP QUANT] Layer {linear_layer.name}: {str(e).split('QUANT_SKIP_LAYERS: ')[1]}") + out_module = linear_layer.to(device=current_device, dtype=compute_dtype) + else: + # Re-raise other ValueError exceptions + raise out_module.device = current_device return out_module @@ -584,6 +656,11 @@ def serialize_weights(cls, model, verbose: bool = False) -> dict: ignore_keys = cls.get_ignore_layers(model) actually_tied = _detect_tied_leaves(model) + # Debug: print what's being ignored + if verbose: + print(f"Ignoring keys: {ignore_keys}") + print(f"Tied leaves: {actually_tied}") + def _is_leaf(m: nn.Module) -> bool: return len(m._modules) == 0 @@ -641,7 +718,7 @@ def save_quantized(cls, model, tokenizer, save_dir: str, verbose: bool = False, cls.save_weights(weights, save_dir) @classmethod - def save_quantized_safetensors(cls, model, tokenizer, save_dir: str, filename: str = "model.safetensors", verbose: bool = False, max_shard_size="4GB", write_tokenizer: bool = True): + def save_quantized_safetensors(cls, model, tokenizer, save_dir: str, filename: str = "model.safetensors", verbose: bool = False, max_shard_size="4GB", write_tokenizer: bool = True, custom_model_class=None): """ Sharded-only: writes multiple *.safetensors shards + a HF-style index file. Non-tensor meta goes to 'model.safetensors.index.json.meta.json'. @@ -670,6 +747,7 @@ def from_quantized( compute_dtype: torch.dtype = float16, device="cuda", cache_dir: Union[str, None] = "", + custom_model_class=None, # Allow custom model class for non-standard architectures **kwargs, ): # Local folder only for now (Hub comes next) @@ -680,7 +758,7 @@ def from_quantized( save_dir = save_dir_or_hub # Recreate empty model from config (meta tensors) - model = cls.create_model(save_dir, kwargs) + model = cls.create_model(save_dir, kwargs, custom_model_class) model.save_dir = save_dir cls.setup_model(model) @@ -873,10 +951,14 @@ def _extract_meta(leaf: str, meta_obj, prefix: str = ""): raise RuntimeError("Sharding produced zero shards.") # Write shards + build index json (HF-style) + metadata = {"total_size": int(sum(shard_sizes))} + + # Add custom model class information if provided + # if custom_model_class is not None: + # metadata["custom_model_class"] = f"{custom_model_class.__module__}.{custom_model_class.__qualname__}" + index = { - "metadata": { - "total_size": int(sum(shard_sizes)), - }, + "metadata": metadata, "weight_map": {} # tensor_key -> shard filename } @@ -898,6 +980,47 @@ def _extract_meta(leaf: str, meta_obj, prefix: str = ""): with open(sidecar_path, "w", encoding="utf-8") as f: json.dump(sidecar, f) + @classmethod + def _match_weights_to_model(cls, weights: dict, model) -> dict: + """ + Match saved weight keys to model module names, handling prefix mismatches. + """ + # Get all parameterized leaf names from the model + model_leaves = {} + for name, module in model.named_modules(): + if len(module._modules) == 0: + has_params = any(True for _ in module.parameters(recurse=False)) or \ + any(b is not None for b in module.buffers(recurse=False)) + if has_params: + model_leaves[name] = module + + # If all model leaves are in weights, return as-is + if all(name in weights for name in model_leaves.keys()): + return weights + + # Try to match with/without 'model.' prefix + matched_weights = {} + for model_name in model_leaves.keys(): + # Try exact match first + if model_name in weights: + matched_weights[model_name] = weights[model_name] + continue + + # Try adding 'model.' prefix + if f"model.{model_name}" in weights: + matched_weights[model_name] = weights[f"model.{model_name}"] + continue + + # Try removing 'model.' prefix from model_name + if model_name.startswith("model."): + stripped = model_name[6:] # Remove 'model.' + if stripped in weights: + matched_weights[model_name] = weights[stripped] + continue + + return matched_weights + + @classmethod def from_quantized_safetensors( cls, @@ -911,6 +1034,7 @@ def from_quantized_safetensors( local_files_only: bool = False, token: Union[str, bool, None] = None, # bool True -> use cached auth allow_patterns: Union[list, None] = None, + custom_model_class=None, # Allow custom model class for non-standard architectures **kwargs, ): """ @@ -961,12 +1085,15 @@ def from_quantized_safetensors( local_dir_use_symlinks=True, ) - model = cls.create_model(save_dir, kwargs) + model = cls.create_model(save_dir, kwargs, custom_model_class) model.save_dir = save_dir cls.setup_model(model) weights = cls.load_weights_safetensors(save_dir, map_location=device, filename=filename) - + + # Match weights to model structure (handles prefix mismatches) + weights = cls._match_weights_to_model(weights, model) + DYNAMIC_TIED = _detect_tied_leaves(model) # ---- Preflight (same as from_quantized) ---- @@ -1154,30 +1281,84 @@ def _maybe_int(x, default=None): # Create empty model from config @classmethod - def create_model(cls, save_dir, kwargs): + def create_model(cls, save_dir, kwargs, custom_model_class=None): model_kwargs = {} for key in ["attn_implementation"]: if key in kwargs: model_kwargs[key] = kwargs[key] config = transformers.AutoConfig.from_pretrained(save_dir) - - auto_class = transformers.AutoModel + + # Use custom model class if provided, otherwise use AutoModel detection + if custom_model_class is not None: + auto_class = custom_model_class + else: + auto_class = transformers.AutoModel # Todo: add support for other auto models archs = config.architectures if len(archs) == 1: - if ("CausalLM" in archs[0]): + if ("CausalLM" in archs[0]) and (custom_model_class is None): auto_class = transformers.AutoModelForCausalLM - elif ("SequenceClassification" in archs[0]): + elif ("SequenceClassification" in archs[0]) and (custom_model_class is None): auto_class = transformers.AutoModelForSequenceClassification + if custom_model_class is not None: + # For custom model classes, use only init_empty_weights to avoid loading full model + model = None + + # Try to create empty model using init_empty_weights context + try: + with init_empty_weights(): + # Force all parameters to be on meta device during initialization + original_no_init = transformers.modeling_utils.no_init_weights + transformers.modeling_utils.no_init_weights = True - with init_empty_weights(): - model = auto_class.from_config(config, **model_kwargs) + # Try different initialization methods + if hasattr(custom_model_class, 'from_config'): + model = custom_model_class.from_config(config, **model_kwargs) + else: + # Create model instance directly + model = custom_model_class(config) + + # Restore original setting + transformers.modeling_utils.no_init_weights = original_no_init + + except Exception as e: + print(f"init_empty_weights failed: {e}") + # Fallback: create model but immediately move all weights to meta + try: + if hasattr(custom_model_class, 'from_config'): + model = custom_model_class.from_config(config, **model_kwargs) + else: + model = custom_model_class(config) + + # Move all parameters and buffers to meta device immediately + for name, param in model.named_parameters(): + if param is not None: + param.data = torch.empty_like(param.data, device="meta") + + for name, buffer in model.named_buffers(): + if buffer is not None: + buffer.data = torch.empty_like(buffer.data, device="meta") + + # Force garbage collection and cache clearing + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + except Exception as e2: + print(f"Fallback initialization also failed: {e2}") + raise RuntimeError(f"Could not create empty model: {e}, {e2}") + else: + with init_empty_weights(): + model = auto_class.from_config(config, **model_kwargs) return model # Auto class used for HF models if no architecture was manually setup class AutoSINQHFModel(BaseSINQHFModel, BasePatch): - pass \ No newline at end of file + + pass + diff --git a/sinq/quantizer.py b/sinq/quantizer.py index 5d93d99..a8bf199 100644 --- a/sinq/quantizer.py +++ b/sinq/quantizer.py @@ -183,7 +183,7 @@ def quantize( zero = torch.round(zero) # Use SINQ on weights - W_q, scale, zero, scale2, awq_scale = Quantizer.optimize_weights( + opt_result = Quantizer.optimize_weights( tensor=W, layer_activations=layer_activations, scale=scale, @@ -196,6 +196,14 @@ def quantize( method=method ) + # Check if optimization returned None (meaning quantization should be skipped) + if opt_result is None: + # Return a signal that this layer should not be quantized + # This will be handled by the quantize_model function + raise ValueError(f"QUANT_SKIP_LAYERS: Layer has incompatible dimensions for quantization with tiling_mode={tiling_mode}") + + W_q, scale, zero, scale2, awq_scale = opt_result + if 'quantAux' in method: scale = rtn8(scale) zero = rtn8(zero, tile=torch.numel(zero)) if not (zero is None) else zero @@ -338,4 +346,4 @@ def dequantize(cls, W_q: Tensor, meta: dict, use_unpack_kernel: bool = False) -> if torch.any(torch.isnan(W_r)): raise RuntimeError("NaN detected in dequantized weights") - return W_r.to(compute_dtype) \ No newline at end of file + return W_r.to(compute_dtype) diff --git a/sinq/sinqlinear.py b/sinq/sinqlinear.py index e3a4d84..f84debc 100644 --- a/sinq/sinqlinear.py +++ b/sinq/sinqlinear.py @@ -32,9 +32,13 @@ def __init__( ): super().__init__() - qc = quant_config['weight_quant_params'] + # Handle case where quant_config is None (loading pre-quantized weights) + if quant_config is None: + qc = None + else: + qc = quant_config['weight_quant_params'] - if ('nogemlite' not in qc['method'].lower()) and qc['nbits'] == 4 and qc['tiling_mode'] == '1D' and has_gemlite: + if qc is not None and ('nogemlite' not in qc['method'].lower()) and qc['nbits'] == 4 and qc['tiling_mode'] == '1D' and has_gemlite: self.use_gemlite = True else: self.use_gemlite = False @@ -381,3 +385,4 @@ def sinq_base_quant_config( # Alias: follow similar Auto-GPTQ naming BaseQuantizeConfig = sinq_base_quant_config +