From 24e9eb57a0fa22d749c038a894fc9eaee764b240 Mon Sep 17 00:00:00 2001 From: Shadi Date: Mon, 10 Jun 2024 11:48:17 -0700 Subject: [PATCH] Fixing bitwise_and bug in `.fit` methods --- viprs/model/VIPRS.py | 4 ++-- viprs/model/gridsearch/VIPRSGrid.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/viprs/model/VIPRS.py b/viprs/model/VIPRS.py index e7672bc..6527be9 100644 --- a/viprs/model/VIPRS.py +++ b/viprs/model/VIPRS.py @@ -802,12 +802,12 @@ def fit(self, prev_elbo = self.history['ELBO'][-2] # Check for convergence in the objective + parameters: - if (i > min_iter) & np.isclose(prev_elbo, curr_elbo, atol=f_abs_tol, rtol=0.): + if (i > min_iter) and np.isclose(prev_elbo, curr_elbo, atol=f_abs_tol, rtol=0.): self.optim_result.update(curr_elbo, stop_iteration=True, success=True, message='Objective (ELBO) converged successfully.') - elif (i > min_iter) & max([np.max(np.abs(diff)) for diff in self.eta_diff.values()]) < x_abs_tol: + elif (i > min_iter) and max([np.max(np.abs(diff)) for diff in self.eta_diff.values()]) < x_abs_tol: self.optim_result.update(curr_elbo, stop_iteration=True, success=True, diff --git a/viprs/model/gridsearch/VIPRSGrid.py b/viprs/model/gridsearch/VIPRSGrid.py index 821e0bd..9d80b01 100644 --- a/viprs/model/gridsearch/VIPRSGrid.py +++ b/viprs/model/gridsearch/VIPRSGrid.py @@ -334,13 +334,13 @@ def fit(self, for m in np.where(self.active_models)[0]: - if (i > min_iter) & np.isclose(prev_elbo[m], curr_elbo[m], atol=f_abs_tol, rtol=0.): + if (i > min_iter) and np.isclose(prev_elbo[m], curr_elbo[m], atol=f_abs_tol, rtol=0.): self.active_models[m] = False self.optim_results[m].update(curr_elbo[m], stop_iteration=True, success=True, message='Objective (ELBO) converged successfully.') - elif (i > min_iter) & max([np.max(np.abs(diff[:, m])) + elif (i > min_iter) and max([np.max(np.abs(diff[:, m])) for diff in self.eta_diff.values()]) < x_abs_tol: self.active_models[m] = False self.optim_results[m].update(curr_elbo[m],