Skip to content

Commit

Permalink
Try to generate constant values via jit.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Sep 27, 2024
1 parent 1746d8c commit 5b54d2d
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 166 deletions.
1 change: 1 addition & 0 deletions executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ powdr-ast.workspace = true
powdr-number.workspace = true
powdr-parser-util.workspace = true
powdr-pil-analyzer.workspace = true
powdr-jit-compiler.workspace = true

itertools = "0.13"
log = { version = "0.4.17" }
Expand Down
158 changes: 158 additions & 0 deletions executor/src/constant_evaluator/interpreter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use std::{
collections::{BTreeMap, HashMap},
sync::{Arc, RwLock},
};

use powdr_ast::{
analyzed::{Analyzed, Expression, FunctionValueDefinition, Symbol, TypedExpression},
parsed::{
types::{ArrayType, Type},
IndexAccess,
},
};
use powdr_number::{BigInt, BigUint, DegreeType, FieldElement};
use powdr_pil_analyzer::evaluator::{self, Definitions, SymbolLookup, Value};
use rayon::iter::{IntoParallelIterator, ParallelIterator};

/// Evaluates the fixed polynomial `name` on all values from 0 to `degree - 1`
/// using an interpreter.
/// If `index` is `Some(i)`, evaluates the `i`-th element of the array.
pub fn generate_values<T: FieldElement>(
analyzed: &Analyzed<T>,
degree: DegreeType,
name: &str,
body: &FunctionValueDefinition,
index: Option<u64>,
) -> Vec<T> {
let symbols = CachedSymbols {
symbols: &analyzed.definitions,
solved_impls: &analyzed.solved_impls,
cache: Arc::new(RwLock::new(Default::default())),
degree,
};
let result = match body {
FunctionValueDefinition::Expression(TypedExpression { e, type_scheme }) => {
if let Some(type_scheme) = type_scheme {
assert!(type_scheme.vars.is_empty());
let ty = &type_scheme.ty;
if ty == &Type::Col {
assert!(index.is_none());
} else if let Type::Array(ArrayType { base, length: _ }) = ty {
assert!(index.is_some());
assert_eq!(base.as_ref(), &Type::Col);
} else {
panic!("Invalid fixed column type: {ty}");
}
};
let index_expr;
let e = if let Some(index) = index {
index_expr = IndexAccess {
array: e.clone().into(),
index: Box::new(BigUint::from(index).into()),
}
.into();
&index_expr
} else {
e
};
let fun = evaluator::evaluate(e, &mut symbols.clone()).unwrap();
(0..degree)
.into_par_iter()
.map(|i| {
evaluator::evaluate_function_call(
fun.clone(),
vec![Arc::new(Value::Integer(BigInt::from(i)))],
&mut symbols.clone(),
)?
.try_to_field_element()
})
.collect::<Result<Vec<_>, _>>()
}
FunctionValueDefinition::Array(values) => {
assert!(index.is_none());
values
.to_repeated_arrays(degree)
.map(|elements| {
let items = elements
.pattern()
.iter()
.map(|v| {
let mut symbols = symbols.clone();
evaluator::evaluate(v, &mut symbols)
.and_then(|v| v.try_to_field_element())
})
.collect::<Result<Vec<_>, _>>()?;

Ok(items
.into_iter()
.cycle()
.take(elements.size() as usize)
.collect::<Vec<_>>())
})
.collect::<Result<Vec<_>, _>>()
.map(|values| {
let values: Vec<T> = values.into_iter().flatten().collect();
assert_eq!(values.len(), degree as usize);
values
})
}
FunctionValueDefinition::TypeDeclaration(_)
| FunctionValueDefinition::TypeConstructor(_, _)
| FunctionValueDefinition::TraitDeclaration(_)
| FunctionValueDefinition::TraitFunction(_, _) => panic!(),
};
match result {
Err(err) => {
eprintln!("Error evaluating fixed polynomial {name}{body}:\n{err}");
panic!("{err}");
}
Ok(v) => v,
}
}

type SymbolCache<'a, T> = HashMap<String, BTreeMap<Option<Vec<Type>>, Arc<Value<'a, T>>>>;

#[derive(Clone)]
pub struct CachedSymbols<'a, T> {
symbols: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition>)>,
solved_impls: &'a HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>,
cache: Arc<RwLock<SymbolCache<'a, T>>>,
degree: DegreeType,
}

impl<'a, T: FieldElement> SymbolLookup<'a, T> for CachedSymbols<'a, T> {
fn lookup(
&mut self,
name: &'a str,
type_args: &Option<Vec<Type>>,
) -> Result<Arc<Value<'a, T>>, evaluator::EvalError> {
if let Some(v) = self
.cache
.read()
.unwrap()
.get(name)
.and_then(|map| map.get(type_args))
{
return Ok(v.clone());
}
let result = Definitions::lookup_with_symbols(
self.symbols,
self.solved_impls,
name,
type_args,
self,
)?;
self.cache
.write()
.unwrap()
.entry(name.to_string())
.or_default()
.entry(type_args.clone())
.or_insert_with(|| result.clone());
Ok(result)
}

fn degree(&self) -> Result<Arc<Value<'a, T>>, evaluator::EvalError> {
Ok(Value::Integer(self.degree.into()).into())
}
}
63 changes: 63 additions & 0 deletions executor/src/constant_evaluator/jit_compiler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::collections::HashMap;

use powdr_ast::analyzed::{Analyzed, PolyID};
use powdr_number::FieldElement;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

use super::VariablySizedColumn;

/// Tries to just-in-time compile all fixed columns in `analyzed`
/// and then evaluates those functions to return `VariablySizedColumn`s.
/// Ignoreds all columns where the compilation fails.
pub fn generate_values<T: FieldElement>(
analyzed: &Analyzed<T>,
) -> HashMap<(String, PolyID), VariablySizedColumn<T>> {
let fun_map = match powdr_jit_compiler::compile(analyzed, &symbols_to_compile(analyzed)) {
Err(err) => {
log::error!("Failed to compile some constant columns: {}", err);
return HashMap::new();
}
Ok(fun_map) => fun_map,
};

analyzed
.constant_polys_in_source_order()
.into_iter()
.filter_map(|(symbol, _)| {
let fun = fun_map.get(symbol.absolute_name.as_str())?;
Some((symbol, fun))
})
.map(|(symbol, fun)| {
let column_values: Vec<Vec<T>> = symbol
.degree
.unwrap()
.iter()
.map(|degree| {
let values = (0..degree)
.into_par_iter()
.map(|i| {
let result = fun.call(i as u64);
T::from(result)
})
.collect();
values
})
.collect();

(
(symbol.absolute_name.clone(), symbol.into()),
column_values.into(),
)
})
.collect()
}

fn symbols_to_compile<T>(analyzed: &Analyzed<T>) -> Vec<&str> {
analyzed
.constant_polys_in_source_order()
.into_iter()
.filter_map(|(symbol, value)| {
(!symbol.is_array() && value.is_some()).then_some(symbol.absolute_name.as_str())
})
.collect()
}
Loading

0 comments on commit 5b54d2d

Please sign in to comment.