Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: GaussianMixtureConditional #239

Merged
merged 4 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 125 additions & 4 deletions compressai/cpp_exts/rans/rans_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -65,6 +66,26 @@ void assert_cdfs(const std::vector<std::vector<int>> &cdfs,
}
}

std::vector<std::vector<int32_t>> make_cdfs_vector_from_tensor(
const torch::Tensor &cdfs, const std::vector<int32_t> &cdfs_sizes) {
assert(cdfs.dim() == 2);
assert(cdfs.size(0) == cdfs_sizes.size());
assert(cdfs.dtype() == torch::kInt32);

auto num_samples = cdfs.size(1);
auto *ptr = reinterpret_cast<int32_t*>(cdfs.data_ptr());

std::vector<std::vector<int32_t>> result;

for (auto cdf_size : cdfs_sizes) {
std::vector<int32_t> cdf_vec(ptr, ptr + cdf_size);
ptr += num_samples;
result.push_back(std::move(cdf_vec));
}

return result;
}

/* Support only 16 bits word max */
inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val,
uint32_t nbits) {
Expand Down Expand Up @@ -172,6 +193,16 @@ void BufferedRansEncoder::encode_with_indexes(
}
}

void BufferedRansEncoder::encode_with_indexes(
const std::vector<int32_t> &symbols, const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets) {
return encode_with_indexes(symbols, indexes,
make_cdfs_vector_from_tensor(cdfs, cdfs_sizes),
cdfs_sizes, offsets);
}

py::bytes BufferedRansEncoder::flush() {
Rans64State rans;
Rans64EncInit(&rans);
Expand Down Expand Up @@ -212,6 +243,17 @@ RansEncoder::encode_with_indexes(const std::vector<int32_t> &symbols,
return buffered_rans_enc.flush();
}

py::bytes
RansEncoder::encode_with_indexes(const std::vector<int32_t> &symbols,
const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets) {
return encode_with_indexes(symbols, indexes,
make_cdfs_vector_from_tensor(cdfs, cdfs_sizes),
cdfs_sizes, offsets);
}

std::vector<int32_t>
RansDecoder::decode_with_indexes(const std::string &encoded,
const std::vector<int32_t> &indexes,
Expand Down Expand Up @@ -283,6 +325,17 @@ RansDecoder::decode_with_indexes(const std::string &encoded,
return output;
}

std::vector<int32_t>
RansDecoder::decode_with_indexes(const std::string &encoded,
const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets) {
return decode_with_indexes(encoded, indexes,
make_cdfs_vector_from_tensor(cdfs, cdfs_sizes),
cdfs_sizes, offsets);
}

void RansDecoder::set_stream(const std::string &encoded) {
_stream = encoded;
uint32_t *ptr = (uint32_t *)_stream.data();
Expand Down Expand Up @@ -358,24 +411,92 @@ RansDecoder::decode_stream(const std::vector<int32_t> &indexes,
return output;
}

std::vector<int32_t>
RansDecoder::decode_stream(const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets) {
return decode_stream(indexes, make_cdfs_vector_from_tensor(cdfs, cdfs_sizes),
cdfs_sizes, offsets);
}

PYBIND11_MODULE(ans, m) {
m.attr("__name__") = "compressai.ans";

m.doc() = "range Asymmetric Numeral System python bindings";

py::class_<BufferedRansEncoder>(m, "BufferedRansEncoder")
.def(py::init<>())
.def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes)
.def("encode_with_indexes",
py::overload_cast<
const std::vector<int32_t> &,
const std::vector<int32_t> &,
const std::vector<std::vector<int32_t>> &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&BufferedRansEncoder::encode_with_indexes))
.def("encode_with_indexes",
py::overload_cast<
const std::vector<int32_t> &,
const std::vector<int32_t> &,
const torch::Tensor &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&BufferedRansEncoder::encode_with_indexes))
.def("flush", &BufferedRansEncoder::flush);

py::class_<RansEncoder>(m, "RansEncoder")
.def(py::init<>())
.def("encode_with_indexes", &RansEncoder::encode_with_indexes);
.def("encode_with_indexes",
py::overload_cast<
const std::vector<int32_t> &,
const std::vector<int32_t> &,
const std::vector<std::vector<int32_t>> &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&RansEncoder::encode_with_indexes))
.def("encode_with_indexes",
py::overload_cast<
const std::vector<int32_t> &,
const std::vector<int32_t> &,
const torch::Tensor &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&RansEncoder::encode_with_indexes));

py::class_<RansDecoder>(m, "RansDecoder")
.def(py::init<>())
.def("set_stream", &RansDecoder::set_stream)
.def("decode_stream", &RansDecoder::decode_stream)
.def("decode_with_indexes", &RansDecoder::decode_with_indexes,
.def("decode_stream",
py::overload_cast<
const std::vector<int32_t> &,
const std::vector<std::vector<int32_t>> &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&RansDecoder::decode_stream))
.def("decode_stream",
py::overload_cast<
const std::vector<int32_t> &,
const torch::Tensor &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&RansDecoder::decode_stream))
.def("decode_with_indexes",
py::overload_cast<
const std::string &,
const std::vector<int32_t> &,
const std::vector<std::vector<int32_t>> &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&RansDecoder::decode_with_indexes),
"Decode a string to a list of symbols")
.def("decode_with_indexes",
py::overload_cast<
const std::string &,
const std::vector<int32_t> &,
const torch::Tensor &,
const std::vector<int32_t> &,
const std::vector<int32_t> &
>(&RansDecoder::decode_with_indexes),
"Decode a string to a list of symbols");
}
31 changes: 31 additions & 0 deletions compressai/cpp_exts/rans/rans_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>

#include "rans64.h"

Expand Down Expand Up @@ -60,6 +61,17 @@ class BufferedRansEncoder {
const std::vector<std::vector<int32_t>> &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);

/* NOTE: In the case of GMM, the type conversion from Tensor to vector is
* better done in C++, or it will occupy ~80% of execution time, which is
* why an interface for torch::Tensor cdfs is provided...
**/
void encode_with_indexes(const std::vector<int32_t> &symbols,
const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);

py::bytes flush();

private:
Expand All @@ -80,6 +92,12 @@ class RansEncoder {
const std::vector<std::vector<int32_t>> &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);

py::bytes encode_with_indexes(const std::vector<int32_t> &symbols,
const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);
};

class RansDecoder {
Expand All @@ -98,6 +116,13 @@ class RansDecoder {
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);

std::vector<int32_t>
decode_with_indexes(const std::string &encoded,
const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);

void set_stream(const std::string &stream);

std::vector<int32_t>
Expand All @@ -106,6 +131,12 @@ class RansDecoder {
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);

std::vector<int32_t>
decode_stream(const std::vector<int32_t> &indexes,
const torch::Tensor &cdfs,
const std::vector<int32_t> &cdfs_sizes,
const std::vector<int32_t> &offsets);

private:
Rans64State _rans;
std::string _stream;
Expand Down
8 changes: 7 additions & 1 deletion compressai/entropy_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from .entropy_models import EntropyBottleneck, EntropyModel, GaussianConditional
from .entropy_models import (
EntropyBottleneck,
EntropyModel,
GaussianConditional,
GaussianMixtureConditional,
)

__all__ = [
"EntropyModel",
"EntropyBottleneck",
"GaussianConditional",
"GaussianMixtureConditional",
]
Loading