Skip to content

Commit

Permalink
[WebNN EP] Use both MLOperandDescriptor.dimensions and MLOperandDescr…
Browse files Browse the repository at this point in the history
…iptor.shape (#22121)

The spec renames MLOperandDescriptor.dimensions to
MLOperandDescriptor.shape, in order to support older Chromium versions,
we will keep both in WebNN EP for a while.

Fixed #22120
  • Loading branch information
Honry committed Sep 19, 2024
1 parent 944d873 commit e33b08e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val desc = emscripten::val::object();
desc.set("dataType", "uint8");
desc.set("dimensions", emscripten::val::array(dims));
desc.set("shape", emscripten::val::array(dims));
const auto num_elements = narrow<uint32_t>(Product(mask_shape));
emscripten::val ones_buffer = emscripten::val::global("Uint8Array").new_(num_elements);
ones_buffer.call<void>("fill", 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
emscripten::val dims = emscripten::val::array();
dims.call<void>("push", rank);
desc.set("dimensions", dims);
desc.set("shape", dims);
emscripten::val shape_buffer = emscripten::val::global("BigInt64Array").new_(emscripten::val::array(input_shape));
emscripten::val shape_constant = model_builder.GetBuilder().call<emscripten::val>("constant", desc, shape_buffer);

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ Status ModelBuilder::RegisterInitializers() {
[](int64_t dim) -> int32_t { return SafeInt<int32_t>(dim); });

emscripten::val desc = emscripten::val::object();
// TODO: @Honry, remove all MLOperandDescriptor.dimensions usage in the future.
// MLOperandDescriptor.dimensions is deprecated in WebNN API, we need to keep it
// in WebNN EP for a while to support older Chromium versions.
desc.set("dimensions", emscripten::val::array(dims));
desc.set("shape", emscripten::val::array(dims));
auto data_type = tensor.data_type();
emscripten::val operand = emscripten::val::object();
if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) {
Expand Down Expand Up @@ -203,6 +207,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
emscripten::val desc = emscripten::val::object();

desc.set("dimensions", emscripten::val::array(dims));
desc.set("shape", emscripten::val::array(dims));

int32_t data_type;
{ // type
Expand Down Expand Up @@ -303,6 +308,7 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer(
}

desc.set("dimensions", emscripten::val::array(shape));
desc.set("shape", emscripten::val::array(shape));
emscripten::val operand = emscripten::val::object();
// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
// buffers in JS side. Simply create a copy to fix it.
Expand Down Expand Up @@ -361,6 +367,7 @@ const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) {
emscripten::val desc = emscripten::val::object();
emscripten::val dims = emscripten::val::array();
desc.set("dimensions", dims);
desc.set("shape", dims);
emscripten::val zero_buffer = emscripten::val::undefined();
if (!SetWebnnDataType(desc, data_type)) {
ORT_THROW("Unsupported data type: " + std::to_string(data_type));
Expand Down

0 comments on commit e33b08e

Please sign in to comment.