Entropy_Data_Processing/functions.py

88 lines
2.9 KiB
Python

import torch.linalg as alg
import pickle
import torch
def square_matrix(tensor):
tensor_size = tensor.size()
if len(tensor_size) == 1:
temp = torch.zeros([tensor_size[0],
tensor_size[0]-1])
return torch.cat((temp.T,
tensor.reshape(1, tensor_size[0])))
elif len(tensor_size) == 2:
if tensor_size[0] > tensor_size[1]:
temp = torch.zeros([tensor_size[0],
tensor_size[0]-tensor_size[1]])
return torch.cat((temp.T, tensor))
elif tensor_size[0] < tensor_size[1]:
temp = torch.zeros([tensor_size[1],
tensor_size[1]-tensor_size[0]])
return torch.cat((temp.T, tensor))
else:
return tensor
elif len(tensor_size) > 2:
temp_tensor = tensor.detach().clone()
for i, x in enumerate(tensor):
# print("i: {}".format(i))
for j, t in enumerate(x):
# print("j: {}".format(j))
t_size = t.size()
if t_size[0] > t_size[1]:
temp = torch.zeros([t_size[0],
t_size[0]-t_size[1]])
temp_tensor[i][j] = torch.cat((temp.T, t))
elif t_size[0] < t_size[1]:
temp = torch.zeros([t_size[1],
t_size[1]-t_size[0]])
temp_tensor[i][j] = torch.cat((temp.T, t))
else:
temp_tensor[i][j] = t
return temp_tensor
def neumann_entropy(tensor):
tensor_size = tensor.size()
if len(tensor_size) == 1:
return 0
elif len(tensor_size) == 2:
e = alg.eigvals(tensor)
# temp_abs = torch.abs(e)
temp_abs = e.real
temp = torch.log(temp_abs).real
temp = torch.nan_to_num(temp,
nan=0.0, posinf=0.0, neginf=0.0)
return -1 * torch.sum(temp_abs * temp)
elif len(tensor_size) > 2:
for i, x in enumerate(tensor):
for j, t in enumerate(x):
e = alg.eigvals(t)
# temp_abs = torch.abs(e)
temp_abs = e.real
temp = torch.log(temp_abs).real
temp = torch.nan_to_num(temp,
nan=0.0, posinf=0.0, neginf=0.0)
return -1 * torch.sum(temp_abs * temp)
def load_pickle(fpath):
with open(fpath, "rb") as f:
data = pickle.load(f)
return data
def save_pickle(pickle_name, data_dump):
with open(pickle_name, 'wb') as f:
pickle.dump(data_dump, f)
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def split(lst, n):
k, m = divmod(len(lst), n)
return (lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))