From 89a04d54fcd3b667ef0ea5474709ffb86845f081 Mon Sep 17 00:00:00 2001 From: l3ra Date: Mon, 30 Dec 2024 15:37:19 +0000 Subject: [PATCH 1/3] Added a remotemodelwrapper class --- .../models/wrappers/remote_model_wrapper.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 textattack/models/wrappers/remote_model_wrapper.py diff --git a/textattack/models/wrappers/remote_model_wrapper.py b/textattack/models/wrappers/remote_model_wrapper.py new file mode 100644 index 00000000..0a0dc046 --- /dev/null +++ b/textattack/models/wrappers/remote_model_wrapper.py @@ -0,0 +1,63 @@ +""" +RemoteModelWrapper class +-------------------------- + +""" + +import requests +import torch +import numpy as np +import transformers + +class RemoteModelWrapper(): + """This model wrapper queries a remote model with a list of text inputs. + + It sends the input to a remote endpoint provided in api_url. + + + """ + def __init__(self, api_url): + self.api_url = api_url + self.model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") + + def __call__(self, text_input_list): + predictions = [] + for text in text_input_list: + params = dict() + params["text"] = text + response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload + if response.status_code != 200: + print(f"Response content: {response.text}") + raise ValueError(f"API call failed with status {response.status_code}") + result = response.json() + # Assuming the API returns probabilities for positive and negative + predictions.append([result["negative"], result["positive"]]) + return torch.tensor(predictions) + +''' +Example usage: + +# Define the remote model API endpoint and tokenizer +api_url = "https://x.com/predict" + +model_wrapper = RemoteModelWrapper(api_url) + +# Build the attack +attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) + +# Define dataset and attack arguments +dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") + +attack_args = textattack.AttackArgs( + num_examples=100, + log_to_csv="/textfooler.csv", + checkpoint_interval=5, + checkpoint_dir="checkpoints", + disable_stdout=True +) + +# Run the attack +attacker = textattack.Attacker(attack, dataset, attack_args) +attacker.attack_dataset() + +''' \ No newline at end of file From 52fe0784d95ead321fc25c9d19632cfafb81d397 Mon Sep 17 00:00:00 2001 From: Lera Leonteva Date: Mon, 20 Jan 2025 09:46:48 +0000 Subject: [PATCH 2/3] Update textattack/models/wrappers/remote_model_wrapper.py Co-authored-by: Bryan Tor --- .../models/wrappers/remote_model_wrapper.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/textattack/models/wrappers/remote_model_wrapper.py b/textattack/models/wrappers/remote_model_wrapper.py index 0a0dc046..ba38dcdd 100644 --- a/textattack/models/wrappers/remote_model_wrapper.py +++ b/textattack/models/wrappers/remote_model_wrapper.py @@ -34,30 +34,29 @@ def __call__(self, text_input_list): predictions.append([result["negative"], result["positive"]]) return torch.tensor(predictions) -''' +""" Example usage: -# Define the remote model API endpoint and tokenizer -api_url = "https://x.com/predict" - -model_wrapper = RemoteModelWrapper(api_url) + >>> # Define the remote model API endpoint + >>> api_url = "https://example.com" -# Build the attack -attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) + >>> model_wrapper = RemoteModelWrapper(api_url) -# Define dataset and attack arguments -dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") + >>> # Build the attack + >>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) -attack_args = textattack.AttackArgs( - num_examples=100, - log_to_csv="/textfooler.csv", - checkpoint_interval=5, - checkpoint_dir="checkpoints", - disable_stdout=True -) + >>> # Define dataset and attack arguments + >>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") -# Run the attack -attacker = textattack.Attacker(attack, dataset, attack_args) -attacker.attack_dataset() + >>> attack_args = textattack.AttackArgs( + ... num_examples=100, + ... log_to_csv="/textfooler.csv", + ... checkpoint_interval=5, + ... checkpoint_dir="checkpoints", + ... disable_stdout=True + ... ) -''' \ No newline at end of file + >>> # Run the attack + >>> attacker = textattack.Attacker(attack, dataset, attack_args) + >>> attacker.attack_dataset() +""" \ No newline at end of file From 7ccfdade87952c12d7505e880186c9f8e8ab3097 Mon Sep 17 00:00:00 2001 From: Lera Leonteva Date: Mon, 20 Jan 2025 09:47:43 +0000 Subject: [PATCH 3/3] Update textattack/models/wrappers/remote_model_wrapper.py Co-authored-by: Bryan Tor --- textattack/models/wrappers/remote_model_wrapper.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/textattack/models/wrappers/remote_model_wrapper.py b/textattack/models/wrappers/remote_model_wrapper.py index ba38dcdd..c7c4b977 100644 --- a/textattack/models/wrappers/remote_model_wrapper.py +++ b/textattack/models/wrappers/remote_model_wrapper.py @@ -10,11 +10,10 @@ import transformers class RemoteModelWrapper(): - """This model wrapper queries a remote model with a list of text inputs. - - It sends the input to a remote endpoint provided in api_url. - + """This model wrapper queries a remote model with a list of text inputs. It sends the input to a remote endpoint provided in api_url. + Args: + api_url (:obj:``): """ def __init__(self, api_url): self.api_url = api_url