diff --git a/models/transition.py b/models/transition.py index 3278913..628af19 100644 --- a/models/transition.py +++ b/models/transition.py @@ -122,9 +122,9 @@ def f_save(step): T.save_weights(checkpoint_dir+"/T_weights.keras", True) def sampler(z, x): - video = np.zeros((128, 80, 160, 3)) + video = np.zeros((batch_size*2, 80, 160, 3)) print "Sampling..." - for i in range(128): + for i in range(batch_size*2): print i x = x.reshape((-1, 80, 160, 3)) # code = E.predict(x, batch_size=batch_size*(time+1))[0]