Skip to content

Commit 179fd90

Browse files
[Fix] Allow to create SparseAutoModelForCausalLM with trust_remote_code=True (#2349)
* initial commit * better comments --------- Co-authored-by: [email protected] <[email protected]>
1 parent f0a3692 commit 179fd90

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

Diff for: src/sparseml/transformers/sparsification/sparse_model.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,29 @@ def skip(*args, **kwargs):
101101
)
102102

103103
# instantiate compressor from model config
104-
compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path)
104+
compressor = ModelCompressor.from_pretrained(
105+
pretrained_model_name_or_path, **kwargs
106+
)
105107

106108
# temporarily set the log level to error, to ignore printing out long missing
107109
# and unexpected key error messages (these are EXPECTED for quantized models)
108110
logger = logging.getLogger("transformers.modeling_utils")
109111
restore_log_level = logger.getEffectiveLevel()
110112
logger.setLevel(level=logging.ERROR)
113+
114+
if kwargs.get("trust_remote_code"):
115+
# By artifically aliasing
116+
# class name SparseAutoModelForCausallLM to
117+
# AutoModelForCausalLM we can "trick" the
118+
# `from_pretrained` method into properly
119+
# resolving the logic when
120+
# (has_remote_code and trust_remote_code) == True
121+
cls.__name__ = AutoModelForCausalLM.__name__
122+
111123
model = super(AutoModelForCausalLM, cls).from_pretrained(
112124
pretrained_model_name_or_path, *model_args, **kwargs
113125
)
126+
114127
if model.dtype != model.config.torch_dtype:
115128
_LOGGER.warning(
116129
f"The dtype of the loaded model: {model.dtype} is different "

0 commit comments

Comments
 (0)