Skip to content

Commit

Permalink
Move Gelu and LayerNorm fusion to L1 optimization (#21332)
Browse files Browse the repository at this point in the history
According to #20915, we
move the Gelu and LayerNorm fusion to L1 with a condition on the ONNX
opset the model imports (LayerNorm requires opset 16+ and Gelu requires
opset 20+.) If the opset version doesn't meet the requirements, the
fusion is delayed to L2 optimization since the internal contrib op
doesn't have a requirement for any specific ONNX opset.

---------

Co-authored-by: Scott McKay <[email protected]>
Co-authored-by: Edward Chen <[email protected]>
  • Loading branch information
3 people committed Sep 9, 2024
1 parent de7a02b commit 2cdc05f
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 41 deletions.
18 changes: 17 additions & 1 deletion onnxruntime/core/optimizer/gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ static bool IsSupportedDataType(const Node& node) {
[root]--> Gelu ==>
*/
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// Gelu is an official ONNX operator as of opset 20, so we can fuse in level 1 if it is available
const bool onnx_gelu_available = (onnx_version != version_map.end() && onnx_version->second >= 20);
const bool fuse_in_level_1 = onnx_gelu_available || allow_contrib_op_in_level_1_;
const auto op_domain = fuse_in_level_1 && onnx_gelu_available ? kOnnxDomain : kMSDomain;

if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
// The following check assumes that there is a GeluFusion instance registered in Level1 that may have
// already done this fusion, in which case we don't need to do it again.
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
return Status::OK();
}

const auto compatible_providers = GetCompatibleExecutionProviders();

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

Expand Down Expand Up @@ -162,7 +178,7 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
"Gelu",
"fused Gelu subgraphs ",
gelu_input_defs,
{}, {}, kMSDomain);
{}, {}, op_domain);

// Assign provider to this new node. Provider should be same as the provider for old node.
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/core/optimizer/gelu_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,26 @@ x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input.
*/
class GeluFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool allow_contrib_op_in_level_1_ = false;
std::string GetGeluFusionName(TransformerLevel level) {
switch (level) {
case TransformerLevel::Level1:
return "GeluFusionL1";
case TransformerLevel::Level2:
return "GeluFusionL2";
default:
return "GeluFusion";
}
}

public:
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
GeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetGeluFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
session_options.free_dimension_overrides));

transformers.emplace_back(std::make_unique<GeluFusion>());
transformers.emplace_back(std::make_unique<LayerNormFusion>());

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());

Expand Down Expand Up @@ -325,8 +328,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_eps));

transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps, level));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ data are casted to float/double to calculate for precision, so if there is any C
Such Cast Op can be the input of the sub-graph, or an Cast Op between the Div and Mul nodes.
*/
Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const auto& version_map = graph.DomainToVersionMap();
const auto& onnx_version = version_map.find(kOnnxDomain);
// LayerNorm is an official ONNX operator as of opset 17, so we can fuse in level 1 if it is available
const bool onnx_layernorm_available = (onnx_version != version_map.end() && onnx_version->second >= 17);
const bool fuse_in_level_1 = onnx_layernorm_available || allow_contrib_op_in_level_1_;

if ((optimization_level_ == TransformerLevel::Level1 && !fuse_in_level_1) ||
// The following check assumes that there is a LayerNormFusion instance registered in Level1 that may have
// already done this fusion, in which case we don't need to do it again.
(optimization_level_ == TransformerLevel::Level2 && fuse_in_level_1)) {
return Status::OK();
}

const auto compatible_providers = GetCompatibleExecutionProviders();

GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;
Expand Down
21 changes: 19 additions & 2 deletions onnxruntime/core/optimizer/layer_norm_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,26 @@ The formula corresponding to LayerNorm activation subgraph:
*/
class LayerNormFusion : public GraphTransformer {
private:
TransformerLevel optimization_level_ = TransformerLevel::Level1;
bool allow_contrib_op_in_level_1_ = false;
std::string GetLayerNormFusionName(TransformerLevel level) {
switch (level) {
case TransformerLevel::Level1:
return "LayerNormFusionL1";
case TransformerLevel::Level2:
return "LayerNormFusionL2";
default:
return "LayerNormFusion";
}
}

public:
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("LayerNormFusion", compatible_execution_providers) {}
LayerNormFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
TransformerLevel level = TransformerLevel::Level1, bool allow_contrib_op_in_level_1 = false) noexcept
: GraphTransformer(GetLayerNormFusionName(level), compatible_execution_providers),
optimization_level_(level),
allow_contrib_op_in_level_1_(allow_contrib_op_in_level_1) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
Expand Down
58 changes: 52 additions & 6 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4434,7 +4434,11 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4445,14 +4449,40 @@ TEST_F(GraphTransformationTests, GeluFusionTest) {
ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 1);
}

TEST_F(GraphTransformationTests, GeluFusionTest_Opset20) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_opset20.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4470,7 +4500,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4488,7 +4522,11 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4506,8 +4544,12 @@ TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphOutput) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand All @@ -4522,8 +4564,12 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2));
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
Expand Down
Loading

0 comments on commit 2cdc05f

Please sign in to comment.