2024-05-10 09:59:24 +00:00
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
from datetime import datetime
|
2025-01-15 10:26:48 +00:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2024-05-10 09:59:24 +00:00
|
|
|
from torch.nn import functional as F
|
|
|
|
from torch.optim import Adam, lr_scheduler
|
2025-01-15 10:26:48 +00:00
|
|
|
|
|
|
|
import data
|
|
|
|
import metrics
|
|
|
|
import utils
|
2024-05-10 09:59:24 +00:00
|
|
|
from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
|
2025-01-15 10:26:48 +00:00
|
|
|
from models.BayesianModels.BayesianAlexNet import BBBAlexNet
|
|
|
|
from models.BayesianModels.BayesianLeNet import BBBLeNet
|
2025-01-29 11:26:17 +00:00
|
|
|
from stopping_crit import accuracy_bound, e_stop, efficiency_stop, energy_bound
|
2024-05-10 09:59:24 +00:00
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
with open("configuration.pkl", "rb") as file:
|
2024-05-10 09:59:24 +00:00
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
cfg = pickle.load(file)
|
|
|
|
except EOFError:
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
# CUDA settings
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
def getModel(net_type, inputs, outputs, priors, layer_type, activation_type):
|
2025-01-15 10:26:48 +00:00
|
|
|
if net_type == "lenet":
|
|
|
|
return BBBLeNet(
|
|
|
|
outputs,
|
|
|
|
inputs,
|
|
|
|
priors,
|
|
|
|
layer_type,
|
|
|
|
activation_type,
|
|
|
|
wide=cfg["model"]["size"],
|
|
|
|
)
|
|
|
|
elif net_type == "alexnet":
|
2024-05-10 09:59:24 +00:00
|
|
|
return BBBAlexNet(outputs, inputs, priors, layer_type, activation_type)
|
2025-01-15 10:26:48 +00:00
|
|
|
elif net_type == "3conv3fc":
|
|
|
|
return BBB3Conv3FC(outputs, inputs, priors, layer_type, activation_type)
|
2024-05-10 09:59:24 +00:00
|
|
|
else:
|
2025-01-15 10:26:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
"Network should be either [LeNet / AlexNet\
|
|
|
|
/ 3Conv3FC"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(
|
|
|
|
net,
|
|
|
|
optimizer,
|
|
|
|
criterion,
|
|
|
|
trainloader,
|
|
|
|
num_ens=1,
|
|
|
|
beta_type=0.1,
|
|
|
|
epoch=None,
|
|
|
|
num_epochs=None,
|
|
|
|
):
|
2024-05-10 09:59:24 +00:00
|
|
|
net.train()
|
|
|
|
training_loss = 0.0
|
|
|
|
accs = []
|
|
|
|
kl_list = []
|
|
|
|
for i, (inputs, labels) in enumerate(trainloader, 1):
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
2025-01-15 10:26:48 +00:00
|
|
|
outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
|
2024-05-10 09:59:24 +00:00
|
|
|
|
|
|
|
kl = 0.0
|
|
|
|
for j in range(num_ens):
|
|
|
|
net_out, _kl = net(inputs)
|
|
|
|
kl += _kl
|
|
|
|
outputs[:, :, j] = F.log_softmax(net_out, dim=1)
|
|
|
|
|
|
|
|
kl = kl / num_ens
|
|
|
|
kl_list.append(kl.item())
|
|
|
|
log_outputs = utils.logmeanexp(outputs, dim=2)
|
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
beta = metrics.get_beta(i - 1, len(trainloader), beta_type, epoch, num_epochs)
|
2024-05-10 09:59:24 +00:00
|
|
|
loss = criterion(log_outputs, labels, kl, beta)
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
accs.append(metrics.acc(log_outputs.data, labels))
|
|
|
|
training_loss += loss.cpu().data.numpy()
|
2025-01-15 10:26:48 +00:00
|
|
|
return training_loss / len(trainloader), np.mean(accs), np.mean(kl_list)
|
2024-05-10 09:59:24 +00:00
|
|
|
|
|
|
|
|
2025-01-29 11:26:17 +00:00
|
|
|
def validate_model(net, criterion, validloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
|
2024-05-10 09:59:24 +00:00
|
|
|
"""Calculate ensemble accuracy and NLL Loss"""
|
|
|
|
net.train()
|
|
|
|
valid_loss = 0.0
|
|
|
|
accs = []
|
|
|
|
|
|
|
|
for i, (inputs, labels) in enumerate(validloader):
|
|
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
2025-01-15 10:26:48 +00:00
|
|
|
outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
|
2024-05-10 09:59:24 +00:00
|
|
|
kl = 0.0
|
|
|
|
for j in range(num_ens):
|
|
|
|
net_out, _kl = net(inputs)
|
|
|
|
kl += _kl
|
|
|
|
outputs[:, :, j] = F.log_softmax(net_out, dim=1).data
|
|
|
|
|
|
|
|
log_outputs = utils.logmeanexp(outputs, dim=2)
|
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
beta = metrics.get_beta(i - 1, len(validloader), beta_type, epoch, num_epochs)
|
2024-05-10 09:59:24 +00:00
|
|
|
valid_loss += criterion(log_outputs, labels, kl, beta).item()
|
|
|
|
accs.append(metrics.acc(log_outputs, labels))
|
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
return valid_loss / len(validloader), np.mean(accs)
|
2024-05-10 09:59:24 +00:00
|
|
|
|
|
|
|
|
|
|
|
def run(dataset, net_type):
|
2025-01-29 11:26:17 +00:00
|
|
|
# Noise applied to dataset
|
|
|
|
noise_type = cfg["noise_type"]
|
|
|
|
mean = 0.5
|
|
|
|
std = 0.5
|
2024-05-10 09:59:24 +00:00
|
|
|
# Hyper Parameter settings
|
|
|
|
layer_type = cfg["model"]["layer_type"]
|
|
|
|
activation_type = cfg["model"]["activation_type"]
|
|
|
|
priors = cfg["model"]["priors"]
|
|
|
|
|
|
|
|
train_ens = cfg["model"]["train_ens"]
|
|
|
|
valid_ens = cfg["model"]["valid_ens"]
|
|
|
|
n_epochs = cfg["model"]["n_epochs"]
|
|
|
|
lr_start = cfg["model"]["lr"]
|
|
|
|
num_workers = cfg["model"]["num_workers"]
|
|
|
|
valid_size = cfg["model"]["valid_size"]
|
|
|
|
batch_size = cfg["model"]["batch_size"]
|
|
|
|
beta_type = cfg["model"]["beta_type"]
|
|
|
|
|
2025-01-29 11:26:17 +00:00
|
|
|
trainset, testset, inputs, outputs = data.getDataset(dataset, noise_type, mean=mean, std=std)
|
2024-05-10 09:59:24 +00:00
|
|
|
train_loader, valid_loader, test_loader = data.getDataloader(
|
2025-01-15 10:26:48 +00:00
|
|
|
trainset, testset, valid_size, batch_size, num_workers
|
|
|
|
)
|
2025-01-29 11:26:17 +00:00
|
|
|
net = getModel(net_type, inputs, outputs, priors, layer_type, activation_type).to(device)
|
2024-05-10 09:59:24 +00:00
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
ckpt_dir = f"checkpoints/{dataset}/bayesian"
|
2024-05-10 09:59:24 +00:00
|
|
|
|
|
|
|
if not os.path.exists(ckpt_dir):
|
|
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
|
|
|
|
|
|
stp = cfg["stopping_crit"]
|
|
|
|
sav = cfg["save"]
|
|
|
|
|
|
|
|
criterion = metrics.ELBO(len(trainset)).to(device)
|
|
|
|
optimizer = Adam(net.parameters(), lr=lr_start)
|
2025-01-15 10:26:48 +00:00
|
|
|
lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
|
2024-05-10 09:59:24 +00:00
|
|
|
# valid_loss_max = np.Inf
|
|
|
|
# if stp == 2:
|
|
|
|
early_stop = []
|
|
|
|
train_data = []
|
|
|
|
for epoch in range(n_epochs): # loop over the dataset multiple times
|
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
train_loss, train_acc, train_kl = train_model(
|
|
|
|
net,
|
|
|
|
optimizer,
|
|
|
|
criterion,
|
|
|
|
train_loader,
|
|
|
|
num_ens=train_ens,
|
|
|
|
beta_type=beta_type,
|
|
|
|
epoch=epoch,
|
|
|
|
num_epochs=n_epochs,
|
|
|
|
)
|
|
|
|
valid_loss, valid_acc = validate_model(
|
|
|
|
net,
|
|
|
|
criterion,
|
|
|
|
valid_loader,
|
|
|
|
num_ens=valid_ens,
|
|
|
|
beta_type=beta_type,
|
|
|
|
epoch=epoch,
|
|
|
|
num_epochs=n_epochs,
|
|
|
|
)
|
2024-05-10 09:59:24 +00:00
|
|
|
lr_sched.step(valid_loss)
|
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
train_data.append([epoch, train_loss, train_acc, valid_loss, valid_acc])
|
|
|
|
print(
|
|
|
|
"Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy:\
|
2024-05-10 09:59:24 +00:00
|
|
|
{:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy:\
|
2025-01-15 10:26:48 +00:00
|
|
|
{:.4f} \ttrain_kl_div: {:.4f}".format(
|
|
|
|
epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl
|
|
|
|
)
|
|
|
|
)
|
2024-05-10 09:59:24 +00:00
|
|
|
|
2025-01-29 11:26:17 +00:00
|
|
|
ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}_{activation_type}_{cfg["model"]["size"]}_epoch_{epoch}_noise_{noise_type}.pt'
|
|
|
|
if sav == 1:
|
|
|
|
torch.save(net.state_dict(), ckpt_name)
|
|
|
|
|
2024-05-10 09:59:24 +00:00
|
|
|
if stp == 2:
|
2025-01-15 10:26:48 +00:00
|
|
|
# print("Using early stopping")
|
|
|
|
if e_stop(early_stop, valid_acc, epoch + 1, 2, cfg["model"]["sens"]) == 1:
|
2024-05-10 09:59:24 +00:00
|
|
|
break
|
|
|
|
elif stp == 3:
|
2025-01-15 10:26:48 +00:00
|
|
|
# print("Using energy bound")
|
|
|
|
if energy_bound(cfg["model"]["energy_thrs"]) == 1:
|
2024-05-10 09:59:24 +00:00
|
|
|
break
|
|
|
|
elif stp == 4:
|
2025-01-29 11:26:17 +00:00
|
|
|
if dataset == "MNIST":
|
|
|
|
# print("Using accuracy bound")
|
|
|
|
if accuracy_bound(train_acc, 0.99) == 1:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
# print("Using accuracy bound")
|
|
|
|
if accuracy_bound(train_acc, 0.50) == 1:
|
|
|
|
break
|
|
|
|
elif stp == 5:
|
|
|
|
# print("Using efficiency stoping")
|
|
|
|
if efficiency_stop(net, train_acc, batch_size, 0.002) == 1:
|
2024-05-10 09:59:24 +00:00
|
|
|
break
|
|
|
|
else:
|
2025-01-29 11:26:17 +00:00
|
|
|
print(f"Training for {cfg['model']['n_epochs']} epochs")
|
2024-05-10 09:59:24 +00:00
|
|
|
|
2025-01-29 11:26:17 +00:00
|
|
|
with open("bayes_exp_data_" + str(cfg["model"]["size"]) + f"_{dataset}" + ".pkl", "wb") as f:
|
2024-05-10 09:59:24 +00:00
|
|
|
pickle.dump(train_data, f)
|
|
|
|
|
|
|
|
|
2025-01-15 10:26:48 +00:00
|
|
|
if __name__ == "__main__":
|
2024-05-10 09:59:24 +00:00
|
|
|
now = datetime.now()
|
|
|
|
current_time = now.strftime("%H:%M:%S")
|
|
|
|
print("Initial Time =", current_time)
|
2025-01-15 10:26:48 +00:00
|
|
|
print(f"Using bayesian model of size: {cfg["model"]["size"]}")
|
2024-05-10 09:59:24 +00:00
|
|
|
run(cfg["data"], cfg["model"]["net_type"])
|
|
|
|
now = datetime.now()
|
|
|
|
current_time = now.strftime("%H:%M:%S")
|
|
|
|
print("Final Time =", current_time)
|