Skip to content

Commit

Permalink
Merge branch 'functions_as_expressions' into target_test
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Sep 24, 2024
2 parents 2e68102 + c40ca57 commit 1008a55
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 46 deletions.
140 changes: 94 additions & 46 deletions jit-compiler/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,26 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
}
}

pub fn request_symbol(&mut self, name: &str) -> Result<(), String> {
/// Requests code for the symbol to be generated and returns the
/// code that returns the value of the symbol, which is either
/// the escaped name of the symbol or a dereferenced name of a lazy static.
pub fn request_symbol(&mut self, name: &str) -> Result<String, String> {
if let Some(err) = self.failed.get(name) {
return Err(err.clone());
}
let reference = if self.symbol_needs_deref(name) {
format!("(*{})", escape_symbol(name))
} else {
escape_symbol(name)
};
if self.requested.contains(name) {
return Ok(());
return Ok(reference);
}
self.requested.insert(name.to_string());
match self.generate_code(name) {
Ok(code) => {
self.symbols.insert(name.to_string(), code);
Ok(())
Ok(reference)
}
Err(err) => {
let err = format!("Failed to compile {name}: {err}");
Expand All @@ -60,7 +68,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
}

fn generate_code(&mut self, symbol: &str) -> Result<String, String> {
if let Some(code) = self.try_generate_builtin(symbol) {
if let Some(code) = try_generate_builtin::<T>(symbol) {
return Ok(code);
}

Expand All @@ -74,37 +82,55 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {

let type_scheme = value.type_scheme.clone().unwrap();

Ok(match type_scheme {
TypeScheme {
vars,
ty:
Type::Function(FunctionType {
params: param_types,
value: return_type,
}),
} => {
Ok(match (&value.e, type_scheme) {
(
Expression::LambdaExpression(..),
TypeScheme {
vars,
ty:
Type::Function(FunctionType {
params: param_types,
value: return_type,
}),
},
) => {
assert!(vars.is_empty());
self.try_format_function(symbol, &param_types, return_type.as_ref(), &value.e)?
}
TypeScheme {
vars,
ty: Type::Col,
} => {
(
Expression::LambdaExpression(..),
TypeScheme {
vars,
ty: Type::Col,
},
) => {
assert!(vars.is_empty());
// TODO we assume it is an int -> int function.
// The type inference algorithm should store the derived type.
// Alternatively, we insert a trait conversion function and store the type
// in the trait vars.
self.try_format_function(symbol, &[Type::Int], &Type::Fe, &value.e)?
}
_ => format!(
"lazy_static::lazy_static! {{\n\
static ref {}: {} = {};\n\
}}\n",
escape_symbol(symbol),
map_type(&value.type_scheme.as_ref().unwrap().ty),
self.format_expr(&value.e)?
),
_ => {
let type_scheme = value.type_scheme.as_ref().unwrap();
assert!(type_scheme.vars.is_empty());
let ty = if type_scheme.ty == Type::Col {
Type::Function(FunctionType {
params: vec![Type::Int],
value: Box::new(Type::Fe),
})
} else {
type_scheme.ty.clone()
};
format!(
"lazy_static::lazy_static! {{\n\
static ref {}: {} = {};\n\
}}\n",
escape_symbol(symbol),
map_type(&ty),
self.format_expr(&value.e)?
)
}
})
}

Expand All @@ -116,7 +142,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
expr: &Expression,
) -> Result<String, String> {
let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = expr else {
return Err(format!("Expected lambda expression for {name}, got {expr}",));
panic!();
};
Ok(format!(
"fn {}({}) -> {} {{ {} }}\n",
Expand All @@ -131,31 +157,14 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
))
}

fn try_generate_builtin(&self, symbol: &str) -> Option<String> {
let code = match symbol {
"std::array::len" => "<T>(a: Vec<T>) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string(),
"std::check::panic" => "(s: &str) -> ! { panic!(\"{s}\"); }".to_string(),
"std::field::modulus" => {
let modulus = T::modulus();
format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}")
}
"std::convert::fe" => "(n: ibig::IBig) -> FieldElement {\n <FieldElement as PrimeField>::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}"
.to_string(),
_ => return None,
};

Some(format!("fn {}{code}", escape_symbol(symbol)))
}

fn format_expr(&mut self, e: &Expression) -> Result<String, String> {
Ok(match e {
Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(),
Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => {
self.request_symbol(name)?;
let reference = self.request_symbol(name)?;
let ta = type_args.as_ref().unwrap();
format!(
"{}{}",
escape_symbol(name),
"{reference}{}",
(!ta.is_empty())
.then(|| format!("::<{}>", ta.iter().map(map_type).join(", ")))
.unwrap_or_default()
Expand Down Expand Up @@ -279,6 +288,19 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
fn format_statement(&mut self, s: &StatementInsideBlock<Expression>) -> Result<String, String> {
Err(format!("Implement {s}"))
}

/// Returns true if a reference to the symbol needs a deref operation or not.
fn symbol_needs_deref(&self, symbol: &str) -> bool {
if is_builtin(symbol) {
return false;
}
let (_, def) = self.analyzed.definitions.get(symbol).as_ref().unwrap();
if let Some(FunctionValueDefinition::Expression(typed_expr)) = def {
matches!(typed_expr.e, Expression::LambdaExpression(..))
} else {
false
}
}
}

pub fn escape_symbol(s: &str) -> String {
Expand All @@ -295,7 +317,11 @@ fn map_type(ty: &Type) -> String {
Type::Expr => "Expr".to_string(),
Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)),
Type::Tuple(_) => todo!(),
Type::Function(ft) => todo!("Type {ft}"),
Type::Function(ft) => format!(
"fn({}) -> {}",
ft.params.iter().map(map_type).join(", "),
map_type(&ft.value)
),
Type::TypeVar(tv) => tv.to_string(),
Type::NamedType(path, type_args) => {
if type_args.is_some() {
Expand All @@ -307,6 +333,28 @@ fn map_type(ty: &Type) -> String {
}
}

fn is_builtin(symbol: &str) -> bool {
matches!(
symbol,
"std::array::len" | "std::check::panic" | "std::field::modulus" | "std::convert::fe"
)
}

fn try_generate_builtin<T: FieldElement>(symbol: &str) -> Option<String> {
let code = match symbol {
"std::array::len" => Some("<T>(a: Vec<T>) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string()),
"std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()),
"std::field::modulus" => {
let modulus = T::modulus();
Some(format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}"))
}
"std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n <FieldElement as PrimeField>::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}"
.to_string()),
_ => None,
}?;
Some(format!("fn {}{code}", escape_symbol(symbol)))
}

#[cfg(test)]
mod test {
use powdr_number::GoldilocksField;
Expand Down
23 changes: 23 additions & 0 deletions jit-compiler/tests/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,29 @@ fn sqrt() {
assert_eq!(f(0), 0);
}

#[test]
fn assigned_functions() {
let input = r#"
namespace std::array;
let len = 8;
namespace main;
let a: int -> int = |i| i + 1;
let b: int -> int = |i| i + 2;
let t: bool = "" == "";
let c = if t { a } else { b };
let d = |i| c(i);
"#;
let c = compile(input, "main::c");

assert_eq!(c(0), 1);
assert_eq!(c(1), 2);
assert_eq!(c(2), 3);
assert_eq!(c(3), 4);

let d = compile(input, "main::d");
assert_eq!(d(0), 1);
}

#[test]
fn simple_field() {
let f = compile(
Expand Down

0 comments on commit 1008a55

Please sign in to comment.