Efficiency-of-Neural-Archit.../main_bayesian.py

178 lines
6.5 KiB
Python
Raw Normal View History

2022-04-16 12:20:44 +00:00
from __future__ import print_function
import os
import data
import utils
import torch
import pickle
import metrics
import argparse
import numpy as np
2023-06-07 06:51:07 +00:00
import amd_sample_draw
2022-04-16 12:20:44 +00:00
import config_bayesian as cfg
from datetime import datetime
from torch.nn import functional as F
from torch.optim import Adam, lr_scheduler
from models.BayesianModels.BayesianLeNet import BBBLeNet
from models.BayesianModels.BayesianAlexNet import BBBAlexNet
from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
2023-06-01 08:20:51 +00:00
from stopping_crit import earlyStopping, energyBound, accuracyBound
2022-04-16 12:20:44 +00:00
# CUDA settings
2023-06-01 08:20:51 +00:00
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
2022-04-16 12:20:44 +00:00
def getModel(net_type, inputs, outputs, priors, layer_type, activation_type):
if (net_type == 'lenet'):
return BBBLeNet(outputs, inputs, priors, layer_type, activation_type,wide=cfg.wide)
elif (net_type == 'alexnet'):
return BBBAlexNet(outputs, inputs, priors, layer_type, activation_type)
elif (net_type == '3conv3fc'):
return BBB3Conv3FC(outputs, inputs, priors, layer_type, activation_type)
else:
raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')
def train_model(net, optimizer, criterion, trainloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
net.train()
training_loss = 0.0
accs = []
kl_list = []
for i, (inputs, labels) in enumerate(trainloader, 1):
optimizer.zero_grad()
inputs, labels = inputs.to(device), labels.to(device)
outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
kl = 0.0
for j in range(num_ens):
net_out, _kl = net(inputs)
kl += _kl
outputs[:, :, j] = F.log_softmax(net_out, dim=1)
kl = kl / num_ens
kl_list.append(kl.item())
log_outputs = utils.logmeanexp(outputs, dim=2)
beta = metrics.get_beta(i-1, len(trainloader), beta_type, epoch, num_epochs)
loss = criterion(log_outputs, labels, kl, beta)
loss.backward()
optimizer.step()
accs.append(metrics.acc(log_outputs.data, labels))
training_loss += loss.cpu().data.numpy()
return training_loss/len(trainloader), np.mean(accs), np.mean(kl_list)
def validate_model(net, criterion, validloader, num_ens=1, beta_type=0.1, epoch=None, num_epochs=None):
"""Calculate ensemble accuracy and NLL Loss"""
net.train()
valid_loss = 0.0
accs = []
for i, (inputs, labels) in enumerate(validloader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
kl = 0.0
for j in range(num_ens):
net_out, _kl = net(inputs)
kl += _kl
outputs[:, :, j] = F.log_softmax(net_out, dim=1).data
log_outputs = utils.logmeanexp(outputs, dim=2)
beta = metrics.get_beta(i-1, len(validloader), beta_type, epoch, num_epochs)
valid_loss += criterion(log_outputs, labels, kl, beta).item()
accs.append(metrics.acc(log_outputs, labels))
return valid_loss/len(validloader), np.mean(accs)
def run(dataset, net_type):
# Hyper Parameter settings
layer_type = cfg.layer_type
activation_type = cfg.activation_type
priors = cfg.priors
train_ens = cfg.train_ens
valid_ens = cfg.valid_ens
n_epochs = cfg.n_epochs
lr_start = cfg.lr_start
num_workers = cfg.num_workers
valid_size = cfg.valid_size
batch_size = cfg.batch_size
beta_type = cfg.beta_type
trainset, testset, inputs, outputs = data.getDataset(dataset)
train_loader, valid_loader, test_loader = data.getDataloader(
trainset, testset, valid_size, batch_size, num_workers)
net = getModel(net_type, inputs, outputs, priors, layer_type, activation_type).to(device)
ckpt_dir = f'checkpoints/{dataset}/bayesian'
ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}_{activation_type}_{cfg.wide}.pt'
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir, exist_ok=True)
2023-06-01 08:20:51 +00:00
with open("stp", "r") as file:
stp = int(file.read())
with open("sav", "r") as file:
sav = int(file.read())
2022-04-16 12:20:44 +00:00
criterion = metrics.ELBO(len(trainset)).to(device)
optimizer = Adam(net.parameters(), lr=lr_start)
lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
#valid_loss_max = np.Inf
2023-06-07 06:51:07 +00:00
#if stp == 2:
early_stop = []
2022-04-16 12:20:44 +00:00
train_data = []
for epoch in range(n_epochs): # loop over the dataset multiple times
train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens, beta_type=beta_type, epoch=epoch, num_epochs=n_epochs)
lr_sched.step(valid_loss)
train_data.append([epoch,train_loss,train_acc,valid_loss,valid_acc])
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))
2023-06-01 08:20:51 +00:00
if stp == 2:
2023-06-07 06:51:07 +00:00
#print('Using early stopping')
if earlyStopping(early_stop,train_acc,epoch,cfg.sens) == 1:
2023-06-01 08:20:51 +00:00
break
elif stp == 3:
2023-06-07 06:51:07 +00:00
#print('Using energy bound')
if energyBound(cfg.energy_thrs) == 1:
2023-06-01 08:20:51 +00:00
break
elif stp == 4:
2023-06-07 06:51:07 +00:00
#print('Using accuracy bound')
if accuracyBound(cfg.acc_thrs) == 1:
2023-06-01 08:20:51 +00:00
break
else:
print('Training for {} epochs'.format(cfg.n_epochs))
if sav == 1:
# save model when finished
2023-06-07 06:51:07 +00:00
if epoch == cfg.n_epochs-1:
2023-06-01 08:20:51 +00:00
torch.save(net.state_dict(), ckpt_name)
2022-04-16 12:20:44 +00:00
2023-06-07 06:51:07 +00:00
2022-04-16 12:20:44 +00:00
with open("bayes_exp_data_"+str(cfg.wide)+".pkl", 'wb') as f:
pickle.dump(train_data, f)
if __name__ == '__main__':
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Initial Time =", current_time)
parser = argparse.ArgumentParser(description = "PyTorch Bayesian Model Training")
parser.add_argument('--net_type', default='lenet', type=str, help='model')
parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset = [MNIST/CIFAR10/CIFAR100]')
args = parser.parse_args()
run(args.dataset, args.net_type)
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Final Time =", current_time)