Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT: proper function types. #1866

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions asm-to-pil/src/vm_to_constrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@ impl<T: FieldElement> VMConverter<T> {
kind: FunctionKind::Query,
params: vec![Pattern::Variable(SourceRef::unknown(), "__i".to_string())],
body: Box::new(call_to_handle_query.into()),
param_types: vec![],
};

statements.push(PilStatement::Expression(
Expand Down
2 changes: 2 additions & 0 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,8 @@ pub struct LambdaExpression<E = Expression<NamespacedPolynomialReference>> {
pub kind: FunctionKind,
pub params: Vec<Pattern>,
pub body: Box<E>,
/// Type of the parameters, filled in during type inference.
pub param_types: Vec<Type>,
}

impl<Ref> From<LambdaExpression<Expression<Ref>>> for Expression<Ref> {
Expand Down
1 change: 1 addition & 0 deletions importer/src/path_canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ fn check_expression(
kind: _,
params,
body,
..
},
) => {
// Add the local variables, ignore collisions.
Expand Down
131 changes: 77 additions & 54 deletions jit-compiler/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use powdr_ast::{
parsed::{
display::quote,
types::{ArrayType, FunctionType, Type, TypeScheme},
visitor::AllChildren,
ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression,
IndexAccess, LambdaExpression, MatchArm, MatchExpression, Number, Pattern,
StatementInsideBlock, UnaryOperation,
Expand Down Expand Up @@ -40,7 +41,9 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
/// On failure, returns an error string.
/// After a failure, `self` can still be used to request other symbols.
/// The code can later be retrieved via `generated_code`.
pub fn request_symbol(&mut self, name: &str) -> Result<String, String> {
pub fn request_symbol(&mut self, name: &str, type_args: &[Type]) -> Result<String, String> {
// For now, code generation is generic, only the reference uses the type args.
// If that changes at some point, we need to store the type args in the symbol map as well.
match self.symbols.get(name) {
Some(Err(e)) => return Err(e.clone()),
Some(_) => {}
Expand All @@ -58,7 +61,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
}
}
}
Ok(self.symbol_reference(name))
Ok(self.symbol_reference(name, type_args))
}

/// Returns the concatenation of all successfully compiled symbols.
Expand Down Expand Up @@ -114,7 +117,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
}}\n",
escape_symbol(symbol),
map_type(&ty),
self.format_expr(&value.e)?
self.format_expr(&value.e, 0)?
)
}
})
Expand All @@ -139,31 +142,23 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
_ => return Err(format!("Expected function type, got {ty}")),
};

let var_height = params.iter().map(|p| p.variables().count()).sum::<usize>();

Ok(format!(
"fn {}({}) -> {} {{ {} }}\n",
"fn {}(({}): ({})) -> {} {{ {} }}\n",
escape_symbol(name),
params
.iter()
.zip(param_types)
.map(|(p, t)| format!("{p}: {}", map_type(t)))
.format(", "),
params.iter().format(", "),
param_types.iter().map(map_type).format(", "),
map_type(return_type),
self.format_expr(body)?
self.format_expr(body, var_height)?
))
}

