Update for 04-17-22

This commit is contained in:
Eduardo Cueto-Mendoza 2022-04-17 17:18:21 +01:00
parent 33eae92941
commit c63d8e6dba
3 changed files with 56 additions and 10 deletions

View File

@ -1,4 +1,6 @@
############### Configuration file for Bayesian ###############
import os
layer_type = 'lrt' # 'bbb' or 'lrt'
activation_type = 'softplus' # 'softplus' or 'relu'
priors={
@ -16,4 +18,25 @@ batch_size = 256
train_ens = 1
valid_ens = 1
beta_type = 0.1 # 'Blundell', 'Standard', etc. Use float for const value
wide = 1
with open("bay", "r") as file:
bay = int(file.read())
if bay == 1:
with open("tmp", "r") as file:
wide = int(file.read())
if os.path.exists("tmp"):
os.remove("tmp")
else:
raise Exception("Tmp file not found")
print("Bayesian configured to run with width: {}".format(wide))
if os.path.exists("bay"):
os.remove("bay")
else:
raise Exception("Bay file not found")

View File

@ -1,7 +1,30 @@
############### Configuration file for Frequentist ###############
n_epochs = 200
import os
n_epochs = 100
lr = 0.001
num_workers = 4
valid_size = 0.2
batch_size = 256
wide = 1
with open("frq", "r") as file:
frq = int(file.read())
if frq == 1:
with open("tmp", "r") as file:
wide = int(file.read())
if os.path.exists("tmp"):
os.remove("tmp")
else:
raise Exception("Tmp file not found")
print("Frequentist configured to run with width: {}".format(wide))
if os.path.exists("frq"):
os.remove("frq")
else:
raise Exception("Frq file not found")

View File

@ -103,16 +103,16 @@ def run(dataset, net_type):
#early_stop.append(valid_acc)
#if epoch % 4 == 0 and epoch > 0:
#print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs))
#if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs:
#break
#early_stop = []
# print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs))
# if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs:
# break
# early_stop = []
if train_acc >= 0.99:
break
#if train_acc >= 0.99:
# break
#if gpu_sample_draw.total_watt_consumed() > 100000:
#break
# break
# save model when finished
#if epoch == n_epochs: