import torch.linalg as alg import pickle import torch def square_matrix(tensor): tensor_size = tensor.size() if len(tensor_size) == 1: temp = torch.zeros([tensor_size[0], tensor_size[0]-1]) return torch.cat((temp.T, tensor.reshape(1, tensor_size[0]))) elif len(tensor_size) == 2: if tensor_size[0] > tensor_size[1]: temp = torch.zeros([tensor_size[0], tensor_size[0]-tensor_size[1]]) return torch.cat((temp.T, tensor)) elif tensor_size[0] < tensor_size[1]: temp = torch.zeros([tensor_size[1], tensor_size[1]-tensor_size[0]]) return torch.cat((temp.T, tensor)) else: return tensor elif len(tensor_size) > 2: temp_tensor = tensor.detach().clone() for i, x in enumerate(tensor): # print("i: {}".format(i)) for j, t in enumerate(x): # print("j: {}".format(j)) t_size = t.size() if t_size[0] > t_size[1]: temp = torch.zeros([t_size[0], t_size[0]-t_size[1]]) temp_tensor[i][j] = torch.cat((temp.T, t)) elif t_size[0] < t_size[1]: temp = torch.zeros([t_size[1], t_size[1]-t_size[0]]) temp_tensor[i][j] = torch.cat((temp.T, t)) else: temp_tensor[i][j] = t return temp_tensor def neumann_entropy(tensor): tensor_size = tensor.size() if len(tensor_size) == 1: return 0 elif len(tensor_size) == 2: e = alg.eigvals(tensor) # temp_abs = torch.abs(e) temp_abs = e.real temp = torch.log(temp_abs).real temp = torch.nan_to_num(temp, nan=0.0, posinf=0.0, neginf=0.0) return -1 * torch.sum(temp_abs * temp) elif len(tensor_size) > 2: for i, x in enumerate(tensor): for j, t in enumerate(x): e = alg.eigvals(t) # temp_abs = torch.abs(e) temp_abs = e.real temp = torch.log(temp_abs).real temp = torch.nan_to_num(temp, nan=0.0, posinf=0.0, neginf=0.0) return -1 * torch.sum(temp_abs * temp) def load_pickle(fpath): with open(fpath, "rb") as f: data = pickle.load(f) return data def save_pickle(pickle_name, data_dump): with open(pickle_name, 'wb') as f: pickle.dump(data_dump, f) def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i:i + n] def split(lst, n): k, m = divmod(len(lst), n) return (lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))