-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added a remotemodelwrapper class #809
base: master
Are you sure you want to change the base?
Conversation
@bterrific2008 could you review this please? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't believe I am a maintainer for this package, so I don't believe I can approve or disapprove your package in good faith. However, I don't think the TextAttack package needs a generic RemoteModelWrapper
function for the following reasons:
- different remote models will have drastically different inputs and outputs. It might be more advisable for folks to follow your example here and create their own
ModelWrapper
class as needed - not sure if this package should be overtly encouraging people to launch adversarial attacks on remote models. What you do on your own time is fine, but I'm not confident if the QData lab is looking to handle that kind of overhead
That being said, I've left some feedback for the PR. Overall this looks fine with a few standout issues:
- run
make format
andmake lint
, as set out in the contributing guidelines - you could better support differing API inputs and outputs by accepting a lambda function that can massage the data to fit the desired payload structure
- inherit from the ModelWrapper abstract class
- why do you return the output as a tensor?
Minus the big issue (is it realistic for TextAttack to have a generic RemoteModelWrapper class) and needing to run the formatters and linters, everything else looks fine. Hopefully someone from the QData group can leave something more substantial
import requests | ||
import torch | ||
import numpy as np | ||
import transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: run make format
or the black
formatter, as mentioned in the contribution guidelines for this repo
""" | ||
def __init__(self, api_url): | ||
self.api_url = api_url | ||
self.model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you set this? The model variable isn't used elsewhere in this class
import numpy as np | ||
import transformers | ||
|
||
class RemoteModelWrapper(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should inherit from the ModelWrapper abstract class if you're looking to be using this as a ModelWrapper
"""This model wrapper queries a remote model with a list of text inputs. | ||
|
||
It sends the input to a remote endpoint provided in api_url. | ||
|
||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should specify the args for your class here, following repo convention
"""This model wrapper queries a remote model with a list of text inputs. | |
It sends the input to a remote endpoint provided in api_url. | |
""" | |
"""This model wrapper queries a remote model with a list of text inputs. It sends the input to a remote endpoint provided in api_url. | |
Args: | |
api_url (:obj:`<TYPE HERE>`): <DESCRIPTION HERE> | |
""" |
for text in text_input_list: | ||
params = dict() | ||
params["text"] = text | ||
response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this kind of request format guaranteed to work for all endpoints?
For example, OpenAI requires a specific kind of payload. I might suggest adding a parameter when you initialize the wrapper you accept a lambda as a param to massage the data into a viable payload format
params["text"] = text | ||
response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload | ||
if response.status_code != 200: | ||
print(f"Response content: {response.text}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommend using the package's logger instead of making print statements, especially when you mean to throw an error. This print statement might not even be necessary since you throw the error below anyways
raise ValueError(f"API call failed with status {response.status_code}") | ||
result = response.json() | ||
# Assuming the API returns probabilities for positive and negative | ||
predictions.append([result["negative"], result["positive"]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you're making these assumptions, but I'm not sure if this is so common as to be widely applicable to make this a good wrapper function. To alleviate this, you could add another lambda to massage the output
result = response.json() | ||
# Assuming the API returns probabilities for positive and negative | ||
predictions.append([result["negative"], result["positive"]]) | ||
return torch.tensor(predictions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to cast this as a tensor?
''' | ||
Example usage: | ||
|
||
# Define the remote model API endpoint and tokenizer | ||
api_url = "https://x.com/predict" | ||
|
||
model_wrapper = RemoteModelWrapper(api_url) | ||
|
||
# Build the attack | ||
attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) | ||
|
||
# Define dataset and attack arguments | ||
dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | ||
|
||
attack_args = textattack.AttackArgs( | ||
num_examples=100, | ||
log_to_csv="/textfooler.csv", | ||
checkpoint_interval=5, | ||
checkpoint_dir="checkpoints", | ||
disable_stdout=True | ||
) | ||
|
||
# Run the attack | ||
attacker = textattack.Attacker(attack, dataset, attack_args) | ||
attacker.attack_dataset() | ||
|
||
''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this to be part of the RemoteModelWrapper docstring and change some of the formatting to more closely align with other parts of the package
''' | |
Example usage: | |
# Define the remote model API endpoint and tokenizer | |
api_url = "https://x.com/predict" | |
model_wrapper = RemoteModelWrapper(api_url) | |
# Build the attack | |
attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) | |
# Define dataset and attack arguments | |
dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | |
attack_args = textattack.AttackArgs( | |
num_examples=100, | |
log_to_csv="/textfooler.csv", | |
checkpoint_interval=5, | |
checkpoint_dir="checkpoints", | |
disable_stdout=True | |
) | |
# Run the attack | |
attacker = textattack.Attacker(attack, dataset, attack_args) | |
attacker.attack_dataset() | |
''' | |
""" | |
Example usage: | |
>>> # Define the remote model API endpoint | |
>>> api_url = "https://example.com" | |
>>> model_wrapper = RemoteModelWrapper(api_url) | |
>>> # Build the attack | |
>>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) | |
>>> # Define dataset and attack arguments | |
>>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | |
>>> attack_args = textattack.AttackArgs( | |
... num_examples=100, | |
... log_to_csv="/textfooler.csv", | |
... checkpoint_interval=5, | |
... checkpoint_dir="checkpoints", | |
... disable_stdout=True | |
... ) | |
>>> # Run the attack | |
>>> attacker = textattack.Attacker(attack, dataset, attack_args) | |
>>> attacker.attack_dataset() | |
""" |
What does this PR do?
Adds a remote model wrapper class based on the TextFooler example.
Summary
This PR adapts the Textfooler example https://textattack.readthedocs.io/en/master/_modules/textattack/attack_recipes/textfooler_jin_2019.html to instead target BERT deployed on a remote endpoint (https://x.com/predict). The remotely deployed model should bring back confidence scores in the same format as the original example (positive, negative scores).
If helpful, can share the remote model deployment files for BERT. This is part of a paper on adversarial attacks against remotely deployed models.
Additions
Changes
Deletions
None
Checklist
.rst
file inTextAttack/docs/apidoc
.'