-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Hello! I encountered the following error while training the model.
ValueError Traceback (most recent call last)
Cell In[28], line 1
----> 1 loss_dict = model.fit_vae(max_epoch=128)
File ~/anaconda3/envs/liana/lib/python3.10/site-packages/spatialmeta/model/_alignment_module.py:659, in AlignmentModule.fit_vae(self, max_epoch, n_per_batch, kl_weight, n_epochs_kl_warmup, optimizer_parameters, weight_decay, lr, random_seed, validation_split)
656 epoch_total_loss = 0
658 batch_data = self._prepare_batch(X_st, X_sm)
--> 659 H, R, L = self.forward(batch_data)
660 reconstruction_loss_st = L['reconstruction_loss_st']
661 reconstruction_loss_sm = L['reconstruction_loss_sm']
File ~/anaconda3/envs/liana/lib/python3.10/site-packages/spatialmeta/model/_alignment_module.py:517, in AlignmentModule.forward(self, batch_data, reduction, **kwargs)
514 kldiv_loss_st = torch.tensor(0., device=self.device)
515 kldiv_loss_sm = torch.tensor(0., device=self.device)
--> 517 H=self.encode(batch_data)
518 R=self.decode(H, batch_data['st']['lib_size'])
520 if H['st'] is not None:
File ~/anaconda3/envs/liana/lib/python3.10/site-packages/spatialmeta/model/_alignment_module.py:445, in AlignmentModule.encode(self, batch_data, eps)
443 q_mu_st = self.z_mean_st_fc(q_st)
444 q_var_st = torch.exp(self.z_var_st_fc(q_st)) + eps
--> 445 z_st = Normal(q_mu_st, q_var_st.sqrt()).rsample()
446 st_dict = dict(
447 q = q_st,
448 q_mu = q_mu_st,
449 q_var = q_var_st,
450 z = z_st
451 )
453 if batch_data['sm'] is not None:
File ~/anaconda3/envs/liana/lib/python3.10/site-packages/torch/distributions/normal.py:56, in Normal.init(self, loc, scale, validate_args)
54 else:
55 batch_shape = self.loc.size()
---> 56 super().init(batch_shape, validate_args=validate_args)
File ~/anaconda3/envs/liana/lib/python3.10/site-packages/torch/distributions/distribution.py:68, in Distribution.init(self, batch_shape, event_shape, validate_args)
66 valid = constraint.check(value)
67 if not valid.all():
---> 68 raise ValueError(
69 f"Expected parameter {param} "
70 f"({type(value).name} of shape {tuple(value.shape)}) "
71 f"of distribution {repr(self)} "
72 f"to satisfy the constraint {repr(constraint)}, "
73 f"but found invalid values:\n{value}"
74 )
75 super().init()
ValueError: Expected parameter loc (Tensor of shape (128, 10)) of distribution Normal(loc: torch.Size([128, 10]), scale: torch.Size([128, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0',
grad_fn=)
Could you please help identify where the problem is, thank you!