bayesian_resnet_experiment/bayesian_resnet.py

232 lines
8.6 KiB
Python

# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Trains a Bayesian neural network to classify CIFAR-10 images.
The architecture can be either ResNet [1] or VGG [2].
To run with default arguments:
```
bazel run tensorflow_probability/examples:cifar10_bnn
```
#### References
[1]: He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
"Deep residual learning for image recognition."
_Proceedings of the IEEE_, 2016.
https://arxiv.org/abs/1512.03385
[2]: Simonyan, Karen, and Andrew Zisserman.
"Very deep convolutional networks for large-scale image recognition."
arXiv preprint arXiv:1409.1556 (2014).
https://arxiv.org/pdf/1409.1556.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import warnings
# Dependency imports
import matplotlib
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
from probability.tensorflow_probability.examples.models.bayesian_resnet import bayesian_resnet
from probability.tensorflow_probability.examples.models.bayesian_vgg import bayesian_vgg
tf.compat.v1.disable_eager_execution()
matplotlib.use("Agg")
warnings.simplefilter(action="ignore")
tfd = tfp.distributions
IMAGE_SHAPE = [32, 32, 3]
learning_rate = 0.0001
epochs = 700
batch_size = 128
data_dir = os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "bayesian_neural_network/data")
model_dir = os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "bayesian_neural_network/")
eval_freq = 400
num_monte_carlo = 50
architecture = "resnet"
kernel_posterior_scale_mean = -9.0
kernel_posterior_scale_constraint = 0.2
kl_annealing = 50
subtract_pixel_mean = True
fake_data = False
def build_input_pipeline(x_train, x_test, y_train, y_test,
batch_size, valid_size):
"""Build an Iterator switching between train and heldout data."""
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
y_train = y_train.flatten()
y_test = y_test.flatten()
if subtract_pixel_mean:
x_train_mean = np.mean(x_train, axis=0)
x_train -= x_train_mean
x_test -= x_train_mean
print("x_train shape:" + str(x_train.shape))
print(str(x_train.shape[0]) + " train samples")
print(str(x_test.shape[0]) + " test samples")
# Build an iterator over training batches.
training_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, np.int32(y_train)))
training_batches = training_dataset.shuffle(
50000, reshuffle_each_iteration=True).repeat().batch(batch_size)
training_iterator = tf.compat.v1.data.make_one_shot_iterator(training_batches)
# Build a iterator over the heldout set with batch_size=heldout_size,
# i.e., return the entire heldout set as a constant.
heldout_dataset = tf.data.Dataset.from_tensor_slices(
(x_test, np.int32(y_test)))
heldout_batches = heldout_dataset.repeat().batch(valid_size)
heldout_iterator = tf.compat.v1.data.make_one_shot_iterator(heldout_batches)
# Combine these into a feedable iterator that can switch between training
# and validation inputs.
handle = tf.compat.v1.placeholder(tf.string, shape=[])
feedable_iterator = tf.compat.v1.data.Iterator.from_string_handle(
handle, training_batches.output_types, training_batches.output_shapes)
images, labels = feedable_iterator.get_next()
return images, labels, handle, training_iterator, heldout_iterator
def main():
#del argv # unused
if tf.io.gfile.exists(model_dir):
tf.compat.v1.logging.warning(
"Warning: deleting old log directory at {}".format(model_dir))
tf.io.gfile.rmtree(model_dir)
tf.io.gfile.makedirs(model_dir)
if fake_data:
(x_train, y_train), (x_test, y_test) = build_fake_data()
else:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
(images, labels, handle, training_iterator, heldout_iterator) = build_input_pipeline(x_train, x_test, y_train, y_test, batch_size, 500)
if architecture == "resnet":
model_fn = bayesian_resnet
else:
model_fn = bayesian_vgg
model = model_fn(
IMAGE_SHAPE,
num_classes=10,
kernel_posterior_scale_mean=kernel_posterior_scale_mean,
kernel_posterior_scale_constraint=kernel_posterior_scale_constraint)
logits = model(images)
labels_distribution = tfd.Categorical(logits=logits)
# Perform KL annealing. The optimal number of annealing steps
# depends on the dataset and architecture.
t = tf.compat.v2.Variable(0.0)
kl_regularizer = t / (kl_annealing * len(x_train) / batch_size)
# Compute the -ELBO as the loss. The kl term is annealed from 0 to 1 over
# the epochs specified by the kl_annealing flag.
log_likelihood = labels_distribution.log_prob(labels)
neg_log_likelihood = -tf.reduce_mean(input_tensor=log_likelihood)
kl = sum(model.losses) / len(x_train) * tf.minimum(1.0, kl_regularizer)
loss = neg_log_likelihood + kl
# Build metrics for evaluation. Predictions are formed from a single forward
# pass of the probabilistic layers. They are cheap but noisy
# predictions.
predictions = tf.argmax(input=logits, axis=1)
with tf.compat.v1.name_scope("train"):
train_accuracy, train_accuracy_update_op = tf.compat.v1.metrics.accuracy(
labels=labels, predictions=predictions)
opt = tf.compat.v1.train.AdamOptimizer(learning_rate)
train_op = opt.minimize(loss)
update_step_op = tf.compat.v1.assign(t, t + 1)
with tf.compat.v1.name_scope("valid"):
valid_accuracy, valid_accuracy_update_op = tf.compat.v1.metrics.accuracy(
labels=labels, predictions=predictions)
init_op = tf.group(tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer())
stream_vars_valid = [
v for v in tf.compat.v1.local_variables() if "valid/" in v.name
]
reset_valid_op = tf.compat.v1.variables_initializer(stream_vars_valid)
with tf.compat.v1.Session() as sess:
sess.run(init_op)
# Run the training loop
train_handle = sess.run(training_iterator.string_handle())
heldout_handle = sess.run(heldout_iterator.string_handle())
training_steps = int(
round(epochs * (len(x_train) / batch_size)))
for step in range(training_steps):
_ = sess.run([train_op,
train_accuracy_update_op,
update_step_op],
feed_dict={handle: train_handle})
# Manually print the frequency
if step % 100 == 0:
loss_value, accuracy_value, kl_value = sess.run(
[loss, train_accuracy, kl], feed_dict={handle: train_handle})
print(
"Step: {:>3d} Loss: {:.3f} Accuracy: {:.3f} KL: {:.3f}".format(
step, loss_value, accuracy_value, kl_value))
if (step + 1) % eval_freq == 0:
# Compute log prob of heldout set by averaging draws from the model:
# p(heldout | train) = int_model p(heldout|model) p(model|train)
# ~= 1/n * sum_{i=1}^n p(heldout | model_i)
# where model_i is a draw from the posterior
# p(model|train).
probs = np.asarray([sess.run((labels_distribution.probs),
feed_dict={handle: heldout_handle})
for _ in range(num_monte_carlo)])
mean_probs = np.mean(probs, axis=0)
_, label_vals = sess.run(
(images, labels), feed_dict={handle: heldout_handle})
heldout_lp = np.mean(np.log(mean_probs[np.arange(mean_probs.shape[0]),
label_vals.flatten()]))
print(" ... Held-out nats: {:.3f}".format(heldout_lp))
# Calculate validation accuracy
for _ in range(20):
sess.run(
valid_accuracy_update_op, feed_dict={handle: heldout_handle})
valid_value = sess.run(
valid_accuracy, feed_dict={handle: heldout_handle})
print(
" ... Validation Accuracy: {:.3f}".format(valid_value))
sess.run(reset_valid_op)
if __name__ == "__main__":
#tf.compat.v1.app.run()
main()