Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store nn.Parameter in entropy_models.py in nn.ParameterList #284

Merged
merged 5 commits into from
May 3, 2024

Conversation

mmuckley
Copy link
Contributor

This PR proposes to store parameters in entropy_models.py in an nn.ParameterList instead of its current string-based lookup. The primary reason to do so is to make EntropyBottleneck more friendly for torch.compile, where the current implementation fails to compile for certain backends (in my own experience, dynamo). The primary reason seems to be that the current implementation relies too much on Python strings and class attributes to access the parameters, whereas the new implementation makes this more clear at the PyTorch level, which helps the compiler.

A major drawback to the PR merging would be that it breaks backwards compatibility. I've included some state_dict adjustments that would allow loading old checkpoints, but I understand this may not be ideal. Also, new checkpoints would not be loadable by older versions of compressai.

The PR also includes a compile test for verifying the implementation.

Happy to see this merged or closed, depending on maintainer preference.

@fracape
Copy link
Collaborator

fracape commented May 1, 2024

Thank you @mmuckley for the PR!
Checking a couple of things on our side. I'll probably tag a new version after merging to indicate the potential minor backward compatibility issue re: new checkpoints with older versions of the package. This is a research project, I think it's worth it and totally fine in that direction (new checkpoints - old package).

torch.tensor([1, 0], dtype=torch.long, device=x.device),
torch.arange(2, x.ndim, dtype=torch.long, device=x.device),
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another possibility:

    perm = torch.tensor(
        [1, 0, *range(2, x.ndim)], dtype=torch.long, device=x.device
    )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, might boil down to the same thing under the hood :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @YodaEmbedding, thanks for the suggestion. If it's okay, I would like to argue for the current implementation, as the modification relies on a Python-level range iterator and unpacking it into a list, which can lead to a lot of Python calls under the hood.

When working with frameworks like torch.jit and torch.compile, I often find that these kinds of constructs can be difficult for the compiler, as the most shaky parts of those libraries are around understanding Python. By keeping everything as PyTorch calls, the compilers seem to perform more stably.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

@YodaEmbedding
Copy link
Contributor

YodaEmbedding commented May 3, 2024

The PR looks good; I tested the current implementation with torch.compile, and it works on my machine with PyTorch 2.3.0.


On a semi-related side note, the current ELIC implementation fails with torch.compile, though perhaps that will resolve itself with future torch versions.

@fracape
Copy link
Collaborator

fracape commented May 3, 2024

The PR looks good; I tested the current implementation with torch.compile, and it works on my machine with PyTorch 2.3.0.

On a semi-related side note, the current ELIC implementation fails with torch.compile, though perhaps that will resolve itself with future torch versions.

I limited the supported torch <2.3 this week, temporarily. I'm getting slightly different results in eval_model video which breaks CI that compares with expected results produced with earlier versions.

@fracape fracape merged commit 721fdee into InterDigitalInc:master May 3, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants