diff --git a/blackbirds/infer/vi.py b/blackbirds/infer/vi.py index 2e98a4a..32cf602 100644 --- a/blackbirds/infer/vi.py +++ b/blackbirds/infer/vi.py @@ -54,7 +54,11 @@ def compute_regularisation_loss( # log_prob_posterior = posterior_estimator.log_prob(z) log_prob_prior = prior.log_prob(z) # compute the Monte Carlo estimate of the KL divergence - kl_divergence = (log_prob_posterior - log_prob_prior).mean() + diffs = log_prob_posterior - log_prob_prior + # ignore nan or inf values + diffs = diffs[~torch.isnan(diffs)] + kl_divergence = diffs.mean() + print(f"z : {z}, log_prob_posterior {log_prob_posterior}, log_prob_prior {log_prob_prior}, kl_divergence {kl_divergence}") return kl_divergence