-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Try to generate constant values via jit.
- Loading branch information
Showing
4 changed files
with
241 additions
and
166 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
use std::{ | ||
collections::{BTreeMap, HashMap}, | ||
sync::{Arc, RwLock}, | ||
}; | ||
|
||
use powdr_ast::{ | ||
analyzed::{Analyzed, Expression, FunctionValueDefinition, Symbol, TypedExpression}, | ||
parsed::{ | ||
types::{ArrayType, Type}, | ||
IndexAccess, | ||
}, | ||
}; | ||
use powdr_number::{BigInt, BigUint, DegreeType, FieldElement}; | ||
use powdr_pil_analyzer::evaluator::{self, Definitions, SymbolLookup, Value}; | ||
use rayon::iter::{IntoParallelIterator, ParallelIterator}; | ||
|
||
/// Evaluates the fixed polynomial `name` on all values from 0 to `degree - 1` | ||
/// using an interpreter. | ||
/// If `index` is `Some(i)`, evaluates the `i`-th element of the array. | ||
pub fn generate_values<T: FieldElement>( | ||
analyzed: &Analyzed<T>, | ||
degree: DegreeType, | ||
name: &str, | ||
body: &FunctionValueDefinition, | ||
index: Option<u64>, | ||
) -> Vec<T> { | ||
let symbols = CachedSymbols { | ||
symbols: &analyzed.definitions, | ||
solved_impls: &analyzed.solved_impls, | ||
cache: Arc::new(RwLock::new(Default::default())), | ||
degree, | ||
}; | ||
let result = match body { | ||
FunctionValueDefinition::Expression(TypedExpression { e, type_scheme }) => { | ||
if let Some(type_scheme) = type_scheme { | ||
assert!(type_scheme.vars.is_empty()); | ||
let ty = &type_scheme.ty; | ||
if ty == &Type::Col { | ||
assert!(index.is_none()); | ||
} else if let Type::Array(ArrayType { base, length: _ }) = ty { | ||
assert!(index.is_some()); | ||
assert_eq!(base.as_ref(), &Type::Col); | ||
} else { | ||
panic!("Invalid fixed column type: {ty}"); | ||
} | ||
}; | ||
let index_expr; | ||
let e = if let Some(index) = index { | ||
index_expr = IndexAccess { | ||
array: e.clone().into(), | ||
index: Box::new(BigUint::from(index).into()), | ||
} | ||
.into(); | ||
&index_expr | ||
} else { | ||
e | ||
}; | ||
let fun = evaluator::evaluate(e, &mut symbols.clone()).unwrap(); | ||
(0..degree) | ||
.into_par_iter() | ||
.map(|i| { | ||
evaluator::evaluate_function_call( | ||
fun.clone(), | ||
vec![Arc::new(Value::Integer(BigInt::from(i)))], | ||
&mut symbols.clone(), | ||
)? | ||
.try_to_field_element() | ||
}) | ||
.collect::<Result<Vec<_>, _>>() | ||
} | ||
FunctionValueDefinition::Array(values) => { | ||
assert!(index.is_none()); | ||
values | ||
.to_repeated_arrays(degree) | ||
.map(|elements| { | ||
let items = elements | ||
.pattern() | ||
.iter() | ||
.map(|v| { | ||
let mut symbols = symbols.clone(); | ||
evaluator::evaluate(v, &mut symbols) | ||
.and_then(|v| v.try_to_field_element()) | ||
}) | ||
.collect::<Result<Vec<_>, _>>()?; | ||
|
||
Ok(items | ||
.into_iter() | ||
.cycle() | ||
.take(elements.size() as usize) | ||
.collect::<Vec<_>>()) | ||
}) | ||
.collect::<Result<Vec<_>, _>>() | ||
.map(|values| { | ||
let values: Vec<T> = values.into_iter().flatten().collect(); | ||
assert_eq!(values.len(), degree as usize); | ||
values | ||
}) | ||
} | ||
FunctionValueDefinition::TypeDeclaration(_) | ||
| FunctionValueDefinition::TypeConstructor(_, _) | ||
| FunctionValueDefinition::TraitDeclaration(_) | ||
| FunctionValueDefinition::TraitFunction(_, _) => panic!(), | ||
}; | ||
match result { | ||
Err(err) => { | ||
eprintln!("Error evaluating fixed polynomial {name}{body}:\n{err}"); | ||
panic!("{err}"); | ||
} | ||
Ok(v) => v, | ||
} | ||
} | ||
|
||
type SymbolCache<'a, T> = HashMap<String, BTreeMap<Option<Vec<Type>>, Arc<Value<'a, T>>>>; | ||
|
||
#[derive(Clone)] | ||
pub struct CachedSymbols<'a, T> { | ||
symbols: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition>)>, | ||
solved_impls: &'a HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>, | ||
cache: Arc<RwLock<SymbolCache<'a, T>>>, | ||
degree: DegreeType, | ||
} | ||
|
||
impl<'a, T: FieldElement> SymbolLookup<'a, T> for CachedSymbols<'a, T> { | ||
fn lookup( | ||
&mut self, | ||
name: &'a str, | ||
type_args: &Option<Vec<Type>>, | ||
) -> Result<Arc<Value<'a, T>>, evaluator::EvalError> { | ||
if let Some(v) = self | ||
.cache | ||
.read() | ||
.unwrap() | ||
.get(name) | ||
.and_then(|map| map.get(type_args)) | ||
{ | ||
return Ok(v.clone()); | ||
} | ||
let result = Definitions::lookup_with_symbols( | ||
self.symbols, | ||
self.solved_impls, | ||
name, | ||
type_args, | ||
self, | ||
)?; | ||
self.cache | ||
.write() | ||
.unwrap() | ||
.entry(name.to_string()) | ||
.or_default() | ||
.entry(type_args.clone()) | ||
.or_insert_with(|| result.clone()); | ||
Ok(result) | ||
} | ||
|
||
fn degree(&self) -> Result<Arc<Value<'a, T>>, evaluator::EvalError> { | ||
Ok(Value::Integer(self.degree.into()).into()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
use std::collections::HashMap; | ||
|
||
use powdr_ast::analyzed::{Analyzed, PolyID}; | ||
use powdr_number::FieldElement; | ||
use rayon::iter::{IntoParallelIterator, ParallelIterator}; | ||
|
||
use super::VariablySizedColumn; | ||
|
||
/// Tries to just-in-time compile all fixed columns in `analyzed` | ||
/// and then evaluates those functions to return `VariablySizedColumn`s. | ||
/// Ignoreds all columns where the compilation fails. | ||
pub fn generate_values<T: FieldElement>( | ||
analyzed: &Analyzed<T>, | ||
) -> HashMap<(String, PolyID), VariablySizedColumn<T>> { | ||
let fun_map = match powdr_jit_compiler::compile(analyzed, &symbols_to_compile(analyzed)) { | ||
Err(err) => { | ||
log::error!("Failed to compile some constant columns: {}", err); | ||
return HashMap::new(); | ||
} | ||
Ok(fun_map) => fun_map, | ||
}; | ||
|
||
analyzed | ||
.constant_polys_in_source_order() | ||
.into_iter() | ||
.filter_map(|(symbol, _)| { | ||
let fun = fun_map.get(symbol.absolute_name.as_str())?; | ||
Some((symbol, fun)) | ||
}) | ||
.map(|(symbol, fun)| { | ||
let column_values: Vec<Vec<T>> = symbol | ||
.degree | ||
.unwrap() | ||
.iter() | ||
.map(|degree| { | ||
let values = (0..degree) | ||
.into_par_iter() | ||
.map(|i| { | ||
let result = fun.call(i as u64); | ||
T::from(result) | ||
}) | ||
.collect(); | ||
values | ||
}) | ||
.collect(); | ||
|
||
( | ||
(symbol.absolute_name.clone(), symbol.into()), | ||
column_values.into(), | ||
) | ||
}) | ||
.collect() | ||
} | ||
|
||
fn symbols_to_compile<T>(analyzed: &Analyzed<T>) -> Vec<&str> { | ||
analyzed | ||
.constant_polys_in_source_order() | ||
.into_iter() | ||
.filter_map(|(symbol, value)| { | ||
(!symbol.is_array() && value.is_some()).then_some(symbol.absolute_name.as_str()) | ||
}) | ||
.collect() | ||
} |
Oops, something went wrong.