import random

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler


class AddNoNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor

    def __repr__(self):
        return self.__class__.__name__ + "No noise"


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )


class AddRaleighNoise(object):
    def __init__(self, a=0.0, b=0.0):
        self.std = (b * (4 - np.pi)) / 4
        self.mean = a + np.sqrt((np.pi * b) / 4)

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )


class AddErlangNoise(object):
    def __init__(self, a=0.0, b=0.0):
        if a == 0.0:
            self.std = 0.0
            self.mean = 0.0
        else:
            self.std = b / a
            self.mean = b / (2 * a)

    def __call__(self, tensor):
        if self.mean == 0.0:
            return tensor * self.mean
        else:
            return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )


class AddExponentialNoise(object):
    def __init__(self, a=0.0, b=0):
        if a == 0.0:
            self.mean = 0.0
        else:
            self.std = 1 / (2 * a)
            self.mean = 1 / a

    def __call__(self, tensor):
        if self.mean == 0.0:
            return tensor * self.mean
        else:
            return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )


class AddUniformNoise(object):
    def __init__(self, a=0.0, b=0.0):
        if a == 0.0:
            self.std = 0.0
            self.mean = 0.0
        else:
            self.std = (b - a) ** 2 / 12
            self.mean = (b + a) / 2

    def __call__(self, tensor):
        if self.mean == 0.0:
            return tensor * self.mean
        else:
            return tensor + (torch.randn(tensor.size()) * self.std + self.mean)

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )


class AddImpulseNoise(object):
    def __init__(self, a=0.0, b=0):
        self.value = a

    def __call__(self, tensor):
        if random.gauss(0, 1) > 0:
            return tensor * self.value
        elif random.gauss(0, 1) < 0:
            return tensor * (-1 * self.value)
        else:
            return tensor * 0.0

    def __repr__(self):
        return self.__class__.__name__ + "(a={0})".format(self.value)


class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)

        return sample, label


def extract_classes(dataset, classes):
    idx = torch.zeros_like(dataset.targets, dtype=torch.bool)
    for target in classes:
        idx = idx | (dataset.targets == target)

    data, targets = dataset.data[idx], dataset.targets[idx]
    return data, targets


