Update for 04-17-22

This commit is contained in:
Eduardo Cueto-Mendoza 2022-04-17 17:18:21 +01:00
parent 33eae92941
commit c63d8e6dba
3 changed files with 56 additions and 10 deletions

View File

@ -1,4 +1,6 @@
############### Configuration file for Bayesian ############### ############### Configuration file for Bayesian ###############
import os
layer_type = 'lrt' # 'bbb' or 'lrt' layer_type = 'lrt' # 'bbb' or 'lrt'
activation_type = 'softplus' # 'softplus' or 'relu' activation_type = 'softplus' # 'softplus' or 'relu'
priors={ priors={
@ -16,4 +18,25 @@ batch_size = 256
train_ens = 1 train_ens = 1
valid_ens = 1 valid_ens = 1
beta_type = 0.1 # 'Blundell', 'Standard', etc. Use float for const value beta_type = 0.1 # 'Blundell', 'Standard', etc. Use float for const value
wide = 1
with open("bay", "r") as file:
bay = int(file.read())
if bay == 1:
with open("tmp", "r") as file:
wide = int(file.read())
if os.path.exists("tmp"):
os.remove("tmp")
else:
raise Exception("Tmp file not found")
print("Bayesian configured to run with width: {}".format(wide))
if os.path.exists("bay"):
os.remove("bay")
else:
raise Exception("Bay file not found")

View File

@ -1,7 +1,30 @@
############### Configuration file for Frequentist ############### ############### Configuration file for Frequentist ###############
n_epochs = 200
import os
n_epochs = 100
lr = 0.001 lr = 0.001
num_workers = 4 num_workers = 4
valid_size = 0.2 valid_size = 0.2
batch_size = 256 batch_size = 256
wide = 1
with open("frq", "r") as file:
frq = int(file.read())
if frq == 1:
with open("tmp", "r") as file:
wide = int(file.read())
if os.path.exists("tmp"):
os.remove("tmp")
else:
raise Exception("Tmp file not found")
print("Frequentist configured to run with width: {}".format(wide))
if os.path.exists("frq"):
os.remove("frq")
else:
raise Exception("Frq file not found")

View File

@ -103,16 +103,16 @@ def run(dataset, net_type):
#early_stop.append(valid_acc) #early_stop.append(valid_acc)
#if epoch % 4 == 0 and epoch > 0: #if epoch % 4 == 0 and epoch > 0:
#print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs)) # print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs))
#if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs: # if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs:
#break # break
#early_stop = [] # early_stop = []
if train_acc >= 0.99: #if train_acc >= 0.99:
break # break
#if gpu_sample_draw.total_watt_consumed() > 100000: #if gpu_sample_draw.total_watt_consumed() > 100000:
#break # break
# save model when finished # save model when finished
#if epoch == n_epochs: #if epoch == n_epochs: