Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a remotemodelwrapper class #809

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Added a remotemodelwrapper class #809

wants to merge 1 commit into from

Conversation

l3ra
Copy link

@l3ra l3ra commented Jan 8, 2025

What does this PR do?

Adds a remote model wrapper class based on the TextFooler example.

Summary

This PR adapts the Textfooler example https://textattack.readthedocs.io/en/master/_modules/textattack/attack_recipes/textfooler_jin_2019.html to instead target BERT deployed on a remote endpoint (https://x.com/predict). The remotely deployed model should bring back confidence scores in the same format as the original example (positive, negative scores).

If helpful, can share the remote model deployment files for BERT. This is part of a paper on adversarial attacks against remotely deployed models.

Additions

  • Added textattack/models/wrappers/remote_model_wrapper.py

Changes

  • A new model wrapper class in the wrappers directory

Deletions

None

Checklist

  • [ ] The title of your pull request should be a summary of its contribution.
  • [ ] Please write detailed description of what parts have been newly added and what parts have been modified. Please also explain why certain changes were made.
  • [ ] If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it)
  • [ ] To indicate a work in progress please mark it as a draft on Github.
  • [ ] Make sure existing tests pass.
  • [ ] Add relevant tests. No quality testing = no merge.
  • [ ] All public methods must have informative docstrings that work nicely with sphinx. For new modules/files, please add/modify the appropriate .rst file in TextAttack/docs/apidoc.'

@l3ra
Copy link
Author

l3ra commented Jan 15, 2025

@bterrific2008 could you review this please?

Copy link
Contributor

@bterrific2008 bterrific2008 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe I am a maintainer for this package, so I don't believe I can approve or disapprove your package in good faith. However, I don't think the TextAttack package needs a generic RemoteModelWrapper function for the following reasons:

  • different remote models will have drastically different inputs and outputs. It might be more advisable for folks to follow your example here and create their own ModelWrapper class as needed
  • not sure if this package should be overtly encouraging people to launch adversarial attacks on remote models. What you do on your own time is fine, but I'm not confident if the QData lab is looking to handle that kind of overhead

That being said, I've left some feedback for the PR. Overall this looks fine with a few standout issues:

  • run make format and make lint, as set out in the contributing guidelines
  • you could better support differing API inputs and outputs by accepting a lambda function that can massage the data to fit the desired payload structure
  • inherit from the ModelWrapper abstract class
  • why do you return the output as a tensor?

Minus the big issue (is it realistic for TextAttack to have a generic RemoteModelWrapper class) and needing to run the formatters and linters, everything else looks fine. Hopefully someone from the QData group can leave something more substantial

Comment on lines +7 to +10
import requests
import torch
import numpy as np
import transformers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: run make format or the black formatter, as mentioned in the contribution guidelines for this repo

"""
def __init__(self, api_url):
self.api_url = api_url
self.model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you set this? The model variable isn't used elsewhere in this class

import numpy as np
import transformers

class RemoteModelWrapper():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should inherit from the ModelWrapper abstract class if you're looking to be using this as a ModelWrapper

Comment on lines +13 to +18
"""This model wrapper queries a remote model with a list of text inputs.

It sends the input to a remote endpoint provided in api_url.


"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should specify the args for your class here, following repo convention

Suggested change
"""This model wrapper queries a remote model with a list of text inputs.
It sends the input to a remote endpoint provided in api_url.
"""
"""This model wrapper queries a remote model with a list of text inputs. It sends the input to a remote endpoint provided in api_url.
Args:
api_url (:obj:`<TYPE HERE>`): <DESCRIPTION HERE>
"""

for text in text_input_list:
params = dict()
params["text"] = text
response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this kind of request format guaranteed to work for all endpoints?

For example, OpenAI requires a specific kind of payload. I might suggest adding a parameter when you initialize the wrapper you accept a lambda as a param to massage the data into a viable payload format

params["text"] = text
response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload
if response.status_code != 200:
print(f"Response content: {response.text}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend using the package's logger instead of making print statements, especially when you mean to throw an error. This print statement might not even be necessary since you throw the error below anyways

raise ValueError(f"API call failed with status {response.status_code}")
result = response.json()
# Assuming the API returns probabilities for positive and negative
predictions.append([result["negative"], result["positive"]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you're making these assumptions, but I'm not sure if this is so common as to be widely applicable to make this a good wrapper function. To alleviate this, you could add another lambda to massage the output

result = response.json()
# Assuming the API returns probabilities for positive and negative
predictions.append([result["negative"], result["positive"]])
return torch.tensor(predictions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to cast this as a tensor?

Comment on lines +37 to +63
'''
Example usage:

# Define the remote model API endpoint and tokenizer
api_url = "https://x.com/predict"

model_wrapper = RemoteModelWrapper(api_url)

# Build the attack
attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper)

# Define dataset and attack arguments
dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")

attack_args = textattack.AttackArgs(
num_examples=100,
log_to_csv="/textfooler.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)

# Run the attack
attacker = textattack.Attacker(attack, dataset, attack_args)
attacker.attack_dataset()

'''
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this to be part of the RemoteModelWrapper docstring and change some of the formatting to more closely align with other parts of the package

Suggested change
'''
Example usage:
# Define the remote model API endpoint and tokenizer
api_url = "https://x.com/predict"
model_wrapper = RemoteModelWrapper(api_url)
# Build the attack
attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper)
# Define dataset and attack arguments
dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")
attack_args = textattack.AttackArgs(
num_examples=100,
log_to_csv="/textfooler.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
# Run the attack
attacker = textattack.Attacker(attack, dataset, attack_args)
attacker.attack_dataset()
'''
"""
Example usage:
>>> # Define the remote model API endpoint
>>> api_url = "https://example.com"
>>> model_wrapper = RemoteModelWrapper(api_url)
>>> # Build the attack
>>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper)
>>> # Define dataset and attack arguments
>>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test")
>>> attack_args = textattack.AttackArgs(
... num_examples=100,
... log_to_csv="/textfooler.csv",
... checkpoint_interval=5,
... checkpoint_dir="checkpoints",
... disable_stdout=True
... )
>>> # Run the attack
>>> attacker = textattack.Attacker(attack, dataset, attack_args)
>>> attacker.attack_dataset()
"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants