Skip to content

Commit

Permalink
fix: gmm on cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
lin-toto committed Jul 8, 2023
1 parent fe01d2d commit a7f9024
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions compressai/entropy_models/entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def _build_cdf(self, scales, means, weights, abs_max):
num_latents = scales.size(1)
num_samples = abs_max * 2 + 1
TINY = 1e-10
device = scales.device

scales = scales.clamp_(0.11, 256)
means += abs_max
Expand All @@ -729,7 +730,7 @@ def _build_cdf(self, scales, means, weights, abs_max):
means_ = means.unsqueeze(-1).expand(-1, -1, num_samples)
weights_ = weights.unsqueeze(-1).expand(-1, -1, num_samples)

samples = torch.arange(num_samples).unsqueeze(0).expand(num_latents, -1)
samples = torch.arange(num_samples).to(device).unsqueeze(0).expand(num_latents, -1)

pmf = torch.zeros_like(samples).float()
for k in range(self.K):
Expand Down Expand Up @@ -758,7 +759,7 @@ def _build_cdf(self, scales, means, weights, abs_max):
pmf_real_zero_indices = (pmf_quantized == 0).nonzero().transpose(0, 1)
pmf_quantized[pmf_real_zero_indices[0], pmf_real_zero_indices[1]] += 1

pmf_real_steal_indices = torch.cat((torch.arange(num_latents).unsqueeze(-1),
pmf_real_steal_indices = torch.cat((torch.arange(num_latents).to(device).unsqueeze(-1),
pmf_first_stealable_indices.unsqueeze(-1)),
dim=1).transpose(0, 1)
pmf_quantized[pmf_real_steal_indices[0], pmf_real_steal_indices[1]] -= pmf_zero_count
Expand Down

0 comments on commit a7f9024

Please sign in to comment.