Skip to content

Commit

Permalink
More tests needed
Browse files Browse the repository at this point in the history
  • Loading branch information
gzanitti committed Aug 13, 2024
1 parent 1688aed commit a0d736b
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 74 deletions.
8 changes: 3 additions & 5 deletions pil-analyzer/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1801,15 +1801,13 @@ mod test {
let fe = || fe();
namespace F(4);
trait Do<T, Q> {
add1: T, T -> T,
add2: T, T -> Q,
add: T, T -> Q,
sub: T, T -> Q,
cast: T -> Q,
}
impl Do<int, fe> {
add1: |a, b| a + b,
add2: |a, b| std::convert::fe(a + b),
add: |a, b| std::convert::fe(a + b),
sub: |a, b| std::convert::fe(a - b),
cast: |a| std::convert::fe(a),
}
Expand All @@ -1818,7 +1816,7 @@ mod test {
v => {
let one: int = 1;
let two: int = 2;
v + Do::add2(one, two)
v + Do::add(one, two)
},
};
Expand Down
10 changes: 8 additions & 2 deletions pil-analyzer/src/pil_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,14 @@ impl PILAnalyzer {
}

pub fn traits_resolution(&mut self) {
TraitsProcessor::new(&mut self.definitions, &self.implementations)
.traits_resolution(&mut self.identities)
// TraitsProcessor::new(&mut self.definitions, &self.implementations)
// .traits_resolution(&mut self.identities)
TraitsProcessor::new(
&mut self.definitions,
&mut self.identities,
&self.implementations,
)
.traits_resolution();
}

pub fn condense<T: FieldElement>(self) -> Analyzed<T> {
Expand Down
175 changes: 111 additions & 64 deletions pil-analyzer/src/traits_processor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::panic;
use std::{collections::HashMap, iter::once};
use std::collections::HashMap;

use itertools::Itertools;
use powdr_ast::{
Expand All @@ -12,24 +12,24 @@ use powdr_ast::{

pub struct TraitsProcessor<'a> {
definitions: &'a mut HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
identities: &'a mut Vec<Identity<SelectedExpressions<Expression>>>,
implementations: &'a HashMap<String, Vec<TraitImplementation<Expression>>>,
}

impl<'a> TraitsProcessor<'a> {
pub fn new(
definitions: &'a mut HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
identities: &'a mut Vec<Identity<SelectedExpressions<Expression>>>,
implementations: &'a HashMap<String, Vec<TraitImplementation<Expression>>>,
) -> Self {
Self {
definitions,
identities,
implementations,
}
}

pub fn traits_resolution(
&mut self,
identities: &mut [Identity<SelectedExpressions<Expression>>],
) {
pub fn traits_resolution(&mut self) {
let keys: Vec<String> = self
.definitions
.iter()
Expand All @@ -42,77 +42,122 @@ impl<'a> TraitsProcessor<'a> {
.map(|(name, _)| name.clone())
.collect();

let refs_in_identities = self.collect_refs_in_identities(identities);

for name in keys {
self.resolve_trait(&name, &refs_in_identities);
self.resolve_trait(&name);
}

self.resolve_trait_identities();
}

fn resolve_trait_identities(&mut self) {
let mut updates = Vec::new();
for (index, identity) in self.identities.iter().enumerate() {
let refs_in_identity = self.collect_refs_in_identity(identity);
for collected_ref in refs_in_identity {
let (trait_name, resolved_impl_pos) =
self.resolve_trait_function(collected_ref.clone());
updates.push((index, trait_name, collected_ref.1, resolved_impl_pos));
}
}

for (index, trait_name, type_args, resolved_impl_pos) in updates {
let identity = &mut self.identities[index];
TraitsProcessor::update_references_in_identity(
identity,
&trait_name,
&type_args,
&resolved_impl_pos,
);
}
}

fn collect_refs_in_identities(
fn collect_refs_in_identity(
&self,
identities: &mut [Identity<SelectedExpressions<Expression>>],
identity: &Identity<SelectedExpressions<Expression>>,
) -> Vec<(String, Vec<Type>)> {
identities
let Identity {
left:
SelectedExpressions {
selector: selector_left,
expressions: expressions_left,
},
right:
SelectedExpressions {
selector: selector_right,
expressions: expressions_right,
},
..
} = identity;

selector_left
.iter()
.flat_map(|identity| {
let Identity {
left:
SelectedExpressions {
selector: selector_left,
expressions: expressions_left,
},
right:
SelectedExpressions {
selector: selector_right,
expressions: expressions_right,
},
..
} = identity;

selector_left
.iter()
.chain(once(expressions_left.as_ref()))
.chain(selector_right.iter())
.chain(once(expressions_right.as_ref()))
.flat_map(move |e| {
e.all_children().filter_map(move |e| match e {
Expression::Reference(
_,
Reference::Poly(PolynomialReference {
name,
type_args: Some(types),
..
}),
) => Some((name.clone(), types.clone())),
_ => None,
})
})
.chain(std::iter::once(expressions_left.as_ref()))
.chain(selector_right.iter())
.chain(std::iter::once(expressions_right.as_ref()))
.flat_map(move |e| {
e.all_children().filter_map(move |e| match e {
Expression::Reference(
_,
Reference::Poly(PolynomialReference {
name,
type_args: Some(types),
..
}),
) => Some((name.clone(), types.clone())),
_ => None,
})
})
.collect()
}

fn resolve_trait(&mut self, current: &str, refs_in_identities: &[(String, Vec<Type>)]) {
fn update_references_in_identity(
identity: &mut Identity<SelectedExpressions<Expression>>,
trait_name: &str,
type_args: &[Type],
resolved_impl_pos: &HashMap<String, Box<Expression>>,
) {
let Identity {
left:
SelectedExpressions {
selector: selector_left,
expressions: expressions_left,
},
right:
SelectedExpressions {
selector: selector_right,
expressions: expressions_right,
},
..
} = identity;

let to_update = selector_left
.iter_mut()
.chain(std::iter::once(expressions_left.as_mut()))
.chain(selector_right.iter_mut())
.chain(std::iter::once(expressions_right.as_mut()));

for expr in to_update {
update_reference(trait_name, type_args, expr, resolved_impl_pos);
}
}

fn resolve_trait(&mut self, current: &str) {
let current_def = &self.definitions.get(current).unwrap().1;
let refs_in_def = self.collect_refs(current_def);
let refs = refs_in_def.iter().chain(refs_in_identities.iter());
let refs_in_def = self.collect_refs_in_def(current_def);

for collected_ref in refs {
let Some((trait_name, resolved_impl_pos)) =
self.resolve_trait_function(collected_ref.clone())
else {
continue;
};
for collected_ref in refs_in_def {
let (trait_name, resolved_impl_pos) =
self.resolve_trait_function(collected_ref.clone());

if let Some(FunctionValueDefinition::Expression(TypedExpression { e: expr, .. })) =
self.definitions.get_mut(current).unwrap().1.as_mut()
{
update_reference(&trait_name, expr, &resolved_impl_pos);
update_reference(&trait_name, &collected_ref.1, expr, &resolved_impl_pos);
}
}
}

fn collect_refs(
fn collect_refs_in_def(
&self,
current_def: &Option<FunctionValueDefinition>,
) -> Vec<(String, Vec<Type>)> {
Expand All @@ -138,12 +183,12 @@ impl<'a> TraitsProcessor<'a> {
fn resolve_trait_function(
&self,
collected_ref: (String, Vec<Type>),
) -> Option<(String, HashMap<String, Box<Expression>>)> {
) -> (String, HashMap<String, Box<Expression>>) {
let mut resolved_impl_pos = HashMap::new();
let Some(FunctionValueDefinition::TraitFunction(ref trait_decl, ref trait_fn)) =
self.definitions.get(&collected_ref.0).unwrap().1
else {
return None;
return ("".to_string(), resolved_impl_pos);
};

if let Some(impls) = self.implementations.get(&trait_decl.name) {
Expand All @@ -162,20 +207,22 @@ impl<'a> TraitsProcessor<'a> {
}
}

Some((
(
format!("{}::{}", trait_decl.name.replace('.', "::"), trait_fn.name),
resolved_impl_pos,
))
)
}
}

fn update_reference(
ref_name: &str,
type_args: &[Type],
expr: &mut Expression,
resolved_impl_pos: &HashMap<String, Box<Expression>>,
) {
fn process_expr(
ref_name: &str,
type_args: &[Type],
c: &mut Expression,
resolved_impl_pos: &HashMap<String, Box<Expression>>,
) {
Expand All @@ -184,7 +231,7 @@ fn update_reference(
Reference::Poly(PolynomialReference {
name,
poly_id,
type_args,
type_args: current_type_args,
..
}),
) = c
Expand All @@ -194,18 +241,18 @@ fn update_reference(
sr.clone(),
Reference::Poly(PolynomialReference {
name: name.clone(),
type_args: type_args.clone(),
poly_id: *poly_id,
type_args: current_type_args.clone(),
poly_id: poly_id.clone(),
resolved_impls: resolved_impl_pos.clone(),
}),
);
}
}

for child in c.children_mut() {
process_expr(ref_name, child, resolved_impl_pos);
process_expr(ref_name, type_args, child, resolved_impl_pos);
}
}

process_expr(ref_name, expr, resolved_impl_pos);
process_expr(ref_name, type_args, expr, resolved_impl_pos);
}
6 changes: 3 additions & 3 deletions pil-analyzer/tests/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,12 @@ fn defined_trait() {
trait Add<T> {
add: T, T -> T,
}
impl Add<int> {
impl Add<fe> {
add: |a, b| a + b,
}
let r: int = Add::add(3, 4);
let r: fe = Add::add(3, 4);
";
type_check(input, &[("r", "", "int")]);
type_check(input, &[("r", "", "fe")]);
}

#[test]
Expand Down

0 comments on commit a0d736b

Please sign in to comment.