Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use HF safetensors by default #1046

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def _load_client(self):
if "use_fp8" in _config.plugins.generators.OptimumPipeline:
self.use_fp8 = True

pipline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
self.generator = pipeline("text-generation", **pipline_kwargs)
pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
self.generator = pipeline("text-generation", **pipeline_kwargs)
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
if _config.loaded:
Expand All @@ -196,8 +196,8 @@ def _load_client(self):

# Note that with pipeline, in order to access the tokenizer, model, or device, you must get the attribute
# directly from self.generator instead of from the ConversationalPipeline object itself.
pipline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
self.generator = pipeline("conversational", **pipline_kwargs)
pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
self.generator = pipeline("conversational", **pipeline_kwargs)
self.conversation = Conversation()
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
Expand Down Expand Up @@ -443,6 +443,8 @@ def _load_client(self):
if _config.run.seed is not None:
transformers.set_seed(_config.run.seed)

print(dir(_config.system))

Comment on lines +446 to +447
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Errant debugging statement?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extremely

model_kwargs = self._gather_hf_params(
hf_constructor=transformers.AutoConfig.from_pretrained
) # will defer to device_map if device map was `auto` may not match self.device
Expand Down
3 changes: 3 additions & 0 deletions garak/resources/api/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def _gather_hf_params(self, hf_constructor: Callable):
):
args["trust_remote_code"] = False

if "use_safetensors" not in params:
args["use_safetensors"] = True

Comment on lines +88 to +90
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A default not supplied by DEFAULT_PARAMS['hf_args'] should only be injected if the hf_constucutor is expected to support or pass thru the option:

The following methods are currently passed as hf_constructor:

transformers.pipeline
transformers.AutoConfig.from_pretrained
transformers.AutoModelForSeq2SeqLM.from_pretrained
transformers.LlavaNextForConditionalGeneration.from_pretrained

Unfortunately all the above methods mask use_safetensors as part of an optional **kwarg param.

This gathering pattern may not work for this option.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's exactly where I ended up with this also. Inspection doesn't immediately reveal the parameter, and support it not universal. There are some lists of what supports this but I prefer something that responds to the code at hand. More research needed.

return args

def _select_hf_device(self):
Expand Down
Loading