Skip to content

Commit

Permalink
use pipeline explicit truncation, use tokenizer decode in Model
Browse files Browse the repository at this point in the history
* When using a pipeline with `chat` support defer to library to enable
  response formatting to be in canonical chat list of dict form.

* Models responses may have the chat template based prompt as a prefix
  decode to ensure the template mutated prompt is removed.
  • Loading branch information
jmartin-tech committed Dec 17, 2024
1 parent 9ec5794 commit 99c760a
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def _load_client(self):
set_seed(_config.run.seed)

pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
pipeline_kwargs["truncation"] = (
True # this is forced to maintain existing pipeline expectations
)
self.generator = pipeline("text-generation", **pipeline_kwargs)
if self.generator.tokenizer is None:
# account for possible model without a stored tokenizer
Expand Down Expand Up @@ -113,27 +116,16 @@ def _call_model(
warnings.simplefilter("ignore", category=UserWarning)
try:
with torch.no_grad():
# workaround for pipeline to truncate the input

# according to docs https://huggingface.co/docs/transformers/main/en/chat_templating
# chat template should be automatically utilized if the pipeline tokenizer has support
# and a properly formatted list[dict] is supplied
if self.use_chat:
formatted_prompt = self.generator.tokenizer.apply_chat_template(
self._format_chat_prompt(prompt),
tokenize=False,
add_generation_prompt=True,
)
formatted_prompt = self._format_chat_prompt(prompt)
else:
formatted_prompt = prompt

encoded_prompt = self.generator.tokenizer(
formatted_prompt, truncation=True
)
truncated_prompt = self.generator.tokenizer.decode(
encoded_prompt["input_ids"], skip_special_tokens=True
)
raw_output = self.generator(
truncated_prompt,
formatted_prompt,
pad_token_id=self.generator.tokenizer.eos_token_id,
max_new_tokens=self.max_tokens,
num_return_sequences=generations_this_call,
Expand All @@ -148,11 +140,15 @@ def _call_model(
i["generated_text"] for i in raw_output
] # generator returns 10 outputs by default in __init__

if not self.deprefix_prompt and not self.use_chat:
return outputs
if self.use_chat:
text_outputs = [_o[-1]["content"].strip() for _o in outputs]
else:
text_outputs = outputs

if not self.deprefix_prompt:
return text_outputs
else:
# consider using formatted_prompt in removal as a `list` or `str`
return [re.sub("^" + re.escape(formatted_prompt), "", _o) for _o in outputs]
return [re.sub("^" + re.escape(prompt), "", _o) for _o in text_outputs]


class OptimumPipeline(Pipeline, HFCompatible):
Expand Down Expand Up @@ -520,7 +516,7 @@ def _call_model(
if self.top_k is not None:
self.generation_config.top_k = self.top_k

text_output = []
raw_text_output = []
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
with torch.no_grad():
Expand All @@ -537,6 +533,10 @@ def _call_model(
formatted_prompt, truncation=True, return_tensors="pt"
).to(self.device)

prefix_prompt = self.tokenizer.decode(
inputs["input_ids"][0], skip_special_tokens=True
)

try:
outputs = self.model.generate(
**inputs, generation_config=self.generation_config
Expand All @@ -549,17 +549,22 @@ def _call_model(
return returnval
else:
raise e
text_output = self.tokenizer.batch_decode(
raw_text_output = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, device=self.device
)

if not self.deprefix_prompt and not self.use_chat:
if self.use_chat:
text_output = [
re.sub("^" + re.escape(prefix_prompt), "", i).strip()
for i in raw_text_output
]
else:
text_output = raw_text_output

if not self.deprefix_prompt:
return text_output
else:
# consider using formatted_prompt in removal as a `list` or `str`
return [
re.sub("^" + re.escape(formatted_prompt), "", i) for i in text_output
]
return [re.sub("^" + re.escape(prefix_prompt), "", i) for i in text_output]


class LLaVA(Generator, HFCompatible):
Expand Down

0 comments on commit 99c760a

Please sign in to comment.