bayesian_resnet_experiment/train_resnet.py

68 lines
3.0 KiB
Python

from __future__ import absolute_import, division, print_function
import tensorflow as tf
from tensorflow.python.keras.engine import training
from tensorflow.python.ops.gradients_impl import gradients
from data_acquisition import get_data
from resnet import resnet_18
import hyper
import math
if __name__ == '__main__':
model = resnet_18()
model.build(input_shape=(None, hyper.IMAGE_SHAPE[0], hyper.IMAGE_SHAPE[1], hyper.IMAGE_SHAPE[2]))
model.summary()
train_dataset,train_count,test_dataset,test_count = get_data()
# GPU settings
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = hyper.scce_loss(y_true=labels,y_pred=predictions)
gradients = tape.gradient(loss,model.trainable_variables)
hyper.adadelta.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))
hyper.mean_loss(loss)
hyper.sccea_loss(labels,predictions)
@tf.function
def valid_step(images, labels):
predictions = model(images, training=False)
v_loss = hyper.scce_loss(labels, predictions)
hyper.test_loss(v_loss)
hyper.test_accuracy(labels, predictions)
for epoch in range(hyper.EPOCHS):
hyper.mean_loss.reset_states()
hyper.sccea_loss.reset_states()
step = 0
for images, labels in train_dataset:
step += 1
train_step(images, labels)
print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch + 1,
hyper.EPOCHS,
step,
math.ceil(train_count / hyper.BATCH_SIZE),
hyper.mean_loss.result(),
hyper.sccea_loss.result()))
for valid_images, valid_labels in test_dataset:
valid_step(valid_images, valid_labels)
print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, "
"valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch + 1,
hyper.EPOCHS,
hyper.mean_loss.result(),
hyper.sccea_loss.result(),
hyper.test_loss.result(),
hyper.test_accuracy.result()))