diff --git a/vae_keras_celeba.py b/vae_keras_celeba.py index b9e1dca..e4dad28 100644 --- a/vae_keras_celeba.py +++ b/vae_keras_celeba.py @@ -132,6 +132,10 @@ def __init__(self): self.losses = [] if not os.path.exists('samples'): os.mkdir('samples') + with open('./encoder_architecture.json', 'w') as f: + f.write(encoder.to_json()) + with open('./decoder_architecture.json', 'w') as f: + f.write(decoder.to_json()) def on_epoch_end(self, epoch, logs=None): path = 'samples/test_%s.png' % epoch sample(path) @@ -139,6 +143,7 @@ def on_epoch_end(self, epoch, logs=None): if logs['loss'] <= self.lowest: self.lowest = logs['loss'] encoder.save_weights('./best_encoder.weights') + decoder.save_weights('./best_decoder.weights') evaluator = Evaluate()