Skip to content

Commit

Permalink
update the pp_api too
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl committed Sep 4, 2024
1 parent 23b2d68 commit 1bde236
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 27 deletions.
6 changes: 3 additions & 3 deletions .pipelines/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ stages:
# compiled as only one operator selected.
- bash: |
set -e -x -u
./build.sh -DOCOS_BUILD_PRESET=ort_genai
./build.sh -DOCOS_ENABLE_C_API=ON
cd out/Linux/RelWithDebInfo
ctest -C RelWithDebInfo --output-on-failure
displayName: Build ort-extensions with API enabled and run tests
Expand Down Expand Up @@ -281,7 +281,7 @@ stages:
# compiled as only one operator selected.
- bash: |
set -e -x -u
./build.sh -DOCOS_BUILD_PRESET=ort_genai
./build.sh -DOCOS_ENABLE_C_API=ON
cd out/Darwin/RelWithDebInfo
ctest -C RelWithDebInfo --output-on-failure
displayName: Build ort-extensions with API enabled and run tests
Expand Down Expand Up @@ -431,7 +431,7 @@ stages:

steps:
- script: |
call .\build.bat -DOCOS_BUILD_PRESET=ort_genai
call .\build.bat -DOCOS_ENABLE_C_API=ON
cd out\Windows
ctest -C RelWithDebInfo --output-on-failure
displayName: Build ort-extensions with API enabled and run tests
Expand Down
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ if(NOT PROJECT_IS_TOP_LEVEL AND ONNXRUNTIME_ROOT)
set(_ONNXRUNTIME_EMBEDDED TRUE)
endif()


if (OCOS_ENABLE_SELECTED_OPLIST OR OCOS_BUILD_PRESET)
disable_all_operators()
if(OCOS_ENABLE_SELECTED_OPLIST)
Expand Down
1 change: 0 additions & 1 deletion cmake/ext_imgcodecs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

set(_IMGCODEC_ROOT_DIR ${dlib_SOURCE_DIR}/dlib/external)


# ----------------------------------------------------------------------------
# project libpng
#
Expand Down
11 changes: 11 additions & 0 deletions include/ortx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t
*/
extError_t ORTX_API_CALL OrtxGetTensorType(OrtxTensor* tensor, extDataType_t* type);

/**
* @brief Retrieves the size of each element in the given tensor.
*
* This function calculates the size of each element in the specified tensor and stores it in the provided size variable.
*
* @param tensor A pointer to the OrtxTensor object.
* @param size A pointer to a size_t variable to store the size of each element.
* @return An extError_t value indicating the success or failure of the operation.
*/
extError_t ORTX_API_CALL OrtxGetTensorSizeOfElement(OrtxTensor* tensor, size_t* size);

