138 lines
4.7 KiB
Python
138 lines
4.7 KiB
Python
|
from __future__ import print_function
|
||
|
import os
|
||
|
import data
|
||
|
import torch
|
||
|
#import utils
|
||
|
import pickle
|
||
|
import metrics
|
||
|
import argparse
|
||
|
import numpy as np
|
||
|
import torch.nn as nn
|
||
|
import gpu_sample_draw
|
||
|
from datetime import datetime
|
||
|
import config_frequentist as cfg
|
||
|
from torch.optim import Adam, lr_scheduler
|
||
|
from models.NonBayesianModels.LeNet import LeNet
|
||
|
from models.NonBayesianModels.AlexNet import AlexNet
|
||
|
from models.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC
|
||
|
|
||
|
|
||
|
# CUDA settings
|
||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
|
||
|
|
||
|
def getModel(net_type, inputs, outputs,wide=cfg.wide):
|
||
|
if (net_type == 'lenet'):
|
||
|
return LeNet(outputs, inputs,wide)
|
||
|
elif (net_type == 'alexnet'):
|
||
|
return AlexNet(outputs, inputs)
|
||
|
elif (net_type == '3conv3fc'):
|
||
|
return ThreeConvThreeFC(outputs, inputs)
|
||
|
else:
|
||
|
raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')
|
||
|
|
||
|
|
||
|
def train_model(net, optimizer, criterion, train_loader):
|
||
|
train_loss = 0.0
|
||
|
net.train()
|
||
|
accs = []
|
||
|
for data, target in train_loader:
|
||
|
data, target = data.to(device), target.to(device)
|
||
|
optimizer.zero_grad()
|
||
|
output = net(data)
|
||
|
loss = criterion(output, target)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
train_loss += loss.item()*data.size(0)
|
||
|
accs.append(metrics.acc(output.detach(), target))
|
||
|
return train_loss, np.mean(accs)
|
||
|
|
||
|
|
||
|
def validate_model(net, criterion, valid_loader):
|
||
|
valid_loss = 0.0
|
||
|
net.eval()
|
||
|
accs = []
|
||
|
for data, target in valid_loader:
|
||
|
data, target = data.to(device), target.to(device)
|
||
|
output = net(data)
|
||
|
loss = criterion(output, target)
|
||
|
valid_loss += loss.item()*data.size(0)
|
||
|
accs.append(metrics.acc(output.detach(), target))
|
||
|
return valid_loss, np.mean(accs)
|
||
|
|
||
|
|
||
|
def run(dataset, net_type):
|
||
|
|
||
|
# Hyper Parameter settings
|
||
|
n_epochs = cfg.n_epochs
|
||
|
lr = cfg.lr
|
||
|
num_workers = cfg.num_workers
|
||
|
valid_size = cfg.valid_size
|
||
|
batch_size = cfg.batch_size
|
||
|
|
||
|
trainset, testset, inputs, outputs = data.getDataset(dataset)
|
||
|
train_loader, valid_loader, test_loader = data.getDataloader(
|
||
|
trainset, testset, valid_size, batch_size, num_workers)
|
||
|
net = getModel(net_type, inputs, outputs).to(device)
|
||
|
|
||
|
ckpt_dir = f'checkpoints/{dataset}/frequentist'
|
||
|
ckpt_name = f'checkpoints/{dataset}/frequentist/model_{net_type}_{cfg.wide}.pt'
|
||
|
|
||
|
if not os.path.exists(ckpt_dir):
|
||
|
os.makedirs(ckpt_dir, exist_ok=True)
|
||
|
|
||
|
criterion = nn.CrossEntropyLoss()
|
||
|
optimizer = Adam(net.parameters(), lr=lr)
|
||
|
lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
|
||
|
#valid_loss_min = np.Inf
|
||
|
#early_stop = []
|
||
|
#thrs=1e-9
|
||
|
train_data = []
|
||
|
for epoch in range(1, n_epochs+1):
|
||
|
|
||
|
train_loss, train_acc = train_model(net, optimizer, criterion, train_loader)
|
||
|
valid_loss, valid_acc = validate_model(net, criterion, valid_loader)
|
||
|
lr_sched.step(valid_loss)
|
||
|
|
||
|
train_loss = train_loss/len(train_loader.dataset)
|
||
|
valid_loss = valid_loss/len(valid_loader.dataset)
|
||
|
|
||
|
train_data.append([epoch,train_loss,train_acc,valid_loss,valid_acc])
|
||
|
print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f}'.format(
|
||
|
epoch, train_loss, train_acc, valid_loss, valid_acc))
|
||
|
|
||
|
#early_stop.append(valid_acc)
|
||
|
#if epoch % 4 == 0 and epoch > 0:
|
||
|
#print("Value 1: {} >= {}, Value 2: {} >= {}, Value 2: {} >= {}".format(early_stop[0],valid_acc-thrs,early_stop[1],valid_acc-thrs,early_stop[2],valid_acc-thrs))
|
||
|
#if abs(early_stop[0]) >= valid_acc-thrs and abs(early_stop[1]) >= valid_acc-thrs and abs(early_stop[2]) >= valid_acc-thrs:
|
||
|
#break
|
||
|
#early_stop = []
|
||
|
|
||
|
if train_acc >= 0.99:
|
||
|
break
|
||
|
|
||
|
#if gpu_sample_draw.total_watt_consumed() > 100000:
|
||
|
#break
|
||
|
|
||
|
# save model when finished
|
||
|
#if epoch == n_epochs:
|
||
|
#torch.save(net.state_dict(), ckpt_name)
|
||
|
|
||
|
with open("freq_exp_data_"+str(cfg.wide)+".pkl", 'wb') as f:
|
||
|
pickle.dump(train_data, f)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
now = datetime.now()
|
||
|
current_time = now.strftime("%H:%M:%S")
|
||
|
print("Initial Time =", current_time)
|
||
|
parser = argparse.ArgumentParser(description = "PyTorch Frequentist Model Training")
|
||
|
parser.add_argument('--net_type', default='lenet', type=str, help='model')
|
||
|
parser.add_argument('--dataset', default='MNIST', type=str, help='dataset = [MNIST/CIFAR10/CIFAR100]')
|
||
|
args = parser.parse_args()
|
||
|
run(args.dataset, args.net_type)
|
||
|
now = datetime.now()
|
||
|
current_time = now.strftime("%H:%M:%S")
|
||
|
print("Final Time =", current_time)
|
||
|
|