Compare commits
1 Commits
main
...
entropy-te
Author | SHA1 | Date |
---|---|---|
Eduardo Cueto-Mendoza | 931cb79e42 |
|
@ -1,4 +1,4 @@
|
||||||
Copyright (c) 2024 TastyPancakes.
|
Copyright (c) 2024 Eduardo Cueto-Mendoza.
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
# Energy efficiency comparison
|
# Energy efficiency comparison
|
||||||
|
|
||||||
This experiment compares a Frequentist CNN model against a Bayesian CNN model
|
This experiment compares a Frequentist CNN model against a Bayesian CNN model
|
||||||
|
|
||||||
|
## Example run command
|
||||||
|
|
||||||
|
python run_service.py -f1 -s -e && sleep 60 && python run_service.py -f2 -s -e && sleep 60 && python run_service.py -f3 -s -e && sleep 60 && python run_service.py -f4 -s -e && sleep 60 && python run_service.py -f5 -s -e && sleep 60 & python run_service.py -f6 -s -e && sleep 60 && python run_service.py -f7 -s -e && sleep 60 && python run_service.py -b1 -s -e && sleep 60 && python run_service.py -b2 -s -e && sleep 60 && python run_service.py -b3 -s -e && sleep 60 && python run_service.py -b4 -s -e && sleep 60 && python run_service.py -b5 -s -e && sleep 60 && python run_service.py -b6 -s -e && sleep 60 && python run_service.py -b7 -s -e && sleep 60
|
||||||
|
|
|
@ -22,6 +22,6 @@ def makeArguments(arguments: ArgumentParser) -> dict:
|
||||||
help="Save model")
|
help="Save model")
|
||||||
all_args.add_argument('--net_type', default='lenet', type=str,
|
all_args.add_argument('--net_type', default='lenet', type=str,
|
||||||
help='model = [lenet/AlexNet/3Conv3FC]')
|
help='model = [lenet/AlexNet/3Conv3FC]')
|
||||||
all_args.add_argument('--dataset', default='CIFAR10', type=str,
|
all_args.add_argument('--dataset', default='MNIST', type=str,
|
||||||
help='dataset = [MNIST/CIFAR10/CIFAR100]')
|
help='dataset = [MNIST/CIFAR10/CIFAR100]')
|
||||||
return vars(all_args.parse_args())
|
return vars(all_args.parse_args())
|
||||||
|
|
|
@ -4,7 +4,7 @@ import os
|
||||||
import data
|
import data
|
||||||
import utils
|
import utils
|
||||||
import torch
|
import torch
|
||||||
import pickle
|
# import pickle
|
||||||
import metrics
|
import metrics
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -15,12 +15,41 @@ from models.BayesianModels.BayesianAlexNet import BBBAlexNet
|
||||||
from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
|
from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
|
||||||
from stopping_crit import earlyStopping, energyBound, accuracyBound
|
from stopping_crit import earlyStopping, energyBound, accuracyBound
|
||||||
|
|
||||||
with (open("configuration.pkl", "rb")) as file:
|
# with (open("configuration.pkl", "rb")) as file:
|
||||||
while True:
|
# while True:
|
||||||
try:
|
# try:
|
||||||
cfg = pickle.load(file)
|
# cfg = pickle.load(file)
|
||||||
except EOFError:
|
# except EOFError:
|
||||||
break
|
# break
|
||||||
|
|
||||||
|
cfg = {
|
||||||
|
"model": {"net_type": "lenet", "type": "bayes", "size": 1,
|
||||||
|
"layer_type": "lrt", "activation_type": "softplus",
|
||||||
|
"priors": {
|
||||||
|
'prior_mu': 0,
|
||||||
|
'prior_sigma': 0.1,
|
||||||
|
'posterior_mu_initial': (0, 0.1), # (mean,std) normal_
|
||||||
|
'posterior_rho_initial': (-5, 0.1), # (mean,std) normal_
|
||||||
|
},
|
||||||
|
"n_epochs": 100,
|
||||||
|
"sens": 1e-9,
|
||||||
|
"energy_thrs": 100000,
|
||||||
|
"acc_thrs": 0.99,
|
||||||
|
"lr": 0.001,
|
||||||
|
"num_workers": 4,
|
||||||
|
"valid_size": 0.2,
|
||||||
|
"batch_size": 256,
|
||||||
|
"train_ens": 1,
|
||||||
|
"valid_ens": 1,
|
||||||
|
"beta_type": 0.1, # 'Blundell','Standard',etc.
|
||||||
|
# Use float for const value
|
||||||
|
},
|
||||||
|
#"data": "CIFAR10",
|
||||||
|
"data": "MNIST",
|
||||||
|
"stopping_crit": 1,
|
||||||
|
"save": 1,
|
||||||
|
"pickle_path": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# CUDA settings
|
# CUDA settings
|
||||||
|
@ -126,8 +155,7 @@ def run(dataset, net_type):
|
||||||
activation_type).to(device)
|
activation_type).to(device)
|
||||||
|
|
||||||
ckpt_dir = f'checkpoints/{dataset}/bayesian'
|
ckpt_dir = f'checkpoints/{dataset}/bayesian'
|
||||||
ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}\
|
ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}_{activation_type}_{cfg["model"]["size"]}'
|
||||||
_{activation_type}_{cfg["model"]["size"]}.pt'
|
|
||||||
|
|
||||||
if not os.path.exists(ckpt_dir):
|
if not os.path.exists(ckpt_dir):
|
||||||
os.makedirs(ckpt_dir, exist_ok=True)
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||||||
|
@ -178,18 +206,23 @@ def run(dataset, net_type):
|
||||||
break
|
break
|
||||||
elif stp == 4:
|
elif stp == 4:
|
||||||
print('Using accuracy bound')
|
print('Using accuracy bound')
|
||||||
if accuracyBound(train_acc, cfg.acc_thrs) == 1:
|
if accuracyBound(train_acc, cfg["model"]["acc_thrs"]) == 1:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print('Training for {} epochs'.format(cfg["model"]["n_epochs"]))
|
print('Training for {} epochs'.format(cfg["model"]["n_epochs"]))
|
||||||
|
|
||||||
if sav == 1:
|
if sav == 1:
|
||||||
# save model when finished
|
# save model when finished
|
||||||
if epoch == cfg.n_epochs-1:
|
# if epoch == cfg["model"]["n_epochs"]-1:
|
||||||
torch.save(net.state_dict(), ckpt_name)
|
torch.save({
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': net.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': train_loss,
|
||||||
|
}, ckpt_name + '_epoch_{}.pt'.format(epoch))
|
||||||
|
|
||||||
with open("bayes_exp_data_"+str(cfg["model"]["size"])+".pkl", 'wb') as f:
|
# with open("bayes_exp_data_"+str(cfg["model"]["size"])+".pkl", 'wb') as f:
|
||||||
pickle.dump(train_data, f)
|
# pickle.dump(train_data, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import print_function
|
||||||
import os
|
import os
|
||||||
import data
|
import data
|
||||||
import torch
|
import torch
|
||||||
import pickle
|
# import pickle
|
||||||
import metrics
|
import metrics
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -13,12 +13,41 @@ from models.NonBayesianModels.AlexNet import AlexNet
|
||||||
from stopping_crit import earlyStopping, energyBound, accuracyBound
|
from stopping_crit import earlyStopping, energyBound, accuracyBound
|
||||||
from models.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC
|
from models.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC
|
||||||
|
|
||||||
with (open("configuration.pkl", "rb")) as file:
|
# with (open("configuration.pkl", "rb")) as file:
|
||||||
while True:
|
# while True:
|
||||||
try:
|
# try:
|
||||||
cfg = pickle.load(file)
|
# cfg = pickle.load(file)
|
||||||
except EOFError:
|
# except EOFError:
|
||||||
break
|
# break
|
||||||
|
|
||||||
|
cfg = {
|
||||||
|
"model": {"net_type": "lenet", "type": "freq", "size": 1,
|
||||||
|
"layer_type": "lrt", "activation_type": "softplus",
|
||||||
|
"priors": {
|
||||||
|
'prior_mu': 0,
|
||||||
|
'prior_sigma': 0.1,
|
||||||
|
'posterior_mu_initial': (0, 0.1), # (mean,std) normal_
|
||||||
|
'posterior_rho_initial': (-5, 0.1), # (mean,std) normal_
|
||||||
|
},
|
||||||
|
"n_epochs": 100,
|
||||||
|
"sens": 1e-9,
|
||||||
|
"energy_thrs": 100000,
|
||||||
|
"acc_thrs": 0.99,
|
||||||
|
"lr": 0.001,
|
||||||
|
"num_workers": 4,
|
||||||
|
"valid_size": 0.2,
|
||||||
|
"batch_size": 256,
|
||||||
|
"train_ens": 1,
|
||||||
|
"valid_ens": 1,
|
||||||
|
"beta_type": 0.1, # 'Blundell','Standard',etc.
|
||||||
|
# Use float for const value
|
||||||
|
},
|
||||||
|
#"data": "CIFAR10",
|
||||||
|
"data": "MNIST",
|
||||||
|
"stopping_crit": 1,
|
||||||
|
"save": 1,
|
||||||
|
"pickle_path": None,
|
||||||
|
}
|
||||||
|
|
||||||
# CUDA settings
|
# CUDA settings
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
@ -80,8 +109,7 @@ def run(dataset, net_type):
|
||||||
net = getModel(net_type, inputs, outputs).to(device)
|
net = getModel(net_type, inputs, outputs).to(device)
|
||||||
|
|
||||||
ckpt_dir = f'checkpoints/{dataset}/frequentist'
|
ckpt_dir = f'checkpoints/{dataset}/frequentist'
|
||||||
ckpt_name = f'checkpoints/{dataset}/frequentist/model\
|
ckpt_name = f'checkpoints/{dataset}/frequentist/model_{net_type}_{cfg["model"]["size"]}'
|
||||||
_{net_type}_{cfg["model"]["size"]}.pt'
|
|
||||||
|
|
||||||
if not os.path.exists(ckpt_dir):
|
if not os.path.exists(ckpt_dir):
|
||||||
os.makedirs(ckpt_dir, exist_ok=True)
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||||||
|
@ -132,11 +160,17 @@ def run(dataset, net_type):
|
||||||
|
|
||||||
if sav == 1:
|
if sav == 1:
|
||||||
# save model when finished
|
# save model when finished
|
||||||
if epoch == n_epochs:
|
# if epoch == n_epochs:
|
||||||
torch.save(net.state_dict(), ckpt_name)
|
# torch.save(net.state_dict(), ckpt_name)
|
||||||
|
torch.save({
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': net.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': train_loss,
|
||||||
|
}, ckpt_name + '_epoch_{}.pt'.format(epoch))
|
||||||
|
|
||||||
with open("freq_exp_data_"+str(cfg["model"]["size"])+".pkl", 'wb') as f:
|
# with open("freq_exp_data_"+str(cfg["model"]["size"])+".pkl", 'wb') as f:
|
||||||
pickle.dump(train_data, f)
|
# pickle.dump(train_data, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -24,7 +24,7 @@ cfg = {
|
||||||
},
|
},
|
||||||
"n_epochs": 100,
|
"n_epochs": 100,
|
||||||
"sens": 1e-9,
|
"sens": 1e-9,
|
||||||
"energy_thrs": 10000,
|
"energy_thrs": 100000,
|
||||||
"acc_thrs": 0.99,
|
"acc_thrs": 0.99,
|
||||||
"lr": 0.001,
|
"lr": 0.001,
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
|
|
Loading…
Reference in New Issue