/** \brief Get the data from the tensor
*
* \param tensor The tensor object
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime_extensions/pp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,16 @@ def __init__(self, processor_json):
self.processor = create_processor(processor_json)

def pre_process(self, images):
if isinstance(images, str):
images = [images]
if isinstance(images, list):
images = load_images(images)
return image_pre_process(self.processor, images)

@staticmethod
def to_numpy(result):
return tensor_result_get_at(result, 0)

def __del__(self):
if delete_object and self.processor:
delete_object(self.processor)
Expand Down
15 changes: 7 additions & 8 deletions pyop/py_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,12 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
const int64_t* shape{};
size_t num_dims;
const void* data{};
size_t elem_size = 0;
if (tensor_type == extDataType_t::kOrtxInt64 || tensor_type == extDataType_t::kOrtxFloat) {
size_t elem_size = 1;
if (tensor_type == extDataType_t::kOrtxInt64 ||
tensor_type == extDataType_t::kOrtxFloat ||
tensor_type == extDataType_t::kOrtxUint8) {
OrtxGetTensorData(tensor, reinterpret_cast<const void**>(&data), &shape, &num_dims);
elem_size = 4;
if (tensor_type == extDataType_t::kOrtxInt64) {
elem_size = 8;
}
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
throw std::runtime_error("Failed to get tensor type");
OrtxGetTensorSizeOfElement(tensor, &elem_size);
} else if (tensor_type == extDataType_t::kOrtxUnknownType) {
throw std::runtime_error("unsupported tensor type");
}
Expand All @@ -108,6 +105,8 @@ void AddGlobalMethodsCApi(pybind11::module& m) {
obj = py::array_t<float>(npy_dims);
} else if (tensor_type == extDataType_t::kOrtxInt64) {
obj = py::array_t<int64_t>(npy_dims);
} else if (tensor_type == extDataType_t::kOrtxUint8) {
obj = py::array_t<uint8_t>(npy_dims);
}

void* out_ptr = obj.mutable_data();
Expand Down
27 changes: 26 additions & 1 deletion shared/api/c_api_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t

auto tensor_ptr = std::make_unique<TensorObject>();
tensor_ptr->SetTensor(ts);
tensor_ptr->SetTensorType(result_ptr->GetTensorType(index));
*tensor = static_cast<OrtxTensor*>(tensor_ptr.release());
return extError_t();
}
Expand All @@ -124,6 +123,24 @@ extError_t ORTX_API_CALL OrtxGetTensorType(OrtxTensor* tensor, extDataType_t* ty
return extError_t();
}

extError_t ORTX_API_CALL OrtxGetTensorSizeOfElement(OrtxTensor* tensor, size_t* size) {
if (tensor == nullptr || size == nullptr) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}

auto tensor_impl = static_cast<TensorObject*>(tensor);
if (tensor_impl->ortx_kind() != extObjectKind_t::kOrtxKindTensor) {
ReturnableStatus::last_error_message_ = "Invalid argument";
return kOrtxErrorInvalidArgument;
}

auto tb = tensor_impl->GetTensor();
assert(tb != nullptr);
*size = tb->SizeInBytes() / tb->NumberOfElement();
return extError_t();
}

extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape,
size_t* num_dims) {
if (tensor == nullptr) {
Expand Down Expand Up @@ -158,3 +175,11 @@ extError_t ORTX_API_CALL OrtxGetTensorDataFloat(OrtxTensor* tensor, const float*
*data = reinterpret_cast<const float*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
return err;
}

extError_t ORTX_API_CALL OrtxGetTensorDataUint8(OrtxTensor* tensor, const uint8_t** data, const int64_t** shape,
size_t* num_dims) {
const void* data_ptr{};
auto err = OrtxGetTensorData(tensor, &data_ptr, shape, num_dims);
*data = reinterpret_cast<const uint8_t*>(data_ptr); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
return err;
}
65 changes: 53 additions & 12 deletions shared/api/c_api_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,56 @@ class TensorObject : public OrtxObjectImpl {
~TensorObject() override = default;

void SetTensor(ortc::TensorBase* tensor) { tensor_ = tensor; }
void SetTensorType(extDataType_t type) { tensor_type_ = type; }

[[nodiscard]] extDataType_t GetTensorType() const { return tensor_type_; }
static extDataType_t GetDataType(ONNXTensorElementDataType dt) {
if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
return extDataType_t::kOrtxFloat;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
return extDataType_t::kOrtxUint8;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
return extDataType_t::kOrtxInt8;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) {
return extDataType_t::kOrtxUint16;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) {
return extDataType_t::kOrtxInt16;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
return extDataType_t::kOrtxInt32;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
return extDataType_t::kOrtxInt64;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
return extDataType_t::kOrtxString;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) {
return extDataType_t::kOrtxBool;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
return extDataType_t::kOrtxFloat16;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
return extDataType_t::kOrtxDouble;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) {
return extDataType_t::kOrtxUint32;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) {
return extDataType_t::kOrtxUint64;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64) {
return extDataType_t::kOrtxComplex64;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128) {
return extDataType_t::kOrtxComplex128;
} else if (dt == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) {
return extDataType_t::kOrtxBFloat16;
} else {
return extDataType_t::kOrtxUnknownType;
}
}

[[nodiscard]] extDataType_t GetTensorType() const {
if (tensor_ == nullptr) {
return extDataType_t::kOrtxUnknownType;
}
return GetDataType(tensor_->Type());
}

[[nodiscard]] ortc::TensorBase* GetTensor() const { return tensor_; }

private:
ortc::TensorBase* tensor_{};
extDataType_t tensor_type_{extDataType_t::kOrtxUnknownType};
};

class TensorResult : public OrtxObjectImpl {
Expand All @@ -116,10 +157,10 @@ class TensorResult : public OrtxObjectImpl {
~TensorResult() override = default;

void SetTensors(std::vector<std::unique_ptr<ortc::TensorBase>>&& tensors) { tensors_ = std::move(tensors); }
void SetTensorTypes(const std::vector<extDataType_t>& types) { tensor_types_ = types; }
// void SetTensorTypes(const std::vector<extDataType_t>& types) { tensor_types_ = types; }
[[nodiscard]] size_t NumTensors() const { return tensors_.size(); }

[[nodiscard]] const std::vector<extDataType_t>& tensor_types() const { return tensor_types_; }
// [[nodiscard]] const std::vector<extDataType_t>& tensor_types() const { return tensor_types_; }

[[nodiscard]] const std::vector<std::unique_ptr<ortc::TensorBase>>& tensors() const { return tensors_; }

Expand All @@ -139,16 +180,16 @@ class TensorResult : public OrtxObjectImpl {
return nullptr;
}

extDataType_t GetTensorType(size_t i) const {
if (i < tensor_types_.size()) {
return tensor_types_[i];
}
return extDataType_t::kOrtxUnknownType;
}
// extDataType_t GetTensorType(size_t i) const {
// if (i < tensor_types_.size()) {
// return tensor_types_[i];
// }
// return extDataType_t::kOrtxUnknownType;
// }

private:
std::vector<std::unique_ptr<ortc::TensorBase>> tensors_;
std::vector<extDataType_t> tensor_types_;
// std::vector<extDataType_t> tensor_types_;
};

struct ReturnableStatus {
Expand Down
2 changes: 1 addition & 1 deletion shared/api/image_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ OrtxStatus ImageProcessor::PreProcess(ort_extensions::span<ImageRawData> image_d
operations_.back()->ResetTensors(allocator_);
if (status.IsOk()) {
r.SetTensors(std::move(img_result));
r.SetTensorTypes({kOrtxFloat, kOrtxInt64, kOrtxInt64});
// r.SetTensorTypes({kOrtxFloat, kOrtxInt64, kOrtxInt64});
}

return status;
Expand Down
40 changes: 40 additions & 0 deletions test/data/processor/image_to_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import tempfile
from PIL import Image

from onnxruntime_extensions.pp_api import ImageProcessor

img_proc = ImageProcessor(R"""
{
"processor": {
"name": "image_processing",
"transforms": [
{
"operation": {
"name": "decode_image",
"type": "DecodeImage",
"attrs": {
"color_space": "BGR"
}
}
},
{
"operation": {
"name": "convert_to_rgb",
"type": "ConvertRGB"
}
}
]
}
}""")

result = img_proc.pre_process(os.path.dirname(__file__) + "/standard_s.jpg")
np_img = img_proc.to_numpy(result)
print(np_img.shape, np_img.dtype)

# can save the image back to disk
img_rgb = np_img[0]
img_bgr = img_rgb[..., ::-1]
output_name = tempfile.gettempdir() + "/standard_s_bgr.jpg"
Image.fromarray(img_bgr).save(output_name)
print(output_name)

0 comments on commit 1bde236

Please sign in to comment.