fn format_expr(&mut self, e: &Expression) -> Result<String, String> {
fn format_expr(&mut self, e: &Expression, var_height: usize) -> Result<String, String> {
Ok(match e {
Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(),
Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => {
let reference = self.request_symbol(name)?;
let ta = type_args.as_ref().unwrap();
format!(
"{reference}{}",
(!ta.is_empty())
.then(|| format!("::<{}>", ta.iter().map(map_type).join(", ")))
.unwrap_or_default()
)
self.request_symbol(name, type_args.as_ref().unwrap())?
}
Expression::Number(
_,
Expand Down Expand Up @@ -193,11 +188,11 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
},
) => {
format!(
"({})({})",
self.format_expr(function)?,
"({}).call(({}))",
self.format_expr(function, var_height)?,
arguments
.iter()
.map(|a| self.format_expr(a))
.map(|a| self.format_expr(a, var_height))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
// TODO these should all be refs -> turn all types to arc
Expand All @@ -207,8 +202,8 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
)
}
Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => {
let left = self.format_expr(left)?;
let right = self.format_expr(right)?;
let left = self.format_expr(left, var_height)?;
let right = self.format_expr(right, var_height)?;
match op {
BinaryOperator::ShiftLeft => {
format!("(({left}).clone() << usize::try_from(({right}).clone()).unwrap())")
Expand All @@ -220,20 +215,44 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
}
}
Expression::UnaryOperation(_, UnaryOperation { op, expr }) => {
format!("({op} ({}).clone())", self.format_expr(expr)?)
format!("({op} ({}).clone())", self.format_expr(expr, var_height)?)
}
Expression::IndexAccess(_, IndexAccess { array, index }) => {
format!(
"{}[usize::try_from({}).unwrap()].clone()",
self.format_expr(array)?,
self.format_expr(index)?
self.format_expr(array, var_height)?,
self.format_expr(index, var_height)?
)
}
Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => {
Expression::LambdaExpression(
_,
LambdaExpression {
params,
body,
param_types,
..
},
) => {
// Number of new variables introduced in the parameters.
let new_vars = params.iter().map(|p| p.variables().count()).sum::<usize>();
// We create clones of the captured variables so that we can move them into the closure.
let captured_vars = body
.all_children()
.filter_map(|e| {
if let Expression::Reference(_, Reference::LocalVar(id, name)) = e {
(*id < var_height as u64).then_some(name)
} else {
None
}
})
.unique()
.map(|v| format!("let {v} = {v}.clone();"))
.format("\n");
format!(
"|{}| {{ {} }}",
"Callable::Closure(std::sync::Arc::new({{\n{captured_vars}\nmove |({}): ({})| {{ {} }}\n}}))",
params.iter().format(", "),
self.format_expr(body)?
param_types.iter().map(map_type).format(", "),
self.format_expr(body, var_height + new_vars)?
)
}
Expression::IfExpression(
Expand All @@ -246,17 +265,17 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
) => {
format!(
"if {} {{ {} }} else {{ {} }}",
self.format_expr(condition)?,
self.format_expr(body)?,
self.format_expr(else_body)?
self.format_expr(condition, var_height)?,
self.format_expr(body, var_height)?,
self.format_expr(else_body, var_height)?
)
}
Expression::ArrayLiteral(_, ArrayLiteral { items }) => {
format!(
"vec![{}]",
items
.iter()
.map(|i| self.format_expr(i))
.map(|i| self.format_expr(i, var_height))
.collect::<Result<Vec<_>, _>>()?
.join(", ")
)
Expand All @@ -266,7 +285,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
"({})",
items
.iter()
.map(|i| Ok(format!("({}.clone())", self.format_expr(i)?)))
.map(|i| Ok(format!("({}.clone())", self.format_expr(i, var_height)?)))
.collect::<Result<Vec<_>, String>>()?
.join(", ")
),
Expand All @@ -279,7 +298,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
.collect::<Result<Vec<_>, _>>()?
.join("\n"),
expr.as_ref()
.map(|e| self.format_expr(e.as_ref()))
.map(|e| self.format_expr(e.as_ref(), var_height))
.transpose()?
.unwrap_or_default()
)
Expand All @@ -293,13 +312,14 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
let var_name = "scrutinee__";
format!(
"{{\nlet {var_name} = ({}).clone();\n{}\n}}\n",
self.format_expr(scrutinee)?,
self.format_expr(scrutinee, var_height)?,
arms.iter()
.map(|MatchArm { pattern, value }| {
let new_vars = pattern.variables().count();
let (bound_vars, arm_test) = check_pattern(var_name, pattern)?;
Ok(format!(
"if let Some({bound_vars}) = ({arm_test}) {{\n{}\n}}",
self.format_expr(value)?,
self.format_expr(value, var_height + new_vars)?,
))
})
.chain(std::iter::once(Ok("{ panic!(\"No match\"); }".to_string())))
Expand All @@ -320,21 +340,24 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
/// Returns a string expression evaluating to the value of the symbol.
/// This is either the escaped name of the symbol or a deref operator
/// applied to it.
fn symbol_reference(&self, symbol: &str) -> String {
let needs_deref = if is_builtin::<T>(symbol) {
false
fn symbol_reference(&self, symbol: &str, type_args: &[Type]) -> String {
let type_args = if type_args.is_empty() {
"".to_string()
} else {
let (_, def) = self.analyzed.definitions.get(symbol).as_ref().unwrap();
if let Some(FunctionValueDefinition::Expression(typed_expr)) = def {
!matches!(typed_expr.e, Expression::LambdaExpression(..))
format!("::<{}>", type_args.iter().map(map_type).join(", "))
};
if is_builtin::<T>(symbol) {
return format!("Callable::Fn({}{type_args})", escape_symbol(symbol));
}
let (_, def) = self.analyzed.definitions.get(symbol).as_ref().unwrap();
if let Some(FunctionValueDefinition::Expression(typed_expr)) = def {
if matches!(typed_expr.e, Expression::LambdaExpression(..)) {
format!("Callable::Fn({}{type_args})", escape_symbol(symbol))
} else {
false
format!("(*{}{type_args})", escape_symbol(symbol))
}
};
if needs_deref {
format!("(*{})", escape_symbol(symbol))
} else {
escape_symbol(symbol)
format!("(*{}{type_args})", escape_symbol(symbol))
}
}
}
Expand Down Expand Up @@ -458,7 +481,7 @@ fn map_type(ty: &Type) -> String {
Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)),
Type::Tuple(_) => todo!(),
Type::Function(ft) => format!(
"fn({}) -> {}",
"Callable<({}), {}>",
ft.params.iter().map(map_type).join(", "),
map_type(&ft.value)
),
Expand Down Expand Up @@ -533,7 +556,7 @@ mod test {
let analyzed = analyze_string::<GoldilocksField>(input).unwrap();
let mut compiler = CodeGenerator::new(&analyzed);
for s in syms {
compiler.request_symbol(s).unwrap();
compiler.request_symbol(s, &[]).unwrap();
}
compiler.generated_code()
}
Expand All @@ -547,7 +570,7 @@ mod test {
#[test]
fn simple_fun() {
let result = compile("let c: int -> int = |i| i;", &["c"]);
assert_eq!(result, "fn c(i: ibig::IBig) -> ibig::IBig { i }\n");
assert_eq!(result, "fn c((i): (ibig::IBig)) -> ibig::IBig { i }\n");
}

#[test]
Expand All @@ -558,9 +581,9 @@ mod test {
);
assert_eq!(
result,
"fn c(i: ibig::IBig) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) }\n\
"fn c((i): (ibig::IBig)) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) }\n\
\n\
fn d(k: ibig::IBig) -> ibig::IBig { (c)(((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone()) }\n\
fn d((k): (ibig::IBig)) -> ibig::IBig { (Callable::Fn(c)).call((((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone())) }\n\
"
);
}
Expand Down
27 changes: 27 additions & 0 deletions jit-compiler/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub fn generate_glue_code<T: FieldElement>(
}

// TODO we should use big int instead of u64
// TODO we actually don't know how we have to access this symbol. It might be a closure,
// which needs to be accessed differently.
let name = escape_symbol(sym);
glue.push_str(&format!(
r#"
Expand Down Expand Up @@ -86,6 +88,22 @@ impl From<FieldElement> for ibig::IBig {
ibig::IBig::from(x.0)
}
}

#[derive(Clone)]
enum Callable<Args, Ret> {
Fn(fn(Args) -> Ret),
Closure(std::sync::Arc<dyn Fn(Args) -> Ret + Send + Sync>),
}
impl<Args, Ret> Callable<Args, Ret> {
#[inline(always)]
fn call(&self, args: Args) -> Ret {
match self {
Callable::Fn(f) => f(args),
Callable::Closure(f) => f(args),
}
}
}

"#;

const CARGO_TOML: &str = r#"
Expand Down Expand Up @@ -117,6 +135,15 @@ pub fn call_cargo(code: &str) -> Result<PathInTempDir, String> {
fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap();
fs::create_dir(dir.join("src")).unwrap();
fs::write(dir.join("src").join("lib.rs"), code).unwrap();
Command::new("cargo")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To remove later.

.arg("fmt")
.current_dir(dir.clone())
.output()
.unwrap();
println!(
"{}",
fs::read_to_string(dir.join("src").join("lib.rs")).unwrap()
);
let out = Command::new("cargo")
.env("RUSTFLAGS", "-C target-cpu=native")
.arg("build")
Expand Down
2 changes: 1 addition & 1 deletion jit-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn compile<T: FieldElement>(
let successful_symbols = requested_symbols
.iter()
.filter_map(|&sym| {
if let Err(e) = codegen.request_symbol(sym) {
if let Err(e) = codegen.request_symbol(sym, &[]) {
log::warn!("Unable to generate code for symbol {sym}: {e}");
None
} else {
Expand Down
Loading