From e4c35fe554cf4a5ef0bf4e0df96f6c49e5e59fd5 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 29 Feb 2024 15:41:35 +0100 Subject: [PATCH 01/57] Experimental pil to rust compiler. --- executor/src/constant_evaluator/compiler.rs | 251 ++++++++++++++++++++ executor/src/constant_evaluator/mod.rs | 14 +- 2 files changed, 263 insertions(+), 2 deletions(-) create mode 100644 executor/src/constant_evaluator/compiler.rs diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs new file mode 100644 index 000000000..41a5b7d11 --- /dev/null +++ b/executor/src/constant_evaluator/compiler.rs @@ -0,0 +1,251 @@ +use std::{collections::HashMap, io::Write}; + +use itertools::Itertools; +use powdr_ast::{ + analyzed::{ + Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, + Reference, SymbolKind, + }, + parsed::{ + types::{ArrayType, FunctionType, Type, TypeScheme}, + ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess, + LambdaExpression, Number, UnaryOperation, + }, +}; +use powdr_number::FieldElement; + +use super::VariablySizedColumn; + +const PREAMBLE: &str = r#" +#![allow(unused_parens)] +use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use std::io::{BufWriter, Write}; +use std::fs::File; + +#[derive(MontConfig)] +#[modulus = "18446744069414584321"] +#[generator = "7"] +pub struct GoldilocksBaseFieldConfig; +pub type FieldElement = Fp64>; +"#; + +pub fn generate_fixed_cols( + analyzed: &Analyzed, +) -> HashMap)> { + let definitions = process_definitions(analyzed); + let degree = analyzed.degree(); + // TODO also eval other cols + let main_func = format!( + " +fn main() {{ + let data = (0..{degree}) + .into_par_iter() + .map(|i| {{ + main_inv(num_bigint::BigInt::from(i)) + }}) + .collect::>(); + let mut writer = BufWriter::new(File::create(\"./constants.bin\").unwrap()); + for i in 0..{degree} {{ + writer + .write_all(&BigInt::from(data[i]).to_bytes_le()) + .unwrap(); + }} +}} +" + ); + let result = format!("{PREAMBLE}\n{definitions}\n{main_func}\n"); + // write result to a temp file + let mut file = std::fs::File::create("/tmp/te/src/main.rs").unwrap(); + file.write_all(result.as_bytes()).unwrap(); + Default::default() +} + +pub fn process_definitions(analyzed: &Analyzed) -> String { + let mut result = String::new(); + for (name, (sym, value)) in &analyzed.definitions { + if name == "std::check::panic" { + result.push_str("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }"); + } else if name == "std::field::modulus" { + result.push_str("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }"); + } else if name == "std::convert::fe" { + result.push_str("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}"); + } else if let Some(FunctionValueDefinition::Expression(value)) = value { + println!("Processing {name} = {}", value.e); + let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { + TypeScheme { + vars: Default::default(), + ty: Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Fe), + }), + } + } else { + value.type_scheme.clone().unwrap() + }; + match &type_scheme { + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + } => { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = + &value.e + else { + todo!("value of fun: {}", value.e) + }; + result.push_str(&format!( + "fn {}<{}>({}) -> {} {{ {} }}\n", + escape(name), + vars, + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(t))) + .format(", "), + map_type(return_type), + format_expr(body) + )); + } + _ => { + result.push_str(&format!( + "const {}: {} = {};\n", + escape(name), + map_type(&value.type_scheme.as_ref().unwrap().ty), + format_expr(&value.e) + )); + } + } + } + } + + result +} + +fn format_expr(e: &Expression) -> String { + match e { + Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), + Expression::Reference( + _, + Reference::Poly(PolynomialReference { + name, + poly_id: _, + type_args: _, + }), + ) => escape(name), // TOOD use type args if needed. + Expression::Number( + _, + Number { + value, + type_: Some(type_), + }, + ) => match type_ { + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + Type::TypeVar(t) => format!("{t}::from({value}_u64)"), + _ => unreachable!(), + }, + Expression::FunctionCall( + _, + FunctionCall { + function, + arguments, + }, + ) => { + format!( + "({})({})", + format_expr(function), + arguments + .iter() + .map(format_expr) + .map(|x| format!("{x}.clone()")) + .collect::>() + .join(", ") + ) + } + Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { + let left = format_expr(left); + let right = format_expr(right); + match op { + BinaryOperator::ShiftLeft => { + format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") + } + _ => format!("(({left}).clone() {op} ({right}).clone())"), + } + } + Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { + format!("({op} ({}).clone())", format_expr(expr)) + } + Expression::IndexAccess(_, IndexAccess { array, index }) => { + format!( + "{}[usize::try_from({}).unwrap()].clone()", + format_expr(array), + format_expr(index) + ) + } + Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { + // let params = if *params == vec!["r".to_string()] { + // // Hack because rust needs the type + // vec!["r: Vec".to_string()] + // } else { + // params.clone() + // }; + format!( + "|{}| {{ {} }}", + params.iter().format(", "), + format_expr(body) + ) + } + Expression::IfExpression( + _, + IfExpression { + condition, + body, + else_body, + }, + ) => { + format!( + "if {} {{ {} }} else {{ {} }}", + format_expr(condition), + format_expr(body), + format_expr(else_body) + ) + } + Expression::ArrayLiteral(_, ArrayLiteral { items }) => { + format!( + "vec![{}]", + items.iter().map(format_expr).collect::>().join(", ") + ) + } + Expression::String(_, s) => format!("{s:?}"), // TODO does this quote properly? + Expression::Tuple(_, items) => format!( + "({})", + items.iter().map(format_expr).collect::>().join(", ") + ), + _ => panic!("Implement {e}"), + } +} + +fn escape(s: &str) -> String { + s.replace('.', "_").replace("::", "_") +} + +fn map_type(ty: &Type) -> String { + match ty { + Type::Bottom | Type::Bool => format!("{ty}"), + Type::Int => "num_bigint::BigInt".to_string(), + Type::Fe => "FieldElement".to_string(), + Type::String => "String".to_string(), + Type::Col => unreachable!(), + Type::Expr => "Expr".to_string(), + Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), + Type::Tuple(_) => todo!(), + Type::Function(ft) => todo!("Type {ft}"), + Type::TypeVar(tv) => tv.to_string(), + Type::NamedType(_path, _type_args) => todo!(), + } +} diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 40256195a..f1bb9a2c9 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -6,7 +6,7 @@ use std::{ pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn}; use itertools::Itertools; use powdr_ast::{ - analyzed::{Analyzed, Expression, FunctionValueDefinition, Symbol, TypedExpression}, + analyzed::{Analyzed, Expression, FunctionValueDefinition, PolyID, Symbol, TypedExpression}, parsed::{ types::{ArrayType, Type}, IndexAccess, @@ -16,6 +16,9 @@ use powdr_number::{BigInt, BigUint, DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Definitions, SymbolLookup, Value}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +// TODO this is probabyl not the right place. +mod compiler; + mod data_structures; /// Generates the fixed column values for all fixed columns that are defined @@ -24,12 +27,19 @@ mod data_structures; /// Arrays of columns are flattened, the name of the `i`th array element /// is `name[i]`. pub fn generate(analyzed: &Analyzed) -> Vec<(String, VariablySizedColumn)> { - let mut fixed_cols = HashMap::new(); + // TODO to do this properly, we should try to compile as much as possible + // and only evaulato if it fails. Still, compilation should be done in one run. + + let mut fixed_cols: HashMap)> = + compiler::generate_fixed_cols(analyzed); for (poly, value) in analyzed.constant_polys_in_source_order() { if let Some(value) = value { // For arrays, generate values for each index, // for non-arrays, set index to None. for (index, (name, id)) in poly.array_elements().enumerate() { + if fixed_cols.contains_key(&name) { + continue; + } let index = poly.is_array().then_some(index as u64); let range = poly.degree.unwrap(); let values = range From 686bdc6bfaf6ed19376a374f0d53ee9ce13f6c61 Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 12:41:00 +0000 Subject: [PATCH 02/57] Compile and dlopen. --- executor/Cargo.toml | 2 + executor/src/constant_evaluator/compiler.rs | 475 ++++++++++++-------- 2 files changed, 292 insertions(+), 185 deletions(-) diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 0079cce9a..330309791 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -14,7 +14,9 @@ powdr-parser-util.workspace = true powdr-pil-analyzer.workspace = true itertools = "0.13" +libc = "0.2.0" log = { version = "0.4.17" } +mktemp = "0.5.0" rayon = "1.7.0" bit-vec = "0.6.3" num-traits = "0.2.15" diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 41a5b7d11..f5ee7ca46 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -1,4 +1,12 @@ -use std::{collections::HashMap, io::Write}; +use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +use std::{ + collections::{HashMap, HashSet}, + ffi::CString, + fs::File, + io::Write, + process::Command, + sync::Arc, +}; use itertools::Itertools; use powdr_ast::{ @@ -7,6 +15,7 @@ use powdr_ast::{ Reference, SymbolKind, }, parsed::{ + display::{format_type_args, quote}, types::{ArrayType, FunctionType, Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess, LambdaExpression, Number, UnaryOperation, @@ -19,9 +28,6 @@ use super::VariablySizedColumn; const PREAMBLE: &str = r#" #![allow(unused_parens)] use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; -use rayon::prelude::{IntoParallelIterator, ParallelIterator}; -use std::io::{BufWriter, Write}; -use std::fs::File; #[derive(MontConfig)] #[modulus = "18446744069414584321"] @@ -30,203 +36,302 @@ pub struct GoldilocksBaseFieldConfig; pub type FieldElement = Fp64>; "#; +// TODO this is the old impl of goldilocks + +const CARGO_TOML: &str = r#" +[package] +name = "powdr_constants" +version = "0.1.0" +edition = "2021" + +[dependencies] +ark-ff = "0.4.2" +"#; + +// TODO crate type dylib? + pub fn generate_fixed_cols( analyzed: &Analyzed, ) -> HashMap)> { - let definitions = process_definitions(analyzed); - let degree = analyzed.degree(); - // TODO also eval other cols - let main_func = format!( - " -fn main() {{ - let data = (0..{degree}) - .into_par_iter() - .map(|i| {{ - main_inv(num_bigint::BigInt::from(i)) - }}) - .collect::>(); - let mut writer = BufWriter::new(File::create(\"./constants.bin\").unwrap()); - for i in 0..{degree} {{ - writer - .write_all(&BigInt::from(data[i]).to_bytes_le()) - .unwrap(); - }} -}} -" - ); - let result = format!("{PREAMBLE}\n{definitions}\n{main_func}\n"); - // write result to a temp file - let mut file = std::fs::File::create("/tmp/te/src/main.rs").unwrap(); - file.write_all(result.as_bytes()).unwrap(); - Default::default() -} + let mut compiler = Compiler::new(analyzed); + for (sym, _) in &analyzed.constant_polys_in_source_order() { + compiler.request_symbol(&sym.absolute_name); + } + let code = format!("{PREAMBLE}\n{}\n", compiler.compiled_symbols()); -pub fn process_definitions(analyzed: &Analyzed) -> String { - let mut result = String::new(); - for (name, (sym, value)) in &analyzed.definitions { - if name == "std::check::panic" { - result.push_str("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }"); - } else if name == "std::field::modulus" { - result.push_str("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }"); - } else if name == "std::convert::fe" { - result.push_str("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}"); - } else if let Some(FunctionValueDefinition::Expression(value)) = value { - println!("Processing {name} = {}", value.e); - let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { - TypeScheme { - vars: Default::default(), - ty: Type::Function(FunctionType { - params: vec![Type::Int], - value: Box::new(Type::Fe), - }), - } - } else { - value.type_scheme.clone().unwrap() - }; - match &type_scheme { - TypeScheme { - vars, - ty: - Type::Function(FunctionType { - params: param_types, - value: return_type, - }), - } => { - let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = - &value.e - else { - todo!("value of fun: {}", value.e) - }; - result.push_str(&format!( - "fn {}<{}>({}) -> {} {{ {} }}\n", - escape(name), - vars, - params - .iter() - .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(t))) - .format(", "), - map_type(return_type), - format_expr(body) - )); - } - _ => { - result.push_str(&format!( - "const {}: {} = {};\n", - escape(name), - map_type(&value.type_scheme.as_ref().unwrap().ty), - format_expr(&value.e) - )); - } + let dir = mktemp::Temp::new_dir().unwrap(); + std::fs::create_dir(dir.as_path().join("src")).unwrap(); + std::fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); + Command::new("cargo") + .arg("build") + .arg("--release") + .current_dir(dir.as_path()) + .output() + .unwrap(); + + unsafe { + let lib_path = CString::new( + dir.as_path() + .join("target") + .join("release") + .join("libpowdr_constants.so") + .to_str() + .unwrap(), + ) + .unwrap(); + let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); + if lib.is_null() { + panic!("Failed to load library: {:?}", lib_path); + } + for (sym, poly_id) in analyzed.constant_polys_in_source_order() { + let sym = escape(&sym.absolute_name); + let sym = CString::new(sym).unwrap(); + let sym = dlsym(lib, sym.as_ptr()); + if sym.is_null() { + println!("Failed to load symbol: {:?}", sym); + continue; } + println!("Loaded symbol: {:?}", sym); + // let sym = sym as *const VariablySizedColumn; + // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); } } + todo!() +} - result +struct Compiler<'a, T> { + analyzed: &'a Analyzed, + queue: Vec, + requested: HashSet, + failed: HashMap, + symbols: HashMap, } -fn format_expr(e: &Expression) -> String { - match e { - Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), - Expression::Reference( - _, - Reference::Poly(PolynomialReference { - name, - poly_id: _, - type_args: _, - }), - ) => escape(name), // TOOD use type args if needed. - Expression::Number( - _, - Number { - value, - type_: Some(type_), - }, - ) => match type_ { - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - Type::TypeVar(t) => format!("{t}::from({value}_u64)"), - _ => unreachable!(), - }, - Expression::FunctionCall( - _, - FunctionCall { - function, - arguments, - }, - ) => { - format!( - "({})({})", - format_expr(function), - arguments - .iter() - .map(format_expr) - .map(|x| format!("{x}.clone()")) - .collect::>() - .join(", ") - ) +impl<'a, T> Compiler<'a, T> { + pub fn new(analyzed: &'a Analyzed) -> Self { + Self { + analyzed, + queue: Default::default(), + requested: Default::default(), + failed: Default::default(), + symbols: Default::default(), } - Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { - let left = format_expr(left); - let right = format_expr(right); - match op { - BinaryOperator::ShiftLeft => { - format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") - } - _ => format!("(({left}).clone() {op} ({right}).clone())"), - } + } + + pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { + if let Some(err) = self.failed.get(name) { + return Err(err.clone()); } - Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { - format!("({op} ({}).clone())", format_expr(expr)) + if self.requested.contains(name) { + return Ok(()); } - Expression::IndexAccess(_, IndexAccess { array, index }) => { - format!( - "{}[usize::try_from({}).unwrap()].clone()", - format_expr(array), - format_expr(index) - ) + self.requested.insert(name.to_string()); + match self.generate_code(name) { + Ok(code) => { + self.symbols.insert(name.to_string(), code); + Ok(()) + } + Err(err) => { + let err = format!("Failed to compile {name}: {err}"); + self.failed.insert(name.to_string(), err.clone()); + Err(err) + } } - Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { - // let params = if *params == vec!["r".to_string()] { - // // Hack because rust needs the type - // vec!["r: Vec".to_string()] - // } else { - // params.clone() - // }; - format!( - "|{}| {{ {} }}", - params.iter().format(", "), - format_expr(body) - ) + } + + pub fn compiled_symbols(self) -> String { + self.symbols + .into_iter() + .map(|(name, code)| code) + .format("\n\n") + .to_string() + } + + fn generate_code(&mut self, symbol: &str) -> Result { + if symbol == "std::check::panic" { + // TODO should this really panic? + return Ok("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }".to_string()); + } else if symbol == "std::field::modulus" { + // TODO depends on T + return Ok("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }" + .to_string()); + } else if symbol == "std::convert::fe" { + return Ok("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + .to_string()); } - Expression::IfExpression( - _, - IfExpression { - condition, - body, - else_body, + + let Some((sym, Some(FunctionValueDefinition::Expression(value)))) = + self.analyzed.definitions.get(symbol) + else { + return Err(format!( + "No definition for {symbol}, or not a generic symbol" + )); + }; + println!("Processing {symbol} = {}", value.e); + Ok(match &value.type_scheme.as_ref().unwrap() { + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + } => { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = + &value.e + else { + return Err(format!( + "Expected lambda expression for {symbol}, got {}", + value.e + )); + }; + format!( + "fn {}<{}>({}) -> {} {{ {} }}\n", + escape(symbol), + vars, + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(t))) + .format(", "), + map_type(return_type), + self.format_expr(body)? + ) + } + _ => format!( + "const {}: {} = {};\n", + escape(symbol), + map_type(&value.type_scheme.as_ref().unwrap().ty), + self.format_expr(&value.e)? + ), + }) + } + + fn format_expr(&mut self, e: &Expression) -> Result { + Ok(match e { + Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), + Expression::Reference( + _, + Reference::Poly(PolynomialReference { + name, + poly_id: _, + type_args, + }), + ) => { + self.request_symbol(name)?; + format!( + "{}{}", + escape(name), + // TODO do all type args work here? + type_args + .as_ref() + .map(|ta| format!("::{}", format_type_args(&ta))) + .unwrap_or_default() + ) + } + Expression::Number( + _, + Number { + value, + type_: Some(type_), + }, + ) => match type_ { + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + Type::TypeVar(t) => format!("{t}::from({value}_u64)"), + _ => unreachable!(), }, - ) => { - format!( - "if {} {{ {} }} else {{ {} }}", - format_expr(condition), - format_expr(body), - format_expr(else_body) - ) - } - Expression::ArrayLiteral(_, ArrayLiteral { items }) => { - format!( - "vec![{}]", - items.iter().map(format_expr).collect::>().join(", ") - ) - } - Expression::String(_, s) => format!("{s:?}"), // TODO does this quote properly? - Expression::Tuple(_, items) => format!( - "({})", - items.iter().map(format_expr).collect::>().join(", ") - ), - _ => panic!("Implement {e}"), + Expression::FunctionCall( + _, + FunctionCall { + function, + arguments, + }, + ) => { + format!( + "({})({})", + self.format_expr(function)?, + arguments + .iter() + .map(|a| self.format_expr(a)) + .collect::, _>>()? + .into_iter() + // TODO these should all be refs + .map(|x| format!("{x}.clone()")) + .collect::>() + .join(", ") + ) + } + Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { + let left = self.format_expr(left)?; + let right = self.format_expr(right)?; + match op { + BinaryOperator::ShiftLeft => { + format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") + } + _ => format!("(({left}).clone() {op} ({right}).clone())"), + } + } + Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { + format!("({op} ({}).clone())", self.format_expr(expr)?) + } + Expression::IndexAccess(_, IndexAccess { array, index }) => { + format!( + "{}[usize::try_from({}).unwrap()].clone()", + self.format_expr(array)?, + self.format_expr(index)? + ) + } + Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { + // let params = if *params == vec!["r".to_string()] { + // // Hack because rust needs the type + // vec!["r: Vec".to_string()] + // } else { + // params.clone() + // }; + format!( + "|{}| {{ {} }}", + params.iter().format(", "), + self.format_expr(body)? + ) + } + Expression::IfExpression( + _, + IfExpression { + condition, + body, + else_body, + }, + ) => { + format!( + "if {} {{ {} }} else {{ {} }}", + self.format_expr(condition)?, + self.format_expr(body)?, + self.format_expr(else_body)? + ) + } + Expression::ArrayLiteral(_, ArrayLiteral { items }) => { + format!( + "vec![{}]", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ) + } + Expression::String(_, s) => quote(s), + Expression::Tuple(_, items) => format!( + "({})", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ), + _ => return Err(format!("Implement {e}")), + }) } } From b2f8d8845bc42e9ef7d5ba6b31744e0d94dfd75c Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:01:51 +0000 Subject: [PATCH 03/57] oeu --- executor/src/constant_evaluator/compiler.rs | 68 +++++++++++++++++---- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index f5ee7ca46..19369983b 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -2,7 +2,7 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; use std::{ collections::{HashMap, HashSet}, ffi::CString, - fs::File, + fs::{self, File}, io::Write, process::Command, sync::Arc, @@ -17,8 +17,8 @@ use powdr_ast::{ parsed::{ display::{format_type_args, quote}, types::{ArrayType, FunctionType, Type, TypeScheme}, - ArrayLiteral, BinaryOperation, BinaryOperator, FunctionCall, IfExpression, IndexAccess, - LambdaExpression, Number, UnaryOperation, + ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, + IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; use powdr_number::FieldElement; @@ -44,8 +44,13 @@ name = "powdr_constants" version = "0.1.0" edition = "2021" +[lib] +crate-type = ["dylib"] + [dependencies] ark-ff = "0.4.2" +num-bigint = { version = "0.4.3", features = ["serde"] } +num-traits = "0.2.15" "#; // TODO crate type dylib? @@ -55,19 +60,28 @@ pub fn generate_fixed_cols( ) -> HashMap)> { let mut compiler = Compiler::new(analyzed); for (sym, _) in &analyzed.constant_polys_in_source_order() { - compiler.request_symbol(&sym.absolute_name); + // ignore err + if let Err(e) = compiler.request_symbol(&sym.absolute_name) { + println!("Failed to compile {}: {e}", &sym.absolute_name); + } } let code = format!("{PREAMBLE}\n{}\n", compiler.compiled_symbols()); + println!("Compiled code:\n{code}"); let dir = mktemp::Temp::new_dir().unwrap(); - std::fs::create_dir(dir.as_path().join("src")).unwrap(); - std::fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); - Command::new("cargo") + fs::write(dir.as_path().join("Cargo.toml"), CARGO_TOML).unwrap(); + fs::create_dir(dir.as_path().join("src")).unwrap(); + fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); + let out = Command::new("cargo") .arg("build") .arg("--release") .current_dir(dir.as_path()) .output() .unwrap(); + out.stderr.iter().for_each(|b| print!("{}", *b as char)); + if !out.status.success() { + panic!("Failed to compile."); + } unsafe { let lib_path = CString::new( @@ -96,12 +110,11 @@ pub fn generate_fixed_cols( // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); } } - todo!() + Default::default() } struct Compiler<'a, T> { analyzed: &'a Analyzed, - queue: Vec, requested: HashSet, failed: HashMap, symbols: HashMap, @@ -111,7 +124,6 @@ impl<'a, T> Compiler<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, - queue: Default::default(), requested: Default::default(), failed: Default::default(), symbols: Default::default(), @@ -129,6 +141,7 @@ impl<'a, T> Compiler<'a, T> { match self.generate_code(name) { Ok(code) => { self.symbols.insert(name.to_string(), code); + println!("Generated code for {name}"); Ok(()) } Err(err) => { @@ -168,7 +181,18 @@ impl<'a, T> Compiler<'a, T> { )); }; println!("Processing {symbol} = {}", value.e); - Ok(match &value.type_scheme.as_ref().unwrap() { + let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { + TypeScheme { + vars: Default::default(), + ty: Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Fe), + }), + } + } else { + value.type_scheme.clone().unwrap() + }; + Ok(match type_scheme { TypeScheme { vars, ty: @@ -192,9 +216,9 @@ impl<'a, T> Compiler<'a, T> { params .iter() .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(t))) + .map(|(p, t)| format!("{}: {}", p, map_type(&t))) .format(", "), - map_type(return_type), + map_type(return_type.as_ref()), self.format_expr(body)? ) } @@ -330,9 +354,27 @@ impl<'a, T> Compiler<'a, T> { .collect::, _>>()? .join(", ") ), + Expression::BlockExpression(_, BlockExpression { statements, expr }) => { + format!( + "{{\n{}\n{}\n}}", + statements + .iter() + .map(|s| self.format_statement(s)) + .collect::, _>>()? + .join("\n"), + expr.as_ref() + .map(|e| self.format_expr(e.as_ref())) + .transpose()? + .unwrap_or_default() + ) + } _ => return Err(format!("Implement {e}")), }) } + + fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { + Err(format!("Implement {s}")) + } } fn escape(s: &str) -> String { From 392afc1e496b3afc287d0e590e2ef09f6837b219 Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:02:02 +0000 Subject: [PATCH 04/57] oeu --- executor/src/constant_evaluator/compiler.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 19369983b..3b452fadb 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -110,6 +110,7 @@ pub fn generate_fixed_cols( // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); } } + panic!(); Default::default() } From 6b2029a9eca606e6b6c0bbca0c25d050fa36074f Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:25:29 +0000 Subject: [PATCH 05/57] loaded sym --- executor/src/constant_evaluator/compiler.rs | 50 ++++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 3b452fadb..87ec77c2d 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -2,8 +2,9 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; use std::{ collections::{HashMap, HashSet}, ffi::CString, - fs::{self, File}, + fs::{self, create_dir, File}, io::Write, + path, process::Command, sync::Arc, }; @@ -32,8 +33,8 @@ use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; #[derive(MontConfig)] #[modulus = "18446744069414584321"] #[generator = "7"] -pub struct GoldilocksBaseFieldConfig; -pub type FieldElement = Fp64>; +struct GoldilocksBaseFieldConfig; +type FieldElement = Fp64>; "#; // TODO this is the old impl of goldilocks @@ -59,23 +60,45 @@ pub fn generate_fixed_cols( analyzed: &Analyzed, ) -> HashMap)> { let mut compiler = Compiler::new(analyzed); + let mut glue = String::new(); for (sym, _) in &analyzed.constant_polys_in_source_order() { // ignore err if let Err(e) = compiler.request_symbol(&sym.absolute_name) { println!("Failed to compile {}: {e}", &sym.absolute_name); } } - let code = format!("{PREAMBLE}\n{}\n", compiler.compiled_symbols()); + for (sym, _) in &analyzed.constant_polys_in_source_order() { + // TODO escape? + if compiler.is_compiled(&sym.absolute_name) { + // TODO it is a rust function, can we use a more complex type as well? + // TODO only works for goldilocks + glue.push_str(&format!( + r#" + #[no_mangle] + pub extern fn extern_{}(i: u64) -> u64 {{ + {}(num_bigint::BigInt::from(i)).into_bigint().0[0] + }} + "#, + escape(&sym.absolute_name), + escape(&sym.absolute_name), + )); + } + } + + let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); println!("Compiled code:\n{code}"); - let dir = mktemp::Temp::new_dir().unwrap(); - fs::write(dir.as_path().join("Cargo.toml"), CARGO_TOML).unwrap(); - fs::create_dir(dir.as_path().join("src")).unwrap(); - fs::write(dir.as_path().join("src").join("lib.rs"), code).unwrap(); + //let dir = mktemp::Temp::new_dir().unwrap(); + let _ = fs::remove_dir_all("/tmp/powdr_constants"); + fs::create_dir("/tmp/powdr_constants").unwrap(); + let dir = path::Path::new("/tmp/powdr_constants"); + fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); + fs::create_dir(dir.join("src")).unwrap(); + fs::write(dir.join("src").join("lib.rs"), code).unwrap(); let out = Command::new("cargo") .arg("build") .arg("--release") - .current_dir(dir.as_path()) + .current_dir(dir) .output() .unwrap(); out.stderr.iter().for_each(|b| print!("{}", *b as char)); @@ -85,8 +108,7 @@ pub fn generate_fixed_cols( unsafe { let lib_path = CString::new( - dir.as_path() - .join("target") + dir.join("target") .join("release") .join("libpowdr_constants.so") .to_str() @@ -98,7 +120,7 @@ pub fn generate_fixed_cols( panic!("Failed to load library: {:?}", lib_path); } for (sym, poly_id) in analyzed.constant_polys_in_source_order() { - let sym = escape(&sym.absolute_name); + let sym = format!("extern_{}", escape(&sym.absolute_name)); let sym = CString::new(sym).unwrap(); let sym = dlsym(lib, sym.as_ptr()); if sym.is_null() { @@ -153,6 +175,10 @@ impl<'a, T> Compiler<'a, T> { } } + pub fn is_compiled(&self, name: &str) -> bool { + self.symbols.contains_key(name) + } + pub fn compiled_symbols(self) -> String { self.symbols .into_iter() From 7ee9afc1824d604c560dd507996b73db0265a94f Mon Sep 17 00:00:00 2001 From: chriseth Date: Sat, 3 Aug 2024 13:34:35 +0000 Subject: [PATCH 06/57] log time --- executor/src/constant_evaluator/compiler.rs | 39 +++++++++++++++++---- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/executor/src/constant_evaluator/compiler.rs b/executor/src/constant_evaluator/compiler.rs index 87ec77c2d..d50977922 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/executor/src/constant_evaluator/compiler.rs @@ -1,4 +1,5 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::{ collections::{HashMap, HashSet}, ffi::CString, @@ -7,6 +8,7 @@ use std::{ path, process::Command, sync::Arc, + time::Instant, }; use itertools::Itertools; @@ -24,6 +26,8 @@ use powdr_ast::{ }; use powdr_number::FieldElement; +use crate::constant_evaluator::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; + use super::VariablySizedColumn; const PREAMBLE: &str = r#" @@ -106,6 +110,7 @@ pub fn generate_fixed_cols( panic!("Failed to compile."); } + let mut columns = HashMap::new(); unsafe { let lib_path = CString::new( dir.join("target") @@ -119,8 +124,9 @@ pub fn generate_fixed_cols( if lib.is_null() { panic!("Failed to load library: {:?}", lib_path); } - for (sym, poly_id) in analyzed.constant_polys_in_source_order() { - let sym = format!("extern_{}", escape(&sym.absolute_name)); + let start = Instant::now(); + for (poly, value) in analyzed.constant_polys_in_source_order() { + let sym = format!("extern_{}", escape(&poly.absolute_name)); let sym = CString::new(sym).unwrap(); let sym = dlsym(lib, sym.as_ptr()); if sym.is_null() { @@ -128,12 +134,33 @@ pub fn generate_fixed_cols( continue; } println!("Loaded symbol: {:?}", sym); - // let sym = sym as *const VariablySizedColumn; - // cols.insert(sym.absolute_name.clone(), (poly_id, (*sym).clone())); + let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); + let degrees = if let Some(degree) = poly.degree { + vec![degree] + } else { + (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) + .map(|degree_log| 1 << degree_log) + .collect::>() + }; + + let col_values = degrees + .into_iter() + .map(|degree| { + (0..degree) + .into_par_iter() + .map(|i| T::from(fun(i as u64))) + .collect::>() + }) + .collect::>() + .into(); + columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); } + log::info!( + "Fixed column generation (without compilation and loading time) took {}s", + start.elapsed().as_secs_f32() + ); } - panic!(); - Default::default() + columns } struct Compiler<'a, T> { From 44e80b8c6d6272c586c1de5971ac31661e736631 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 08:48:46 +0000 Subject: [PATCH 07/57] new crate. --- Cargo.toml | 2 + executor/Cargo.toml | 3 +- executor/src/constant_evaluator/mod.rs | 9 +- jit-compiler/Cargo.toml | 20 ++ .../src}/compiler.rs | 238 ++++++++---------- jit-compiler/src/lib.rs | 2 + jit-compiler/src/loader.rs | 130 ++++++++++ 7 files changed, 267 insertions(+), 137 deletions(-) create mode 100644 jit-compiler/Cargo.toml rename {executor/src/constant_evaluator => jit-compiler/src}/compiler.rs (69%) create mode 100644 jit-compiler/src/lib.rs create mode 100644 jit-compiler/src/loader.rs diff --git a/Cargo.toml b/Cargo.toml index 93a492e8f..d76db81f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "cli", "cli-rs", "executor", + "jit-compiler", "riscv", "parser-util", "pil-analyzer", @@ -48,6 +49,7 @@ powdr-analysis = { path = "./analysis", version = "0.1.0-alpha.2" } powdr-backend = { path = "./backend", version = "0.1.0-alpha.2" } powdr-executor = { path = "./executor", version = "0.1.0-alpha.2" } powdr-importer = { path = "./importer", version = "0.1.0-alpha.2" } +powdr-jit-compiler = { path = "./jit-compiler", version = "0.1.0-alpha.2" } powdr-linker = { path = "./linker", version = "0.1.0-alpha.2" } powdr-number = { path = "./number", version = "0.1.0-alpha.2" } powdr-parser = { path = "./parser", version = "0.1.0-alpha.2" } diff --git a/executor/Cargo.toml b/executor/Cargo.toml index 330309791..c0a89ef69 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -12,11 +12,10 @@ powdr-ast.workspace = true powdr-number.workspace = true powdr-parser-util.workspace = true powdr-pil-analyzer.workspace = true +powdr-jit-compiler.workspace = true itertools = "0.13" -libc = "0.2.0" log = { version = "0.4.17" } -mktemp = "0.5.0" rayon = "1.7.0" bit-vec = "0.6.3" num-traits = "0.2.15" diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index f1bb9a2c9..9d79f1d50 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -16,9 +16,6 @@ use powdr_number::{BigInt, BigUint, DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Definitions, SymbolLookup, Value}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; -// TODO this is probabyl not the right place. -mod compiler; - mod data_structures; /// Generates the fixed column values for all fixed columns that are defined @@ -30,16 +27,12 @@ pub fn generate(analyzed: &Analyzed) -> Vec<(String, Variabl // TODO to do this properly, we should try to compile as much as possible // and only evaulato if it fails. Still, compilation should be done in one run. - let mut fixed_cols: HashMap)> = - compiler::generate_fixed_cols(analyzed); + let mut fixed_cols = HashMap::new(); for (poly, value) in analyzed.constant_polys_in_source_order() { if let Some(value) = value { // For arrays, generate values for each index, // for non-arrays, set index to None. for (index, (name, id)) in poly.array_elements().enumerate() { - if fixed_cols.contains_key(&name) { - continue; - } let index = poly.is_array().then_some(index as u64); let range = poly.degree.unwrap(); let values = range diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml new file mode 100644 index 000000000..8650abad4 --- /dev/null +++ b/jit-compiler/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "powdr-jit-compiler" +description = "powdr just-in-time compiler" +version = { workspace = true } +edition = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +powdr-ast.workspace = true +powdr-number.workspace = true +powdr-parser.workspace = true + +libc = "0.2.0" +mktemp = "0.5.0" +itertools = "0.13" + +[lints.clippy] +uninlined_format_args = "deny" diff --git a/executor/src/constant_evaluator/compiler.rs b/jit-compiler/src/compiler.rs similarity index 69% rename from executor/src/constant_evaluator/compiler.rs rename to jit-compiler/src/compiler.rs index d50977922..7f6c51b79 100644 --- a/executor/src/constant_evaluator/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,5 +1,4 @@ use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::{ collections::{HashMap, HashSet}, ffi::CString, @@ -24,28 +23,18 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::FieldElement; - -use crate::constant_evaluator::{MAX_DEGREE_LOG, MIN_DEGREE_LOG}; - -use super::VariablySizedColumn; +use powdr_number::{FieldElement, LargeInt}; const PREAMBLE: &str = r#" #![allow(unused_parens)] -use ark_ff::{BigInt, BigInteger, Fp64, MontBackend, MontConfig, PrimeField}; -#[derive(MontConfig)] -#[modulus = "18446744069414584321"] -#[generator = "7"] -struct GoldilocksBaseFieldConfig; -type FieldElement = Fp64>; "#; // TODO this is the old impl of goldilocks const CARGO_TOML: &str = r#" [package] -name = "powdr_constants" +name = "powdr_jit_compiled" version = "0.1.0" edition = "2021" @@ -53,115 +42,116 @@ edition = "2021" crate-type = ["dylib"] [dependencies] -ark-ff = "0.4.2" -num-bigint = { version = "0.4.3", features = ["serde"] } +// TODO version? +powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +num-bigint = { version = "0.4.3" } num-traits = "0.2.15" "#; // TODO crate type dylib? -pub fn generate_fixed_cols( - analyzed: &Analyzed, -) -> HashMap)> { - let mut compiler = Compiler::new(analyzed); - let mut glue = String::new(); - for (sym, _) in &analyzed.constant_polys_in_source_order() { - // ignore err - if let Err(e) = compiler.request_symbol(&sym.absolute_name) { - println!("Failed to compile {}: {e}", &sym.absolute_name); - } - } - for (sym, _) in &analyzed.constant_polys_in_source_order() { - // TODO escape? - if compiler.is_compiled(&sym.absolute_name) { - // TODO it is a rust function, can we use a more complex type as well? - // TODO only works for goldilocks - glue.push_str(&format!( - r#" - #[no_mangle] - pub extern fn extern_{}(i: u64) -> u64 {{ - {}(num_bigint::BigInt::from(i)).into_bigint().0[0] - }} - "#, - escape(&sym.absolute_name), - escape(&sym.absolute_name), - )); - } - } +// pub fn generate_fixed_cols( +// analyzed: &Analyzed, +// ) -> HashMap)> { +// let mut compiler = Compiler::new(analyzed); +// let mut glue = String::new(); +// for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // ignore err +// if let Err(e) = compiler.request_symbol(&sym.absolute_name) { +// println!("Failed to compile {}: {e}", &sym.absolute_name); +// } +// } +// for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // TODO escape? +// if compiler.is_compiled(&sym.absolute_name) { +// // TODO it is a rust function, can we use a more complex type as well? +// // TODO only works for goldilocks +// glue.push_str(&format!( +// r#" +// #[no_mangle] +// pub extern fn extern_{}(i: u64) -> u64 {{ +// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] +// }} +// "#, +// escape(&sym.absolute_name), +// escape(&sym.absolute_name), +// )); +// } +// } - let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); - println!("Compiled code:\n{code}"); +// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); +// println!("Compiled code:\n{code}"); - //let dir = mktemp::Temp::new_dir().unwrap(); - let _ = fs::remove_dir_all("/tmp/powdr_constants"); - fs::create_dir("/tmp/powdr_constants").unwrap(); - let dir = path::Path::new("/tmp/powdr_constants"); - fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); - fs::create_dir(dir.join("src")).unwrap(); - fs::write(dir.join("src").join("lib.rs"), code).unwrap(); - let out = Command::new("cargo") - .arg("build") - .arg("--release") - .current_dir(dir) - .output() - .unwrap(); - out.stderr.iter().for_each(|b| print!("{}", *b as char)); - if !out.status.success() { - panic!("Failed to compile."); - } +// //let dir = mktemp::Temp::new_dir().unwrap(); +// let _ = fs::remove_dir_all("/tmp/powdr_constants"); +// fs::create_dir("/tmp/powdr_constants").unwrap(); +// let dir = path::Path::new("/tmp/powdr_constants"); +// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); +// fs::create_dir(dir.join("src")).unwrap(); +// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); +// let out = Command::new("cargo") +// .arg("build") +// .arg("--release") +// .current_dir(dir) +// .output() +// .unwrap(); +// out.stderr.iter().for_each(|b| print!("{}", *b as char)); +// if !out.status.success() { +// panic!("Failed to compile."); +// } - let mut columns = HashMap::new(); - unsafe { - let lib_path = CString::new( - dir.join("target") - .join("release") - .join("libpowdr_constants.so") - .to_str() - .unwrap(), - ) - .unwrap(); - let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); - if lib.is_null() { - panic!("Failed to load library: {:?}", lib_path); - } - let start = Instant::now(); - for (poly, value) in analyzed.constant_polys_in_source_order() { - let sym = format!("extern_{}", escape(&poly.absolute_name)); - let sym = CString::new(sym).unwrap(); - let sym = dlsym(lib, sym.as_ptr()); - if sym.is_null() { - println!("Failed to load symbol: {:?}", sym); - continue; - } - println!("Loaded symbol: {:?}", sym); - let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); - let degrees = if let Some(degree) = poly.degree { - vec![degree] - } else { - (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) - .map(|degree_log| 1 << degree_log) - .collect::>() - }; +// let mut columns = HashMap::new(); +// unsafe { +// let lib_path = CString::new( +// dir.join("target") +// .join("release") +// .join("libpowdr_constants.so") +// .to_str() +// .unwrap(), +// ) +// .unwrap(); +// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); +// if lib.is_null() { +// panic!("Failed to load library: {:?}", lib_path); +// } +// let start = Instant::now(); +// for (poly, value) in analyzed.constant_polys_in_source_order() { +// let sym = format!("extern_{}", escape(&poly.absolute_name)); +// let sym = CString::new(sym).unwrap(); +// let sym = dlsym(lib, sym.as_ptr()); +// if sym.is_null() { +// println!("Failed to load symbol: {:?}", sym); +// continue; +// } +// println!("Loaded symbol: {:?}", sym); +// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); +// let degrees = if let Some(degree) = poly.degree { +// vec![degree] +// } else { +// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) +// .map(|degree_log| 1 << degree_log) +// .collect::>() +// }; - let col_values = degrees - .into_iter() - .map(|degree| { - (0..degree) - .into_par_iter() - .map(|i| T::from(fun(i as u64))) - .collect::>() - }) - .collect::>() - .into(); - columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); - } - log::info!( - "Fixed column generation (without compilation and loading time) took {}s", - start.elapsed().as_secs_f32() - ); - } - columns -} +// let col_values = degrees +// .into_iter() +// .map(|degree| { +// (0..degree) +// .into_par_iter() +// .map(|i| T::from(fun(i as u64))) +// .collect::>() +// }) +// .collect::>() +// .into(); +// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); +// } +// log::info!( +// "Fixed column generation (without compilation and loading time) took {}s", +// start.elapsed().as_secs_f32() +// ); +// } +// columns +// } struct Compiler<'a, T> { analyzed: &'a Analyzed, @@ -170,7 +160,7 @@ struct Compiler<'a, T> { symbols: HashMap, } -impl<'a, T> Compiler<'a, T> { +impl<'a, T: FieldElement> Compiler<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, @@ -219,9 +209,8 @@ impl<'a, T> Compiler<'a, T> { // TODO should this really panic? return Ok("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }".to_string()); } else if symbol == "std::field::modulus" { - // TODO depends on T - return Ok("fn std_field_modulus() -> num_bigint::BigInt { num_bigint::BigInt::from(18446744069414584321_u64) }" - .to_string()); + let modulus = T::modulus(); + return Ok(format!("fn std_field_modulus() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")); } else if symbol == "std::convert::fe" { return Ok("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()); @@ -235,6 +224,7 @@ impl<'a, T> Compiler<'a, T> { )); }; println!("Processing {symbol} = {}", value.e); + // TODO assert type scheme is there? let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { TypeScheme { vars: Default::default(), @@ -288,14 +278,7 @@ impl<'a, T> Compiler<'a, T> { fn format_expr(&mut self, e: &Expression) -> Result { Ok(match e { Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), - Expression::Reference( - _, - Reference::Poly(PolynomialReference { - name, - poly_id: _, - type_args, - }), - ) => { + Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { self.request_symbol(name)?; format!( "{}{}", @@ -314,6 +297,7 @@ impl<'a, T> Compiler<'a, T> { type_: Some(type_), }, ) => match type_ { + // TODO value does not need to be u64 Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), Type::Fe => format!("FieldElement::from({value}_u64)"), Type::Expr => format!("Expr::from({value}_u64)"), @@ -335,7 +319,7 @@ impl<'a, T> Compiler<'a, T> { .map(|a| self.format_expr(a)) .collect::, _>>()? .into_iter() - // TODO these should all be refs + // TODO these should all be refs -> turn all types to arc .map(|x| format!("{x}.clone()")) .collect::>() .join(", ") @@ -441,12 +425,12 @@ fn map_type(ty: &Type) -> String { Type::Int => "num_bigint::BigInt".to_string(), Type::Fe => "FieldElement".to_string(), Type::String => "String".to_string(), - Type::Col => unreachable!(), Type::Expr => "Expr".to_string(), Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), Type::Tuple(_) => todo!(), Type::Function(ft) => todo!("Type {ft}"), Type::TypeVar(tv) => tv.to_string(), Type::NamedType(_path, _type_args) => todo!(), + Type::Col | Type::Inter => unreachable!(), } } diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs new file mode 100644 index 000000000..7dd2003c5 --- /dev/null +++ b/jit-compiler/src/lib.rs @@ -0,0 +1,2 @@ +mod compiler; +mod loader; diff --git a/jit-compiler/src/loader.rs b/jit-compiler/src/loader.rs new file mode 100644 index 000000000..690bc5467 --- /dev/null +++ b/jit-compiler/src/loader.rs @@ -0,0 +1,130 @@ +// use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +// use rayon::iter::{IntoParallelIterator, ParallelIterator}; +// use std::{ +// collections::{HashMap, HashSet}, +// ffi::CString, +// fs::{self, create_dir, File}, +// io::Write, +// path, +// process::Command, +// sync::Arc, +// time::Instant, +// }; + +// use itertools::Itertools; +// use powdr_ast::{ +// analyzed::{ +// Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, +// Reference, SymbolKind, +// }, +// parsed::{ +// display::{format_type_args, quote}, +// types::{ArrayType, FunctionType, Type, TypeScheme}, +// ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, +// IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, +// }, +// }; +// use powdr_number::FieldElement; + +// // pub fn generate_fixed_cols( +// // analyzed: &Analyzed, +// // ) -> HashMap)> { +// // let mut compiler = Compiler::new(analyzed); +// // let mut glue = String::new(); +// // for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // // ignore err +// // if let Err(e) = compiler.request_symbol(&sym.absolute_name) { +// // println!("Failed to compile {}: {e}", &sym.absolute_name); +// // } +// // } +// // for (sym, _) in &analyzed.constant_polys_in_source_order() { +// // // TODO escape? +// // if compiler.is_compiled(&sym.absolute_name) { +// // // TODO it is a rust function, can we use a more complex type as well? +// // // TODO only works for goldilocks +// // glue.push_str(&format!( +// // r#" +// #[no_mangle] +// pub extern fn extern_{}(i: u64) -> u64 {{ +// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] +// }} +// "#, +// escape(&sym.absolute_name), +// escape(&sym.absolute_name), +// )); +// } +// } + +// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); +// println!("Compiled code:\n{code}"); + +// //let dir = mktemp::Temp::new_dir().unwrap(); +// let _ = fs::remove_dir_all("/tmp/powdr_constants"); +// fs::create_dir("/tmp/powdr_constants").unwrap(); +// let dir = path::Path::new("/tmp/powdr_constants"); +// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); +// fs::create_dir(dir.join("src")).unwrap(); +// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); +// let out = Command::new("cargo") +// .arg("build") +// .arg("--release") +// .current_dir(dir) +// .output() +// .unwrap(); +// out.stderr.iter().for_each(|b| print!("{}", *b as char)); +// if !out.status.success() { +// panic!("Failed to compile."); +// } + +// let mut columns = HashMap::new(); +// unsafe { +// let lib_path = CString::new( +// dir.join("target") +// .join("release") +// .join("libpowdr_constants.so") +// .to_str() +// .unwrap(), +// ) +// .unwrap(); +// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); +// if lib.is_null() { +// panic!("Failed to load library: {:?}", lib_path); +// } +// let start = Instant::now(); +// for (poly, value) in analyzed.constant_polys_in_source_order() { +// let sym = format!("extern_{}", escape(&poly.absolute_name)); +// let sym = CString::new(sym).unwrap(); +// let sym = dlsym(lib, sym.as_ptr()); +// if sym.is_null() { +// println!("Failed to load symbol: {:?}", sym); +// continue; +// } +// println!("Loaded symbol: {:?}", sym); +// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); +// let degrees = if let Some(degree) = poly.degree { +// vec![degree] +// } else { +// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) +// .map(|degree_log| 1 << degree_log) +// .collect::>() +// }; + +// let col_values = degrees +// .into_iter() +// .map(|degree| { +// (0..degree) +// .into_par_iter() +// .map(|i| T::from(fun(i as u64))) +// .collect::>() +// }) +// .collect::>() +// .into(); +// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); +// } +// log::info!( +// "Fixed column generation (without compilation and loading time) took {}s", +// start.elapsed().as_secs_f32() +// ); +// } +// columns +// } From 65e5f44b3ad1ae917c5f985166a8482e347e2b0b Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 08:49:47 +0000 Subject: [PATCH 08/57] fix --- executor/src/constant_evaluator/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index 9d79f1d50..40256195a 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -6,7 +6,7 @@ use std::{ pub use data_structures::{get_uniquely_sized, get_uniquely_sized_cloned, VariablySizedColumn}; use itertools::Itertools; use powdr_ast::{ - analyzed::{Analyzed, Expression, FunctionValueDefinition, PolyID, Symbol, TypedExpression}, + analyzed::{Analyzed, Expression, FunctionValueDefinition, Symbol, TypedExpression}, parsed::{ types::{ArrayType, Type}, IndexAccess, @@ -24,9 +24,6 @@ mod data_structures; /// Arrays of columns are flattened, the name of the `i`th array element /// is `name[i]`. pub fn generate(analyzed: &Analyzed) -> Vec<(String, VariablySizedColumn)> { - // TODO to do this properly, we should try to compile as much as possible - // and only evaulato if it fails. Still, compilation should be done in one run. - let mut fixed_cols = HashMap::new(); for (poly, value) in analyzed.constant_polys_in_source_order() { if let Some(value) = value { From ec40700d61d8a2cb9b6169e79319e6b4faf60a83 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 09:35:54 +0000 Subject: [PATCH 09/57] work --- executor/Cargo.toml | 1 - jit-compiler/Cargo.toml | 5 + jit-compiler/src/compiler.rs | 437 +--------------------------------- jit-compiler/src/lib.rs | 5 +- jit-compiler/tests/codegen.rs | 38 +++ 5 files changed, 47 insertions(+), 439 deletions(-) create mode 100644 jit-compiler/tests/codegen.rs diff --git a/executor/Cargo.toml b/executor/Cargo.toml index c0a89ef69..0079cce9a 100644 --- a/executor/Cargo.toml +++ b/executor/Cargo.toml @@ -12,7 +12,6 @@ powdr-ast.workspace = true powdr-number.workspace = true powdr-parser-util.workspace = true powdr-pil-analyzer.workspace = true -powdr-jit-compiler.workspace = true itertools = "0.13" log = { version = "0.4.17" } diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index 8650abad4..94da05758 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -16,5 +16,10 @@ libc = "0.2.0" mktemp = "0.5.0" itertools = "0.13" +[dev-dependencies] +powdr-pil-analyzer.workspace = true +pretty_assertions = "1.4.0" + + [lints.clippy] uninlined_format_args = "deny" diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 7f6c51b79..b9f5e2def 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,436 +1 @@ -use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -use std::{ - collections::{HashMap, HashSet}, - ffi::CString, - fs::{self, create_dir, File}, - io::Write, - path, - process::Command, - sync::Arc, - time::Instant, -}; - -use itertools::Itertools; -use powdr_ast::{ - analyzed::{ - Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, - Reference, SymbolKind, - }, - parsed::{ - display::{format_type_args, quote}, - types::{ArrayType, FunctionType, Type, TypeScheme}, - ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, - IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, - }, -}; -use powdr_number::{FieldElement, LargeInt}; - -const PREAMBLE: &str = r#" -#![allow(unused_parens)] - -"#; - -// TODO this is the old impl of goldilocks - -const CARGO_TOML: &str = r#" -[package] -name = "powdr_jit_compiled" -version = "0.1.0" -edition = "2021" - -[lib] -crate-type = ["dylib"] - -[dependencies] -// TODO version? -powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } -num-bigint = { version = "0.4.3" } -num-traits = "0.2.15" -"#; - -// TODO crate type dylib? - -// pub fn generate_fixed_cols( -// analyzed: &Analyzed, -// ) -> HashMap)> { -// let mut compiler = Compiler::new(analyzed); -// let mut glue = String::new(); -// for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // ignore err -// if let Err(e) = compiler.request_symbol(&sym.absolute_name) { -// println!("Failed to compile {}: {e}", &sym.absolute_name); -// } -// } -// for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // TODO escape? -// if compiler.is_compiled(&sym.absolute_name) { -// // TODO it is a rust function, can we use a more complex type as well? -// // TODO only works for goldilocks -// glue.push_str(&format!( -// r#" -// #[no_mangle] -// pub extern fn extern_{}(i: u64) -> u64 {{ -// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] -// }} -// "#, -// escape(&sym.absolute_name), -// escape(&sym.absolute_name), -// )); -// } -// } - -// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); -// println!("Compiled code:\n{code}"); - -// //let dir = mktemp::Temp::new_dir().unwrap(); -// let _ = fs::remove_dir_all("/tmp/powdr_constants"); -// fs::create_dir("/tmp/powdr_constants").unwrap(); -// let dir = path::Path::new("/tmp/powdr_constants"); -// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); -// fs::create_dir(dir.join("src")).unwrap(); -// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); -// let out = Command::new("cargo") -// .arg("build") -// .arg("--release") -// .current_dir(dir) -// .output() -// .unwrap(); -// out.stderr.iter().for_each(|b| print!("{}", *b as char)); -// if !out.status.success() { -// panic!("Failed to compile."); -// } - -// let mut columns = HashMap::new(); -// unsafe { -// let lib_path = CString::new( -// dir.join("target") -// .join("release") -// .join("libpowdr_constants.so") -// .to_str() -// .unwrap(), -// ) -// .unwrap(); -// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); -// if lib.is_null() { -// panic!("Failed to load library: {:?}", lib_path); -// } -// let start = Instant::now(); -// for (poly, value) in analyzed.constant_polys_in_source_order() { -// let sym = format!("extern_{}", escape(&poly.absolute_name)); -// let sym = CString::new(sym).unwrap(); -// let sym = dlsym(lib, sym.as_ptr()); -// if sym.is_null() { -// println!("Failed to load symbol: {:?}", sym); -// continue; -// } -// println!("Loaded symbol: {:?}", sym); -// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); -// let degrees = if let Some(degree) = poly.degree { -// vec![degree] -// } else { -// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) -// .map(|degree_log| 1 << degree_log) -// .collect::>() -// }; - -// let col_values = degrees -// .into_iter() -// .map(|degree| { -// (0..degree) -// .into_par_iter() -// .map(|i| T::from(fun(i as u64))) -// .collect::>() -// }) -// .collect::>() -// .into(); -// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); -// } -// log::info!( -// "Fixed column generation (without compilation and loading time) took {}s", -// start.elapsed().as_secs_f32() -// ); -// } -// columns -// } - -struct Compiler<'a, T> { - analyzed: &'a Analyzed, - requested: HashSet, - failed: HashMap, - symbols: HashMap, -} - -impl<'a, T: FieldElement> Compiler<'a, T> { - pub fn new(analyzed: &'a Analyzed) -> Self { - Self { - analyzed, - requested: Default::default(), - failed: Default::default(), - symbols: Default::default(), - } - } - - pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { - if let Some(err) = self.failed.get(name) { - return Err(err.clone()); - } - if self.requested.contains(name) { - return Ok(()); - } - self.requested.insert(name.to_string()); - match self.generate_code(name) { - Ok(code) => { - self.symbols.insert(name.to_string(), code); - println!("Generated code for {name}"); - Ok(()) - } - Err(err) => { - let err = format!("Failed to compile {name}: {err}"); - self.failed.insert(name.to_string(), err.clone()); - Err(err) - } - } - } - - pub fn is_compiled(&self, name: &str) -> bool { - self.symbols.contains_key(name) - } - - pub fn compiled_symbols(self) -> String { - self.symbols - .into_iter() - .map(|(name, code)| code) - .format("\n\n") - .to_string() - } - - fn generate_code(&mut self, symbol: &str) -> Result { - if symbol == "std::check::panic" { - // TODO should this really panic? - return Ok("fn std_check_panic(s: &str) -> ! { panic!(\"{s}\"); }".to_string()); - } else if symbol == "std::field::modulus" { - let modulus = T::modulus(); - return Ok(format!("fn std_field_modulus() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")); - } else if symbol == "std::convert::fe" { - return Ok("fn std_convert_fe(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" - .to_string()); - } - - let Some((sym, Some(FunctionValueDefinition::Expression(value)))) = - self.analyzed.definitions.get(symbol) - else { - return Err(format!( - "No definition for {symbol}, or not a generic symbol" - )); - }; - println!("Processing {symbol} = {}", value.e); - // TODO assert type scheme is there? - let type_scheme = if sym.kind == SymbolKind::Poly(PolynomialType::Constant) { - TypeScheme { - vars: Default::default(), - ty: Type::Function(FunctionType { - params: vec![Type::Int], - value: Box::new(Type::Fe), - }), - } - } else { - value.type_scheme.clone().unwrap() - }; - Ok(match type_scheme { - TypeScheme { - vars, - ty: - Type::Function(FunctionType { - params: param_types, - value: return_type, - }), - } => { - let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = - &value.e - else { - return Err(format!( - "Expected lambda expression for {symbol}, got {}", - value.e - )); - }; - format!( - "fn {}<{}>({}) -> {} {{ {} }}\n", - escape(symbol), - vars, - params - .iter() - .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(&t))) - .format(", "), - map_type(return_type.as_ref()), - self.format_expr(body)? - ) - } - _ => format!( - "const {}: {} = {};\n", - escape(symbol), - map_type(&value.type_scheme.as_ref().unwrap().ty), - self.format_expr(&value.e)? - ), - }) - } - - fn format_expr(&mut self, e: &Expression) -> Result { - Ok(match e { - Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), - Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { - self.request_symbol(name)?; - format!( - "{}{}", - escape(name), - // TODO do all type args work here? - type_args - .as_ref() - .map(|ta| format!("::{}", format_type_args(&ta))) - .unwrap_or_default() - ) - } - Expression::Number( - _, - Number { - value, - type_: Some(type_), - }, - ) => match type_ { - // TODO value does not need to be u64 - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - Type::TypeVar(t) => format!("{t}::from({value}_u64)"), - _ => unreachable!(), - }, - Expression::FunctionCall( - _, - FunctionCall { - function, - arguments, - }, - ) => { - format!( - "({})({})", - self.format_expr(function)?, - arguments - .iter() - .map(|a| self.format_expr(a)) - .collect::, _>>()? - .into_iter() - // TODO these should all be refs -> turn all types to arc - .map(|x| format!("{x}.clone()")) - .collect::>() - .join(", ") - ) - } - Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { - let left = self.format_expr(left)?; - let right = self.format_expr(right)?; - match op { - BinaryOperator::ShiftLeft => { - format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") - } - _ => format!("(({left}).clone() {op} ({right}).clone())"), - } - } - Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { - format!("({op} ({}).clone())", self.format_expr(expr)?) - } - Expression::IndexAccess(_, IndexAccess { array, index }) => { - format!( - "{}[usize::try_from({}).unwrap()].clone()", - self.format_expr(array)?, - self.format_expr(index)? - ) - } - Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { - // let params = if *params == vec!["r".to_string()] { - // // Hack because rust needs the type - // vec!["r: Vec".to_string()] - // } else { - // params.clone() - // }; - format!( - "|{}| {{ {} }}", - params.iter().format(", "), - self.format_expr(body)? - ) - } - Expression::IfExpression( - _, - IfExpression { - condition, - body, - else_body, - }, - ) => { - format!( - "if {} {{ {} }} else {{ {} }}", - self.format_expr(condition)?, - self.format_expr(body)?, - self.format_expr(else_body)? - ) - } - Expression::ArrayLiteral(_, ArrayLiteral { items }) => { - format!( - "vec![{}]", - items - .iter() - .map(|i| self.format_expr(i)) - .collect::, _>>()? - .join(", ") - ) - } - Expression::String(_, s) => quote(s), - Expression::Tuple(_, items) => format!( - "({})", - items - .iter() - .map(|i| self.format_expr(i)) - .collect::, _>>()? - .join(", ") - ), - Expression::BlockExpression(_, BlockExpression { statements, expr }) => { - format!( - "{{\n{}\n{}\n}}", - statements - .iter() - .map(|s| self.format_statement(s)) - .collect::, _>>()? - .join("\n"), - expr.as_ref() - .map(|e| self.format_expr(e.as_ref())) - .transpose()? - .unwrap_or_default() - ) - } - _ => return Err(format!("Implement {e}")), - }) - } - - fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { - Err(format!("Implement {s}")) - } -} - -fn escape(s: &str) -> String { - s.replace('.', "_").replace("::", "_") -} - -fn map_type(ty: &Type) -> String { - match ty { - Type::Bottom | Type::Bool => format!("{ty}"), - Type::Int => "num_bigint::BigInt".to_string(), - Type::Fe => "FieldElement".to_string(), - Type::String => "String".to_string(), - Type::Expr => "Expr".to_string(), - Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), - Type::Tuple(_) => todo!(), - Type::Function(ft) => todo!("Type {ft}"), - Type::TypeVar(tv) => tv.to_string(), - Type::NamedType(_path, _type_args) => todo!(), - Type::Col | Type::Inter => unreachable!(), - } -} +// TODO run cargo and stuff diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 7dd2003c5..1ca0c6654 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,2 +1,3 @@ -mod compiler; -mod loader; +pub mod codegen; +pub mod compiler; +pub mod loader; diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs new file mode 100644 index 000000000..9505ef5f6 --- /dev/null +++ b/jit-compiler/tests/codegen.rs @@ -0,0 +1,38 @@ +use powdr_jit_compiler::codegen::Compiler; +use powdr_number::GoldilocksField; +use powdr_pil_analyzer::analyze_string; + +use pretty_assertions::assert_eq; + +fn compile(input: &str, syms: &[&str]) -> String { + let analyzed = analyze_string::(input); + let mut compiler = Compiler::new(&analyzed); + for s in syms { + compiler.request_symbol(s).unwrap(); + } + compiler.compiled_symbols() +} + +#[test] +fn empty_code() { + let result = compile("", &[]); + assert_eq!(result, ""); +} + +#[test] +fn simple_fun() { + let result = compile("let c: int -> int = |i| i;", &["c"]); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + ); +} + +#[test] +fn constant() { + let result = compile("let c: int -> int = |i| i; let d = c(20);", &["c", "d"]); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + ); +} From 0be7ac0685e40a3942d49c42e67d96f9af723ac5 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 09:36:09 +0000 Subject: [PATCH 10/57] work --- jit-compiler/src/codegen.rs | 304 ++++++++++++++++++++++++++++++++++++ 1 file changed, 304 insertions(+) create mode 100644 jit-compiler/src/codegen.rs diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs new file mode 100644 index 000000000..24727e724 --- /dev/null +++ b/jit-compiler/src/codegen.rs @@ -0,0 +1,304 @@ +use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; +use std::{ + collections::{HashMap, HashSet}, + ffi::CString, + fs::{self, create_dir, File}, + io::Write, + path, + process::Command, + sync::Arc, + time::Instant, +}; + +use itertools::Itertools; +use powdr_ast::{ + analyzed::{ + Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, + Reference, SymbolKind, + }, + parsed::{ + display::{format_type_args, quote}, + types::{ArrayType, FunctionType, Type, TypeScheme}, + ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, + IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, + }, +}; +use powdr_number::{FieldElement, LargeInt}; + +pub struct Compiler<'a, T> { + analyzed: &'a Analyzed, + requested: HashSet, + failed: HashMap, + symbols: HashMap, +} + +impl<'a, T: FieldElement> Compiler<'a, T> { + pub fn new(analyzed: &'a Analyzed) -> Self { + Self { + analyzed, + requested: Default::default(), + failed: Default::default(), + symbols: Default::default(), + } + } + + pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { + if let Some(err) = self.failed.get(name) { + return Err(err.clone()); + } + if self.requested.contains(name) { + return Ok(()); + } + self.requested.insert(name.to_string()); + match self.generate_code(name) { + Ok(code) => { + self.symbols.insert(name.to_string(), code); + Ok(()) + } + Err(err) => { + let err = format!("Failed to compile {name}: {err}"); + self.failed.insert(name.to_string(), err.clone()); + Err(err) + } + } + } + + pub fn is_compiled(&self, name: &str) -> bool { + self.symbols.contains_key(name) + } + + pub fn compiled_symbols(self) -> String { + self.symbols + .into_iter() + .map(|(_, code)| code) + .format("\n\n") + .to_string() + } + + fn generate_code(&mut self, symbol: &str) -> Result { + if let Some(code) = self.try_generate_builtin(symbol) { + return Ok(code); + } + + let Some((_, Some(FunctionValueDefinition::Expression(value)))) = + self.analyzed.definitions.get(symbol) + else { + return Err(format!( + "No definition for {symbol}, or not a generic symbol" + )); + }; + + let type_scheme = value.type_scheme.clone().unwrap(); + + Ok(match type_scheme { + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + } => { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = + &value.e + else { + return Err(format!( + "Expected lambda expression for {symbol}, got {}", + value.e + )); + }; + assert!(vars.is_empty()); + format!( + "fn {}({}) -> {} {{ {} }}\n", + escape_symbol(symbol), + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(&t))) + .format(", "), + map_type(return_type.as_ref()), + self.format_expr(body)? + ) + } + _ => format!( + "const {}: {} = {};\n", + escape_symbol(symbol), + map_type(&value.type_scheme.as_ref().unwrap().ty), + self.format_expr(&value.e)? + ), + }) + } + + fn try_generate_builtin(&self, symbol: &str) -> Option { + let code = match symbol { + "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), + "std::field::modulus" => { + let modulus = T::modulus(); + Some(format!("() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")) + } + "std::convert::fe" => Some("(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + .to_string()), + _ => None, + }?; + Some(format!("fn {}{code}", escape_symbol(symbol))) + } + + fn format_expr(&mut self, e: &Expression) -> Result { + Ok(match e { + Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), + Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { + self.request_symbol(name)?; + format!( + "{}{}", + escape_symbol(name), + // TODO do all type args work here? + type_args + .as_ref() + .map(|ta| format!("::{}", format_type_args(&ta))) + .unwrap_or_default() + ) + } + Expression::Number( + _, + Number { + value, + type_: Some(type_), + }, + ) => match type_ { + // TODO value does not need to be u64 + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + Type::TypeVar(t) => format!("{t}::from({value}_u64)"), + _ => unreachable!(), + }, + Expression::FunctionCall( + _, + FunctionCall { + function, + arguments, + }, + ) => { + format!( + "({})({})", + self.format_expr(function)?, + arguments + .iter() + .map(|a| self.format_expr(a)) + .collect::, _>>()? + .into_iter() + // TODO these should all be refs -> turn all types to arc + .map(|x| format!("{x}.clone()")) + .collect::>() + .join(", ") + ) + } + Expression::BinaryOperation(_, BinaryOperation { left, op, right }) => { + let left = self.format_expr(left)?; + let right = self.format_expr(right)?; + match op { + BinaryOperator::ShiftLeft => { + format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") + } + _ => format!("(({left}).clone() {op} ({right}).clone())"), + } + } + Expression::UnaryOperation(_, UnaryOperation { op, expr }) => { + format!("({op} ({}).clone())", self.format_expr(expr)?) + } + Expression::IndexAccess(_, IndexAccess { array, index }) => { + format!( + "{}[usize::try_from({}).unwrap()].clone()", + self.format_expr(array)?, + self.format_expr(index)? + ) + } + Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { + // let params = if *params == vec!["r".to_string()] { + // // Hack because rust needs the type + // vec!["r: Vec".to_string()] + // } else { + // params.clone() + // }; + format!( + "|{}| {{ {} }}", + params.iter().format(", "), + self.format_expr(body)? + ) + } + Expression::IfExpression( + _, + IfExpression { + condition, + body, + else_body, + }, + ) => { + format!( + "if {} {{ {} }} else {{ {} }}", + self.format_expr(condition)?, + self.format_expr(body)?, + self.format_expr(else_body)? + ) + } + Expression::ArrayLiteral(_, ArrayLiteral { items }) => { + format!( + "vec![{}]", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ) + } + Expression::String(_, s) => quote(s), + Expression::Tuple(_, items) => format!( + "({})", + items + .iter() + .map(|i| self.format_expr(i)) + .collect::, _>>()? + .join(", ") + ), + Expression::BlockExpression(_, BlockExpression { statements, expr }) => { + format!( + "{{\n{}\n{}\n}}", + statements + .iter() + .map(|s| self.format_statement(s)) + .collect::, _>>()? + .join("\n"), + expr.as_ref() + .map(|e| self.format_expr(e.as_ref())) + .transpose()? + .unwrap_or_default() + ) + } + _ => return Err(format!("Implement {e}")), + }) + } + + fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { + Err(format!("Implement {s}")) + } +} + +fn escape_symbol(s: &str) -> String { + s.replace('.', "_").replace("::", "_") +} + +fn map_type(ty: &Type) -> String { + match ty { + Type::Bottom | Type::Bool => format!("{ty}"), + Type::Int => "num_bigint::BigInt".to_string(), + Type::Fe => "FieldElement".to_string(), + Type::String => "String".to_string(), + Type::Expr => "Expr".to_string(), + Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), + Type::Tuple(_) => todo!(), + Type::Function(ft) => todo!("Type {ft}"), + Type::TypeVar(tv) => tv.to_string(), + Type::NamedType(_path, _type_args) => todo!(), + Type::Col | Type::Inter => unreachable!(), + } +} From 6aa9f4509558b42a247ff72314900f7c062dced9 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 09:58:26 +0000 Subject: [PATCH 11/57] wor --- jit-compiler/src/codegen.rs | 53 ++++++++++++----------------------- jit-compiler/tests/codegen.rs | 12 ++++++-- 2 files changed, 27 insertions(+), 38 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 24727e724..0e2ea33eb 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -1,21 +1,8 @@ -use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -use std::{ - collections::{HashMap, HashSet}, - ffi::CString, - fs::{self, create_dir, File}, - io::Write, - path, - process::Command, - sync::Arc, - time::Instant, -}; +use std::collections::{HashMap, HashSet}; use itertools::Itertools; use powdr_ast::{ - analyzed::{ - Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, - Reference, SymbolKind, - }, + analyzed::{Analyzed, Expression, FunctionValueDefinition, PolynomialReference, Reference}, parsed::{ display::{format_type_args, quote}, types::{ArrayType, FunctionType, Type, TypeScheme}, @@ -23,7 +10,7 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::{FieldElement, LargeInt}; +use powdr_number::FieldElement; pub struct Compiler<'a, T> { analyzed: &'a Analyzed, @@ -70,8 +57,9 @@ impl<'a, T: FieldElement> Compiler<'a, T> { pub fn compiled_symbols(self) -> String { self.symbols .into_iter() + .sorted() .map(|(_, code)| code) - .format("\n\n") + .format("\n") .to_string() } @@ -148,13 +136,12 @@ impl<'a, T: FieldElement> Compiler<'a, T> { Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { self.request_symbol(name)?; + let ta = type_args.as_ref().unwrap(); format!( "{}{}", escape_symbol(name), - // TODO do all type args work here? - type_args - .as_ref() - .map(|ta| format!("::{}", format_type_args(&ta))) + (!ta.is_empty()) + .then(|| format!("::{}", format_type_args(ta))) .unwrap_or_default() ) } @@ -164,14 +151,15 @@ impl<'a, T: FieldElement> Compiler<'a, T> { value, type_: Some(type_), }, - ) => match type_ { - // TODO value does not need to be u64 - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - Type::TypeVar(t) => format!("{t}::from({value}_u64)"), - _ => unreachable!(), - }, + ) => { + let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); + match type_ { + Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + _ => unreachable!(), + } + } Expression::FunctionCall( _, FunctionCall { @@ -214,12 +202,6 @@ impl<'a, T: FieldElement> Compiler<'a, T> { ) } Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) => { - // let params = if *params == vec!["r".to_string()] { - // // Hack because rust needs the type - // vec!["r: Vec".to_string()] - // } else { - // params.clone() - // }; format!( "|{}| {{ {} }}", params.iter().format(", "), @@ -284,6 +266,7 @@ impl<'a, T: FieldElement> Compiler<'a, T> { } fn escape_symbol(s: &str) -> String { + // TODO better escaping s.replace('.', "_").replace("::", "_") } diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs index 9505ef5f6..8187ef102 100644 --- a/jit-compiler/tests/codegen.rs +++ b/jit-compiler/tests/codegen.rs @@ -29,10 +29,16 @@ fn simple_fun() { } #[test] -fn constant() { - let result = compile("let c: int -> int = |i| i; let d = c(20);", &["c", "d"]); +fn fun_calls() { + let result = compile( + "let c: int -> int = |i| i + 20; let d = |k| c(k * 20);", + &["c", "d"], + ); assert_eq!( result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } + +fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } +" ); } From 01272e89818d74c6856749c82a001c8663386521 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 12:46:19 +0000 Subject: [PATCH 12/57] loading --- jit-compiler/src/codegen.rs | 14 ++--- jit-compiler/src/compiler.rs | 112 +++++++++++++++++++++++++++++++++- jit-compiler/src/lib.rs | 2 + jit-compiler/src/loader.rs | 2 +- jit-compiler/tests/codegen.rs | 4 +- 5 files changed, 123 insertions(+), 11 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 0e2ea33eb..96f2d15be 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -12,14 +12,14 @@ use powdr_ast::{ }; use powdr_number::FieldElement; -pub struct Compiler<'a, T> { +pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, requested: HashSet, failed: HashMap, symbols: HashMap, } -impl<'a, T: FieldElement> Compiler<'a, T> { +impl<'a, T: FieldElement> CodeGenerator<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, @@ -122,9 +122,9 @@ impl<'a, T: FieldElement> Compiler<'a, T> { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), "std::field::modulus" => { let modulus = T::modulus(); - Some(format!("() -> num_bigint::BigInt {{ num_bigint::BigInt::from(\"{modulus}\") }}")) + Some(format!("() -> powdr_number::BigInt {{ powdr_number::BigInt::from(\"{modulus}\") }}")) } - "std::convert::fe" => Some("(n: num_bigint::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + "std::convert::fe" => Some("(n: powdr_number::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()), _ => None, }?; @@ -154,7 +154,7 @@ impl<'a, T: FieldElement> Compiler<'a, T> { ) => { let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); match type_ { - Type::Int => format!("num_bigint::BigInt::from({value}_u64)"), + Type::Int => format!("powdr_number::BigInt::from({value}_u64)"), Type::Fe => format!("FieldElement::from({value}_u64)"), Type::Expr => format!("Expr::from({value}_u64)"), _ => unreachable!(), @@ -265,7 +265,7 @@ impl<'a, T: FieldElement> Compiler<'a, T> { } } -fn escape_symbol(s: &str) -> String { +pub fn escape_symbol(s: &str) -> String { // TODO better escaping s.replace('.', "_").replace("::", "_") } @@ -273,7 +273,7 @@ fn escape_symbol(s: &str) -> String { fn map_type(ty: &Type) -> String { match ty { Type::Bottom | Type::Bool => format!("{ty}"), - Type::Int => "num_bigint::BigInt".to_string(), + Type::Int => "powdr_number::BigInt".to_string(), Type::Fe => "FieldElement".to_string(), Type::String => "String".to_string(), Type::Expr => "Expr".to_string(), diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index b9f5e2def..f6c182683 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1 +1,111 @@ -// TODO run cargo and stuff +use libc::{c_void, dlopen, dlsym, RTLD_NOW}; +use std::{ + collections::{HashMap}, + ffi::CString, + fs::{self}, + path, + process::Command, +}; + +use itertools::Itertools; +use powdr_ast::{ + analyzed::{ + Analyzed, + }, +}; +use powdr_number::FieldElement; + +use crate::codegen::{escape_symbol, CodeGenerator}; + +// TODO make this depend on T + +const PREAMBLE: &str = r#" +#![allow(unused_parens)] +type FieldElement = powdr_number::goldilocks::GoldilocksField; +"#; + +// TODO this is the old impl of goldilocks + +const CARGO_TOML: &str = r#" +[package] +name = "powdr_jit_compiled" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["dylib"] + +[dependencies] +powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +"#; + +pub fn compile( + analyzed: &Analyzed, + symbols: &[&str], +) -> Result u64>, String> { + let mut codegen = CodeGenerator::new(analyzed); + let mut glue = String::new(); + for sym in symbols { + codegen.request_symbol(sym)?; + // TODO verify that the type is `int -> int`. + // TODO we should use big int instead of u64 + let name = escape_symbol(sym); + glue.push_str(&format!( + r#" + #[no_mangle] + pub extern fn extern_{name}(i: u64) -> u64 {{ + {name}(powdr_number::BigInt::from(i)).into_bigint().0[0] + }} + "# + )); + } + + let code = format!("{PREAMBLE}\n{}\n{glue}\n", codegen.compiled_symbols()); + println!("Compiled code:\n{code}"); + + // TODO for testing, keep the dir the same + //let dir = mktemp::Temp::new_dir().unwrap(); + let _ = fs::remove_dir_all("/tmp/powdr_constants"); + fs::create_dir("/tmp/powdr_constants").unwrap(); + let dir = path::Path::new("/tmp/powdr_constants"); + fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); + fs::create_dir(dir.join("src")).unwrap(); + fs::write(dir.join("src").join("lib.rs"), code).unwrap(); + let out = Command::new("cargo") + .arg("build") + .arg("--release") + .current_dir(dir) + .output() + .unwrap(); + out.stderr.iter().for_each(|b| print!("{}", *b as char)); + if !out.status.success() { + panic!("Failed to compile."); + } + + let lib_path = CString::new( + dir.join("target") + .join("release") + .join("libpowdr_constants.so") + .to_str() + .unwrap(), + ) + .unwrap(); + + let lib = unsafe { dlopen(lib_path.as_ptr(), RTLD_NOW) }; + if lib.is_null() { + panic!("Failed to load library: {lib_path:?}"); + } + let mut result = HashMap::new(); + for sym in symbols { + let sym = format!("extern_{}", escape_symbol(sym)); + let sym_cstr = CString::new(sym.clone()).unwrap(); + let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; + if fun_ptr.is_null() { + return Err(format!("Failed to load symbol: {fun_ptr:?}")); + } + println!("Loaded symbol: {fun_ptr:?}"); + let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; + result.insert(sym, fun); + } + Ok(result) +} diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 1ca0c6654..3c347edfa 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,3 +1,5 @@ pub mod codegen; pub mod compiler; pub mod loader; + +//let n = num_bigint::BigUint::from_bytes_le(&n.to_le_bytes()); diff --git a/jit-compiler/src/loader.rs b/jit-compiler/src/loader.rs index 690bc5467..ac0b0042b 100644 --- a/jit-compiler/src/loader.rs +++ b/jit-compiler/src/loader.rs @@ -101,7 +101,7 @@ // } // println!("Loaded symbol: {:?}", sym); // let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); -// let degrees = if let Some(degree) = poly.degree { +// let degrees = if let Some(degree) = poly.degraee { // vec![degree] // } else { // (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs index 8187ef102..2344cb07b 100644 --- a/jit-compiler/tests/codegen.rs +++ b/jit-compiler/tests/codegen.rs @@ -1,4 +1,4 @@ -use powdr_jit_compiler::codegen::Compiler; +use powdr_jit_compiler::codegen::CodeGenerator; use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; @@ -6,7 +6,7 @@ use pretty_assertions::assert_eq; fn compile(input: &str, syms: &[&str]) -> String { let analyzed = analyze_string::(input); - let mut compiler = Compiler::new(&analyzed); + let mut compiler = CodeGenerator::new(&analyzed); for s in syms { compiler.request_symbol(s).unwrap(); } From 249f3c879b12ce9c412a87fb44612ec5a987f2a6 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 12:50:35 +0000 Subject: [PATCH 13/57] fix --- jit-compiler/src/compiler.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index f6c182683..4efc93afa 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,6 +1,6 @@ use libc::{c_void, dlopen, dlsym, RTLD_NOW}; use std::{ - collections::{HashMap}, + collections::HashMap, ffi::CString, fs::{self}, path, @@ -8,11 +8,7 @@ use std::{ }; use itertools::Itertools; -use powdr_ast::{ - analyzed::{ - Analyzed, - }, -}; +use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; use crate::codegen::{escape_symbol, CodeGenerator}; @@ -21,7 +17,7 @@ use crate::codegen::{escape_symbol, CodeGenerator}; const PREAMBLE: &str = r#" #![allow(unused_parens)] -type FieldElement = powdr_number::goldilocks::GoldilocksField; +//type FieldElement = powdr_number::GoldilocksField; "#; // TODO this is the old impl of goldilocks @@ -54,7 +50,7 @@ pub fn compile( r#" #[no_mangle] pub extern fn extern_{name}(i: u64) -> u64 {{ - {name}(powdr_number::BigInt::from(i)).into_bigint().0[0] + u64::try_from({name}(powdr_number::BigInt::from(i))).unwrap() }} "# )); @@ -85,7 +81,7 @@ pub fn compile( let lib_path = CString::new( dir.join("target") .join("release") - .join("libpowdr_constants.so") + .join("libpowdr_jit_compiled.so") .to_str() .unwrap(), ) @@ -97,15 +93,15 @@ pub fn compile( } let mut result = HashMap::new(); for sym in symbols { - let sym = format!("extern_{}", escape_symbol(sym)); - let sym_cstr = CString::new(sym.clone()).unwrap(); + let extern_sym = format!("extern_{}", escape_symbol(sym)); + let sym_cstr = CString::new(extern_sym).unwrap(); let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; if fun_ptr.is_null() { return Err(format!("Failed to load symbol: {fun_ptr:?}")); } println!("Loaded symbol: {fun_ptr:?}"); let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; - result.insert(sym, fun); + result.insert(sym.to_string(), fun); } Ok(result) } From 1570fb08ef297d8486474ba2b701eababcf6260b Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 12:53:16 +0000 Subject: [PATCH 14/57] sqrt --- jit-compiler/src/compiler.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 4efc93afa..c996b749b 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -7,7 +7,6 @@ use std::{ process::Command, }; -use itertools::Itertools; use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; From 81dca2bf2383f37a7591e7ac7e6aad78306308a2 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 16:13:47 +0000 Subject: [PATCH 15/57] add benchmark --- jit-compiler/src/lib.rs | 3 +++ pil-analyzer/tests/types.rs | 2 +- pipeline/Cargo.toml | 1 + pipeline/benches/evaluator_benchmark.rs | 34 ++++++++++++++++++++++++- 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 3c347edfa..8eb456054 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,5 +1,8 @@ +// TODO make them non-pub? pub mod codegen; pub mod compiler; pub mod loader; //let n = num_bigint::BigUint::from_bytes_le(&n.to_le_bytes()); + +pub use compiler::compile; diff --git a/pil-analyzer/tests/types.rs b/pil-analyzer/tests/types.rs index 253ed4e34..a149de289 100644 --- a/pil-analyzer/tests/types.rs +++ b/pil-analyzer/tests/types.rs @@ -704,7 +704,7 @@ fn trait_user_defined_enum_wrong_type() { } let n: int = 7; - let r1 = Convert::convert(n); + let r1: int = Convert::convert(n); "; type_check(input, &[]); } diff --git a/pipeline/Cargo.toml b/pipeline/Cargo.toml index 2e1f3bc28..67a89ec50 100644 --- a/pipeline/Cargo.toml +++ b/pipeline/Cargo.toml @@ -41,6 +41,7 @@ num-traits = "0.2.15" test-log = "0.2.12" env_logger = "0.10.0" criterion = { version = "0.4", features = ["html_reports"] } +powdr-jit-compiler.workspace = true [package.metadata.cargo-udeps.ignore] development = ["env_logger"] diff --git a/pipeline/benches/evaluator_benchmark.rs b/pipeline/benches/evaluator_benchmark.rs index f5d2eb483..76d027cc2 100644 --- a/pipeline/benches/evaluator_benchmark.rs +++ b/pipeline/benches/evaluator_benchmark.rs @@ -114,5 +114,37 @@ fn evaluator_benchmark(c: &mut Criterion) { group.finish(); } -criterion_group!(benches_pil, evaluator_benchmark); +fn jit_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("jit-benchmark"); + + let sqrt_analyzed: Analyzed = { + let code = " + let sqrt: int -> int = |x| sqrt_rec(x, x); + let sqrt_rec: int, int -> int = |y, x| + if y * y <= x && (y + 1) * (y + 1) > x { + y + } else { + sqrt_rec((y + x / y) / 2, x) + }; + " + .to_string(); + let mut pipeline = Pipeline::default().from_asm_string(code, None); + pipeline.compute_analyzed_pil().unwrap().clone() + }; + + let sqrt_fun = powdr_jit_compiler::compile(&sqrt_analyzed, &["sqrt"]).unwrap()["sqrt"]; + + for x in [879882356, 1882356, 1187956, 56] { + group.bench_with_input(format!("sqrt_{x}"), &x, |b, &x| { + b.iter(|| { + let y = (x as u64) * 112655675; + sqrt_fun(y); + }); + }); + } + + group.finish(); +} + +criterion_group!(benches_pil, evaluator_benchmark, jit_benchmark); criterion_main!(benches_pil); From 641f1f2432375a718209752751eafa768e22de05 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 12 Sep 2024 16:33:27 +0000 Subject: [PATCH 16/57] forgot test file --- jit-compiler/tests/execution.rs | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 jit-compiler/tests/execution.rs diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs new file mode 100644 index 000000000..dbb1b0f4e --- /dev/null +++ b/jit-compiler/tests/execution.rs @@ -0,0 +1,41 @@ +use powdr_jit_compiler::compiler; +use powdr_number::GoldilocksField; +use powdr_pil_analyzer::analyze_string; + +fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { + let analyzed = analyze_string::(input); + compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] +} + +#[test] +fn identity_function() { + let f = compile("let c: int -> int = |i| i;", "c"); + + assert_eq!(f(10), 10); +} + +#[test] +fn sqrt() { + let f = compile( + " + let sqrt_rec: int, int -> int = |y, x| + if y * y <= x && (y + 1) * (y + 1) > x { + y + } else { + sqrt_rec((y + x / y) / 2, x) + }; + + let sqrt: int -> int = |x| sqrt_rec(x, x);", + "sqrt", + ); + + for i in 0..100000 { + f(879882356 * 112655675); + // assert_eq!(f(9), 3); + // assert_eq!(f(100), 10); + // assert_eq!(f(8), 2); + // assert_eq!(f(101), 10); + // assert_eq!(f(99), 9); + // assert_eq!(f(0), 0); + } +} From 65c6c049bcd9c6d97b27cdf32e1e2a5ec4c32e18 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:16:00 +0000 Subject: [PATCH 17/57] clean. --- jit-compiler/src/codegen.rs | 53 ++++++++++++- jit-compiler/src/compiler.rs | 113 ++++++++++++++++----------- jit-compiler/src/lib.rs | 31 ++++++-- jit-compiler/src/loader.rs | 130 -------------------------------- jit-compiler/tests/codegen.rs | 44 ----------- jit-compiler/tests/execution.rs | 18 ++--- 6 files changed, 150 insertions(+), 239 deletions(-) delete mode 100644 jit-compiler/src/loader.rs delete mode 100644 jit-compiler/tests/codegen.rs diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 96f2d15be..841c6c557 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -50,10 +50,6 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } - pub fn is_compiled(&self, name: &str) -> bool { - self.symbols.contains_key(name) - } - pub fn compiled_symbols(self) -> String { self.symbols .into_iter() @@ -285,3 +281,52 @@ fn map_type(ty: &Type) -> String { Type::Col | Type::Inter => unreachable!(), } } + +#[cfg(test)] +mod test { + use powdr_number::GoldilocksField; + use powdr_pil_analyzer::analyze_string; + + use pretty_assertions::assert_eq; + + use super::CodeGenerator; + + fn compile(input: &str, syms: &[&str]) -> String { + let analyzed = analyze_string::(input); + let mut compiler = CodeGenerator::new(&analyzed); + for s in syms { + compiler.request_symbol(s).unwrap(); + } + compiler.compiled_symbols() + } + + #[test] + fn empty_code() { + let result = compile("", &[]); + assert_eq!(result, ""); + } + + #[test] + fn simple_fun() { + let result = compile("let c: int -> int = |i| i;", &["c"]); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + ); + } + + #[test] + fn fun_calls() { + let result = compile( + "let c: int -> int = |i| i + 20; let d = |k| c(k * 20);", + &["c", "d"], + ); + assert_eq!( + result, + "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } + +fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } +" + ); + } +} diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index c996b749b..60ae6e706 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,16 +1,25 @@ use libc::{c_void, dlopen, dlsym, RTLD_NOW}; +use mktemp::Temp; use std::{ collections::HashMap, ffi::CString, fs::{self}, - path, process::Command, }; -use powdr_ast::analyzed::Analyzed; +use powdr_ast::{ + analyzed::Analyzed, + parsed::{ + display::format_type_scheme_around_name, + types::{FunctionType, Type, TypeScheme}, + }, +}; use powdr_number::FieldElement; -use crate::codegen::{escape_symbol, CodeGenerator}; +use crate::{ + codegen::{escape_symbol, CodeGenerator}, + SymbolMap, +}; // TODO make this depend on T @@ -19,88 +28,104 @@ const PREAMBLE: &str = r#" //type FieldElement = powdr_number::GoldilocksField; "#; -// TODO this is the old impl of goldilocks - -const CARGO_TOML: &str = r#" -[package] -name = "powdr_jit_compiled" -version = "0.1.0" -edition = "2021" - -[lib] -crate-type = ["dylib"] - -[dependencies] -powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } -"#; - -pub fn compile( +pub fn create_full_code( analyzed: &Analyzed, symbols: &[&str], -) -> Result u64>, String> { +) -> Result { let mut codegen = CodeGenerator::new(analyzed); let mut glue = String::new(); + let int_int_fun: TypeScheme = Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Int), + }) + .into(); for sym in symbols { + let ty = analyzed.type_of_symbol(sym); + if &ty != &int_int_fun { + return Err(format!( + "Only (int -> int) functions are supported, but requested {}", + format_type_scheme_around_name(sym, &Some(ty)), + )); + } codegen.request_symbol(sym)?; - // TODO verify that the type is `int -> int`. // TODO we should use big int instead of u64 let name = escape_symbol(sym); glue.push_str(&format!( r#" #[no_mangle] - pub extern fn extern_{name}(i: u64) -> u64 {{ + pub extern fn {}(i: u64) -> u64 {{ u64::try_from({name}(powdr_number::BigInt::from(i))).unwrap() }} - "# + "#, + extern_symbol_name(sym) )); } - let code = format!("{PREAMBLE}\n{}\n{glue}\n", codegen.compiled_symbols()); - println!("Compiled code:\n{code}"); + Ok(format!( + "{PREAMBLE}\n{}\n{glue}\n", + codegen.compiled_symbols() + )) +} - // TODO for testing, keep the dir the same - //let dir = mktemp::Temp::new_dir().unwrap(); - let _ = fs::remove_dir_all("/tmp/powdr_constants"); - fs::create_dir("/tmp/powdr_constants").unwrap(); - let dir = path::Path::new("/tmp/powdr_constants"); +const CARGO_TOML: &str = r#" +[package] +name = "powdr_jit_compiled" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["dylib"] + +[dependencies] +powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +"#; + +/// Compiles the given code and returns the path to the +/// temporary directory containing the compiled library +/// and the path to the compiled library. +pub fn call_cargo(code: &str) -> (Temp, String) { + let dir = mktemp::Temp::new_dir().unwrap(); fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); fs::create_dir(dir.join("src")).unwrap(); fs::write(dir.join("src").join("lib.rs"), code).unwrap(); let out = Command::new("cargo") .arg("build") .arg("--release") - .current_dir(dir) + .current_dir(dir.clone()) .output() .unwrap(); out.stderr.iter().for_each(|b| print!("{}", *b as char)); if !out.status.success() { panic!("Failed to compile."); } + let lib_path = dir + .join("target") + .join("release") + .join("libpowdr_jit_compiled.so"); + (dir, lib_path.to_str().unwrap().to_string()) +} - let lib_path = CString::new( - dir.join("target") - .join("release") - .join("libpowdr_jit_compiled.so") - .to_str() - .unwrap(), - ) - .unwrap(); - - let lib = unsafe { dlopen(lib_path.as_ptr(), RTLD_NOW) }; +/// Loads the given library and creates funtion pointers for the given symbols. +pub fn load_library(path: &str, symbols: &[&str]) -> Result { + let c_path = CString::new(path).unwrap(); + let lib = unsafe { dlopen(c_path.as_ptr(), RTLD_NOW) }; if lib.is_null() { - panic!("Failed to load library: {lib_path:?}"); + return Err(format!("Failed to load library: {path:?}")); } let mut result = HashMap::new(); for sym in symbols { - let extern_sym = format!("extern_{}", escape_symbol(sym)); + let extern_sym = extern_symbol_name(sym); let sym_cstr = CString::new(extern_sym).unwrap(); let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; if fun_ptr.is_null() { return Err(format!("Failed to load symbol: {fun_ptr:?}")); } - println!("Loaded symbol: {fun_ptr:?}"); let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; result.insert(sym.to_string(), fun); } Ok(result) } + +fn extern_symbol_name(sym: &str) -> String { + format!("extern_{}", escape_symbol(sym)) +} diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 8eb456054..cb49f3133 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,8 +1,27 @@ -// TODO make them non-pub? -pub mod codegen; -pub mod compiler; -pub mod loader; +mod codegen; +mod compiler; -//let n = num_bigint::BigUint::from_bytes_le(&n.to_le_bytes()); +use std::collections::HashMap; -pub use compiler::compile; +use compiler::{call_cargo, create_full_code, load_library}; +use powdr_ast::analyzed::Analyzed; +use powdr_number::FieldElement; + +pub type SymbolMap = HashMap u64>; + +/// Compiles the given symbols (and their dependencies) and returns them as a map +/// from symbol name to function pointer. +/// Only functions of type (int -> int) are supported for now. +pub fn compile( + analyzed: &Analyzed, + symbols: &[&str], +) -> Result { + let code = create_full_code(analyzed, symbols)?; + + let (dir, lib_path) = call_cargo(&code); + + let result = load_library(&lib_path, symbols); + + drop(dir); + result +} diff --git a/jit-compiler/src/loader.rs b/jit-compiler/src/loader.rs deleted file mode 100644 index ac0b0042b..000000000 --- a/jit-compiler/src/loader.rs +++ /dev/null @@ -1,130 +0,0 @@ -// use libc::{c_void, dlclose, dlopen, dlsym, RTLD_NOW}; -// use rayon::iter::{IntoParallelIterator, ParallelIterator}; -// use std::{ -// collections::{HashMap, HashSet}, -// ffi::CString, -// fs::{self, create_dir, File}, -// io::Write, -// path, -// process::Command, -// sync::Arc, -// time::Instant, -// }; - -// use itertools::Itertools; -// use powdr_ast::{ -// analyzed::{ -// Analyzed, Expression, FunctionValueDefinition, PolyID, PolynomialReference, PolynomialType, -// Reference, SymbolKind, -// }, -// parsed::{ -// display::{format_type_args, quote}, -// types::{ArrayType, FunctionType, Type, TypeScheme}, -// ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, -// IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, -// }, -// }; -// use powdr_number::FieldElement; - -// // pub fn generate_fixed_cols( -// // analyzed: &Analyzed, -// // ) -> HashMap)> { -// // let mut compiler = Compiler::new(analyzed); -// // let mut glue = String::new(); -// // for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // // ignore err -// // if let Err(e) = compiler.request_symbol(&sym.absolute_name) { -// // println!("Failed to compile {}: {e}", &sym.absolute_name); -// // } -// // } -// // for (sym, _) in &analyzed.constant_polys_in_source_order() { -// // // TODO escape? -// // if compiler.is_compiled(&sym.absolute_name) { -// // // TODO it is a rust function, can we use a more complex type as well? -// // // TODO only works for goldilocks -// // glue.push_str(&format!( -// // r#" -// #[no_mangle] -// pub extern fn extern_{}(i: u64) -> u64 {{ -// {}(num_bigint::BigInt::from(i)).into_bigint().0[0] -// }} -// "#, -// escape(&sym.absolute_name), -// escape(&sym.absolute_name), -// )); -// } -// } - -// let code = format!("{PREAMBLE}\n{}\n{glue}\n", compiler.compiled_symbols()); -// println!("Compiled code:\n{code}"); - -// //let dir = mktemp::Temp::new_dir().unwrap(); -// let _ = fs::remove_dir_all("/tmp/powdr_constants"); -// fs::create_dir("/tmp/powdr_constants").unwrap(); -// let dir = path::Path::new("/tmp/powdr_constants"); -// fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); -// fs::create_dir(dir.join("src")).unwrap(); -// fs::write(dir.join("src").join("lib.rs"), code).unwrap(); -// let out = Command::new("cargo") -// .arg("build") -// .arg("--release") -// .current_dir(dir) -// .output() -// .unwrap(); -// out.stderr.iter().for_each(|b| print!("{}", *b as char)); -// if !out.status.success() { -// panic!("Failed to compile."); -// } - -// let mut columns = HashMap::new(); -// unsafe { -// let lib_path = CString::new( -// dir.join("target") -// .join("release") -// .join("libpowdr_constants.so") -// .to_str() -// .unwrap(), -// ) -// .unwrap(); -// let lib = dlopen(lib_path.as_ptr(), RTLD_NOW); -// if lib.is_null() { -// panic!("Failed to load library: {:?}", lib_path); -// } -// let start = Instant::now(); -// for (poly, value) in analyzed.constant_polys_in_source_order() { -// let sym = format!("extern_{}", escape(&poly.absolute_name)); -// let sym = CString::new(sym).unwrap(); -// let sym = dlsym(lib, sym.as_ptr()); -// if sym.is_null() { -// println!("Failed to load symbol: {:?}", sym); -// continue; -// } -// println!("Loaded symbol: {:?}", sym); -// let fun = std::mem::transmute::<*mut c_void, fn(u64) -> u64>(sym); -// let degrees = if let Some(degree) = poly.degraee { -// vec![degree] -// } else { -// (MIN_DEGREE_LOG..=MAX_DEGREE_LOG) -// .map(|degree_log| 1 << degree_log) -// .collect::>() -// }; - -// let col_values = degrees -// .into_iter() -// .map(|degree| { -// (0..degree) -// .into_par_iter() -// .map(|i| T::from(fun(i as u64))) -// .collect::>() -// }) -// .collect::>() -// .into(); -// columns.insert(poly.absolute_name.clone(), (poly.into(), col_values)); -// } -// log::info!( -// "Fixed column generation (without compilation and loading time) took {}s", -// start.elapsed().as_secs_f32() -// ); -// } -// columns -// } diff --git a/jit-compiler/tests/codegen.rs b/jit-compiler/tests/codegen.rs deleted file mode 100644 index 2344cb07b..000000000 --- a/jit-compiler/tests/codegen.rs +++ /dev/null @@ -1,44 +0,0 @@ -use powdr_jit_compiler::codegen::CodeGenerator; -use powdr_number::GoldilocksField; -use powdr_pil_analyzer::analyze_string; - -use pretty_assertions::assert_eq; - -fn compile(input: &str, syms: &[&str]) -> String { - let analyzed = analyze_string::(input); - let mut compiler = CodeGenerator::new(&analyzed); - for s in syms { - compiler.request_symbol(s).unwrap(); - } - compiler.compiled_symbols() -} - -#[test] -fn empty_code() { - let result = compile("", &[]); - assert_eq!(result, ""); -} - -#[test] -fn simple_fun() { - let result = compile("let c: int -> int = |i| i;", &["c"]); - assert_eq!( - result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" - ); -} - -#[test] -fn fun_calls() { - let result = compile( - "let c: int -> int = |i| i + 20; let d = |k| c(k * 20);", - &["c", "d"], - ); - assert_eq!( - result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } - -fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } -" - ); -} diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index dbb1b0f4e..697cb0391 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -1,10 +1,9 @@ -use powdr_jit_compiler::compiler; use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { let analyzed = analyze_string::(input); - compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] + powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] } #[test] @@ -29,13 +28,10 @@ fn sqrt() { "sqrt", ); - for i in 0..100000 { - f(879882356 * 112655675); - // assert_eq!(f(9), 3); - // assert_eq!(f(100), 10); - // assert_eq!(f(8), 2); - // assert_eq!(f(101), 10); - // assert_eq!(f(99), 9); - // assert_eq!(f(0), 0); - } + assert_eq!(f(9), 3); + assert_eq!(f(100), 10); + assert_eq!(f(8), 2); + assert_eq!(f(101), 10); + assert_eq!(f(99), 9); + assert_eq!(f(0), 0); } From 1578b1e5d32c8c064f337c814dca41ac466b6dc6 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:30:38 +0000 Subject: [PATCH 18/57] fix --- jit-compiler/src/codegen.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 841c6c557..aaa2a533d 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -311,7 +311,7 @@ mod test { let result = compile("let c: int -> int = |i| i;", &["c"]); assert_eq!( result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { i }\n" + "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { i }\n" ); } @@ -323,9 +323,9 @@ mod test { ); assert_eq!( result, - "fn c(i: num_bigint::BigInt) -> num_bigint::BigInt { ((i).clone() + (num_bigint::BigInt::from(20_u64)).clone()) } + "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { ((i).clone() + (powdr_number::BigInt::from(20_u64)).clone()) } -fn d(k: num_bigint::BigInt) -> num_bigint::BigInt { (c)(((k).clone() * (num_bigint::BigInt::from(20_u64)).clone()).clone()) } +fn d(k: powdr_number::BigInt) -> powdr_number::BigInt { (c)(((k).clone() * (powdr_number::BigInt::from(20_u64)).clone()).clone()) } " ); } From ffa6187715458ef401ae139b35495ec50bea4354 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:40:21 +0000 Subject: [PATCH 19/57] Some logging. --- jit-compiler/Cargo.toml | 4 ++++ jit-compiler/src/lib.rs | 7 ++++++- jit-compiler/tests/execution.rs | 2 ++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index 94da05758..f3330c1d9 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -13,12 +13,16 @@ powdr-number.workspace = true powdr-parser.workspace = true libc = "0.2.0" +log = "0.4.18" mktemp = "0.5.0" itertools = "0.13" [dev-dependencies] powdr-pil-analyzer.workspace = true pretty_assertions = "1.4.0" +test-log = "0.2.12" +env_logger = "0.10.0" + [lints.clippy] diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index cb49f3133..2102aea54 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,7 +1,7 @@ mod codegen; mod compiler; -use std::collections::HashMap; +use std::{collections::HashMap, fs}; use compiler::{call_cargo, create_full_code, load_library}; use powdr_ast::analyzed::Analyzed; @@ -16,11 +16,16 @@ pub fn compile( analyzed: &Analyzed, symbols: &[&str], ) -> Result { + log::info!("JIT-compiling {} symbols...", symbols.len()); let code = create_full_code(analyzed, symbols)?; let (dir, lib_path) = call_cargo(&code); + let metadata = fs::metadata(&lib_path).unwrap(); + + log::info!("Loading library with size {}...", metadata.len()); let result = load_library(&lib_path, symbols); + log::info!("Done."); drop(dir); result diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 697cb0391..ea705635c 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -1,3 +1,5 @@ +use test_log::test; + use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; From 51cae14c50f4e8da4863789528ec062738db500f Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:45:28 +0000 Subject: [PATCH 20/57] size in mb. --- jit-compiler/src/lib.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 2102aea54..b88cc02e4 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -22,7 +22,10 @@ pub fn compile( let (dir, lib_path) = call_cargo(&code); let metadata = fs::metadata(&lib_path).unwrap(); - log::info!("Loading library with size {}...", metadata.len()); + log::info!( + "Loading library with size {} MB...", + metadata.len() as f64 / 1000000.0 + ); let result = load_library(&lib_path, symbols); log::info!("Done."); From 22de29a751afd97ed2f538fcc6d3100d424694bc Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 18:48:44 +0200 Subject: [PATCH 21/57] Update pil-analyzer/tests/types.rs --- pil-analyzer/tests/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pil-analyzer/tests/types.rs b/pil-analyzer/tests/types.rs index a149de289..253ed4e34 100644 --- a/pil-analyzer/tests/types.rs +++ b/pil-analyzer/tests/types.rs @@ -704,7 +704,7 @@ fn trait_user_defined_enum_wrong_type() { } let n: int = 7; - let r1: int = Convert::convert(n); + let r1 = Convert::convert(n); "; type_check(input, &[]); } From 2682d279e3cdaa9b2e5f83947ded06e133cf81e3 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 16:52:40 +0000 Subject: [PATCH 22/57] Use ibig. --- jit-compiler/src/codegen.rs | 14 +++++++------- jit-compiler/src/compiler.rs | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index aaa2a533d..f05edd414 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -118,9 +118,9 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), "std::field::modulus" => { let modulus = T::modulus(); - Some(format!("() -> powdr_number::BigInt {{ powdr_number::BigInt::from(\"{modulus}\") }}")) + Some(format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}")) } - "std::convert::fe" => Some("(n: powdr_number::BigInt) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + "std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()), _ => None, }?; @@ -150,7 +150,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { ) => { let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); match type_ { - Type::Int => format!("powdr_number::BigInt::from({value}_u64)"), + Type::Int => format!("ibig::IBig::from({value}_u64)"), Type::Fe => format!("FieldElement::from({value}_u64)"), Type::Expr => format!("Expr::from({value}_u64)"), _ => unreachable!(), @@ -269,7 +269,7 @@ pub fn escape_symbol(s: &str) -> String { fn map_type(ty: &Type) -> String { match ty { Type::Bottom | Type::Bool => format!("{ty}"), - Type::Int => "powdr_number::BigInt".to_string(), + Type::Int => "ibig::IBig".to_string(), Type::Fe => "FieldElement".to_string(), Type::String => "String".to_string(), Type::Expr => "Expr".to_string(), @@ -311,7 +311,7 @@ mod test { let result = compile("let c: int -> int = |i| i;", &["c"]); assert_eq!( result, - "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { i }\n" + "fn c(i: ibig::IBig) -> ibig::IBig { i }\n" ); } @@ -323,9 +323,9 @@ mod test { ); assert_eq!( result, - "fn c(i: powdr_number::BigInt) -> powdr_number::BigInt { ((i).clone() + (powdr_number::BigInt::from(20_u64)).clone()) } + "fn c(i: ibig::IBig) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) } -fn d(k: powdr_number::BigInt) -> powdr_number::BigInt { (c)(((k).clone() * (powdr_number::BigInt::from(20_u64)).clone()).clone()) } +fn d(k: ibig::IBig) -> ibig::IBig { (c)(((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone()) } " ); } diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 60ae6e706..321d39eb8 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -54,7 +54,7 @@ pub fn create_full_code( r#" #[no_mangle] pub extern fn {}(i: u64) -> u64 {{ - u64::try_from({name}(powdr_number::BigInt::from(i))).unwrap() + u64::try_from({name}(ibig::IBig::from(i))).unwrap() }} "#, extern_symbol_name(sym) @@ -77,7 +77,7 @@ edition = "2021" crate-type = ["dylib"] [dependencies] -powdr-number = { git = "https://github.com/powdr-labs/powdr.git" } +ibig = { version = "0.3.6", features = [] } "#; /// Compiles the given code and returns the path to the From e9d291096af77b529a71bf97a22bb3e0d9a49b95 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 13 Sep 2024 17:09:22 +0000 Subject: [PATCH 23/57] Use native cpu. --- jit-compiler/src/compiler.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 321d39eb8..fef3e5e89 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -89,6 +89,7 @@ pub fn call_cargo(code: &str) -> (Temp, String) { fs::create_dir(dir.join("src")).unwrap(); fs::write(dir.join("src").join("lib.rs"), code).unwrap(); let out = Command::new("cargo") + .env("RUSTFLAGS", "-C target-cpu=native") .arg("build") .arg("--release") .current_dir(dir.clone()) From 541d8dc9c6fd1f24e88c3698370605f1ae3b4724 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 19 Sep 2024 13:56:55 +0000 Subject: [PATCH 24/57] clippy --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index fef3e5e89..c98766160 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -41,7 +41,7 @@ pub fn create_full_code( .into(); for sym in symbols { let ty = analyzed.type_of_symbol(sym); - if &ty != &int_int_fun { + if ty != int_int_fun { return Err(format!( "Only (int -> int) functions are supported, but requested {}", format_type_scheme_around_name(sym, &Some(ty)), From 5aea6b2f394618c50bd70bc8944a72773c4dc0de Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 23 Sep 2024 13:51:09 +0000 Subject: [PATCH 25/57] merge fix. --- jit-compiler/src/codegen.rs | 7 ++----- jit-compiler/tests/execution.rs | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index f05edd414..dc49df562 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -292,7 +292,7 @@ mod test { use super::CodeGenerator; fn compile(input: &str, syms: &[&str]) -> String { - let analyzed = analyze_string::(input); + let analyzed = analyze_string::(input).unwrap(); let mut compiler = CodeGenerator::new(&analyzed); for s in syms { compiler.request_symbol(s).unwrap(); @@ -309,10 +309,7 @@ mod test { #[test] fn simple_fun() { let result = compile("let c: int -> int = |i| i;", &["c"]); - assert_eq!( - result, - "fn c(i: ibig::IBig) -> ibig::IBig { i }\n" - ); + assert_eq!(result, "fn c(i: ibig::IBig) -> ibig::IBig { i }\n"); } #[test] diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index ea705635c..b20b00db2 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -4,7 +4,7 @@ use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { - let analyzed = analyze_string::(input); + let analyzed = analyze_string::(input).unwrap(); powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] } From 1a6db99908ac40e3ec0e2258b41698a9bb587431 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 23 Sep 2024 15:05:16 +0000 Subject: [PATCH 26/57] Remove ibig features. --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index c98766160..d1f48f5ff 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -77,7 +77,7 @@ edition = "2021" crate-type = ["dylib"] [dependencies] -ibig = { version = "0.3.6", features = [] } +ibig = { version = "0.3.6", features = [], default-features = false } "#; /// Compiles the given code and returns the path to the From cdf3ba99aa80d355934989b443082087f51a60c1 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 00:33:56 +0000 Subject: [PATCH 27/57] Nicer error messages. --- jit-compiler/src/codegen.rs | 54 +++++++++++++++++++++++------------- jit-compiler/src/compiler.rs | 13 +++++---- jit-compiler/src/lib.rs | 2 +- 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index dc49df562..cad73cc92 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -83,26 +83,19 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { value: return_type, }), } => { - let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = - &value.e - else { - return Err(format!( - "Expected lambda expression for {symbol}, got {}", - value.e - )); - }; assert!(vars.is_empty()); - format!( - "fn {}({}) -> {} {{ {} }}\n", - escape_symbol(symbol), - params - .iter() - .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(&t))) - .format(", "), - map_type(return_type.as_ref()), - self.format_expr(body)? - ) + self.try_format_function(symbol, ¶m_types, return_type.as_ref(), &value.e)? + } + TypeScheme { + vars, + ty: Type::Col, + } => { + assert!(vars.is_empty()); + // TODO we assume it is an int -> int function. + // The type inference algorithm should store the derived type. + // Alternatively, we insert a trait conversion function and store the type + // in the trait vars. + self.try_format_function(symbol, &[Type::Int], &Type::Int, &value.e)? } _ => format!( "const {}: {} = {};\n", @@ -113,6 +106,29 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { }) } + fn try_format_function( + &mut self, + name: &str, + param_types: &[Type], + return_type: &Type, + expr: &Expression, + ) -> Result { + let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = expr else { + return Err(format!("Expected lambda expression for {name}, got {expr}",)); + }; + Ok(format!( + "fn {}({}) -> {} {{ {} }}\n", + escape_symbol(name), + params + .iter() + .zip(param_types) + .map(|(p, t)| format!("{}: {}", p, map_type(&t))) + .format(", "), + map_type(return_type), + self.format_expr(body)? + )) + } + fn try_generate_builtin(&self, symbol: &str) -> Option { let code = match symbol { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index d1f48f5ff..d9ce98550 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -5,6 +5,7 @@ use std::{ ffi::CString, fs::{self}, process::Command, + str::from_utf8, }; use powdr_ast::{ @@ -41,9 +42,9 @@ pub fn create_full_code( .into(); for sym in symbols { let ty = analyzed.type_of_symbol(sym); - if ty != int_int_fun { + if ty != int_int_fun && ty.ty != Type::Col { return Err(format!( - "Only (int -> int) functions are supported, but requested {}", + "Only (int -> int) functions and columns are supported, but requested {}", format_type_scheme_around_name(sym, &Some(ty)), )); } @@ -83,7 +84,7 @@ ibig = { version = "0.3.6", features = [], default-features = false } /// Compiles the given code and returns the path to the /// temporary directory containing the compiled library /// and the path to the compiled library. -pub fn call_cargo(code: &str) -> (Temp, String) { +pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { let dir = mktemp::Temp::new_dir().unwrap(); fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); fs::create_dir(dir.join("src")).unwrap(); @@ -95,15 +96,15 @@ pub fn call_cargo(code: &str) -> (Temp, String) { .current_dir(dir.clone()) .output() .unwrap(); - out.stderr.iter().for_each(|b| print!("{}", *b as char)); if !out.status.success() { - panic!("Failed to compile."); + let stderr = from_utf8(&out.stderr).unwrap_or("UTF-8 error in error message."); + return Err(format!("Failed to compile: {stderr}.")); } let lib_path = dir .join("target") .join("release") .join("libpowdr_jit_compiled.so"); - (dir, lib_path.to_str().unwrap().to_string()) + Ok((dir, lib_path.to_str().unwrap().to_string())) } /// Loads the given library and creates funtion pointers for the given symbols. diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index b88cc02e4..8d107fafd 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -19,7 +19,7 @@ pub fn compile( log::info!("JIT-compiling {} symbols...", symbols.len()); let code = create_full_code(analyzed, symbols)?; - let (dir, lib_path) = call_cargo(&code); + let (dir, lib_path) = call_cargo(&code)?; let metadata = fs::metadata(&lib_path).unwrap(); log::info!( From 7b0dcebf4c64d0f7c248d61af9b8abb0c0993292 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 12:23:03 +0000 Subject: [PATCH 28/57] Partial compile. --- jit-compiler/src/codegen.rs | 19 ++++++++++++------- jit-compiler/src/compiler.rs | 12 ++++-------- jit-compiler/src/lib.rs | 33 ++++++++++++++++++++++++++------- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index cad73cc92..e8f2066fe 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -293,7 +293,12 @@ fn map_type(ty: &Type) -> String { Type::Tuple(_) => todo!(), Type::Function(ft) => todo!("Type {ft}"), Type::TypeVar(tv) => tv.to_string(), - Type::NamedType(_path, _type_args) => todo!(), + Type::NamedType(path, type_args) => { + if type_args.is_some() { + unimplemented!() + } + escape_symbol(&path.to_string()) + } Type::Col | Type::Inter => unreachable!(), } } @@ -335,11 +340,11 @@ mod test { &["c", "d"], ); assert_eq!( - result, - "fn c(i: ibig::IBig) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) } - -fn d(k: ibig::IBig) -> ibig::IBig { (c)(((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone()) } -" - ); + result, + "fn c(i: ibig::IBig) -> ibig::IBig { ((i).clone() + (ibig::IBig::from(20_u64)).clone()) }\n\ + \n\ + fn d(k: ibig::IBig) -> ibig::IBig { (c)(((k).clone() * (ibig::IBig::from(20_u64)).clone()).clone()) }\n\ + " + ); } } diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index d9ce98550..4ab9a30ba 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -29,11 +29,10 @@ const PREAMBLE: &str = r#" //type FieldElement = powdr_number::GoldilocksField; "#; -pub fn create_full_code( - analyzed: &Analyzed, +pub fn generate_glue_code( symbols: &[&str], + analyzed: &Analyzed, ) -> Result { - let mut codegen = CodeGenerator::new(analyzed); let mut glue = String::new(); let int_int_fun: TypeScheme = Type::Function(FunctionType { params: vec![Type::Int], @@ -48,7 +47,7 @@ pub fn create_full_code( format_type_scheme_around_name(sym, &Some(ty)), )); } - codegen.request_symbol(sym)?; + // TODO we should use big int instead of u64 let name = escape_symbol(sym); glue.push_str(&format!( @@ -62,10 +61,7 @@ pub fn create_full_code( )); } - Ok(format!( - "{PREAMBLE}\n{}\n{glue}\n", - codegen.compiled_symbols() - )) + Ok(format!("{PREAMBLE}\n{glue}\n",)) } const CARGO_TOML: &str = r#" diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 8d107fafd..8c05ba67b 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -3,7 +3,8 @@ mod compiler; use std::{collections::HashMap, fs}; -use compiler::{call_cargo, create_full_code, load_library}; +use codegen::CodeGenerator; +use compiler::{call_cargo, generate_glue_code, load_library}; use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; @@ -14,12 +15,30 @@ pub type SymbolMap = HashMap u64>; /// Only functions of type (int -> int) are supported for now. pub fn compile( analyzed: &Analyzed, - symbols: &[&str], + requested_symbols: &[&str], ) -> Result { - log::info!("JIT-compiling {} symbols...", symbols.len()); - let code = create_full_code(analyzed, symbols)?; - - let (dir, lib_path) = call_cargo(&code)?; + log::info!("JIT-compiling {} symbols...", requested_symbols.len()); + + let mut codegen = CodeGenerator::new(analyzed); + let successful_symbols = requested_symbols + .into_iter() + .filter_map(|&sym| { + if let Err(e) = codegen.request_symbol(sym) { + log::warn!("Unable to generate code for symbol {sym}: {e}"); + None + } else { + Some(sym) + } + }) + .collect::>(); + + if successful_symbols.is_empty() { + return Ok(Default::default()); + }; + + let glue_code = generate_glue_code(&successful_symbols, analyzed)?; + + let (dir, lib_path) = call_cargo(&format!("{glue_code}\n{}\n", codegen.compiled_symbols()))?; let metadata = fs::metadata(&lib_path).unwrap(); log::info!( @@ -27,7 +46,7 @@ pub fn compile( metadata.len() as f64 / 1000000.0 ); - let result = load_library(&lib_path, symbols); + let result = load_library(&lib_path, &successful_symbols); log::info!("Done."); drop(dir); From 76e63890536c1966ffe7bf1a280446bde23dbf63 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 12:38:20 +0000 Subject: [PATCH 29/57] clippy --- jit-compiler/Cargo.toml | 2 -- jit-compiler/src/codegen.rs | 2 +- jit-compiler/src/compiler.rs | 2 +- jit-compiler/src/lib.rs | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index f3330c1d9..2ae4e3c8a 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -23,7 +23,5 @@ pretty_assertions = "1.4.0" test-log = "0.2.12" env_logger = "0.10.0" - - [lints.clippy] uninlined_format_args = "deny" diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index e8f2066fe..368ea1c26 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -122,7 +122,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { params .iter() .zip(param_types) - .map(|(p, t)| format!("{}: {}", p, map_type(&t))) + .map(|(p, t)| format!("{p}: {}", map_type(t))) .format(", "), map_type(return_type), self.format_expr(body)? diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 4ab9a30ba..fd5504f62 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -18,7 +18,7 @@ use powdr_ast::{ use powdr_number::FieldElement; use crate::{ - codegen::{escape_symbol, CodeGenerator}, + codegen::{escape_symbol}, SymbolMap, }; diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 8c05ba67b..27d7f7ad3 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -21,7 +21,7 @@ pub fn compile( let mut codegen = CodeGenerator::new(analyzed); let successful_symbols = requested_symbols - .into_iter() + .iter() .filter_map(|&sym| { if let Err(e) = codegen.request_symbol(sym) { log::warn!("Unable to generate code for symbol {sym}: {e}"); From 00b56e6ae45594304bf85724e383b0b17f84176b Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 12:44:43 +0000 Subject: [PATCH 30/57] fmt --- jit-compiler/src/compiler.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index fd5504f62..bafa3dea9 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -17,10 +17,7 @@ use powdr_ast::{ }; use powdr_number::FieldElement; -use crate::{ - codegen::{escape_symbol}, - SymbolMap, -}; +use crate::{codegen::escape_symbol, SymbolMap}; // TODO make this depend on T From 3544d576a37dec02b572752023bb67cec82dd845 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 15:39:09 +0200 Subject: [PATCH 31/57] Update jit-compiler/src/lib.rs --- jit-compiler/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 27d7f7ad3..4656bb6f2 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -42,7 +42,7 @@ pub fn compile( let metadata = fs::metadata(&lib_path).unwrap(); log::info!( - "Loading library with size {} MB...", + "Loading library of size {} MB...", metadata.len() as f64 / 1000000.0 ); From 5cf259cc20ad77bee86682b7c54ec4e072f62ca4 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 16:19:31 +0000 Subject: [PATCH 32/57] Portability. --- jit-compiler/src/compiler.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index bafa3dea9..1ecfd885b 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -68,7 +68,7 @@ version = "0.1.0" edition = "2021" [lib] -crate-type = ["dylib"] +crate-type = ["cdylib"] [dependencies] ibig = { version = "0.3.6", features = [], default-features = false } @@ -93,10 +93,17 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { let stderr = from_utf8(&out.stderr).unwrap_or("UTF-8 error in error message."); return Err(format!("Failed to compile: {stderr}.")); } + let extension = if cfg!(target_os = "windows") { + "dll" + } else if cfg!(target_os = "macos") { + "dylib" + } else { + "so" + }; let lib_path = dir .join("target") .join("release") - .join("libpowdr_jit_compiled.so"); + .join(&format!("libpowdr_jit_compiled.{extension}")); Ok((dir, lib_path.to_str().unwrap().to_string())) } From e928a376cb8818425d5e459bef6eb638deccde88 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 18:21:00 +0200 Subject: [PATCH 33/57] Update jit-compiler/tests/execution.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gastón Zanitti --- jit-compiler/tests/execution.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index b20b00db2..89d4dcd75 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -14,7 +14,11 @@ fn identity_function() { assert_eq!(f(10), 10); } - +#[test] +#[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] +fn invalid_function() { + let _ = compile("let c: int -> bool = |i| true;", "c"); +} #[test] fn sqrt() { let f = compile( From 86b0a2df7e3b1da550d7cd8e98e8dacf4803e685 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 16:22:54 +0000 Subject: [PATCH 34/57] fix error message. --- jit-compiler/src/compiler.rs | 2 +- jit-compiler/tests/execution.rs | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 1ecfd885b..27127d5e0 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -40,7 +40,7 @@ pub fn generate_glue_code( let ty = analyzed.type_of_symbol(sym); if ty != int_int_fun && ty.ty != Type::Col { return Err(format!( - "Only (int -> int) functions and columns are supported, but requested {}", + "Only (int -> int) functions and columns are supported, but requested{}", format_type_scheme_around_name(sym, &Some(ty)), )); } diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 89d4dcd75..e85c4e8fe 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -14,11 +14,7 @@ fn identity_function() { assert_eq!(f(10), 10); } -#[test] -#[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] -fn invalid_function() { - let _ = compile("let c: int -> bool = |i| true;", "c"); -} + #[test] fn sqrt() { let f = compile( @@ -41,3 +37,9 @@ fn sqrt() { assert_eq!(f(99), 9); assert_eq!(f(0), 0); } + +#[test] +#[should_panic = "Only (int -> int) functions and columns are supported, but requested c: int -> bool"] +fn invalid_function() { + let _ = compile("let c: int -> bool = |i| true;", "c"); +} From 3327e520862b10d7d10239233e2369b51f4dd683 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 16:42:02 +0000 Subject: [PATCH 35/57] clippy --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 27127d5e0..1e9f3ebf9 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -103,7 +103,7 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { let lib_path = dir .join("target") .join("release") - .join(&format!("libpowdr_jit_compiled.{extension}")); + .join(format!("libpowdr_jit_compiled.{extension}")); Ok((dir, lib_path.to_str().unwrap().to_string())) } From 2a456531b23a15320e34b2439527a7d0ee980d05 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 10:44:36 +0000 Subject: [PATCH 36/57] Use extern c. --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 1e9f3ebf9..fa61efc2d 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -50,7 +50,7 @@ pub fn generate_glue_code( glue.push_str(&format!( r#" #[no_mangle] - pub extern fn {}(i: u64) -> u64 {{ + pub extern "C" fn {}(i: u64) -> u64 {{ u64::try_from({name}(ibig::IBig::from(i))).unwrap() }} "#, From 5647eb6553ca050cbcbe73c5ef053c2be8ea0af7 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 13:40:49 +0000 Subject: [PATCH 37/57] Use libloading. --- jit-compiler/Cargo.toml | 2 +- jit-compiler/src/compiler.rs | 44 +++++++++++++------------ jit-compiler/src/lib.rs | 22 ++++++++++--- jit-compiler/tests/execution.rs | 19 ++++++----- pipeline/benches/evaluator_benchmark.rs | 4 +-- 5 files changed, 54 insertions(+), 37 deletions(-) diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index 2ae4e3c8a..5be80611b 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -12,10 +12,10 @@ powdr-ast.workspace = true powdr-number.workspace = true powdr-parser.workspace = true -libc = "0.2.0" log = "0.4.18" mktemp = "0.5.0" itertools = "0.13" +libloading = "0.8" [dev-dependencies] powdr-pil-analyzer.workspace = true diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index fa61efc2d..51faccd46 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -1,11 +1,10 @@ -use libc::{c_void, dlopen, dlsym, RTLD_NOW}; use mktemp::Temp; use std::{ collections::HashMap, - ffi::CString, fs::{self}, process::Command, str::from_utf8, + sync::Arc, }; use powdr_ast::{ @@ -17,7 +16,7 @@ use powdr_ast::{ }; use powdr_number::FieldElement; -use crate::{codegen::escape_symbol, SymbolMap}; +use crate::{codegen::escape_symbol, LoadedFunction}; // TODO make this depend on T @@ -108,24 +107,27 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { } /// Loads the given library and creates funtion pointers for the given symbols. -pub fn load_library(path: &str, symbols: &[&str]) -> Result { - let c_path = CString::new(path).unwrap(); - let lib = unsafe { dlopen(c_path.as_ptr(), RTLD_NOW) }; - if lib.is_null() { - return Err(format!("Failed to load library: {path:?}")); - } - let mut result = HashMap::new(); - for sym in symbols { - let extern_sym = extern_symbol_name(sym); - let sym_cstr = CString::new(extern_sym).unwrap(); - let fun_ptr = unsafe { dlsym(lib, sym_cstr.as_ptr()) }; - if fun_ptr.is_null() { - return Err(format!("Failed to load symbol: {fun_ptr:?}")); - } - let fun = unsafe { std::mem::transmute::<*mut c_void, fn(u64) -> u64>(fun_ptr) }; - result.insert(sym.to_string(), fun); - } - Ok(result) +pub fn load_library( + path: &str, + symbols: &[&str], +) -> Result, String> { + let library = Arc::new( + unsafe { libloading::Library::new(path) } + .map_err(|e| format!("Error loading library at {path}: {e}"))?, + ); + symbols + .iter() + .map(|&sym| { + let extern_sym = extern_symbol_name(sym); + let function = *unsafe { library.get:: u64>(extern_sym.as_bytes()) } + .map_err(|e| format!("Error accessing symbol {sym}: {e}"))?; + let fun = LoadedFunction { + library: library.clone(), + function, + }; + Ok((sym.to_string(), fun)) + }) + .collect::>() } fn extern_symbol_name(sym: &str) -> String { diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 4656bb6f2..b6c12b60a 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -1,22 +1,36 @@ mod codegen; mod compiler; -use std::{collections::HashMap, fs}; +use std::{collections::HashMap, fs, sync::Arc}; use codegen::CodeGenerator; use compiler::{call_cargo, generate_glue_code, load_library}; + use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; -pub type SymbolMap = HashMap u64>; +/// Wrapper around a dynamically loaded function. +/// Prevents the dynamically loaded library to be unloaded while the function is still in use. +#[derive(Clone)] +pub struct LoadedFunction { + #[allow(dead_code)] + library: Arc, + function: fn(u64) -> u64, +} + +impl LoadedFunction { + pub fn call(&self, arg: u64) -> u64 { + (self.function)(arg) + } +} /// Compiles the given symbols (and their dependencies) and returns them as a map -/// from symbol name to function pointer. +/// from symbol name to function. /// Only functions of type (int -> int) are supported for now. pub fn compile( analyzed: &Analyzed, requested_symbols: &[&str], -) -> Result { +) -> Result, String> { log::info!("JIT-compiling {} symbols...", requested_symbols.len()); let mut codegen = CodeGenerator::new(analyzed); diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index e85c4e8fe..f6d80c7be 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -1,18 +1,19 @@ +use powdr_jit_compiler::LoadedFunction; use test_log::test; use powdr_number::GoldilocksField; use powdr_pil_analyzer::analyze_string; -fn compile(input: &str, symbol: &str) -> fn(u64) -> u64 { +fn compile(input: &str, symbol: &str) -> LoadedFunction { let analyzed = analyze_string::(input).unwrap(); - powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol] + powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol].clone() } #[test] fn identity_function() { let f = compile("let c: int -> int = |i| i;", "c"); - assert_eq!(f(10), 10); + assert_eq!(f.call(10), 10); } #[test] @@ -30,12 +31,12 @@ fn sqrt() { "sqrt", ); - assert_eq!(f(9), 3); - assert_eq!(f(100), 10); - assert_eq!(f(8), 2); - assert_eq!(f(101), 10); - assert_eq!(f(99), 9); - assert_eq!(f(0), 0); + assert_eq!(f.call(9), 3); + assert_eq!(f.call(100), 10); + assert_eq!(f.call(8), 2); + assert_eq!(f.call(101), 10); + assert_eq!(f.call(99), 9); + assert_eq!(f.call(0), 0); } #[test] diff --git a/pipeline/benches/evaluator_benchmark.rs b/pipeline/benches/evaluator_benchmark.rs index 76d027cc2..b0d67fee1 100644 --- a/pipeline/benches/evaluator_benchmark.rs +++ b/pipeline/benches/evaluator_benchmark.rs @@ -132,13 +132,13 @@ fn jit_benchmark(c: &mut Criterion) { pipeline.compute_analyzed_pil().unwrap().clone() }; - let sqrt_fun = powdr_jit_compiler::compile(&sqrt_analyzed, &["sqrt"]).unwrap()["sqrt"]; + let sqrt_fun = &powdr_jit_compiler::compile(&sqrt_analyzed, &["sqrt"]).unwrap()["sqrt"]; for x in [879882356, 1882356, 1187956, 56] { group.bench_with_input(format!("sqrt_{x}"), &x, |b, &x| { b.iter(|| { let y = (x as u64) * 112655675; - sqrt_fun(y); + sqrt_fun.call(y); }); }); } From 540671d8889cf8183a5a88a41ae1a08aa977bc9c Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:43:08 +0200 Subject: [PATCH 38/57] Update jit-compiler/src/compiler.rs Co-authored-by: Georg Wiese --- jit-compiler/src/compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 51faccd46..4c506876f 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -106,7 +106,7 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { Ok((dir, lib_path.to_str().unwrap().to_string())) } -/// Loads the given library and creates funtion pointers for the given symbols. +/// Loads the given library and creates function pointers for the given symbols. pub fn load_library( path: &str, symbols: &[&str], From 2937860c3fe2abcd09970655ade71c9ee7d489a0 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 13:55:50 +0000 Subject: [PATCH 39/57] Extract sqrt code. --- pipeline/benches/evaluator_benchmark.rs | 34 +++++++++---------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/pipeline/benches/evaluator_benchmark.rs b/pipeline/benches/evaluator_benchmark.rs index b0d67fee1..a55522dc6 100644 --- a/pipeline/benches/evaluator_benchmark.rs +++ b/pipeline/benches/evaluator_benchmark.rs @@ -9,6 +9,16 @@ use powdr_pipeline::test_util::{evaluate_function, evaluate_integer_function, st use criterion::{criterion_group, criterion_main, Criterion}; +const SQRT_CODE: &str = " + let sqrt: int -> int = |x| sqrt_rec(x, x); + let sqrt_rec: int, int -> int = |y, x| + if y * y <= x && (y + 1) * (y + 1) > x { + y + } else { + sqrt_rec((y + x / y) / 2, x) + }; +"; + fn evaluator_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("evaluator-benchmark"); @@ -67,17 +77,7 @@ fn evaluator_benchmark(c: &mut Criterion) { }); let sqrt_analyzed: Analyzed = { - let code = " - let sqrt: int -> int = |x| sqrt_rec(x, x); - let sqrt_rec: int, int -> int = |y, x| - if y * y <= x && (y + 1) * (y + 1) > x { - y - } else { - sqrt_rec((y + x / y) / 2, x) - }; - " - .to_string(); - let mut pipeline = Pipeline::default().from_asm_string(code, None); + let mut pipeline = Pipeline::default().from_asm_string(SQRT_CODE.to_string(), None); pipeline.compute_analyzed_pil().unwrap().clone() }; @@ -118,17 +118,7 @@ fn jit_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("jit-benchmark"); let sqrt_analyzed: Analyzed = { - let code = " - let sqrt: int -> int = |x| sqrt_rec(x, x); - let sqrt_rec: int, int -> int = |y, x| - if y * y <= x && (y + 1) * (y + 1) > x { - y - } else { - sqrt_rec((y + x / y) / 2, x) - }; - " - .to_string(); - let mut pipeline = Pipeline::default().from_asm_string(code, None); + let mut pipeline = Pipeline::default().from_asm_string(SQRT_CODE.to_string(), None); pipeline.compute_analyzed_pil().unwrap().clone() }; From 681937aa094104e5190cac281139d67e62034d96 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:03:51 +0000 Subject: [PATCH 40/57] Remove drop. --- jit-compiler/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index b6c12b60a..77216bb84 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -62,7 +62,5 @@ pub fn compile( let result = load_library(&lib_path, &successful_symbols); log::info!("Done."); - - drop(dir); result } From 6a96b72831d6c6588a40922f8067039519e3e019 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:05:11 +0000 Subject: [PATCH 41/57] Add release - we need the variable. --- jit-compiler/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 77216bb84..d530345ba 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -62,5 +62,7 @@ pub fn compile( let result = load_library(&lib_path, &successful_symbols); log::info!("Done."); + + dir.release(); result } From 8db0eba442031659f1b6353605f25097d7e909bb Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 15:08:55 +0000 Subject: [PATCH 42/57] Encapsulate temp dir in struct. --- jit-compiler/src/compiler.rs | 14 ++++++++++++-- jit-compiler/src/lib.rs | 8 +++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 4c506876f..c864d3ba3 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -73,10 +73,17 @@ crate-type = ["cdylib"] ibig = { version = "0.3.6", features = [], default-features = false } "#; +pub struct PathInTempDir { + #[allow(dead_code)] + dir: Temp, + /// The absolute path + pub path: String, +} + /// Compiles the given code and returns the path to the /// temporary directory containing the compiled library /// and the path to the compiled library. -pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { +pub fn call_cargo(code: &str) -> Result { let dir = mktemp::Temp::new_dir().unwrap(); fs::write(dir.join("Cargo.toml"), CARGO_TOML).unwrap(); fs::create_dir(dir.join("src")).unwrap(); @@ -103,7 +110,10 @@ pub fn call_cargo(code: &str) -> Result<(Temp, String), String> { .join("target") .join("release") .join(format!("libpowdr_jit_compiled.{extension}")); - Ok((dir, lib_path.to_str().unwrap().to_string())) + Ok(PathInTempDir { + dir, + path: lib_path.to_str().unwrap().to_string(), + }) } /// Loads the given library and creates function pointers for the given symbols. diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index d530345ba..a8679a962 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -52,17 +52,15 @@ pub fn compile( let glue_code = generate_glue_code(&successful_symbols, analyzed)?; - let (dir, lib_path) = call_cargo(&format!("{glue_code}\n{}\n", codegen.compiled_symbols()))?; - let metadata = fs::metadata(&lib_path).unwrap(); + let lib_file = call_cargo(&format!("{glue_code}\n{}\n", codegen.compiled_symbols()))?; + let metadata = fs::metadata(&lib_file.path).unwrap(); log::info!( "Loading library of size {} MB...", metadata.len() as f64 / 1000000.0 ); - let result = load_library(&lib_path, &successful_symbols); + let result = load_library(&lib_file.path, &successful_symbols); log::info!("Done."); - - dir.release(); result } From b47df417bff1a16faa105191b8e7489b1b7184de Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 16:03:25 +0000 Subject: [PATCH 43/57] Simplify compiler state. --- jit-compiler/src/codegen.rs | 55 ++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 368ea1c26..df0164d5b 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use itertools::Itertools; use powdr_ast::{ @@ -14,45 +14,56 @@ use powdr_number::FieldElement; pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, - requested: HashSet, - failed: HashMap, - symbols: HashMap, + /// Symbols mapping to either their code or an error message explaining + /// why they could not be compiled. + /// While the code is still being generated, this contains `None`. + symbols: HashMap, String>>, } impl<'a, T: FieldElement> CodeGenerator<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { analyzed, - requested: Default::default(), - failed: Default::default(), symbols: Default::default(), } } + /// Request a symbol to be compiled. The code can later be retrieved + /// via `compiled_symbols`. + /// In the error case, `self` can still be used to compile other symbols. pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { - if let Some(err) = self.failed.get(name) { - return Err(err.clone()); - } - if self.requested.contains(name) { - return Ok(()); - } - self.requested.insert(name.to_string()); - match self.generate_code(name) { - Ok(code) => { - self.symbols.insert(name.to_string(), code); - Ok(()) - } - Err(err) => { - let err = format!("Failed to compile {name}: {err}"); - self.failed.insert(name.to_string(), err.clone()); - Err(err) + match self.symbols.get(name) { + Some(Ok(_)) => Ok(()), + Some(Err(e)) => Err(e.clone()), + None => { + let name = name.to_string(); + self.symbols.insert(name.clone(), Ok(None)); + let to_insert; + let to_return; + match self.generate_code(&name) { + Ok(code) => { + to_insert = Ok(Some(code)); + to_return = Ok(()); + } + Err(err) => { + to_insert = Err(err.clone()); + to_return = Err(err); + } + } + self.symbols.insert(name, to_insert); + to_return } } } + /// Returns the concatenation of all successfully compiled symbols. pub fn compiled_symbols(self) -> String { self.symbols .into_iter() + .filter_map(|(s, r)| match r { + Ok(Some(code)) => Some((s, code)), + _ => None, + }) .sorted() .map(|(_, code)| code) .format("\n") From bdc2d38b83e973bbf513369f73a7662347c91eb9 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 16:04:22 +0000 Subject: [PATCH 44/57] Error message. --- jit-compiler/src/codegen.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index df0164d5b..99b9367e7 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -284,7 +284,9 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } fn format_statement(&mut self, s: &StatementInsideBlock) -> Result { - Err(format!("Implement {s}")) + Err(format!( + "Compiling statements inside blocks is not yet implemented: {s}" + )) } } From 8c23119a0ee6626726ce12fe93309ac83de4c638 Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 19:32:42 +0000 Subject: [PATCH 45/57] use unsafe extern C fn --- jit-compiler/src/compiler.rs | 5 +++-- jit-compiler/src/lib.rs | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index c864d3ba3..a92476b61 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -129,8 +129,9 @@ pub fn load_library( .iter() .map(|&sym| { let extern_sym = extern_symbol_name(sym); - let function = *unsafe { library.get:: u64>(extern_sym.as_bytes()) } - .map_err(|e| format!("Error accessing symbol {sym}: {e}"))?; + let function = + *unsafe { library.get:: u64>(extern_sym.as_bytes()) } + .map_err(|e| format!("Error accessing symbol {sym}: {e}"))?; let fun = LoadedFunction { library: library.clone(), function, diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index a8679a962..465663e46 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -15,12 +15,12 @@ use powdr_number::FieldElement; pub struct LoadedFunction { #[allow(dead_code)] library: Arc, - function: fn(u64) -> u64, + function: unsafe extern "C" fn(u64) -> u64, } impl LoadedFunction { pub fn call(&self, arg: u64) -> u64 { - (self.function)(arg) + unsafe { (self.function)(arg) } } } From 9501ef9b9525765eae70911c23738de89779602e Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 25 Sep 2024 19:33:34 +0000 Subject: [PATCH 46/57] use mebibytes. --- jit-compiler/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index 465663e46..fc72d3789 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -57,7 +57,7 @@ pub fn compile( log::info!( "Loading library of size {} MB...", - metadata.len() as f64 / 1000000.0 + metadata.len() as f64 / (1024.0 * 1024.0) ); let result = load_library(&lib_file.path, &successful_symbols); From 13305d1195c4f7b33f02897f12adc98dd0abe8a7 Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 24 Sep 2024 14:24:34 +0000 Subject: [PATCH 47/57] Support functions as expressions. --- jit-compiler/src/codegen.rs | 155 +++++++++++++++++++++----------- jit-compiler/src/compiler.rs | 1 + jit-compiler/tests/execution.rs | 23 +++++ 3 files changed, 125 insertions(+), 54 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 99b9367e7..bbc1c9019 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use powdr_ast::{ analyzed::{Analyzed, Expression, FunctionValueDefinition, PolynomialReference, Reference}, parsed::{ - display::{format_type_args, quote}, + display::quote, types::{ArrayType, FunctionType, Type, TypeScheme}, ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression, IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, @@ -28,32 +28,35 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } - /// Request a symbol to be compiled. The code can later be retrieved - /// via `compiled_symbols`. + /// Requests code for the symbol to be generated and returns the + /// code that returns the value of the symbol, which is either + /// the escaped name of the symbol or a dereferenced name of a lazy static. + /// The code can later be retrieved via `compiled_symbols`. /// In the error case, `self` can still be used to compile other symbols. - pub fn request_symbol(&mut self, name: &str) -> Result<(), String> { + pub fn request_symbol(&mut self, name: &str) -> Result { match self.symbols.get(name) { - Some(Ok(_)) => Ok(()), - Some(Err(e)) => Err(e.clone()), + Some(Err(e)) => return Err(e.clone()), + Some(_) => {} None => { let name = name.to_string(); self.symbols.insert(name.clone(), Ok(None)); - let to_insert; - let to_return; match self.generate_code(&name) { Ok(code) => { - to_insert = Ok(Some(code)); - to_return = Ok(()); + self.symbols.insert(name.clone(), Ok(Some(code))); } Err(err) => { - to_insert = Err(err.clone()); - to_return = Err(err); + self.symbols.insert(name.clone(), Err(err.clone())); + return Err(err); } } - self.symbols.insert(name, to_insert); - to_return } } + let reference = if self.symbol_needs_deref(name) { + format!("(*{})", escape_symbol(name)) + } else { + escape_symbol(name) + }; + Ok(reference) } /// Returns the concatenation of all successfully compiled symbols. @@ -71,7 +74,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } fn generate_code(&mut self, symbol: &str) -> Result { - if let Some(code) = self.try_generate_builtin(symbol) { + if let Some(code) = try_generate_builtin::(symbol) { return Ok(code); } @@ -85,22 +88,28 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { let type_scheme = value.type_scheme.clone().unwrap(); - Ok(match type_scheme { - TypeScheme { - vars, - ty: - Type::Function(FunctionType { - params: param_types, - value: return_type, - }), - } => { + Ok(match (&value.e, type_scheme) { + ( + Expression::LambdaExpression(..), + TypeScheme { + vars, + ty: + Type::Function(FunctionType { + params: param_types, + value: return_type, + }), + }, + ) => { assert!(vars.is_empty()); self.try_format_function(symbol, ¶m_types, return_type.as_ref(), &value.e)? } - TypeScheme { - vars, - ty: Type::Col, - } => { + ( + Expression::LambdaExpression(..), + TypeScheme { + vars, + ty: Type::Col, + }, + ) => { assert!(vars.is_empty()); // TODO we assume it is an int -> int function. // The type inference algorithm should store the derived type. @@ -108,12 +117,26 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { // in the trait vars. self.try_format_function(symbol, &[Type::Int], &Type::Int, &value.e)? } - _ => format!( - "const {}: {} = {};\n", - escape_symbol(symbol), - map_type(&value.type_scheme.as_ref().unwrap().ty), - self.format_expr(&value.e)? - ), + _ => { + let type_scheme = value.type_scheme.as_ref().unwrap(); + assert!(type_scheme.vars.is_empty()); + let ty = if type_scheme.ty == Type::Col { + Type::Function(FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Fe), + }) + } else { + type_scheme.ty.clone() + }; + format!( + "lazy_static::lazy_static! {{\n\ + static ref {}: {} = {};\n\ + }}\n", + escape_symbol(symbol), + map_type(&ty), + self.format_expr(&value.e)? + ) + } }) } @@ -125,7 +148,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { expr: &Expression, ) -> Result { let Expression::LambdaExpression(_, LambdaExpression { params, body, .. }) = expr else { - return Err(format!("Expected lambda expression for {name}, got {expr}",)); + panic!(); }; Ok(format!( "fn {}({}) -> {} {{ {} }}\n", @@ -140,31 +163,16 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { )) } - fn try_generate_builtin(&self, symbol: &str) -> Option { - let code = match symbol { - "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), - "std::field::modulus" => { - let modulus = T::modulus(); - Some(format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}")) - } - "std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" - .to_string()), - _ => None, - }?; - Some(format!("fn {}{code}", escape_symbol(symbol))) - } - fn format_expr(&mut self, e: &Expression) -> Result { Ok(match e { Expression::Reference(_, Reference::LocalVar(_id, name)) => name.clone(), Expression::Reference(_, Reference::Poly(PolynomialReference { name, type_args })) => { - self.request_symbol(name)?; + let reference = self.request_symbol(name)?; let ta = type_args.as_ref().unwrap(); format!( - "{}{}", - escape_symbol(name), + "{reference}{}", (!ta.is_empty()) - .then(|| format!("::{}", format_type_args(ta))) + .then(|| format!("::<{}>", ta.iter().map(map_type).join(", "))) .unwrap_or_default() ) } @@ -288,6 +296,19 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { "Compiling statements inside blocks is not yet implemented: {s}" )) } + + /// Returns true if a reference to the symbol needs a deref operation or not. + fn symbol_needs_deref(&self, symbol: &str) -> bool { + if is_builtin(symbol) { + return false; + } + let (_, def) = self.analyzed.definitions.get(symbol).as_ref().unwrap(); + if let Some(FunctionValueDefinition::Expression(typed_expr)) = def { + !matches!(typed_expr.e, Expression::LambdaExpression(..)) + } else { + false + } + } } pub fn escape_symbol(s: &str) -> String { @@ -304,7 +325,11 @@ fn map_type(ty: &Type) -> String { Type::Expr => "Expr".to_string(), Type::Array(ArrayType { base, length: _ }) => format!("Vec<{}>", map_type(base)), Type::Tuple(_) => todo!(), - Type::Function(ft) => todo!("Type {ft}"), + Type::Function(ft) => format!( + "fn({}) -> {}", + ft.params.iter().map(map_type).join(", "), + map_type(&ft.value) + ), Type::TypeVar(tv) => tv.to_string(), Type::NamedType(path, type_args) => { if type_args.is_some() { @@ -316,6 +341,28 @@ fn map_type(ty: &Type) -> String { } } +fn is_builtin(symbol: &str) -> bool { + matches!( + symbol, + "std::check::panic" | "std::field::modulus" | "std::convert::fe" + ) +} + +fn try_generate_builtin(symbol: &str) -> Option { + let code = match symbol { + "std::array::len" => "(a: Vec) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string(), + "std::check::panic" => "(s: &str) -> ! { panic!(\"{s}\"); }".to_string(), + "std::field::modulus" => { + let modulus = T::modulus(); + format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}") + } + "std::convert::fe" => "(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" + .to_string(), + _ => return None, + }; + Some(format!("fn {}{code}", escape_symbol(symbol))) +} + #[cfg(test)] mod test { use powdr_number::GoldilocksField; diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index a92476b61..8e4eaf7e6 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -71,6 +71,7 @@ crate-type = ["cdylib"] [dependencies] ibig = { version = "0.3.6", features = [], default-features = false } +lazy_static = "1.4.0" "#; pub struct PathInTempDir { diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index f6d80c7be..4aa373bcf 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -44,3 +44,26 @@ fn sqrt() { fn invalid_function() { let _ = compile("let c: int -> bool = |i| true;", "c"); } + +#[test] +fn assigned_functions() { + let input = r#" + namespace std::array; + let len = 8; + namespace main; + let a: int -> int = |i| i + 1; + let b: int -> int = |i| i + 2; + let t: bool = "" == ""; + let c = if t { a } else { b }; + let d = |i| c(i); + "#; + let c = compile(input, "main::c"); + + assert_eq!(c.call(0), 1); + assert_eq!(c.call(1), 2); + assert_eq!(c.call(2), 3); + assert_eq!(c.call(3), 4); + + let d = compile(input, "main::d"); + assert_eq!(d.call(0), 1); +} From f5f84f31df0c252770e11f46d980c41f998e5440 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 10:05:08 +0000 Subject: [PATCH 48/57] Extract magic numbers. --- pipeline/benches/evaluator_benchmark.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pipeline/benches/evaluator_benchmark.rs b/pipeline/benches/evaluator_benchmark.rs index a55522dc6..fa012682e 100644 --- a/pipeline/benches/evaluator_benchmark.rs +++ b/pipeline/benches/evaluator_benchmark.rs @@ -19,6 +19,14 @@ const SQRT_CODE: &str = " }; "; +/// Just some numbers to test the sqrt function on. +fn sqrt_inputs() -> Vec<(String, u64)> { + [879882356, 1882356, 1187956, 56] + .into_iter() + .map(|x| (x.to_string(), (x as u64) * 112655675_u64)) + .collect() +} + fn evaluator_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("evaluator-benchmark"); @@ -81,11 +89,10 @@ fn evaluator_benchmark(c: &mut Criterion) { pipeline.compute_analyzed_pil().unwrap().clone() }; - for x in [879882356, 1882356, 1187956, 56] { - group.bench_with_input(format!("sqrt_{x}"), &x, |b, &x| { + for (name, val) in sqrt_inputs() { + group.bench_with_input(format!("sqrt_{name}"), &val, |b, val| { b.iter(|| { - let y = BigInt::from(x) * BigInt::from(112655675); - evaluate_integer_function(&sqrt_analyzed, "sqrt", vec![y.clone()]); + evaluate_integer_function(&sqrt_analyzed, "sqrt", vec![BigInt::from(*val)]); }); }); } @@ -124,11 +131,10 @@ fn jit_benchmark(c: &mut Criterion) { let sqrt_fun = &powdr_jit_compiler::compile(&sqrt_analyzed, &["sqrt"]).unwrap()["sqrt"]; - for x in [879882356, 1882356, 1187956, 56] { - group.bench_with_input(format!("sqrt_{x}"), &x, |b, &x| { + for (name, val) in sqrt_inputs() { + group.bench_with_input(format!("sqrt_{name}"), &val, |b, val| { b.iter(|| { - let y = (x as u64) * 112655675; - sqrt_fun.call(y); + sqrt_fun.call(*val); }); }); } From bd515f120d62c72c69af1f13e6f52cbba7b05b42 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 26 Sep 2024 12:39:44 +0000 Subject: [PATCH 49/57] Make function safe. --- jit-compiler/src/compiler.rs | 2 +- jit-compiler/src/lib.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index a92476b61..f5bf2d56a 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -130,7 +130,7 @@ pub fn load_library( .map(|&sym| { let extern_sym = extern_symbol_name(sym); let function = - *unsafe { library.get:: u64>(extern_sym.as_bytes()) } + *unsafe { library.get:: u64>(extern_sym.as_bytes()) } .map_err(|e| format!("Error accessing symbol {sym}: {e}"))?; let fun = LoadedFunction { library: library.clone(), diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index fc72d3789..56ce8afd3 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -15,12 +15,12 @@ use powdr_number::FieldElement; pub struct LoadedFunction { #[allow(dead_code)] library: Arc, - function: unsafe extern "C" fn(u64) -> u64, + function: extern "C" fn(u64) -> u64, } impl LoadedFunction { pub fn call(&self, arg: u64) -> u64 { - unsafe { (self.function)(arg) } + (self.function)(arg) } } From 3e9c60eaba173ccd065f0cfa0a8d71642d189ead Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 27 Sep 2024 08:30:28 +0000 Subject: [PATCH 50/57] Implement field element type. --- jit-compiler/src/codegen.rs | 6 +++--- jit-compiler/src/compiler.rs | 30 +++++++++++++++++++++++------- jit-compiler/tests/execution.rs | 27 ++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index bbc1c9019..a4a981118 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -111,11 +111,11 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { }, ) => { assert!(vars.is_empty()); - // TODO we assume it is an int -> int function. + // TODO we assume it is an int -> fe function. // The type inference algorithm should store the derived type. // Alternatively, we insert a trait conversion function and store the type // in the trait vars. - self.try_format_function(symbol, &[Type::Int], &Type::Int, &value.e)? + self.try_format_function(symbol, &[Type::Int], &Type::Fe, &value.e)? } _ => { let type_scheme = value.type_scheme.as_ref().unwrap(); @@ -344,7 +344,7 @@ fn map_type(ty: &Type) -> String { fn is_builtin(symbol: &str) -> bool { matches!( symbol, - "std::check::panic" | "std::field::modulus" | "std::convert::fe" + "std::array::len" | "std::check::panic" | "std::field::modulus" | "std::convert::fe" ) } diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 5c3a72b4d..9bc67b64f 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -18,17 +18,16 @@ use powdr_number::FieldElement; use crate::{codegen::escape_symbol, LoadedFunction}; -// TODO make this depend on T - -const PREAMBLE: &str = r#" -#![allow(unused_parens)] -//type FieldElement = powdr_number::GoldilocksField; -"#; - pub fn generate_glue_code( symbols: &[&str], analyzed: &Analyzed, ) -> Result { + if T::BITS > 64 { + return Err(format!( + "Fields with more than 64 bits not supported, but requested {}", + T::BITS, + )); + } let mut glue = String::new(); let int_int_fun: TypeScheme = Type::Function(FunctionType { params: vec![Type::Int], @@ -60,6 +59,23 @@ pub fn generate_glue_code( Ok(format!("{PREAMBLE}\n{glue}\n",)) } +const PREAMBLE: &str = r#" +#![allow(unused_parens)] + +#[derive(Clone, Copy)] +struct FieldElement(u64); +impl From for FieldElement { + fn from(x: u64) -> Self { + FieldElement(x) + } +} +impl From for u64 { + fn from(x: FieldElement) -> u64 { + x.0 + } +} +"#; + const CARGO_TOML: &str = r#" [package] name = "powdr_jit_compiled" diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index 4aa373bcf..e58b43f8f 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -6,7 +6,13 @@ use powdr_pil_analyzer::analyze_string; fn compile(input: &str, symbol: &str) -> LoadedFunction { let analyzed = analyze_string::(input).unwrap(); - powdr_jit_compiler::compile(&analyzed, &[symbol]).unwrap()[symbol].clone() + powdr_jit_compiler::compile(&analyzed, &[symbol]) + .map_err(|e| { + eprintln!("Error jit-compiling:\n{e}"); + e + }) + .unwrap()[symbol] + .clone() } #[test] @@ -67,3 +73,22 @@ fn assigned_functions() { let d = compile(input, "main::d"); assert_eq!(d.call(0), 1); } + +#[test] +fn simple_field() { + let f = compile( + " + namespace std::array; + let len = 8; + namespace main; + let a: fe[] = [1, 2, 3]; + let q: col = |i| a[i % std::array::len(a)]; + ", + "main::q", + ); + + assert_eq!(f.call(0), 1); + assert_eq!(f.call(1), 2); + assert_eq!(f.call(2), 3); + assert_eq!(f.call(3), 1); +} From e3dc58901a56eb2ae9cf2c63ccc5d02b230787bf Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 27 Sep 2024 08:57:02 +0000 Subject: [PATCH 51/57] Re-organize builtins. --- jit-compiler/Cargo.toml | 1 + jit-compiler/src/codegen.rs | 63 ++++++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/jit-compiler/Cargo.toml b/jit-compiler/Cargo.toml index 5be80611b..386a68ac6 100644 --- a/jit-compiler/Cargo.toml +++ b/jit-compiler/Cargo.toml @@ -16,6 +16,7 @@ log = "0.4.18" mktemp = "0.5.0" itertools = "0.13" libloading = "0.8" +lazy_static = "1.4.0" [dev-dependencies] powdr-pil-analyzer.workspace = true diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index a4a981118..1dda475d0 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -1,4 +1,6 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::OnceLock}; + +use lazy_static::lazy_static; use itertools::Itertools; use powdr_ast::{ @@ -10,7 +12,7 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::FieldElement; +use powdr_number::{FieldElement, LargeInt}; pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, @@ -75,7 +77,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { fn generate_code(&mut self, symbol: &str) -> Result { if let Some(code) = try_generate_builtin::(symbol) { - return Ok(code); + return Ok(code.clone()); } let Some((_, Some(FunctionValueDefinition::Expression(value)))) = @@ -299,7 +301,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { /// Returns true if a reference to the symbol needs a deref operation or not. fn symbol_needs_deref(&self, symbol: &str) -> bool { - if is_builtin(symbol) { + if is_builtin::(symbol) { return false; } let (_, def) = self.analyzed.definitions.get(symbol).as_ref().unwrap(); @@ -341,26 +343,43 @@ fn map_type(ty: &Type) -> String { } } -fn is_builtin(symbol: &str) -> bool { - matches!( - symbol, - "std::array::len" | "std::check::panic" | "std::field::modulus" | "std::convert::fe" - ) +fn get_builtins() -> &'static HashMap { + static BUILTINS: OnceLock> = OnceLock::new(); + BUILTINS.get_or_init(|| { + [ + ( + "std::array::len", + "(a: Vec) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string(), + ), + ( + "std::check::panic", + "(s: &str) -> ! { panic!(\"{s}\"); }".to_string(), + ), + ( + "std::field::modulus", + format!( + "() -> ibig::IBig {{ {} }}", + format_number(T::modulus().to_arbitrary_integer()) + ), + ), + ] + .into_iter() + .map(|(name, code)| { + ( + name.to_string(), + format!("fn {}{code}", escape_symbol(name)), + ) + }) + .collect() + }) } -fn try_generate_builtin(symbol: &str) -> Option { - let code = match symbol { - "std::array::len" => "(a: Vec) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string(), - "std::check::panic" => "(s: &str) -> ! { panic!(\"{s}\"); }".to_string(), - "std::field::modulus" => { - let modulus = T::modulus(); - format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}") - } - "std::convert::fe" => "(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" - .to_string(), - _ => return None, - }; - Some(format!("fn {}{code}", escape_symbol(symbol))) +fn is_builtin(symbol: &str) -> bool { + get_builtins::().contains_key(symbol) +} + +fn try_generate_builtin(symbol: &str) -> Option<&String> { + get_builtins::().get(symbol) } #[cfg(test)] From d1ee4bd9c1afdb7df04a3c673f1c76414557798d Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 27 Sep 2024 09:08:16 +0000 Subject: [PATCH 52/57] Format numbers properly. --- jit-compiler/src/codegen.rs | 43 ++++++++++++++++++++++++--------- jit-compiler/tests/execution.rs | 7 ++++++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 99b9367e7..f23f39feb 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -10,7 +10,7 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::FieldElement; +use powdr_number::{BigUint, FieldElement, LargeInt}; pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, @@ -20,6 +20,11 @@ pub struct CodeGenerator<'a, T> { symbols: HashMap, String>>, } +pub fn escape_symbol(s: &str) -> String { + // TODO better escaping + s.replace('.', "_").replace("::", "_") +} + impl<'a, T: FieldElement> CodeGenerator<'a, T> { pub fn new(analyzed: &'a Analyzed) -> Self { Self { @@ -144,8 +149,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { let code = match symbol { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), "std::field::modulus" => { - let modulus = T::modulus(); - Some(format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}")) + Some(format!("() -> ibig::IBig {{ {} }}", format_number(&T::modulus().to_arbitrary_integer()))) } "std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()), @@ -175,12 +179,15 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { type_: Some(type_), }, ) => { - let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); - match type_ { - Type::Int => format!("ibig::IBig::from({value}_u64)"), - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - _ => unreachable!(), + if *type_ == Type::Int { + format_number(value) + } else { + let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); + match type_ { + Type::Fe => format!("FieldElement::from({value}_u64)"), + Type::Expr => format!("Expr::from({value}_u64)"), + _ => unreachable!(), + } } } Expression::FunctionCall( @@ -211,6 +218,9 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { BinaryOperator::ShiftLeft => { format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") } + BinaryOperator::ShiftRight => { + format!("(({left}).clone() >> usize::try_from(({right}).clone()).unwrap())") + } _ => format!("(({left}).clone() {op} ({right}).clone())"), } } @@ -290,9 +300,18 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } -pub fn escape_symbol(s: &str) -> String { - // TODO better escaping - s.replace('.', "_").replace("::", "_") +fn format_number(n: &BigUint) -> String { + if let Ok(n) = u64::try_from(n) { + format!("ibig::IBig::from({n}_u64)") + } else { + format!( + "ibig::IBig::from(ibig::UBig::from_le_bytes(&[{}]))", + n.to_le_bytes() + .iter() + .map(|b| format!("{b}_u8")) + .format(", ") + ) + } } fn map_type(ty: &Type) -> String { diff --git a/jit-compiler/tests/execution.rs b/jit-compiler/tests/execution.rs index f6d80c7be..2d57a9675 100644 --- a/jit-compiler/tests/execution.rs +++ b/jit-compiler/tests/execution.rs @@ -44,3 +44,10 @@ fn sqrt() { fn invalid_function() { let _ = compile("let c: int -> bool = |i| true;", "c"); } + +#[test] +fn gigantic_number() { + let f = compile("let c: int -> int = |i| (i * 0x1000000000000000000000000000000000000000000000000000000000000000000000000000000000) >> (81 * 4);", "c"); + + assert_eq!(f.call(10), 10); +} From 2e371c8492590163dc4820b648ccff157e18bc34 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 27 Sep 2024 12:24:53 +0000 Subject: [PATCH 53/57] Rename function and use single match. --- jit-compiler/src/codegen.rs | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index f23f39feb..614d9cffe 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -149,7 +149,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { let code = match symbol { "std::check::panic" => Some("(s: &str) -> ! { panic!(\"{s}\"); }".to_string()), "std::field::modulus" => { - Some(format!("() -> ibig::IBig {{ {} }}", format_number(&T::modulus().to_arbitrary_integer()))) + Some(format!("() -> ibig::IBig {{ {} }}", format_unsigned_integer(&T::modulus().to_arbitrary_integer()))) } "std::convert::fe" => Some("(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string()), @@ -178,18 +178,20 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { value, type_: Some(type_), }, - ) => { - if *type_ == Type::Int { - format_number(value) - } else { - let value = u64::try_from(value).unwrap_or_else(|_| unimplemented!()); - match type_ { - Type::Fe => format!("FieldElement::from({value}_u64)"), - Type::Expr => format!("Expr::from({value}_u64)"), - _ => unreachable!(), - } + ) => match type_ { + Type::Int => format_unsigned_integer(value), + Type::Fe => { + let val = u64::try_from(value) + .map_err(|_| "Large numbers for fe not yet implemented.".to_string())?; + format!("FieldElement::from({val}_u64)",) } - } + Type::Expr => { + let val = u64::try_from(value) + .map_err(|_| "Large numbers for fe not yet implemented.".to_string())?; + format!("Expr::from({val}_u64)") + } + _ => unreachable!(), + }, Expression::FunctionCall( _, FunctionCall { @@ -300,7 +302,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } } -fn format_number(n: &BigUint) -> String { +fn format_unsigned_integer(n: &BigUint) -> String { if let Ok(n) = u64::try_from(n) { format!("ibig::IBig::from({n}_u64)") } else { From 05be1a4c46039de79bf62f26ea017654cb387777 Mon Sep 17 00:00:00 2001 From: chriseth Date: Fri, 27 Sep 2024 12:25:27 +0000 Subject: [PATCH 54/57] typo --- jit-compiler/src/codegen.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 614d9cffe..81126e296 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -187,7 +187,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { } Type::Expr => { let val = u64::try_from(value) - .map_err(|_| "Large numbers for fe not yet implemented.".to_string())?; + .map_err(|_| "Large numbers for expr not yet implemented.".to_string())?; format!("Expr::from({val}_u64)") } _ => unreachable!(), From 8cd38333ad244985a37a1f579d453b1fa438ad8b Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 30 Sep 2024 10:48:23 +0000 Subject: [PATCH 55/57] fix shift --- jit-compiler/src/codegen.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 936e3102b..97691018f 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -10,7 +10,7 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::{BigUint, FieldElement, LargeInt}; +use powdr_number::{BigUint, FieldElement}; pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, @@ -210,7 +210,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { let right = self.format_expr(right)?; match op { BinaryOperator::ShiftLeft => { - format!("(({left}).clone() << u32::try_from(({right}).clone()).unwrap())") + format!("(({left}).clone() << usize::try_from(({right}).clone()).unwrap())") } BinaryOperator::ShiftRight => { format!("(({left}).clone() >> usize::try_from(({right}).clone()).unwrap())") From a00d4f40aeac87fbcd14c6193097cbfdb8dbeb49 Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 30 Sep 2024 10:57:51 +0000 Subject: [PATCH 56/57] fix merge --- jit-compiler/src/codegen.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 97691018f..13b567780 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -10,7 +10,7 @@ use powdr_ast::{ IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation, }, }; -use powdr_number::{BigUint, FieldElement}; +use powdr_number::{BigUint, FieldElement, LargeInt}; pub struct CodeGenerator<'a, T> { analyzed: &'a Analyzed, @@ -292,6 +292,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> { "Compiling statements inside blocks is not yet implemented: {s}" )) } + /// Returns a string expression evaluating to the value of the symbol. /// This is either the escaped name of the symbol or a deref operator /// applied to it. @@ -365,8 +366,7 @@ fn try_generate_builtin(symbol: &str) -> Option { "std::array::len" => "(a: Vec) -> ibig::IBig { ibig::IBig::from(a.len()) }".to_string(), "std::check::panic" => "(s: &str) -> ! { panic!(\"{s}\"); }".to_string(), "std::field::modulus" => { - let modulus = T::modulus(); - format!("() -> ibig::IBig {{ ibig::IBig::from(\"{modulus}\") }}") + format!("() -> ibig::IBig {{ {} }}", format_unsigned_integer(&T::modulus().to_arbitrary_integer())) } "std::convert::fe" => "(n: ibig::IBig) -> FieldElement {\n ::BigInt::try_from(n.to_biguint().unwrap()).unwrap().into()\n}" .to_string(), From 3e4926f42c6a686f4fd4da9f78a5b28be5b496ea Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 30 Sep 2024 13:08:22 +0000 Subject: [PATCH 57/57] remove parts from other pr --- jit-compiler/src/compiler.rs | 30 +++++++----------------------- 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/jit-compiler/src/compiler.rs b/jit-compiler/src/compiler.rs index 9bc67b64f..5c3a72b4d 100644 --- a/jit-compiler/src/compiler.rs +++ b/jit-compiler/src/compiler.rs @@ -18,16 +18,17 @@ use powdr_number::FieldElement; use crate::{codegen::escape_symbol, LoadedFunction}; +// TODO make this depend on T + +const PREAMBLE: &str = r#" +#![allow(unused_parens)] +//type FieldElement = powdr_number::GoldilocksField; +"#; + pub fn generate_glue_code( symbols: &[&str], analyzed: &Analyzed, ) -> Result { - if T::BITS > 64 { - return Err(format!( - "Fields with more than 64 bits not supported, but requested {}", - T::BITS, - )); - } let mut glue = String::new(); let int_int_fun: TypeScheme = Type::Function(FunctionType { params: vec![Type::Int], @@ -59,23 +60,6 @@ pub fn generate_glue_code( Ok(format!("{PREAMBLE}\n{glue}\n",)) } -const PREAMBLE: &str = r#" -#![allow(unused_parens)] - -#[derive(Clone, Copy)] -struct FieldElement(u64); -impl From for FieldElement { - fn from(x: u64) -> Self { - FieldElement(x) - } -} -impl From for u64 { - fn from(x: FieldElement) -> u64 { - x.0 - } -} -"#; - const CARGO_TOML: &str = r#" [package] name = "powdr_jit_compiled"