diff --git a/OpenAttack/attackers/uat/__init__.py b/OpenAttack/attackers/uat/__init__.py index 72a4a89..d44a96d 100644 --- a/OpenAttack/attackers/uat/__init__.py +++ b/OpenAttack/attackers/uat/__init__.py @@ -53,8 +53,22 @@ def __init__(self, def set_triggers(self, victim : Classifier, - dataset : datasets.Dataset,): - self.triggers = self.get_triggers(victim, dataset, self.tokenizer) + dataset : datasets.Dataset, + epoch : int = 5, + batch_size : int = 5, + trigger_len : int = 3, + beam_size : int = 5, + lang = None): + + self.triggers = self.get_triggers(victim, + dataset, + self.tokenizer, + epoch=epoch, + batch_size=batch_size, + trigger_len=trigger_len, + beam_size=beam_size, + lang=lang + ) def attack(self, victim: Classifier, sentence : str, goal : ClassifierGoal): trigger_sent = self.tokenizer.detokenize( self.triggers + self.tokenizer.tokenize(sentence, pos_tagging=False) )