2024-04-25 13:14:19 +00:00
|
|
|
import torch.linalg as alg
|
|
|
|
import pickle
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def square_matrix(tensor):
|
|
|
|
tensor_size = tensor.size()
|
2024-09-16 11:39:14 +00:00
|
|
|
if len(tensor_size) == 0:
|
|
|
|
return tensor
|
2024-04-25 13:14:19 +00:00
|
|
|
if len(tensor_size) == 1:
|
|
|
|
temp = torch.zeros([tensor_size[0],
|
|
|
|
tensor_size[0]-1])
|
|
|
|
return torch.cat((temp.T,
|
|
|
|
tensor.reshape(1, tensor_size[0])))
|
|
|
|
elif len(tensor_size) == 2:
|
|
|
|
if tensor_size[0] > tensor_size[1]:
|
|
|
|
temp = torch.zeros([tensor_size[0],
|
|
|
|
tensor_size[0]-tensor_size[1]])
|
|
|
|
return torch.cat((temp.T, tensor))
|
|
|
|
elif tensor_size[0] < tensor_size[1]:
|
|
|
|
temp = torch.zeros([tensor_size[1],
|
|
|
|
tensor_size[1]-tensor_size[0]])
|
|
|
|
return torch.cat((temp.T, tensor))
|
|
|
|
else:
|
|
|
|
return tensor
|
|
|
|
elif len(tensor_size) > 2:
|
|
|
|
temp_tensor = tensor.detach().clone()
|
|
|
|
for i, x in enumerate(tensor):
|
|
|
|
# print("i: {}".format(i))
|
|
|
|
for j, t in enumerate(x):
|
|
|
|
# print("j: {}".format(j))
|
|
|
|
t_size = t.size()
|
|
|
|
if t_size[0] > t_size[1]:
|
|
|
|
temp = torch.zeros([t_size[0],
|
|
|
|
t_size[0]-t_size[1]])
|
|
|
|
temp_tensor[i][j] = torch.cat((temp.T, t))
|
|
|
|
elif t_size[0] < t_size[1]:
|
|
|
|
temp = torch.zeros([t_size[1],
|
|
|
|
t_size[1]-t_size[0]])
|
|
|
|
temp_tensor[i][j] = torch.cat((temp.T, t))
|
|
|
|
else:
|
|
|
|
temp_tensor[i][j] = t
|
|
|
|
return temp_tensor
|
|
|
|
|
|
|
|
|
|
|
|
def neumann_entropy(tensor):
|
|
|
|
tensor_size = tensor.size()
|
2024-09-16 11:39:14 +00:00
|
|
|
if len(tensor_size) == 0:
|
|
|
|
return tensor
|
2024-04-25 13:14:19 +00:00
|
|
|
if len(tensor_size) == 1:
|
|
|
|
return 0
|
|
|
|
elif len(tensor_size) == 2:
|
2024-07-30 13:14:18 +00:00
|
|
|
e = alg.eigvals(tensor)#.real
|
|
|
|
#se = sum(e)
|
|
|
|
#e = e / se
|
|
|
|
temp_abs = torch.abs(e)
|
|
|
|
#temp_abs = e
|
2024-04-26 11:13:11 +00:00
|
|
|
temp = torch.log(temp_abs).real
|
|
|
|
temp = torch.nan_to_num(temp,
|
|
|
|
nan=0.0, posinf=0.0, neginf=0.0)
|
2024-04-25 13:14:19 +00:00
|
|
|
return -1 * torch.sum(temp_abs * temp)
|
|
|
|
elif len(tensor_size) > 2:
|
|
|
|
for i, x in enumerate(tensor):
|
|
|
|
for j, t in enumerate(x):
|
2024-07-30 13:14:18 +00:00
|
|
|
e = alg.eigvals(t)#.real
|
|
|
|
#se = sum(e)
|
|
|
|
#e = e / se
|
|
|
|
temp_abs = torch.abs(e)
|
|
|
|
# temp_abs = e
|
2024-04-26 11:13:11 +00:00
|
|
|
temp = torch.log(temp_abs).real
|
|
|
|
temp = torch.nan_to_num(temp,
|
|
|
|
nan=0.0, posinf=0.0, neginf=0.0)
|
2024-04-25 13:14:19 +00:00
|
|
|
return -1 * torch.sum(temp_abs * temp)
|
|
|
|
|
|
|
|
|
|
|
|
def load_pickle(fpath):
|
|
|
|
with open(fpath, "rb") as f:
|
|
|
|
data = pickle.load(f)
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
def save_pickle(pickle_name, data_dump):
|
|
|
|
with open(pickle_name, 'wb') as f:
|
|
|
|
pickle.dump(data_dump, f)
|
|
|
|
|
|
|
|
|
|
|
|
def chunks(lst, n):
|
|
|
|
"""Yield successive n-sized chunks from lst."""
|
|
|
|
for i in range(0, len(lst), n):
|
|
|
|
yield lst[i:i + n]
|
|
|
|
|
|
|
|
|
|
|
|
def split(lst, n):
|
|
|
|
k, m = divmod(len(lst), n)
|
|
|
|
return (lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|