Skip to content

Commit

Permalink
filter nan and infs in kl divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed Feb 28, 2024
1 parent 081290f commit 050000d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion blackbirds/infer/vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def compute_regularisation_loss(
# log_prob_posterior = posterior_estimator.log_prob(z)
log_prob_prior = prior.log_prob(z)
# compute the Monte Carlo estimate of the KL divergence
kl_divergence = (log_prob_posterior - log_prob_prior).mean()
diffs = log_prob_posterior - log_prob_prior
# ignore nan or inf values
diffs = diffs[~torch.isnan(diffs)]
kl_divergence = diffs.mean()
print(f"z : {z}, log_prob_posterior {log_prob_posterior}, log_prob_prior {log_prob_prior}, kl_divergence {kl_divergence}")
return kl_divergence


Expand Down

0 comments on commit 050000d

Please sign in to comment.