Update for 04-17-22
This commit is contained in:
parent
33eae92941
commit
c63d8e6dba
|
@ -1,4 +1,6 @@
|
||||||
############### Configuration file for Bayesian ###############
|
############### Configuration file for Bayesian ###############
|
||||||
|
|
||||||
|
import os
|
||||||
layer_type = 'lrt' # 'bbb' or 'lrt'
|
layer_type = 'lrt' # 'bbb' or 'lrt'
|
||||||
activation_type = 'softplus' # 'softplus' or 'relu'
|
activation_type = 'softplus' # 'softplus' or 'relu'
|
||||||
priors={
|
priors={
|
||||||
|
@ -16,4 +18,25 @@ batch_size = 256
|
||||||
train_ens = 1
|
train_ens = 1
|
||||||
valid_ens = 1
|
valid_ens = 1
|
||||||
beta_type = 0.1 # 'Blundell', 'Standard', etc. Use float for const value
|
beta_type = 0.1 # 'Blundell', 'Standard', etc. Use float for const value
|
||||||
wide = 1
|
|
||||||
|
|
||||||
|
with open("bay", "r") as file:
|
||||||
|
bay = int(file.read())
|
||||||
|
|
||||||
|
if bay == 1:
|
||||||
|
with open("tmp", "r") as file:
|
||||||
|
wide = int(file.read())
|
||||||
|
|
||||||
|
if os.path.exists("tmp"):
|
||||||
|
os.remove("tmp")
|
||||||
|
else:
|
||||||
|
raise Exception("Tmp file not found")
|
||||||
|
|
||||||
|
print("Bayesian configured to run with width: {}".format(wide))
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.exists("bay"):
|
||||||
|
os.remove("bay")
|
||||||
|
else:
|
||||||
|
raise Exception("Bay file not found")
|
||||||
|
|
|
@ -1,7 +1,30 @@
|
||||||
############### Configuration file for Frequentist ###############
|
############### Configuration file for Frequentist ###############
|
||||||
n_epochs = 200
|
|
||||||
|
import os
|
||||||
|
n_epochs = 100
|
||||||
lr = 0.001
|
lr = 0.001
|
||||||
num_workers = 4
|
num_workers = 4
|
||||||
valid_size = 0.2
|
valid_size = 0.2
|
||||||
batch_size = 256
|
batch_size = 256
|
||||||
wide = 1
|
|
||||||
|
with open("frq", "r") as file:
|
||||||
|
frq = int(file.read())
|
||||||
|
|
||||||
|
if frq == 1:
|
||||||
|
with open("tmp", "r") as file:
|
||||||
|
wide = int(file.read())
|
||||||
|
|
||||||
|
if os.path.exists("tmp"):
|
||||||
|
os.remove("tmp")
|
||||||
|
else:
|
||||||
|
raise Exception("Tmp file not found")
|
||||||
|
|
||||||
|
print("Frequentist configured to run with width: {}".format(wide))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.exists("frq"):
|
||||||
|
os.remove("frq")
|
||||||
|
else:
|
||||||
|
raise Exception("Frq file not found")
|
||||||
|
|
||||||
|
|
|
@ -103,16 +103,16 @@ def run(dataset, net_type):
|
||||||
|
|
||||||
#early_stop.append(valid_acc)
|
#early_stop.append(valid_acc)
|
||||||
#if epoch % 4 == 0 and epoch > 0:
|
#if epoch % 4 == 0 and epoch > 0:
|
||||||
#print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs))
|
# print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs))
|
||||||
#if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs:
|
# if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs:
|
||||||
#break
|
# break
|
||||||
#early_stop = []
|
# early_stop = []
|
||||||
|
|
||||||
if train_acc >= 0.99:
|
#if train_acc >= 0.99:
|
||||||
break
|
# break
|
||||||
|
|
||||||
#if gpu_sample_draw.total_watt_consumed() > 100000:
|
#if gpu_sample_draw.total_watt_consumed() > 100000:
|
||||||
#break
|
# break
|
||||||
|
|
||||||
# save model when finished
|
# save model when finished
|
||||||
#if epoch == n_epochs:
|
#if epoch == n_epochs:
|
||||||
|
|
Loading…
Reference in New Issue