Skip to content

error in 'model.fit_vae' #2

@peipp410

Description

@peipp410

Hello! I encountered the following error while training the model.


ValueError Traceback (most recent call last)
Cell In[28], line 1
----> 1 loss_dict = model.fit_vae(max_epoch=128)

File ~/anaconda3/envs/liana/lib/python3.10/site-packages/spatialmeta/model/_alignment_module.py:659, in AlignmentModule.fit_vae(self, max_epoch, n_per_batch, kl_weight, n_epochs_kl_warmup, optimizer_parameters, weight_decay, lr, random_seed, validation_split)
656 epoch_total_loss = 0
658 batch_data = self._prepare_batch(X_st, X_sm)
--> 659 H, R, L = self.forward(batch_data)
660 reconstruction_loss_st = L['reconstruction_loss_st']
661 reconstruction_loss_sm = L['reconstruction_loss_sm']

File ~/anaconda3/envs/liana/lib/python3.10/site-packages/spatialmeta/model/_alignment_module.py:517, in AlignmentModule.forward(self, batch_data, reduction, **kwargs)
514 kldiv_loss_st = torch.tensor(0., device=self.device)
515 kldiv_loss_sm = torch.tensor(0., device=self.device)
--> 517 H=self.encode(batch_data)
518 R=self.decode(H, batch_data['st']['lib_size'])
520 if H['st'] is not None:

File ~/anaconda3/envs/liana/lib/python3.10/site-packages/spatialmeta/model/_alignment_module.py:445, in AlignmentModule.encode(self, batch_data, eps)
443 q_mu_st = self.z_mean_st_fc(q_st)
444 q_var_st = torch.exp(self.z_var_st_fc(q_st)) + eps
--> 445 z_st = Normal(q_mu_st, q_var_st.sqrt()).rsample()
446 st_dict = dict(
447 q = q_st,
448 q_mu = q_mu_st,
449 q_var = q_var_st,
450 z = z_st
451 )
453 if batch_data['sm'] is not None:

File ~/anaconda3/envs/liana/lib/python3.10/site-packages/torch/distributions/normal.py:56, in Normal.init(self, loc, scale, validate_args)
54 else:
55 batch_shape = self.loc.size()
---> 56 super().init(batch_shape, validate_args=validate_args)

File ~/anaconda3/envs/liana/lib/python3.10/site-packages/torch/distributions/distribution.py:68, in Distribution.init(self, batch_shape, event_shape, validate_args)
66 valid = constraint.check(value)
67 if not valid.all():
---> 68 raise ValueError(
69 f"Expected parameter {param} "
70 f"({type(value).name} of shape {tuple(value.shape)}) "
71 f"of distribution {repr(self)} "
72 f"to satisfy the constraint {repr(constraint)}, "
73 f"but found invalid values:\n{value}"
74 )
75 super().init()

ValueError: Expected parameter loc (Tensor of shape (128, 10)) of distribution Normal(loc: torch.Size([128, 10]), scale: torch.Size([128, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0',
grad_fn=)

Could you please help identify where the problem is, thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions