Skip to content

Commit

Permalink
More evaluator, but errors
Browse files Browse the repository at this point in the history
  • Loading branch information
gzanitti committed Aug 9, 2024
1 parent 433d110 commit 091bf73
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 28 deletions.
2 changes: 1 addition & 1 deletion ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ pub struct PolynomialReference {
/// Guaranteed to be Some(_) after type checking is completed.
pub type_args: Option<Vec<Type>>,
///
pub resolved_impl_pos: HashMap<String, usize>,
pub resolved_impls: HashMap<String, Box<Expression>>,
}

#[derive(
Expand Down
65 changes: 64 additions & 1 deletion pil-analyzer/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ pub enum Value<'a, T> {
Enum(&'a str, Option<Vec<Arc<Self>>>),
BuiltinFunction(BuiltinFunction),
Expression(AlgebraicExpression<T>),
TraitFunction(TraitFunction<'a, T>),
}

impl<'a, T: FieldElement> From<T> for Value<'a, T> {
Expand Down Expand Up @@ -215,6 +216,12 @@ impl<'a, T: FieldElement> Value<'a, T> {
Value::Enum(name, _) => name.to_string(),
Value::BuiltinFunction(b) => format!("builtin_{b:?}"),
Value::Expression(_) => "expr".to_string(),
Value::TraitFunction(trait_function) => {
format!(
"trait_function<{}>",
trait_function.type_args.values().format(", ")
)
}
}
}

Expand Down Expand Up @@ -370,6 +377,11 @@ impl<'a, T: Display> Display for Value<'a, T> {
}
Value::BuiltinFunction(b) => write!(f, "{b:?}"),
Value::Expression(e) => write!(f, "{e}"),
Value::TraitFunction(trait_function) => write!(
f,
"trait_function<{}>",
trait_function.type_args.values().format(", ")
),
}
}
}
Expand Down Expand Up @@ -400,6 +412,13 @@ impl<'a, T> Closure<'a, T> {
}
}

#[derive(Clone, Debug)]
pub struct TraitFunction<'a, T> {
pub body: Box<Expression>,
pub environment: Vec<Arc<Value<'a, T>>>,
pub type_args: HashMap<String, Type>,
}

pub struct Definitions<'a>(pub &'a HashMap<String, (Symbol, Option<FunctionValueDefinition>)>);

impl<'a> Definitions<'a> {
Expand Down Expand Up @@ -816,7 +835,38 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> {
}
ta
});
self.symbols.lookup(&poly.name, &type_args)?

let impl_pos = type_args.as_ref().and_then(|type_args| {
let key = type_args.iter().format(",").to_string();
poly.resolved_impls.get(&key).as_ref().copied()
});

match impl_pos {
Some(body) => {
let local_type_args = poly
.type_args
.clone()
.map(|ta| {
ta.into_iter()
.filter_map(|ty| {
let key = ty.to_string();
self.type_args
.get(&key)
.map(|value| (key, value.clone()))
})
.collect::<HashMap<_, _>>()
})
.unwrap();

Value::TraitFunction(TraitFunction {
body: body.clone(),
environment: vec![],
type_args: local_type_args.clone(),
})
.into()
}
None => self.symbols.lookup(&poly.name, &type_args)?,
}
}
}
})
Expand Down Expand Up @@ -985,6 +1035,19 @@ impl<'a, 'b, T: FieldElement, S: SymbolLookup<'a, T>> Evaluator<'a, 'b, T, S> {
self.type_args = type_args.clone();
self.expand(&lambda.body)?;
}
Value::TraitFunction(TraitFunction {
body,
environment,
type_args,
}) => {
self.op_stack.push(Operation::SetEnvironment(
std::mem::take(&mut self.local_vars),
std::mem::take(&mut self.type_args),
));
self.local_vars = vec![];
self.type_args = type_args.clone();
self.expand(&body)?;
}
e => panic!("Expected function but got {e}"),
};
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion pil-analyzer/src/expression_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
name: self.driver.resolve_value_ref(&reference.path),
poly_id: None,
type_args,
resolved_impl_pos: HashMap::new(),
resolved_impls: HashMap::new(),
}
}

Expand Down
37 changes: 16 additions & 21 deletions pil-analyzer/src/traits_processor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use core::panic;
use std::collections::HashMap;

use itertools::Itertools;
use powdr_ast::{
analyzed::{
Expression, FunctionValueDefinition, Identity, PolynomialReference, Reference, Symbol,
Expand Down Expand Up @@ -101,12 +102,12 @@ impl<'a> TraitsProcessor<'a> {
&mut self,
current: &str,
collected_ref: (String, Vec<Type>),
) -> (String, HashMap<String, usize>) {
) -> (String, HashMap<String, Box<Expression>>) {
let mut resolved_impl_pos = HashMap::new();
let Some(FunctionValueDefinition::TraitFunction(ref trait_decl, ref mut trait_fn)) =
self.definitions.get_mut(&collected_ref.0).unwrap().1
else {
return ("".to_string(), resolved_impl_pos); //TODO GZ: Error?
return ("".to_string(), resolved_impl_pos); //TODO GZ: Error
};

if let Some(impls) = self.implementations.get(&trait_decl.name) {
Expand All @@ -121,17 +122,19 @@ impl<'a> TraitsProcessor<'a> {
);
};
trait_decl.function_by_name(&trait_fn.name);
let type_vars = trait_decl.type_vars.clone();
// let type_vars = trait_decl.type_vars.clone();
let collected_types = collected_ref.1.clone();
let substitutions: HashMap<_, _> = type_vars
.into_iter()
.zip(collected_types.into_iter())
.collect();
trait_fn.ty.substitute_type_vars(&substitutions); // TODO GZ: avoid mutation?
// let substitutions: HashMap<_, _> = type_vars
// .into_iter()
// .zip(collected_types.into_iter())
// .collect();
//trait_fn.ty.substitute_type_vars(&substitutions); // TODO GZ: avoid mutation?

//check first!

resolved_impl_pos.insert(trait_decl.name.clone(), i);
resolved_impl_pos.insert(
collected_types.iter().format(",").to_string(),
impl_fn.body.clone(),
);
}
}

Expand All @@ -140,28 +143,20 @@ impl<'a> TraitsProcessor<'a> {
resolved_impl_pos,
)
}

fn split_trait_and_function(&self, full_name: &str) -> (String, String) {
// TODO GZ: we probably have a better way to do this (SymbolPath insteand of String)
let mut parts: Vec<&str> = full_name.rsplitn(2, "::").collect();
let trait_name = parts.pop().unwrap_or("").to_string();
let fname = parts.pop().unwrap_or("").to_string();
(trait_name, fname)
}
}

// TODO GZ: Is it really needed to go children_mut deep?
fn update_reference(
ref_name: &str,
type_args: &Vec<Type>,
expr: &mut Expression,
resolved_impl_pos: &HashMap<String, usize>,
resolved_impl_pos: &HashMap<String, Box<Expression>>,
) {
fn process_expr(
ref_name: &str,
type_args: &Vec<Type>,
c: &mut Expression,
resolved_impl_pos: &HashMap<String, usize>,
resolved_impl_pos: &HashMap<String, Box<Expression>>,
) {
if let Expression::Reference(
sr,
Expand All @@ -180,7 +175,7 @@ fn update_reference(
name: name.clone(),
type_args: type_args.clone(),
poly_id: poly_id.clone(),
resolved_impl_pos: resolved_impl_pos.clone(),
resolved_impls: resolved_impl_pos.clone(),
}),
);
}
Expand Down
4 changes: 2 additions & 2 deletions pil-analyzer/src/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ impl TypeChecker {
name,
poly_id: _,
type_args,
resolved_impl_pos: _,
resolved_impls: _,
}),
) => {
for ty in type_args.as_mut().unwrap() {
Expand Down Expand Up @@ -511,7 +511,7 @@ impl TypeChecker {
name,
poly_id: _,
type_args,
resolved_impl_pos: _,
resolved_impls: _,
}),
) => {
let (ty, args) = self.instantiate_scheme(self.declared_types[name].1.clone());
Expand Down
4 changes: 2 additions & 2 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl ReferencedSymbols for Expression {
name,
type_args,
poly_id: _,
resolved_impl_pos: _,
resolved_impls: _,
}),
) => Some(
type_args
Expand Down Expand Up @@ -462,7 +462,7 @@ fn substitute_polynomial_references<T: FieldElement>(
name: _,
poly_id: Some(poly_id),
type_args: _,
resolved_impl_pos: _,
resolved_impls: _,
}),
) = e
{
Expand Down

0 comments on commit 091bf73

Please sign in to comment.