-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Need to add an early stopping in it_sol() function
def it_sol(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001, robust=False):
n = (1 - np.isnan(sdat)).sum(axis=1)
g_old = g_hat.copy()
d_old = d_hat.copy()
ones = np.ones((1, sdat.shape[1]))
change = 1
change_old = 1
count = 0
while change > conv:
g_new = np.array(postmean(g_hat, g_bar, n, d_old, t2))
if robust:
sum2 = n*biweight_midvar(sdat,center=g_new.reshape((g_new.shape[0], 1)),axis=1)
# sum2 = n*(1.482602218505602*np.median(abs(sdat - np.dot(g_new.reshape((g_new.shape[0], 1)), ones)), axis = 1)) ** 2
else:
sum2 = ((sdat - np.dot(g_new.reshape((g_new.shape[0], 1)), ones)) ** 2).sum(
axis=1
)
d_new = postvar(sum2, n, a, b)
# change = max(
# (abs(g_new - g_old.item()) / g_old.item()).max(),
# (abs(d_new - d_old) / d_old).max(),
# )
change = max(max(abs(g_new - g_old) / g_old), max(abs(d_new - d_old) / d_old))
# print(max(abs(g_new - g_old) / g_old), "," ,max(abs(d_new - d_old) / d_old))
if count > 30:
if change > change_old:
print('[neuroCombat WARNING] Empirical Bayes step failed to converge after 30 iterations, using estimate before change between iterations increases.')
break
g_old = g_new # .copy()
d_old = d_new # .copy()
change_old = change
count = count + 1
adjust = (g_new, d_new)
return adjust
Metadata
Metadata
Assignees
Labels
No labels