Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory access violations in the CPU float16 min and max operators #22135

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {

ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
auto num_elements = per_iter_bh.NumOutputElements();
auto num_elements = per_iter_bh.EigenInput1<MLFloat16>().rows();

const auto* input_1 = reinterpret_cast<const Eigen::half*>(per_iter_bh.EigenInput1<MLFloat16>().data());
ConstEigenVectorArrayMap<Eigen::half> input_1_vec_map(input_1, num_elements);
Expand All @@ -763,7 +763,7 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
}
},
[](BroadcastHelper& per_iter_bh) {
auto num_elements = per_iter_bh.NumOutputElements();
auto num_elements = per_iter_bh.EigenInput0<MLFloat16>().rows();

const auto* input_0 = reinterpret_cast<const Eigen::half*>(per_iter_bh.EigenInput0<MLFloat16>().data());
ConstEigenVectorArrayMap<Eigen::half> input_0_vec_map(input_0, num_elements);
Expand All @@ -778,7 +778,7 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) {
}
},
[](BroadcastHelper& per_iter_bh) {
auto num_elements = per_iter_bh.NumOutputElements();
auto num_elements = per_iter_bh.EigenInput0<MLFloat16>().rows();

const auto* input_0 = reinterpret_cast<const Eigen::half*>(per_iter_bh.EigenInput0<MLFloat16>().data());
ConstEigenVectorArrayMap<Eigen::half> input_0_vec_map(input_0, num_elements);
Expand Down
98 changes: 98 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1787,6 +1787,54 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_MLFloat16_MatrixVector) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {3, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f}));
test.AddInput<MLFloat16>("data_1", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddOutput<MLFloat16>("min", {3, 3},
MakeMLFloat16({0.0f, 0.0f, 0.0f,
-1.0f, -1.0f, -2.0f,
0.5f, 0.0f, 1.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_MLFloat16_VectorMatrix) {
OpTester test("Min", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddInput<MLFloat16>("data_1", {3, 4},
MakeMLFloat16({1.0f, 1.0f, 1.0f, -1.0f,
-0.5f, 0.0f, -2.0f, -1.25f,
0.5f, 0.0f, 2.0f, 1.5f}));
test.AddOutput<MLFloat16>("min", {3, 4},
MakeMLFloat16({0.0f, 0.0f, 0.0f, -1.0f,
-1.0f, -1.0f, -2.0f, -1.25f,
0.5f, 0.0f, 1.0f, 1.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_6) {
OpTester test("Max", 6);
std::vector<int64_t> dims{3, 3};
Expand Down Expand Up @@ -2137,6 +2185,56 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_MLFloat16_MatrixVector) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {4, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.0f, 0.5f, 0.75f,
0.5f, 0.0f, 2.0f}));
test.AddInput<MLFloat16>("data_1", {4, 1},
MakeMLFloat16({0.0f, -1.0f, 0.5f, 1.0f}));
test.AddOutput<MLFloat16>("max", {4, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -1.0f,
0.5f, 0.5f, 0.75f,
1.0f, 1.0f, 2.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_MLFloat16_VectorMatrix) {
OpTester test("Max", 12);
test.AddInput<MLFloat16>("data_0", {3, 1},
MakeMLFloat16({0.0f, -1.0f, 1.0f}));
test.AddInput<MLFloat16>("data_1", {3, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f}));
test.AddOutput<MLFloat16>("max", {3, 3},
MakeMLFloat16({1.0f, 1.0f, 1.0f,
-0.5f, 0.0f, -1.0f,
1.0f, 1.0f, 2.0f}));
if (nullptr != DefaultCpuExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Not) {
OpTester test("Not");
std::vector<int64_t> dims{2};
Expand Down
Loading