Entropy_Data_Processing/general_plots.py

30 lines
1.1 KiB
Python
Raw Normal View History

2024-04-25 13:14:19 +00:00
import matplotlib.pyplot as plt
import functions as aux
eff_df = aux.load_pickle("efficiency_data.pkl")
# bayes_cifar_entropy = aux.load_pickle("bayes_data_cifar_ne.pkl")
# lenet_mnist_entropy = aux.load_pickle("lenet_data_mnist_ne.pkl")
entropy_data = aux.load_pickle("entropy_data.pkl")
2024-04-25 13:14:19 +00:00
bayes_keys = ['conv1.W_mu', 'conv1.W_rho', 'conv1.bias_mu', 'conv1.bias_rho',
'conv2.W_mu', 'conv2.W_rho', 'conv2.bias_mu', 'conv2.bias_rho',
'fc1.W_mu', 'fc1.W_rho', 'fc1.bias_mu', 'fc1.bias_rho',
'fc2.W_mu', 'fc2.W_rho', 'fc2.bias_mu', 'fc2.bias_rho',
'fc3.W_mu', 'fc3.W_rho', 'fc3.bias_mu', 'fc3.bias_rho']
lenet_keys = ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias',
'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight',
'fc3.bias']
for size in range(1, 8):
2024-09-16 11:39:14 +00:00
# if size != 3:
plt.plot(eff_df['MNIST']['LeNet'][size],
label='Efficiency size {}'.format(size))
2024-09-16 11:39:14 +00:00
plt.plot(entropy_data['MNIST']['LeNet'][size],
label='Entropy size {}'.format(size))
2024-04-25 13:14:19 +00:00
plt.legend(loc='upper right')
# plt.legend(loc='lower right')
2024-04-25 13:14:19 +00:00
plt.show()