diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index fed415710..a4ee167c3 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -1801,15 +1801,13 @@ mod test { let fe = || fe(); namespace F(4); trait Do { - add1: T, T -> T, - add2: T, T -> Q, + add: T, T -> Q, sub: T, T -> Q, cast: T -> Q, } impl Do { - add1: |a, b| a + b, - add2: |a, b| std::convert::fe(a + b), + add: |a, b| std::convert::fe(a + b), sub: |a, b| std::convert::fe(a - b), cast: |a| std::convert::fe(a), } @@ -1818,7 +1816,7 @@ mod test { v => { let one: int = 1; let two: int = 2; - v + Do::add2(one, two) + v + Do::add(one, two) }, }; diff --git a/pil-analyzer/src/pil_analyzer.rs b/pil-analyzer/src/pil_analyzer.rs index a0c36c794..09e600826 100644 --- a/pil-analyzer/src/pil_analyzer.rs +++ b/pil-analyzer/src/pil_analyzer.rs @@ -327,8 +327,14 @@ impl PILAnalyzer { } pub fn traits_resolution(&mut self) { - TraitsProcessor::new(&mut self.definitions, &self.implementations) - .traits_resolution(&mut self.identities) + // TraitsProcessor::new(&mut self.definitions, &self.implementations) + // .traits_resolution(&mut self.identities) + TraitsProcessor::new( + &mut self.definitions, + &mut self.identities, + &self.implementations, + ) + .traits_resolution(); } pub fn condense(self) -> Analyzed { diff --git a/pil-analyzer/src/traits_processor.rs b/pil-analyzer/src/traits_processor.rs index 83fff395d..6dd34ecf8 100644 --- a/pil-analyzer/src/traits_processor.rs +++ b/pil-analyzer/src/traits_processor.rs @@ -1,5 +1,5 @@ use core::panic; -use std::{collections::HashMap, iter::once}; +use std::collections::HashMap; use itertools::Itertools; use powdr_ast::{ @@ -12,24 +12,24 @@ use powdr_ast::{ pub struct TraitsProcessor<'a> { definitions: &'a mut HashMap)>, + identities: &'a mut Vec>>, implementations: &'a HashMap>>, } impl<'a> TraitsProcessor<'a> { pub fn new( definitions: &'a mut HashMap)>, + identities: &'a mut Vec>>, implementations: &'a HashMap>>, ) -> Self { Self { definitions, + identities, implementations, } } - pub fn traits_resolution( - &mut self, - identities: &mut [Identity>], - ) { + pub fn traits_resolution(&mut self) { let keys: Vec = self .definitions .iter() @@ -42,77 +42,122 @@ impl<'a> TraitsProcessor<'a> { .map(|(name, _)| name.clone()) .collect(); - let refs_in_identities = self.collect_refs_in_identities(identities); - for name in keys { - self.resolve_trait(&name, &refs_in_identities); + self.resolve_trait(&name); + } + + self.resolve_trait_identities(); + } + + fn resolve_trait_identities(&mut self) { + let mut updates = Vec::new(); + for (index, identity) in self.identities.iter().enumerate() { + let refs_in_identity = self.collect_refs_in_identity(identity); + for collected_ref in refs_in_identity { + let (trait_name, resolved_impl_pos) = + self.resolve_trait_function(collected_ref.clone()); + updates.push((index, trait_name, collected_ref.1, resolved_impl_pos)); + } + } + + for (index, trait_name, type_args, resolved_impl_pos) in updates { + let identity = &mut self.identities[index]; + TraitsProcessor::update_references_in_identity( + identity, + &trait_name, + &type_args, + &resolved_impl_pos, + ); } } - fn collect_refs_in_identities( + fn collect_refs_in_identity( &self, - identities: &mut [Identity>], + identity: &Identity>, ) -> Vec<(String, Vec)> { - identities + let Identity { + left: + SelectedExpressions { + selector: selector_left, + expressions: expressions_left, + }, + right: + SelectedExpressions { + selector: selector_right, + expressions: expressions_right, + }, + .. + } = identity; + + selector_left .iter() - .flat_map(|identity| { - let Identity { - left: - SelectedExpressions { - selector: selector_left, - expressions: expressions_left, - }, - right: - SelectedExpressions { - selector: selector_right, - expressions: expressions_right, - }, - .. - } = identity; - - selector_left - .iter() - .chain(once(expressions_left.as_ref())) - .chain(selector_right.iter()) - .chain(once(expressions_right.as_ref())) - .flat_map(move |e| { - e.all_children().filter_map(move |e| match e { - Expression::Reference( - _, - Reference::Poly(PolynomialReference { - name, - type_args: Some(types), - .. - }), - ) => Some((name.clone(), types.clone())), - _ => None, - }) - }) + .chain(std::iter::once(expressions_left.as_ref())) + .chain(selector_right.iter()) + .chain(std::iter::once(expressions_right.as_ref())) + .flat_map(move |e| { + e.all_children().filter_map(move |e| match e { + Expression::Reference( + _, + Reference::Poly(PolynomialReference { + name, + type_args: Some(types), + .. + }), + ) => Some((name.clone(), types.clone())), + _ => None, + }) }) .collect() } - fn resolve_trait(&mut self, current: &str, refs_in_identities: &[(String, Vec)]) { + fn update_references_in_identity( + identity: &mut Identity>, + trait_name: &str, + type_args: &[Type], + resolved_impl_pos: &HashMap>, + ) { + let Identity { + left: + SelectedExpressions { + selector: selector_left, + expressions: expressions_left, + }, + right: + SelectedExpressions { + selector: selector_right, + expressions: expressions_right, + }, + .. + } = identity; + + let to_update = selector_left + .iter_mut() + .chain(std::iter::once(expressions_left.as_mut())) + .chain(selector_right.iter_mut()) + .chain(std::iter::once(expressions_right.as_mut())); + + for expr in to_update { + update_reference(trait_name, type_args, expr, resolved_impl_pos); + } + } + + fn resolve_trait(&mut self, current: &str) { let current_def = &self.definitions.get(current).unwrap().1; - let refs_in_def = self.collect_refs(current_def); - let refs = refs_in_def.iter().chain(refs_in_identities.iter()); + let refs_in_def = self.collect_refs_in_def(current_def); - for collected_ref in refs { - let Some((trait_name, resolved_impl_pos)) = - self.resolve_trait_function(collected_ref.clone()) - else { - continue; - }; + for collected_ref in refs_in_def { + let (trait_name, resolved_impl_pos) = + self.resolve_trait_function(collected_ref.clone()); if let Some(FunctionValueDefinition::Expression(TypedExpression { e: expr, .. })) = self.definitions.get_mut(current).unwrap().1.as_mut() { - update_reference(&trait_name, expr, &resolved_impl_pos); + update_reference(&trait_name, &collected_ref.1, expr, &resolved_impl_pos); } } } - fn collect_refs( + fn collect_refs_in_def( &self, current_def: &Option, ) -> Vec<(String, Vec)> { @@ -138,12 +183,12 @@ impl<'a> TraitsProcessor<'a> { fn resolve_trait_function( &self, collected_ref: (String, Vec), - ) -> Option<(String, HashMap>)> { + ) -> (String, HashMap>) { let mut resolved_impl_pos = HashMap::new(); let Some(FunctionValueDefinition::TraitFunction(ref trait_decl, ref trait_fn)) = self.definitions.get(&collected_ref.0).unwrap().1 else { - return None; + return ("".to_string(), resolved_impl_pos); }; if let Some(impls) = self.implementations.get(&trait_decl.name) { @@ -162,20 +207,22 @@ impl<'a> TraitsProcessor<'a> { } } - Some(( + ( format!("{}::{}", trait_decl.name.replace('.', "::"), trait_fn.name), resolved_impl_pos, - )) + ) } } fn update_reference( ref_name: &str, + type_args: &[Type], expr: &mut Expression, resolved_impl_pos: &HashMap>, ) { fn process_expr( ref_name: &str, + type_args: &[Type], c: &mut Expression, resolved_impl_pos: &HashMap>, ) { @@ -184,7 +231,7 @@ fn update_reference( Reference::Poly(PolynomialReference { name, poly_id, - type_args, + type_args: current_type_args, .. }), ) = c @@ -194,8 +241,8 @@ fn update_reference( sr.clone(), Reference::Poly(PolynomialReference { name: name.clone(), - type_args: type_args.clone(), - poly_id: *poly_id, + type_args: current_type_args.clone(), + poly_id: poly_id.clone(), resolved_impls: resolved_impl_pos.clone(), }), ); @@ -203,9 +250,9 @@ fn update_reference( } for child in c.children_mut() { - process_expr(ref_name, child, resolved_impl_pos); + process_expr(ref_name, type_args, child, resolved_impl_pos); } } - process_expr(ref_name, expr, resolved_impl_pos); + process_expr(ref_name, type_args, expr, resolved_impl_pos); } diff --git a/pil-analyzer/tests/types.rs b/pil-analyzer/tests/types.rs index 2b1ed018b..d0e192499 100644 --- a/pil-analyzer/tests/types.rs +++ b/pil-analyzer/tests/types.rs @@ -567,12 +567,12 @@ fn defined_trait() { trait Add { add: T, T -> T, } - impl Add { + impl Add { add: |a, b| a + b, } - let r: int = Add::add(3, 4); + let r: fe = Add::add(3, 4); "; - type_check(input, &[("r", "", "int")]); + type_check(input, &[("r", "", "fe")]); } #[test]