diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index a42e596ea..b648b21c1 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -223,10 +223,19 @@ class DatasetArguments(CustomDatasetArguments): quantization_aware_calibration: bool = field( default=True, metadata={ - "help": "Whether to enable quantization-aware calibration in the pipeline. " - "When True, quantization is applied during forward pass in calibration. " - "When False, quantization is disabled during forward pass in calibration. " - "Default is set to True." + "help": "Only relevant for the sequential pipeline. " + "If True, quantization is applied during forward pass in calibration. " + "If False, quantization is disabled during forward pass in calibration. " + "Default is True." + }, + ) + propagate_error: bool = field( + default=True, + metadata={ + "help": "Only relevant for the sequential pipeline. If True, use quantized " + "layer outputs as the inputs to the next sequential layer. If False, use " + "unquantized layer outputs as the inputs to the next sequential layer. " + "Default is True" }, ) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index e95ffa915..bffd90c64 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -95,22 +95,27 @@ def __call__( # reduce memory movement by keeping modules onloaded with disable_offloading(): # do a preliminary pass to trigger modifier hooks - for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): - inputs = activations.fetch(batch_idx, subgraph.input_names) - subgraph.forward(model, **inputs) + for b_idx in tqdm(range(len(dataloader)), desc=calib_desc): + inputs = activations.fetch(b_idx, subgraph.input_names) + outputs = subgraph.forward(model, **inputs) - LifecycleCallbacks.sequential_epoch_end(subgraph) + if not dataset_args.propagate_error: + activations.update(b_idx, outputs) + activations.delete(b_idx, subgraph.consumed_names) - # this pass does not trigger modifier hooks - # and is only used for capturing outputs of newly compressed modules - with HooksMixin.disable_hooks(): - for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc): - inputs = activations.fetch(batch_idx, subgraph.input_names) - output = subgraph.forward(model, **inputs) + LifecycleCallbacks.sequential_epoch_end(subgraph) - if subgraph_index < num_subgraphs - 1: - activations.update(batch_idx, output) - activations.delete(batch_idx, subgraph.consumed_names) + if dataset_args.propagate_error: + # this pass does not trigger modifier hooks + # and is only used for capturing outputs of compressed modules + with HooksMixin.disable_hooks(): + for b_idx in tqdm(range(len(dataloader)), desc=prop_desc): + inputs = activations.fetch(b_idx, subgraph.input_names) + outputs = subgraph.forward(model, **inputs) + + if dataset_args.propagate_error: + activations.update(b_idx, outputs) + activations.delete(b_idx, subgraph.consumed_names) # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end()