diff --git a/vae_keras.py b/vae_keras.py index ca1f4ef..a3328ec 100644 --- a/vae_keras.py +++ b/vae_keras.py @@ -80,7 +80,7 @@ def sampling(args): x_test_encoded = encoder.predict(x_test, batch_size=batch_size) plt.figure(figsize=(6, 6)) -plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test_) +plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test) plt.colorbar() plt.show()