diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index aec7ea37f..a359b3e07 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -29,18 +29,26 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } - pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { + /// Requests code for the symbol to be generated and returns the + /// code that returns the value of the symbol, which is either + /// the escaped name of the symbol or a dereferenced name of a lazy static. + pub fn request_symbol(&mut self, name: &str) -> Result { if let Some(err) = self.failed.get(name) { return Err(err.clone()); } + let reference = if self.symbol_needs_deref(name) { + format!("(*{})", escape_symbol(name)) + } else { + escape_symbol(name) + }; if self.requested.contains(name) { - return Ok(()); + return Ok(reference); } self.requested.insert(name.to_string()); match self.generate_code(name) { Ok(code) => { self.symbols.insert(name.to_string(), code); - Ok(()) + Ok(reference) } Err(err) => { let err = format!("Failed to compile {name}: {err}"); @@ -60,7 +68,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } fn generate_code(&mut self, symbol: &str) -> Result { - if let Some(code) = self.try_generate_builtin(symbol) { + if let Some(code) = try_generate_builtin::(symbol) { return Ok(code); } @@ -74,22 +82,28 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { let type_scheme = value.type_scheme.clone().unwrap(); - Ok(match type_scheme { - TypeScheme { - vars, - ty: - Type::Function(FunctionType { - params: param_types, - value: return_type, - }), - } => { + Ok(match (&value.e, type_scheme) { + ( + Expression::LambdaExpression(..), + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + }, + ) => { assert!(vars.is_empty()); self.try_format_function(symbol, ¶m_types, return_type.as_ref(), &value.e)? } - TypeScheme { - vars, - ty: Type::Col, - } => { + ( + Expression::LambdaExpression(..), + TypeScheme { + vars, + ty: Type::Col, + }, + ) => { assert!(vars.is_empty()); // TODO we assume it is an int -> int function. // The type inference algorithm should store the derived type. @@ -97,14 +111,26 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { // in the trait vars. self.try_format_function(symbol, &[Type::Int], &Type::Fe, &value.e)? } - _ => format!( - "lazy_static::lazy_static! {{\n\ - static ref {}: {} = {};\n\ - }}\n", - escape_symbol(symbol), - map_type(&value.type_scheme.as_ref().unwrap().ty), - self.format_expr(&value.e)? - ), + _ => { + let type_scheme = value.type_scheme.as_ref().unwrap(); + assert!(type_scheme.vars.is_empty()); + let ty = if type_scheme.ty == Type::Col { + Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Fe), + }) + } else { + type_scheme.ty.clone() + }; + format!( + "lazy_static::lazy_static! {{\n\ + static ref {}: {} = {};\n\ + }}\n", + escape_symbol(symbol), + map_type(&ty), + self.format_expr(&value.e)? + ) + } }) } @@ -116,7 +142,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { expr: &Expression, ) -> Result { let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = expr else { - return Err(format!("Expected lambda expression for {name}, got {expr}",)); + panic!(); }; Ok(format!( "fn {}({}) -> {} {{ {} }}\n", @@ -131,31 +157,14 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { )) } - fn try_generate_builtin(&self, symbol: &str) -> Option { - let code = match symbol { - "std::array::len" => "(a: Vec) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string(), - "std::check::panic" => "(s: &str) -> ! { panic!(\"{s}\"); }".to_string(), - "std::field::modulus" => { - let modulus = T::modulus(); - format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}") - } - "std::convert::fe" => "(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" - .to_string(), - _ => return None, - }; - - Some(format!("fn {}{code}", escape_symbol(symbol))) - } - fn format_expr(&mut self, e: &Expression) -> Result { Ok(match e { Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { - self.request_symbol(name)?; + let reference = self.request_symbol(name)?; let ta = type_args.as_ref().unwrap(); format!( - "{}{}", - escape_symbol(name), + "{reference}{}", (!ta.is_empty()) .then(|| format!("::<{}>", ta.iter().map(map_type).join(", "))) .unwrap_or_default() @@ -279,6 +288,19 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { Err(format!("Implement {s}")) } + + /// Returns true if a reference to the symbol needs a deref operation or not. + fn symbol_needs_deref(&self, symbol: &str) -> bool { + if is_builtin(symbol) { + return false; + } + let (_, def) = self.analyzed.definitions.get(symbol).as_ref().unwrap(); + if let Some(FunctionValueDefinition::Expression(typed_expr)) = def { + matches!(typed_expr.e, Expression::LambdaExpression(..)) + } else { + false + } + } } pub fn escape_symbol(s: &str) -> String { @@ -295,7 +317,11 @@ fn map_type(ty: &Type) -> String { Type::Expr => "Expr".to_string(), Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), Type::Tuple(_) => todo!(), - Type::Function(ft) => todo!("Type {ft}"), + Type::Function(ft) => format!( + "fn({}) -> {}", + ft.params.iter().map(map_type).join(", "), + map_type(&ft.value) + ), Type::TypeVar(tv) => tv.to_string(), Type::NamedType(path, type_args) => { if type_args.is_some() { @@ -307,6 +333,28 @@ fn map_type(ty: &Type) -> String { } } +fn is_builtin(symbol: &str) -> bool { + matches!( + symbol, + "std::array::len" | "std::check::panic" | "std::field::modulus" | "std::convert::fe" + ) +} + +fn try_generate_builtin(symbol: &str) -> Option { + let code = match symbol { + "std::array::len" => Some("(a: Vec) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string()), + "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), + "std::field::modulus" => { + let modulus = T::modulus(); + Some(format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}")) + } + "std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + .to_string()), + _ => None, + }?; + Some(format!("fn {}{code}", escape_symbol(symbol))) +} + #[cfg(test)] mod test { use powdr_number::GoldilocksField; diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index debc2ae66..3ce6115f0 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -44,6 +44,29 @@ fn sqrt() { assert_eq!(f(0), 0); } +#[test] +fn assigned_functions() { + let input = r#" + namespace std::array; + let len = 8; + namespace main; + let a: int -> int = |i| i + 1; + let b: int -> int = |i| i + 2; + let t: bool = "" == ""; + let c = if t { a } else { b }; + let d = |i| c(i); + "#; + let c = compile(input, "main::c"); + + assert_eq!(c(0), 1); + assert_eq!(c(1), 2); + assert_eq!(c(2), 3); + assert_eq!(c(3), 4); + + let d = compile(input, "main::d"); + assert_eq!(d(0), 1); +} + #[test] fn simple_field() { let f = compile(