diff --git a/quail/analysis/clustering.py b/quail/analysis/clustering.py index ae6db08..f89274b 100644 --- a/quail/analysis/clustering.py +++ b/quail/analysis/clustering.py @@ -72,7 +72,11 @@ def _get_weight_exact(egg, feature, distdict, permute, n_perms): dists = distmat[pres.index(c),:] di = dists[pres.index(n)] dists_filt = np.array([dist for idx, dist in enumerate(dists) if idx not in past_idxs]) - ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) + + if len(np.unique(dists_filt)) == 1: + ranks.append(0.5) + else: + ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) past_idxs.append(pres.index(c)) past_words.append(c) return np.nanmean(ranks) @@ -97,7 +101,10 @@ def _get_weight_best(egg, feature, distdict, permute, n_perms, distance): dists = distmat[cdx, :] di = dists[ndx] dists_filt = np.array([dist for idx, dist in enumerate(dists)]) - ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) + if len(np.unique(dists_filt)) == 1: + ranks.append(0.5) + else: + ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0] + 1) / len(dists_filt)) return np.nanmean(ranks) def _get_weight_smooth(egg, feature, distdict, permute, n_perms, distance): @@ -120,7 +127,10 @@ def _get_weight_smooth(egg, feature, distdict, permute, n_perms, distance): dists = distmat[cdx, :] di = dists[ndx] dists_filt = np.array([dist for idx, dist in enumerate(dists)]) - ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0]+1) / len(dists_filt)) + if len(np.unique(dists_filt)) == 1: + ranks.append(0.5) + else: + ranks.append(np.mean(np.where(np.sort(dists_filt)[::-1] == di)[0] + 1) / len(dists_filt)) return np.nanmean(ranks) def get_distmat(egg, feature, distdict):