@@ -101,16 +101,29 @@ def skip(*args, **kwargs):
101
101
)
102
102
103
103
# instantiate compressor from model config
104
- compressor = ModelCompressor .from_pretrained (pretrained_model_name_or_path )
104
+ compressor = ModelCompressor .from_pretrained (
105
+ pretrained_model_name_or_path , ** kwargs
106
+ )
105
107
106
108
# temporarily set the log level to error, to ignore printing out long missing
107
109
# and unexpected key error messages (these are EXPECTED for quantized models)
108
110
logger = logging .getLogger ("transformers.modeling_utils" )
109
111
restore_log_level = logger .getEffectiveLevel ()
110
112
logger .setLevel (level = logging .ERROR )
113
+
114
+ if kwargs .get ("trust_remote_code" ):
115
+ # By artifically aliasing
116
+ # class name SparseAutoModelForCausallLM to
117
+ # AutoModelForCausalLM we can "trick" the
118
+ # `from_pretrained` method into properly
119
+ # resolving the logic when
120
+ # (has_remote_code and trust_remote_code) == True
121
+ cls .__name__ = AutoModelForCausalLM .__name__
122
+
111
123
model = super (AutoModelForCausalLM , cls ).from_pretrained (
112
124
pretrained_model_name_or_path , * model_args , ** kwargs
113
125
)
126
+
114
127
if model .dtype != model .config .torch_dtype :
115
128
_LOGGER .warning (
116
129
f"The dtype of the loaded model: { model .dtype } is different "
0 commit comments