-
Notifications
You must be signed in to change notification settings - Fork 130
Open
Description
import OpenAttack
import numpy as np
import datasets
import nltk
from nltk.sentiment.vader import SentimentIntensityAnalyzer
class MyClassifier(OpenAttack.Classifier):
def __init__(self):
nltk.download('vader_lexicon')
self.model = SentimentIntensityAnalyzer()
def get_pred(self, input_):
return self.get_prob(input_).argmax(axis=1)
def get_prob(self, input_):
ret = []
for sent in input_:
res = self.model.polarity_scores(sent)
prob = (res["pos"] + 1e-6) / (res["neg"] + res["pos"] + 2e-6)
ret.append(np.array([1 - prob, prob]))
return np.array(ret)
def main():
def dataset_mapping(x):
return {
"x": x["sentence"],
"y": 1 if x["label"] > 0.5 else 0,
}
dataset = datasets.load_dataset("sst", split="train[:20]").map(function=dataset_mapping)
victim = MyClassifier()
attacker = OpenAttack.attackers.PWWSAttacker()
attack_eval = OpenAttack.AttackEval(attacker, victim)
eval = attack_eval.eval(dataset, visualize=True)
print(eval)
main()
Error:
SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels