Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not drop QDQ around linear Resize (fixes #21319) #22089

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
const std::string drop_action_name{"drop"};
const std::string drop_action_no_int16_name{"drop_no_int16_support"};
const std::string drop_action_no_int16_and_positive_scale_name{"drop_no_int16_support_and_positive_scale"};
const std::string drop_action_resize_nearest_name{"drop_resize_nearest"};
NTO::NodeLocation dq{NTO::NodeType::kInput, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};

Check warning on line 47 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:47:2: "NTO" is a misspelling of "NOT"

Check warning on line 47 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:47:23: "NTO" is a misspelling of "NOT"

Check warning on line 48 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:48:2: "NTO" is a misspelling of "NOT"

Check warning on line 48 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:48:22: "NTO" is a misspelling of "NOT"
// Move DQ input 0 to target input 0.
// Move Q output 0 to target output 0.
std::vector<NodeAndMoveInfo> moves{
Expand All @@ -55,6 +56,8 @@
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action_no_int16_and_positive_scale = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action_resize_nearest = std::make_unique<MergeIntoTargetFixed>(
std::vector<NodeAndMoveInfo>(moves)); // Copy before std::move(moves)
std::unique_ptr<Action> drop_action = std::make_unique<MergeIntoTargetFixed>(std::move(moves));

#if !defined(ORT_MINIMAL_BUILD)
Expand All @@ -67,14 +70,11 @@
// And cannot eliminate the QDQ for MaxPool if the scale is not positive, as a negative
// scale will change the ordering of the elements between quantized & de-quantized values.
std::vector<const char*> providers = {kCpuExecutionProvider, kDmlExecutionProvider};
std::unique_ptr<NodeSelector> selector_no_16bit = std::make_unique<QDQ::DropQDQNodesSelector>(false,
false,
true,
providers);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_name,
std::unique_ptr<NodeSelector> selector_resize_nearest = std::make_unique<QDQ::DropQDQNodesResizeNearestSelector>(false, false, true, providers);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_resize_nearest_name,

Check warning on line 74 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc:74: Lines should be <= 120 characters long [whitespace/line_length] [2]
{{"Resize", {}}},
std::move(selector_no_16bit),
std::move(drop_action_no_int16));
std::move(selector_resize_nearest),
std::move(drop_action_resize_nearest));

std::unique_ptr<NodeSelector> selector_no_16bit_and_positive_scale =
std::make_unique<QDQ::DropQDQNodesSelector>(false, true, false, providers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ bool DropQDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath());
}

bool DropQDQNodeResizeNearestSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!DropQDQNodeGroupSelector::Check(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}

const onnx::AttributeProto* mode = graph_utils::GetNodeAttribute(node, "mode");
// default mode is 'nearest'
return mode == nullptr || mode->s() == "nearest";
}

bool DropDQNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,27 @@
bool allow_nonpositive_scale = true)
: allow_16bit_(allow_16bit), allow_4bit_(allow_4bit), allow_nonpositive_scale_(allow_nonpositive_scale) {}

private:
protected:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;

private:
bool allow_16bit_;
bool allow_4bit_;
bool allow_nonpositive_scale_;
};

class DropQDQNodeResizeNearestSelector : public DropQDQNodeGroupSelector {
public:
using DropQDQNodeGroupSelector::DropQDQNodeGroupSelector;

private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};

// Single DQ -> node.
class DropDQNodeGroupSelector : public NodeGroupSelector {
public:
Expand Down Expand Up @@ -309,6 +320,14 @@
compatible_providers) {}
};

class DropQDQNodesResizeNearestSelector : public BaseSelector {
public:
explicit DropQDQNodesResizeNearestSelector(bool allow_16bit = false, bool allow_4bit = false, bool allow_nonpositive_scale = true,

Check warning on line 325 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h:325: Lines should be <= 120 characters long [whitespace/line_length] [2]
gsl::span<const char*> compatible_providers = {})
: BaseSelector(std::make_unique<DropQDQNodeResizeNearestSelector>(allow_16bit, allow_4bit, allow_nonpositive_scale),

Check warning on line 327 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h:327: Lines should be <= 120 characters long [whitespace/line_length] [2]
compatible_providers) {}
};

class DropDQNodesSelector : public BaseSelector {
public:
explicit DropDQNodesSelector(bool allow_16bit = false,
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,33 @@ TEST(QDQTransformerTests, Resize) {
test_case({2, 13, 12, 37}, rand_gen.Uniform<int64_t>(std::vector<int64_t>{4}, 1, 16), true /*use_contrib_qdq*/);
}

TEST(QDQTransformerTests, ResizeLinearNoFusion) {
auto test_case = [&](bool use_contrib_qdq = false) {
auto check_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["Resize"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1);
};

TransformerTester(BuildQDQResizeTestCase({1, 64, 64, 3},
{1, 32, 32, 3},
"linear", // mode
"half_pixel", // coordinate_transformation_mode
"round_prefer_floor", // nearest_mode
false, // add_dq_output_float
use_contrib_qdq),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2);
};

RandomValueGenerator rand_gen{optional<RandomValueGenerator::RandomSeedType>{2345}};
test_case();
test_case(true /*use_contrib_qdq*/);
}

TEST(QDQTransformerTests, Resize_No_Fusion) {
auto test_case = [&](const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& sizes_shape,
Expand Down
Loading