From 81b26853c186c02a5a107ef2970dae6e9938f1ca Mon Sep 17 00:00:00 2001 From: Henry Berger Date: Thu, 27 Mar 2025 13:13:11 -0400 Subject: [PATCH] Fix bug with prediction of means from sampling --- BHM/pytorch/bhmtorch_cuda.py | 2 +- BHM_Online_Learning/bhmtorch_cpu.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/BHM/pytorch/bhmtorch_cuda.py b/BHM/pytorch/bhmtorch_cuda.py index d19067a..a60b62b 100755 --- a/BHM/pytorch/bhmtorch_cuda.py +++ b/BHM/pytorch/bhmtorch_cuda.py @@ -145,7 +145,7 @@ def predictSampling(self, Xq, nSamples=50): mu_a = Xq.mm(w).squeeze() probs = pt.sigmoid(mu_a) - mean = pt.std(probs, dim=1).squeeze() + mean = pt.mean(probs, dim=1).squeeze() std = pt.std(probs, dim=1).squeeze() return mean, std \ No newline at end of file diff --git a/BHM_Online_Learning/bhmtorch_cpu.py b/BHM_Online_Learning/bhmtorch_cpu.py index bfe82a2..5d5de84 100644 --- a/BHM_Online_Learning/bhmtorch_cpu.py +++ b/BHM_Online_Learning/bhmtorch_cpu.py @@ -154,7 +154,7 @@ def predictSampling(self, Xq, nSamples=50): mu_a = Xq.mm(w).squeeze() probs = pt.sigmoid(mu_a) - mean = pt.std(probs, dim=1).squeeze() + mean = pt.mean(probs, dim=1).squeeze() std = pt.std(probs, dim=1).squeeze() return mean, std @@ -289,7 +289,7 @@ def predictSampling(self, Xq, nSamples=50): mu_a = Xq.mm(w).squeeze() probs = pt.sigmoid(mu_a) - mean = pt.std(probs, dim=1).squeeze() + mean = pt.mean(probs, dim=1).squeeze() std = pt.std(probs, dim=1).squeeze() return mean, std