Skip to content

Commit

Permalink
Refactor asm ast (#1736)
Browse files Browse the repository at this point in the history
  • Loading branch information
Schaeff authored Sep 24, 2024
1 parent 5739ebc commit 64e7bf1
Show file tree
Hide file tree
Showing 30 changed files with 679 additions and 628 deletions.
64 changes: 37 additions & 27 deletions airgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
use std::collections::BTreeMap;

use powdr_ast::{
asm_analysis::{self, combine_flags, AnalysisASMFile, Item, LinkDefinition},
object::{
Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph, TypeOrExpression,
},
asm_analysis::{self, combine_flags, AnalysisASMFile, LinkDefinition},
object::{Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph},
parsed::{
asm::{parse_absolute_path, AbsoluteSymbolPath, CallableRef, MachineParams},
Expression, PilStatement,
Expand Down Expand Up @@ -47,15 +45,15 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph {
main,
entry_points: Default::default(),
objects: [(main_location, Default::default())].into(),
definitions: utility_functions(input),
statements: utility_functions(input),
};
}
// if there is a single machine, treat it as main
1 => (*non_std_non_rom_machines.keys().next().unwrap()).clone(),
// otherwise, use the machine called `MAIN`
_ => {
let p = parse_absolute_path(MAIN_MACHINE);
assert!(input.items.contains_key(&p));
assert!(input.get_machine(&p).is_some());
p
}
};
Expand All @@ -67,7 +65,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph {
let mut instances = BTreeMap::default();

while let Some((location, ty, args)) = queue.pop() {
let machine = input.items.get(&ty).unwrap().try_to_machine().unwrap();
let machine = &input.get_machine(&ty).unwrap();

queue.extend(machine.submachines.iter().map(|def| {
(
Expand Down Expand Up @@ -127,9 +125,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph {
}
}

let Item::Machine(main_ty) = input.items.get(&main_ty).unwrap() else {
panic!()
};
let main_ty = &input.get_machine(&main_ty).unwrap();

let main = powdr_ast::object::Machine {
location: main_location,
Expand All @@ -150,7 +146,7 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph {
main,
entry_points,
objects,
definitions: utility_functions(input),
statements: utility_functions(input),
}
}

Expand Down Expand Up @@ -178,14 +174,32 @@ fn resolve_submachine_arg(
}
}

fn utility_functions(asm_file: AnalysisASMFile) -> BTreeMap<AbsoluteSymbolPath, TypeOrExpression> {
fn utility_functions(asm_file: AnalysisASMFile) -> BTreeMap<AbsoluteSymbolPath, Vec<PilStatement>> {
asm_file
.items
.modules
.into_iter()
.filter_map(|(n, v)| match v {
Item::Expression(e) => Some((n, TypeOrExpression::Expression(e))),
Item::TypeDeclaration(type_decl) => Some((n, TypeOrExpression::Type(type_decl))),
_ => None,
.map(|(module_path, module)| {
(
module_path,
module
.into_inner()
.1
.into_iter()
.filter(|s| match s {
PilStatement::EnumDeclaration(..) | PilStatement::LetStatement(..) => true,
PilStatement::Include(..) => false,
PilStatement::Namespace(..) => false,
PilStatement::PolynomialDefinition(..) => false,
PilStatement::PublicDeclaration(..) => false,
PilStatement::PolynomialConstantDeclaration(..) => false,
PilStatement::PolynomialConstantDefinition(..) => false,
PilStatement::PolynomialCommitDeclaration(..) => false,
PilStatement::TraitImplementation(..) => false,
PilStatement::TraitDeclaration(..) => false,
PilStatement::Expression(..) => false,
})
.collect(),
)
})
.collect()
}
Expand All @@ -205,7 +219,7 @@ struct ASMPILConverter<'a> {
/// Current machine instance
location: &'a Location,
/// Input definitions and machines.
items: &'a BTreeMap<AbsoluteSymbolPath, Item>,
input: &'a AnalysisASMFile,
/// Pil statements generated for the machine
pil: Vec<PilStatement>,
/// Submachine instances accessible to the machine (includes those passed as a parameter)
Expand All @@ -224,7 +238,7 @@ impl<'a> ASMPILConverter<'a> {
Self {
instances,
location,
items: &input.items,
input,
pil: Default::default(),
submachines: Default::default(),
incoming_permutations,
Expand All @@ -247,9 +261,7 @@ impl<'a> ASMPILConverter<'a> {
fn convert_machine_inner(mut self) -> Object {
let (ty, args) = self.instances.get(self.location).as_ref().unwrap();
// TODO: This clone doubles the current memory usage
let Item::Machine(input) = self.items.get(ty).unwrap().clone() else {
panic!();
};
let input = self.input.get_machine(ty).unwrap().clone();

let degree = input.degree;

Expand Down Expand Up @@ -331,9 +343,7 @@ impl<'a> ASMPILConverter<'a> {
panic!("could not find submachine named `{instance}` in machine `{ty}`");
});
// get the machine type from the machine map
let Item::Machine(instance_ty) = self.items.get(&instance.ty).unwrap() else {
panic!();
};
let instance_ty = &self.input.get_machine(&instance.ty).unwrap();

// check that the operation exists and that it has the same number of inputs/outputs as the link
let operation = instance_ty
Expand Down Expand Up @@ -513,8 +523,8 @@ impl<'a> ASMPILConverter<'a> {

for (param, value) in params.iter().zip(values) {
let ty = AbsoluteSymbolPath::default().join(param.ty.clone().unwrap());
match self.items.get(&ty) {
Some(Item::Machine(_)) => self.submachines.push(SubmachineRef {
match self.input.get_machine(&ty) {
Some(_) => self.submachines.push(SubmachineRef {
location: value.clone(),
name: param.name.clone(),
ty,
Expand Down
39 changes: 14 additions & 25 deletions analysis/src/machine_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use powdr_ast::{
asm_analysis::{
AnalysisASMFile, AssignmentStatement, CallableSymbolDefinitions, DebugDirective,
FunctionBody, FunctionStatements, FunctionSymbol, InstructionDefinitionStatement,
InstructionStatement, Item, LabelStatement, LinkDefinition, Machine, MachineDegree,
InstructionStatement, LabelStatement, LinkDefinition, Machine, MachineDegree, Module,
OperationSymbol, RegisterDeclarationStatement, RegisterTy, Return, SubmachineDeclaration,
},
parsed::{
Expand All @@ -23,10 +23,8 @@ use powdr_ast::{
/// Also transfers generic PIL definitions but does not verify anything about them.
pub fn check(file: ASMProgram) -> Result<AnalysisASMFile, Vec<String>> {
let ctx = AbsoluteSymbolPath::default();
let machines = TypeChecker::default().check_module(file.main, &ctx)?;
Ok(AnalysisASMFile {
items: machines.into_iter().collect(),
})
let modules = TypeChecker::default().check_module(file.main, &ctx)?;
Ok(AnalysisASMFile { modules })
}

#[derive(Default)]
Expand Down Expand Up @@ -312,10 +310,11 @@ impl TypeChecker {
&mut self,
module: ASMModule,
ctx: &AbsoluteSymbolPath,
) -> Result<BTreeMap<AbsoluteSymbolPath, Item>, Vec<String>> {
) -> Result<BTreeMap<AbsoluteSymbolPath, Module>, Vec<String>> {
let mut errors = vec![];

let mut res: BTreeMap<AbsoluteSymbolPath, Item> = BTreeMap::default();
let mut checked_module = Module::default();
let mut res = BTreeMap::default();

for m in module.statements {
match m {
Expand All @@ -327,7 +326,7 @@ impl TypeChecker {
errors.extend(e);
}
Ok(machine) => {
res.insert(ctx.with_part(&name), Item::Machine(machine));
checked_module.push_machine(name, machine);
}
};
}
Expand All @@ -343,6 +342,8 @@ impl TypeChecker {
asm::Module::Local(m) => m,
};

checked_module.push_module(name);

match self.check_module(m, &ctx) {
Err(err) => {
errors.extend(err);
Expand All @@ -352,29 +353,17 @@ impl TypeChecker {
}
};
}
asm::SymbolValue::Expression(e) => {
res.insert(ctx.clone().with_part(&name), Item::Expression(e));
}
asm::SymbolValue::TypeDeclaration(enum_decl) => {
res.insert(
ctx.clone().with_part(&name),
Item::TypeDeclaration(enum_decl),
);
}
asm::SymbolValue::TraitDeclaration(trait_decl) => {
res.insert(
ctx.clone().with_part(&name),
Item::TraitDeclaration(trait_decl),
);
}
}
}
ModuleStatement::TraitImplementation(trait_impl) => {
res.insert(ctx.clone(), Item::TraitImplementation(trait_impl));
ModuleStatement::PilStatement(s) => {
checked_module.push_pil_statement(s);
}
}
}

// add this module to the map of modules found inside it
res.insert(ctx.clone(), checked_module);

if !errors.is_empty() {
Err(errors)
} else {
Expand Down
13 changes: 3 additions & 10 deletions analysis/src/vm/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
use itertools::Itertools;
use powdr_ast::{
asm_analysis::{
AnalysisASMFile, BatchMetadata, FunctionStatement, Incompatible, IncompatibleSet, Item,
Machine,
AnalysisASMFile, BatchMetadata, FunctionStatement, Incompatible, IncompatibleSet, Machine,
},
parsed::asm::AbsoluteSymbolPath,
};
Expand Down Expand Up @@ -132,14 +131,8 @@ impl RomBatcher {
}

pub fn batch(&mut self, mut asm_file: AnalysisASMFile) -> AnalysisASMFile {
for (name, machine) in asm_file.items.iter_mut().filter_map(|(n, m)| match m {
Item::Machine(m) => Some((n, m)),
Item::Expression(_)
| Item::TypeDeclaration(_)
| Item::TraitDeclaration(_)
| Item::TraitImplementation(_) => None,
}) {
self.extract_batches(name, machine);
for (name, machine) in asm_file.machines_mut() {
self.extract_batches(&name, machine);
}

asm_file
Expand Down
42 changes: 13 additions & 29 deletions analysis/src/vm/inference.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,29 @@
//! Infer assignment registers in asm statements

use powdr_ast::{
asm_analysis::{AnalysisASMFile, Expression, FunctionStatement, Item, Machine},
asm_analysis::{AnalysisASMFile, Expression, FunctionStatement, Machine},
parsed::asm::AssignmentRegister,
};

pub fn infer(file: AnalysisASMFile) -> Result<AnalysisASMFile, Vec<String>> {
pub fn infer(mut file: AnalysisASMFile) -> Result<AnalysisASMFile, Vec<String>> {
let mut errors = vec![];

let items = file
.items
.into_iter()
.filter_map(|(name, m)| match m {
Item::Machine(m) => match infer_machine(m) {
Ok(m) => Some((name, Item::Machine(m))),
Err(e) => {
errors.extend(e);
None
}
},
Item::Expression(e) => Some((name, Item::Expression(e))),
Item::TypeDeclaration(enum_decl) => Some((name, Item::TypeDeclaration(enum_decl))),
Item::TraitImplementation(trait_impl) => {
Some((name, Item::TraitImplementation(trait_impl)))
file.machines_mut()
.for_each(|(_, m)| match infer_machine(m) {
Ok(()) => {}
Err(e) => {
errors.extend(e);
}
Item::TraitDeclaration(trait_decl) => Some((name, Item::TraitDeclaration(trait_decl))),
})
.collect();
});

if !errors.is_empty() {
Err(errors)
} else {
Ok(AnalysisASMFile { items })
Ok(file)
}
}

fn infer_machine(mut machine: Machine) -> Result<Machine, Vec<String>> {
fn infer_machine(machine: &mut Machine) -> Result<(), Vec<String>> {
let mut errors = vec![];

for f in machine.callable.functions_mut() {
Expand Down Expand Up @@ -96,7 +84,7 @@ fn infer_machine(mut machine: Machine) -> Result<Machine, Vec<String>> {
if !errors.is_empty() {
Err(errors)
} else {
Ok(machine)
Ok(())
}
}

Expand Down Expand Up @@ -127,9 +115,7 @@ mod tests {

let file = infer_str(file).unwrap();

let machine = &file.items[&parse_absolute_path("::Machine")]
.try_to_machine()
.unwrap();
let machine = &file.get_machine(&parse_absolute_path("::Machine")).unwrap();
if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = machine
.functions()
.next()
Expand Down Expand Up @@ -168,9 +154,7 @@ mod tests {

let file = infer_str(file).unwrap();

let machine = &file.items[&parse_absolute_path("::Machine")]
.try_to_machine()
.unwrap();
let machine = &file.get_machine(&parse_absolute_path("::Machine")).unwrap();
if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = &machine
.functions()
.next()
Expand Down
Loading

0 comments on commit 64e7bf1

Please sign in to comment.