Badnets


source

BadNetsAttack

 BadNetsAttack (trigger:<built-infunctionarray>, target_class:str, *args,
                test_only=False, poison_fraction=0.1)

https://arxiv.org/abs/1708.06733

Type Default Details
trigger array Expected to be the same shape as the input
target_class str The target class for inputs with the trigger
args
test_only bool False
poison_fraction float 0.1
import torch
import torch.nn.functional as F

from similarity_learning.all import *
mnist = MNIST()
classifier = MLP(10)
/home/paperspace/.local/lib/python3.9/site-packages/torch/nn/modules/lazy.py:180: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
patch = torch.tensor([[1, 0, 1],
                      [0, 1, 0],
                      [1, 0, 1]]).int()*255
trigger = F.pad(patch, (25, 0, 25, 0)).numpy()
PILImageBW.create(trigger).show()
<AxesSubplot:>

Training without Poison (Baseline)

learn = Learner(mnist.dls(), classifier, metrics=accuracy, cbs=BadNetsAttack(trigger, target_class='0', test_only=True))
learn.fit_one_cycle(1)
epoch train_loss valid_loss accuracy time
0 0.115392 0.093313 0.970600 00:17

Benign Accuracy

learn.show_results()

acc = learn.validate()[1]
test(acc, .95, ge)
print(as_percentage(acc))
97.06%

Attack Success

We obviously expect the baseline ASR to be quite low. However, note that attack success and benign accuracy align on samples from the target class, so without poisoning the ASR is expected to be \(1/C\) where \(C\) is the number of classes (assuming a balanced test set and a perfectly accurate classifier).

learn.show_results(2)

Since MNIST isn’t a balanced dataset, and the model isn’t perfectly accurate, the baseline ASR might be over \(10\%\), but not by much:

acc = learn.validate(2)[1]
test(acc, .15, le)
print(as_percentage(acc))
9.95%

Training With Poison

learn = Learner(mnist.dls(), classifier, metrics=accuracy, cbs=BadNetsAttack(trigger, target_class='0'))
learn.fit_one_cycle(1)
epoch train_loss valid_loss accuracy time
0 0.067391 0.073537 0.977600 00:20

Here’s what the poison data looks like:

learn.dls.train_ds.poison.dl().show_batch()

Benign Accuracy

learn.show_results()

acc = learn.validate()[1]
test(acc, .95, ge)
print(as_percentage(acc))
97.76%

Attack Success

learn.show_results(2)

acc = learn.validate(2)[1]
test(acc, .95, ge)
print(as_percentage(acc))
100.0%