-
Notifications
You must be signed in to change notification settings - Fork 25
Description
Hello!
First, thank you for your great work on this framework!
While experimenting with the code and adapting it for a slightly different task, I've found two potential issues that could affect stability and training performance.
1. Potential NaN from Division by Zero in InterLoss
Location: losses.py, class InterLoss
In the forward_relative_rot function, several vectors are normalized by dividing by their L2 norm. However, a small epsilon is not added to the denominator for numerical stability.
def forward_relatvie_rot(self):
r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
across = self.pred_g_joints[..., r_hip, :] - self.pred_g_joints[..., l_hip, :]
across = across / across.norm(dim=-1, keepdim=True)
across_gt = self.tgt_g_joints[..., r_hip, :] - self.tgt_g_joints[..., l_hip, :]
across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)
y_axis = torch.zeros_like(across)
y_axis[..., 1] = 1
forward = torch.cross(y_axis, across, axis=-1)
forward = forward / forward.norm(dim=-1, keepdim=True)
forward_gt = torch.cross(y_axis, across_gt, axis=-1)
forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)
pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :])
tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :])
self.losses["RO"] = self.mix_masked_mse(pred_relative_rot[..., [0, 2]],
tgt_relative_rot[..., [0, 2]],
self.mask[..., 0, :], self.timestep_mask) * self.weights["RO"]While this might not happen under the original paper's setup, it can occur in some other scenarios.
2. Suboptimal Normalization in global_std.npy
The provided global_std.npy file contains some small std values. In training_losses in gaussian_diffusion.py the mean and std of the normalized data can be larger after applying self.normalizer.forward(target).
def training_losses(self, model, mask, t_bar, cond_mask, *args, **kwargs):
target = kwargs["x_start"]
B, T = target.shape[:-1]
target = target.reshape(B, T, 2, -1)
# print(f"target.max: {target.max().item()}")
# print(f"target.min: {target.min().item()}")
# print(f"target.mean: {target.mean().item()}")
# print(f"target.std: {target.std().item()}")
mask = mask.reshape(B, T, -1, 1)
target = self.normalizer.forward(target)
# print(f"target_normalized.max: {target.max().item()}")
# print(f"target_normalized.min: {target.min().item()}")
# print(f"target_normalized.mean: {target.mean().item()}")
# print(f"target_normalized.std: {target.std().item()}")I hope these observations are helpful. Anyway, thanks again for the excellent work!