def getDataset(dataset, noise=None, mean=0.0, std=0.0):
    """Function to get training datasets"""
    noise_type = None
    if noise is None:
        # print("No noise added")
        noise_type = AddNoNoise
    elif noise == "gaussian":
        noise_type = AddGaussianNoise
    elif noise == "raleigh":
        noise_type = AddRaleighNoise
    elif noise == "erlang":
        noise_type = AddErlangNoise
    elif noise == "exponential":
        noise_type = AddExponentialNoise
    elif noise == "uniform":
        noise_type = AddUniformNoise
    elif noise == "impulse":
        noise_type = AddImpulseNoise

    print(f"{noise_type} noise added")
    transform_split_mnist = transforms.Compose(
        [
            transforms.ToPILImage(),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            noise_type(mean, std),
        ]
    )

    transform_mnist = transforms.Compose(
        [
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            noise_type(mean, std),
        ]
    )

    transform_cifar = transforms.Compose(
        [
            transforms.Resize((32, 32)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            noise_type(mean, std),
        ]
    )

    if dataset == "CIFAR10":
        trainset = torchvision.datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform_cifar
        )
        testset = torchvision.datasets.CIFAR10(
            root="./data", train=False, download=True, transform=transform_cifar
        )
        num_classes = 10
        inputs = 3

    elif dataset == "CIFAR100":
        trainset = torchvision.datasets.CIFAR100(
            root="./data", train=True, download=True, transform=transform_cifar
        )
        testset = torchvision.datasets.CIFAR100(
            root="./data", train=False, download=True, transform=transform_cifar
        )
        num_classes = 100
        inputs = 3

    elif dataset == "MNIST":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )
        num_classes = 10
        inputs = 1

    elif dataset == "SplitMNIST-2.1":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )

        train_data, train_targets = extract_classes(trainset, [0, 1, 2, 3, 4])
        test_data, test_targets = extract_classes(testset, [0, 1, 2, 3, 4])

        trainset = CustomDataset(
            train_data, train_targets, transform=transform_split_mnist
        )
        testset = CustomDataset(
            test_data, test_targets, transform=transform_split_mnist
        )
        num_classes = 5
        inputs = 1

    elif dataset == "SplitMNIST-2.2":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )

        train_data, train_targets = extract_classes(trainset, [5, 6, 7, 8, 9])
        test_data, test_targets = extract_classes(testset, [5, 6, 7, 8, 9])
        train_targets -= 5  # Mapping target 5-9 to 0-4
        test_targets -= 5  # Hence, add 5 after prediction

        trainset = CustomDataset(
            train_data, train_targets, transform=transform_split_mnist
        )
        testset = CustomDataset(
            test_data, test_targets, transform=transform_split_mnist
        )
        num_classes = 5
        inputs = 1

    elif dataset == "SplitMNIST-5.1":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )

        train_data, train_targets = extract_classes(trainset, [0, 1])
        test_data, test_targets = extract_classes(testset, [0, 1])

        trainset = CustomDataset(
            train_data, train_targets, transform=transform_split_mnist
        )
        testset = CustomDataset(
            test_data, test_targets, transform=transform_split_mnist
        )
        num_classes = 2
        inputs = 1

    elif dataset == "SplitMNIST-5.2":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )

        train_data, train_targets = extract_classes(trainset, [2, 3])
        test_data, test_targets = extract_classes(testset, [2, 3])
        train_targets -= 2  # Mapping target 2-3 to 0-1
        test_targets -= 2  # Hence, add 2 after prediction

        trainset = CustomDataset(
            train_data, train_targets, transform=transform_split_mnist
        )
        testset = CustomDataset(
            test_data, test_targets, transform=transform_split_mnist
        )
        num_classes = 2
        inputs = 1

    elif dataset == "SplitMNIST-5.3":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )

        train_data, train_targets = extract_classes(trainset, [4, 5])
        test_data, test_targets = extract_classes(testset, [4, 5])
        train_targets -= 4  # Mapping target 4-5 to 0-1
        test_targets -= 4  # Hence, add 4 after prediction

        trainset = CustomDataset(
            train_data, train_targets, transform=transform_split_mnist
        )
        testset = CustomDataset(
            test_data, test_targets, transform=transform_split_mnist
        )
        num_classes = 2
        inputs = 1

    elif dataset == "SplitMNIST-5.4":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )

        train_data, train_targets = extract_classes(trainset, [6, 7])
        test_data, test_targets = extract_classes(testset, [6, 7])
        train_targets -= 6  # Mapping target 6-7 to 0-1
        test_targets -= 6  # Hence, add 6 after prediction

        trainset = CustomDataset(
            train_data, train_targets, transform=transform_split_mnist
        )
        testset = CustomDataset(
            test_data, test_targets, transform=transform_split_mnist
        )
        num_classes = 2
        inputs = 1

    elif dataset == "SplitMNIST-5.5":
        trainset = torchvision.datasets.MNIST(
            root="./data", train=True, download=True, transform=transform_mnist
        )
        testset = torchvision.datasets.MNIST(
            root="./data", train=False, download=True, transform=transform_mnist
        )

        train_data, train_targets = extract_classes(trainset, [8, 9])
        test_data, test_targets = extract_classes(testset, [8, 9])
        train_targets -= 8  # Mapping target 8-9 to 0-1
        test_targets -= 8  # Hence, add 8 after prediction

        trainset = CustomDataset(
            train_data, train_targets, transform=transform_split_mnist
        )
        testset = CustomDataset(
            test_data, test_targets, transform=transform_split_mnist
        )
        num_classes = 2
        inputs = 1

    return trainset, testset, inputs, num_classes


def getDataloader(trainset, testset, valid_size, batch_size, num_workers):
    num_train = len(trainset)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(valid_size * num_train))
    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers
    )
    valid_loader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers
    )
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, num_workers=num_workers
    )

    return train_loader, valid_loader, test_loader