From 13f0b7878994a84dc7540f37f4a5b7346c0de649 Mon Sep 17 00:00:00 2001 From: Liu Dingming Date: Tue, 17 Sep 2024 09:02:01 +0800 Subject: [PATCH] feature(prost-build): Generate less boxed if nested type is boxed manually --- prost-build/src/code_generator.rs | 2 +- prost-build/src/message_graph.rs | 142 ++++++++++++++++-- tests/src/build.rs | 13 ++ tests/src/lib.rs | 8 + tests/src/nesting_complex.proto | 47 ++++++ .../nesting_complex/boxed/nesting_complex.rs | 56 +++++++ tests/src/nesting_complex/nesting_complex.rs | 56 +++++++ 7 files changed, 310 insertions(+), 14 deletions(-) create mode 100644 tests/src/nesting_complex.proto create mode 100644 tests/src/nesting_complex/boxed/nesting_complex.rs create mode 100644 tests/src/nesting_complex/nesting_complex.rs diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 24b5194e5..3eec41994 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -1075,7 +1075,7 @@ impl<'a> CodeGenerator<'a> { && (fd_type == Type::Message || fd_type == Type::Group) && self .message_graph - .is_nested(field.type_name(), fq_message_name) + .is_directly_nested(field.type_name(), fq_message_name) { return true; } diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index e2bcad918..b63a4b004 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -1,8 +1,8 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; -use petgraph::algo::has_path_connecting; use petgraph::graph::NodeIndex; -use petgraph::Graph; +use petgraph::visit::{EdgeRef, VisitMap}; +use petgraph::{Direction, Graph}; use prost_types::{ field_descriptor_proto::{Label, Type}, @@ -15,9 +15,13 @@ use crate::path::PathMap; /// The goal is to recognize when message types are recursively nested, so /// that fields can be boxed when necessary. pub struct MessageGraph { + /// Map index: HashMap, - graph: Graph, + /// Graph with fq type name as node, field name as edge + graph: Graph, + /// Map messages: HashMap, + /// Manually boxed fields boxed: PathMap<()>, } @@ -71,7 +75,8 @@ impl MessageGraph { for field in &msg.field { if field.r#type() == Type::Message && field.label() != Label::Repeated { let field_index = self.get_or_insert_index(field.type_name.clone().unwrap()); - self.graph.add_edge(msg_index, field_index, ()); + self.graph + .add_edge(msg_index, field_index, field.name.clone().unwrap()); } } self.messages.insert(msg_name.clone(), msg.clone()); @@ -86,8 +91,9 @@ impl MessageGraph { self.messages.get(message) } - /// Returns true if message type `inner` is nested in message type `outer`. - pub fn is_nested(&self, outer: &str, inner: &str) -> bool { + /// Returns true if message type `inner` is nested in message type `outer`, + /// and no field edge in the chain of dependencies is manually boxed. + pub fn is_directly_nested(&self, outer: &str, inner: &str) -> bool { let outer = match self.index.get(outer) { Some(outer) => *outer, None => return false, @@ -97,7 +103,12 @@ impl MessageGraph { None => return false, }; - has_path_connecting(&self.graph, outer, inner, None) + // Check if `inner` is nested in `outer` and ensure that all edge fields are not boxed manually. + is_connected_with_edge_filter(&self.graph, outer, inner, |node, field_name| { + self.boxed + .get_first_field(&self.graph[node], field_name) + .is_none() + }) } /// Returns `true` if this message can automatically derive Copy trait. @@ -123,11 +134,11 @@ impl MessageGraph { false } else if field.r#type() == Type::Message { // nested and boxed messages cannot derive Copy - if self.is_nested(field.type_name(), fq_message_name) - || self - .boxed - .get_first_field(fq_message_name, field.name()) - .is_some() + if self + .boxed + .get_first_field(fq_message_name, field.name()) + .is_some() + || self.is_directly_nested(field.type_name(), fq_message_name) { false } else { @@ -154,3 +165,108 @@ impl MessageGraph { } } } + +/// Check two nodes is connected with edge filter +fn is_connected_with_edge_filter( + graph: &Graph, + start: NodeIndex, + end: NodeIndex, + mut is_good_edge: F, +) -> bool +where + F: FnMut(NodeIndex, &E) -> bool, +{ + fn visitor( + graph: &Graph, + start: NodeIndex, + end: NodeIndex, + is_good_edge: &mut F, + visited: &mut HashSet, + ) -> bool + where + F: FnMut(NodeIndex, &E) -> bool, + { + if start == end { + return true; + } + visited.visit(start); + for edge in graph.edges_directed(start, Direction::Outgoing) { + // if the edge doesn't pass the filter, skip it + if !is_good_edge(start, edge.weight()) { + continue; + } + let target = edge.target(); + if visited.is_visited(&target) { + continue; + } + if visitor(graph, target, end, is_good_edge, visited) { + return true; + } + } + false + } + let mut visited = HashSet::new(); + visitor(graph, start, end, &mut is_good_edge, &mut visited) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_connected() { + let mut graph = Graph::new(); + let n1 = graph.add_node(1); + let n2 = graph.add_node(2); + let n3 = graph.add_node(3); + let n4 = graph.add_node(4); + let n5 = graph.add_node(5); + let n6 = graph.add_node(6); + let n7 = graph.add_node(7); + let n8 = graph.add_node(8); + graph.add_edge(n1, n2, 1.); + graph.add_edge(n2, n3, 2.); + graph.add_edge(n3, n4, 3.); + graph.add_edge(n4, n5, 4.); + graph.add_edge(n5, n6, 5.); + graph.add_edge(n6, n7, 6.); + graph.add_edge(n7, n8, 7.); + graph.add_edge(n8, n1, 8.); + assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| { + dbg!(edge); + true + }),); + assert!(is_connected_with_edge_filter(&graph, n2, n1, |_, edge| { + dbg!(edge); + edge < &8.5 + }),); + assert!(!is_connected_with_edge_filter(&graph, n2, n1, |_, edge| { + dbg!(edge); + edge < &7.5 + }),); + } + + #[test] + fn test_connected_multi_circle() { + let mut graph = Graph::new(); + let n0 = graph.add_node(0); + let n1 = graph.add_node(1); + let n2 = graph.add_node(2); + let n3 = graph.add_node(3); + let n4 = graph.add_node(4); + graph.add_edge(n0, n1, 0.); + graph.add_edge(n1, n2, 1.); + graph.add_edge(n2, n3, 2.); + graph.add_edge(n3, n0, 3.); + graph.add_edge(n1, n4, 1.5); + graph.add_edge(n4, n0, 2.5); + assert!(is_connected_with_edge_filter(&graph, n1, n0, |_, edge| { + dbg!(edge); + edge < &2.8 + }),); + assert!(!is_connected_with_edge_filter(&graph, n1, n0, |_, edge| { + dbg!(edge); + edge < &2.1 + }),); + } +} diff --git a/tests/src/build.rs b/tests/src/build.rs index b707fb270..2a4e5a09f 100644 --- a/tests/src/build.rs +++ b/tests/src/build.rs @@ -151,6 +151,19 @@ fn main() { std::fs::create_dir_all(&out_path).unwrap(); + prost_build::Config::new() + .out_dir(src.join("nesting_complex/boxed")) + .boxed("Foo.bar") + .boxed("BazB.baz_c") + .boxed("BakC.bak_d") + .compile_protos(&[src.join("nesting_complex.proto")], includes) + .unwrap(); + + prost_build::Config::new() + .out_dir(src.join("nesting_complex/")) + .compile_protos(&[src.join("nesting_complex.proto")], includes) + .unwrap(); + prost_build::Config::new() .bytes(["."]) .out_dir(out_path) diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 614baa969..990b4790a 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -134,6 +134,14 @@ pub mod proto3 { } } +pub mod nesting_complex_boxed { + include!("nesting_complex/boxed/nesting_complex.rs"); +} + +pub mod nesting_complex { + include!("nesting_complex/nesting_complex.rs"); +} + pub mod invalid { pub mod doctest { include!(concat!(env!("OUT_DIR"), "/invalid.doctest.rs")); diff --git a/tests/src/nesting_complex.proto b/tests/src/nesting_complex.proto new file mode 100644 index 000000000..19b6b7a85 --- /dev/null +++ b/tests/src/nesting_complex.proto @@ -0,0 +1,47 @@ +syntax = "proto2"; + +package nesting_complex; + +// ----- Directly nested +message Foo { + optional Bar bar = 1; +} + +message Bar { + optional Foo foo = 1; +} + +// ----- Transitively nested +message BazA { + optional BazB baz_b = 1; +} + +message BazB { + optional BazC baz_c = 1; +} + +message BazC { + optional BazA baz_a = 1; +} + +// ----- Transitively nested in two chain +message BakA { + optional BakB bak_b = 1; +} + +message BakB { + optional BakC bak_c = 1; + optional BakE bak_e = 2; +} + +message BakC { + optional BakD bak_d = 1; +} + +message BakD { + optional BakA bak_a = 1; +} + +message BakE { + optional BakA bak_a = 1; +} diff --git a/tests/src/nesting_complex/boxed/nesting_complex.rs b/tests/src/nesting_complex/boxed/nesting_complex.rs new file mode 100644 index 000000000..3a3a48ed1 --- /dev/null +++ b/tests/src/nesting_complex/boxed/nesting_complex.rs @@ -0,0 +1,56 @@ +// This file is @generated by prost-build. +/// ----- Directly nested +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Foo { + #[prost(message, optional, boxed, tag = "1")] + pub bar: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Bar { + #[prost(message, optional, tag = "1")] + pub foo: ::core::option::Option, +} +/// ----- Transitively nested +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BazA { + #[prost(message, optional, tag = "1")] + pub baz_b: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BazB { + #[prost(message, optional, boxed, tag = "1")] + pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BazC { + #[prost(message, optional, tag = "1")] + pub baz_a: ::core::option::Option, +} +/// ----- Transitively nested in two chain +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakA { + #[prost(message, optional, boxed, tag = "1")] + pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakB { + #[prost(message, optional, tag = "1")] + pub bak_c: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakC { + #[prost(message, optional, boxed, tag = "1")] + pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakD { + #[prost(message, optional, tag = "1")] + pub bak_a: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakE { + #[prost(message, optional, boxed, tag = "1")] + pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box>, +} diff --git a/tests/src/nesting_complex/nesting_complex.rs b/tests/src/nesting_complex/nesting_complex.rs new file mode 100644 index 000000000..d1ce1b125 --- /dev/null +++ b/tests/src/nesting_complex/nesting_complex.rs @@ -0,0 +1,56 @@ +// This file is @generated by prost-build. +/// ----- Directly nested +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Foo { + #[prost(message, optional, boxed, tag = "1")] + pub bar: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Bar { + #[prost(message, optional, boxed, tag = "1")] + pub foo: ::core::option::Option<::prost::alloc::boxed::Box>, +} +/// ----- Transitively nested +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BazA { + #[prost(message, optional, boxed, tag = "1")] + pub baz_b: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BazB { + #[prost(message, optional, boxed, tag = "1")] + pub baz_c: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BazC { + #[prost(message, optional, boxed, tag = "1")] + pub baz_a: ::core::option::Option<::prost::alloc::boxed::Box>, +} +/// ----- Transitively nested in two chain +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakA { + #[prost(message, optional, boxed, tag = "1")] + pub bak_b: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakB { + #[prost(message, optional, boxed, tag = "1")] + pub bak_c: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub bak_e: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakC { + #[prost(message, optional, boxed, tag = "1")] + pub bak_d: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakD { + #[prost(message, optional, boxed, tag = "1")] + pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BakE { + #[prost(message, optional, boxed, tag = "1")] + pub bak_a: ::core::option::Option<::prost::alloc::boxed::Box>, +}