diff --git a/BHM/pytorch/bhmtorch_cuda.py b/BHM/pytorch/bhmtorch_cuda.py index d19067a..a60b62b 100755 --- a/BHM/pytorch/bhmtorch_cuda.py +++ b/BHM/pytorch/bhmtorch_cuda.py @@ -145,7 +145,7 @@ def predictSampling(self, Xq, nSamples=50): mu_a = Xq.mm(w).squeeze() probs = pt.sigmoid(mu_a) - mean = pt.std(probs, dim=1).squeeze() + mean = pt.mean(probs, dim=1).squeeze() std = pt.std(probs, dim=1).squeeze() return mean, std \ No newline at end of file diff --git a/BHM_Online_Learning/bhmtorch_cpu.py b/BHM_Online_Learning/bhmtorch_cpu.py index bfe82a2..5d5de84 100644 --- a/BHM_Online_Learning/bhmtorch_cpu.py +++ b/BHM_Online_Learning/bhmtorch_cpu.py @@ -154,7 +154,7 @@ def predictSampling(self, Xq, nSamples=50): mu_a = Xq.mm(w).squeeze() probs = pt.sigmoid(mu_a) - mean = pt.std(probs, dim=1).squeeze() + mean = pt.mean(probs, dim=1).squeeze() std = pt.std(probs, dim=1).squeeze() return mean, std @@ -289,7 +289,7 @@ def predictSampling(self, Xq, nSamples=50): mu_a = Xq.mm(w).squeeze() probs = pt.sigmoid(mu_a) - mean = pt.std(probs, dim=1).squeeze() + mean = pt.mean(probs, dim=1).squeeze() std = pt.std(probs, dim=1).squeeze() return mean, std