51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
|
import numpy as np
|
||
|
import seaborn as sns
|
||
|
import matplotlib.pyplot as plt
|
||
|
import torch
|
||
|
|
||
|
import gmm
|
||
|
|
||
|
def create_synthetic_data(num_gaussians, num_features, num_samples, means, vars):
|
||
|
assert len(means[0]) == len(vars[0]) == num_features
|
||
|
samples = []
|
||
|
for g in range(num_gaussians):
|
||
|
loc = torch.tensor(means[g]).float()
|
||
|
covariance_matrix = torch.eye(num_features).float() * torch.tensor(vars[g]).float()
|
||
|
dist = torch.distributions.multivariate_normal.MultivariateNormal(
|
||
|
loc = loc, covariance_matrix=covariance_matrix)
|
||
|
|
||
|
for i in range(num_samples//num_gaussians):
|
||
|
sample = dist.sample()
|
||
|
samples.append(sample.unsqueeze(0))
|
||
|
|
||
|
samples = torch.cat(samples, axis=0)
|
||
|
return samples
|
||
|
|
||
|
|
||
|
def plot_data(data, y=None):
|
||
|
if y is not None:
|
||
|
for sample, target in zip(data, y):
|
||
|
if target==0:
|
||
|
plt.scatter(*sample, color='blue')
|
||
|
elif target==1:
|
||
|
plt.scatter(*sample, color='red')
|
||
|
elif target==2:
|
||
|
plt.scatter(*sample, color='green')
|
||
|
else:
|
||
|
for sample in data:
|
||
|
plt.scatter(*sample, color='black')
|
||
|
plt.show(block=False)
|
||
|
plt.pause(2)
|
||
|
plt.close()
|
||
|
|
||
|
|
||
|
means = [[1, 4], [5, 5], [2, -1]] # list of task's means(which is mean of each feature)
|
||
|
vars = [[0.1, 0.1], [0.05, 0.4], [0.5, 0.2]] # list of task's vars(which is var of each feature)
|
||
|
data = create_synthetic_data(3, 2, 600, means, vars) # shape: (total_samples, num_features)
|
||
|
plot_data(data)
|
||
|
model = gmm.GaussianMixture(3, 2) # (num_gaussians, num_features)
|
||
|
model.fit(data)
|
||
|
y = model.predict(data)
|
||
|
plot_data(data, y)
|
||
|
|