JPG压缩防御对抗样例攻击在MNIST数据集上(pytorch)_zzr-程序员宅基地

技术标签: pytorch  JPG压缩  对抗样例攻击  

import torch

import torchvision

from torch.autograd import Variable

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

from torch import nn

from numpy import *

import torch.nn.functional as F

import advertorch.defenses as defenses

import numpy as np

import matplotlib.pyplot as plt

seed = 2014

 

torch.manual_seed(seed)

np.random.seed(seed)  # Numpy module.

random.seed(seed)  # Python random module.

torch.manual_seed(seed)

 

train_dataset = datasets.MNIST(root = 'data/', train = True, 

                               transform = transforms.ToTensor(), download = True)

train_loader = DataLoader(dataset = train_dataset, batch_size = 500, shuffle = True)

 

class Linear_cliassifer(torch.nn.Module):

    def __init__(self) :

        super(Linear_cliassifer, self).__init__()

 

        self.Line1 = torch.nn.Linear(28 * 28, 10)

 

    def forward(self, x):

 

        x = self.Line1(x.view(-1, 28 * 28))

        return x


 

net = Linear_cliassifer()

cost = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.0005)

 

test_loader = torch.utils.data.DataLoader(

    datasets.MNIST('data/', train=False, download=True, transform=transforms.Compose([

            transforms.ToTensor(),

            ])),

        batch_size=1, shuffle=True)

p = 70

epoch = 5

for k in range(epoch):

    sum_loss = 0.0

    train_correct = 0

    for i, data in enumerate(train_loader, 0):

        inputs, labels = data

        optimizer.zero_grad()

        inputs = defenses.JPEGFilter(quality=p)(inputs)

        outputs = net(inputs)

        loss = cost(outputs, labels)

        loss.backward()

        optimizer.step()

 

        print(loss)

        _, id = torch.max(outputs.data, 1) 

        sum_loss += loss.data

        train_correct += torch.sum(id == labels.data)

        #print('[%d,%d] loss:%.03f' % (k + 1, k, sum_loss / len(train_loader)))

    print('        correct:%.03f%%' % (100 * train_correct / len(train_dataset)))



 

def fgsm_attack(image, epsilon, data_grad):

    # Collect the element-wise sign of the data gradient

    sign_data_grad = data_grad.sign()

    # Create the perturbed image by adjusting each pixel of the input image

    perturbed_image = image + epsilon*sign_data_grad

    # Adding clipping to maintain [0,1] range

    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    # Return the perturbed image

    return perturbed_image, epsilon*sign_data_grad



 

def test( model, test_loader, epsilon):

    

    # Accuracy counter

    correct = 0

    adv_examples = []

    ns = []

    # Loop over all examples in test set

    for data, target in test_loader:

 

        # Set requires_grad attribute of tensor. Important for Attack

 

        data = defenses.JPEGFilter(quality=p)(data)

 

        data.requires_grad = True

 

        output = net(data)


 

        init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

 

        # If the initial prediction is wrong, dont bother attacking, just move on

        if init_pred.item() != target.item():

            continue

 

        # Calculate the loss

        loss = F.nll_loss(output, target)

 

        # Zero all existing gradients

        model.zero_grad()

 

        # Calculate gradients of model in backward pass

        loss.backward()

 

        # Collect datagrad

        data_grad = data.grad.data

 

        # Call FGSM Attack

        perturbed_data, n = fgsm_attack(data, epsilon, data_grad)

 

        ns.append(torch.sum(torch.abs(n)).tolist())

        # Re-classify the perturbed image

 

        perturbed_data = defenses.JPEGFilter(quality=p)(perturbed_data)

        output = net(perturbed_data)

 

        # Check for success

        final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

        if final_pred.item() == target.item():

            correct += 1

           

 

    # Calculate final accuracy for this epsilon

    final_acc = correct/float(len(test_loader))

    print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))

 

    # Return the accuracy and an adversarial example

    ns = sum(ns)

    return final_acc, adv_examples, ns

 

accuracies = []

examples = []

noise = []

 

epsilons = [0, .05, .1, .15, .2, .25, .3]

# Run test for each epsilon

for eps in epsilons:

    acc, ex, ns = test(net, test_loader, eps)

    accuracies.append(acc)

    examples.append(ex)

    noise.append(ns)

 

print(accuracies)

 

plt.figure(figsize=(5,5))

plt.plot(epsilons, accuracies, "*-")

plt.yticks(np.arange(0, 1.1, step=0.1))

plt.xticks(np.arange(0, .35, step=0.05))

plt.title("Accuracy vs Epsilon")

plt.xlabel("Epsilon")

plt.ylabel("Accuracy")

plt.show()

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_23144435/article/details/107741974

智能推荐

随便推点