diff --git a/config_bayesian.py b/config_bayesian.py index 03cd9dd..07887cb 100755 --- a/config_bayesian.py +++ b/config_bayesian.py @@ -1,4 +1,6 @@ ############### Configuration file for Bayesian ############### + +import os layer_type = 'lrt' # 'bbb' or 'lrt' activation_type = 'softplus' # 'softplus' or 'relu' priors={ @@ -16,4 +18,25 @@ batch_size = 256 train_ens = 1 valid_ens = 1 beta_type = 0.1 # 'Blundell', 'Standard', etc. Use float for const value -wide = 1 + + +with open("bay", "r") as file: + bay = int(file.read()) + +if bay == 1: + with open("tmp", "r") as file: + wide = int(file.read()) + + if os.path.exists("tmp"): + os.remove("tmp") + else: + raise Exception("Tmp file not found") + + print("Bayesian configured to run with width: {}".format(wide)) + + +if os.path.exists("bay"): + os.remove("bay") +else: + raise Exception("Bay file not found") + \ No newline at end of file diff --git a/config_frequentist.py b/config_frequentist.py index f8242bd..431a594 100755 --- a/config_frequentist.py +++ b/config_frequentist.py @@ -1,7 +1,30 @@ ############### Configuration file for Frequentist ############### -n_epochs = 200 + +import os +n_epochs = 100 lr = 0.001 num_workers = 4 valid_size = 0.2 batch_size = 256 -wide = 1 + +with open("frq", "r") as file: + frq = int(file.read()) + +if frq == 1: + with open("tmp", "r") as file: + wide = int(file.read()) + + if os.path.exists("tmp"): + os.remove("tmp") + else: + raise Exception("Tmp file not found") + + print("Frequentist configured to run with width: {}".format(wide)) + + + +if os.path.exists("frq"): + os.remove("frq") +else: + raise Exception("Frq file not found") + diff --git a/main_frequentist.py b/main_frequentist.py index b834065..f5da58a 100755 --- a/main_frequentist.py +++ b/main_frequentist.py @@ -103,16 +103,16 @@ def run(dataset, net_type): #early_stop.append(valid_acc) #if epoch % 4 == 0 and epoch > 0: - #print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs)) - #if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs: - #break - #early_stop = [] + # print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs)) + # if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs: + # break + # early_stop = [] - if train_acc >= 0.99: - break + #if train_acc >= 0.99: + # break #if gpu_sample_draw.total_watt_consumed() > 100000: - #break + # break # save model when finished #if epoch == n_epochs: