Skip to content

Commit

Permalink
Change architecture around assertion stack, preprocessing and cnfization
Browse files Browse the repository at this point in the history
Maintaining the assertion stack should be the responsibility of the
MainSolver, not the Theory.
Related to this change of responsibilities, we also change the
preprocessing, so that it operates only on the current frame.
(This should probably be changed so that simplifications from lower
frames can be applied to higher frames as well, but we need to test if
this would really be useful.)

Another change is a more separate Cnfizer, which now does not need to
know about the PartitioningManager or the CoreSolver. Instead, it just
reports new clauses to a callback that needs to be set up beforehand.

The larger goal here is to eventually remove separate Theory and Logic classes and
have a single class responsible for managing terms (former Logic) and
a single class for preprocessing (former Theory). These could be
customized using a plugin architecture.
  • Loading branch information
blishko committed Feb 24, 2024
1 parent b95e90d commit 09a8a1a
Show file tree
Hide file tree
Showing 29 changed files with 572 additions and 1,184 deletions.
1 change: 0 additions & 1 deletion src/api/Interpret.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ void Interpret::interp(ASTNode& n) {
}
initializeLogic(logic_type);
main_solver = createMainSolver(logic_name);
main_solver->initialize();
notify_success();
}
break;
Expand Down
196 changes: 145 additions & 51 deletions src/api/MainSolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,88 +42,96 @@ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

namespace opensmt { bool stop; }

void MainSolver::initialize() {
smt_solver->initialize();
opensmt::pair<CRef, CRef> iorefs{CRef_Undef, CRef_Undef};
smt_solver->addOriginalSMTClause({term_mapper->getOrCreateLit(logic.getTerm_true())}, iorefs);
if (iorefs.first != CRef_Undef) { pmanager.addClauseClassMask(iorefs.first, 1); }

smt_solver->addOriginalSMTClause({~term_mapper->getOrCreateLit(logic.getTerm_false())}, iorefs);
if (iorefs.first != CRef_Undef) { pmanager.addClauseClassMask(iorefs.first, 1); }
}

void
MainSolver::push()
{
bool alreadyUnsat = isLastFrameUnsat();
frames.push(pfstore.alloc());
frames.push();
frameTerms.push(newFrameTerm(frames.last().getId()));
if (alreadyUnsat) { rememberLastFrameUnsat(); }
}

bool
MainSolver::pop()
{
if (frames.size() > 1) {
if (frames.frameCount() > 1) {
if (config.produce_inter() > 0) {
auto toPop = frames.last();
auto& partitionsToInvalidate = pfstore[toPop].formulas;
auto const & partitionsToInvalidate = frames.last().formulas;
ipartitions_t mask = 0;
for (int i = 0; i < partitionsToInvalidate.size(); ++i) {
PTRef part = partitionsToInvalidate[i];
auto index = pmanager.getPartitionIndex(part);
for (PTRef partition : partitionsToInvalidate) {
auto index = pmanager.getPartitionIndex(partition);
assert(index != -1);
opensmt::setbit(mask, static_cast<unsigned int>(index));
}
pmanager.invalidatePartitions(mask);
}
frames.pop();
if (!isLastFrameUnsat()) {
firstNotSimplifiedFrame = std::min(firstNotSimplifiedFrame, frames.frameCount());
if (not isLastFrameUnsat()) {
getSMTSolver().restoreOK();
}
return true;
}
return false;
}

PTRef MainSolver::newFrameTerm(MainSolver::FrameId frameId) {
assert(frameId.id != 0);
auto name = std::string(Logic::s_framev_prefix) + std::to_string(frameId.id);
PTRef frameTerm = logic.mkBoolVar(name.c_str());
Lit l = term_mapper->getOrCreateLit(frameTerm);
term_mapper->setFrozen(var(l));
smt_solver->addAssumptionVar(var(l));
return frameTerm;
}

void MainSolver::insertFormula(PTRef root)
void MainSolver::insertFormula(PTRef fla)
{
if (logic.getSortRef(root) != logic.getSort_bool()) {
throw OsmtApiException("Top-level assertion sort must be Bool, got " + logic.printSort(logic.getSortRef(root)));
if (logic.getSortRef(fla) != logic.getSort_bool()) {
throw OsmtApiException("Top-level assertion sort must be Bool, got " + logic.printSort(logic.getSortRef(fla)));
}

root = logic.conjoinExtras(root);
root = IteHandler(logic, getPartitionManager().getNofPartitions()).rewrite(root);
// TODO: Move this to preprocessing of the formulas
fla = IteHandler(logic, getPartitionManager().getNofPartitions()).rewrite(fla);

if (getConfig().produce_inter()) {
// MB: Important for HiFrog! partition index is the index of the formula in an virtual array of inserted formulas,
// thus we need the old value of count. TODO: Find a good interface for this so it cannot be broken this easily
unsigned int partition_index = inserted_formulas_count++;
pmanager.assignTopLevelPartitionIndex(partition_index, root);
assert(pmanager.getPartitionIndex(root) != -1);
pmanager.assignTopLevelPartitionIndex(partition_index, fla);
assert(pmanager.getPartitionIndex(fla) != -1);
}
else {
++inserted_formulas_count;
}

PushFrame& lastFrame = pfstore[frames.last()];
lastFrame.push(root);
lastFrame.root = PTRef_Undef;
// New formula has been added to the last frame. If the frame has been simplified before, we need to do it again
frames.setSimplifiedUntil(std::min(frames.getSimplifiedUntil(), frames.size() - 1));
frames.add(fla);
firstNotSimplifiedFrame = std::min(firstNotSimplifiedFrame, frames.frameCount() - 1);
}

sstat MainSolver::simplifyFormulas()
{
status = s_Undef;

vec<PTRef> coll_f;
bool keepPartitionsSeparate = getConfig().produce_inter();
// Process (and simplify) not yet processed frames. Stop processing if solver is in UNSAT state already
for (std::size_t i = frames.getSimplifiedUntil(); i < frames.size() && status != s_False; i++) {
getTheory().simplify(frames.getFrameReferences(), pmanager, i);
frames.setSimplifiedUntil(i + 1);
const PushFrame & frame = pfstore[frames.getFrameReference(i)];

for (std::size_t i = firstNotSimplifiedFrame; i < frames.frameCount() && status != s_False; i++) {
if (keepPartitionsSeparate) {
vec<PTRef> const & flas = frame.formulas;
if (flas.size() == 0 or std::all_of(flas.begin(), flas.end(), [&](PTRef fla) { return fla == logic.getTerm_true(); })) {
pmanager.assignTopLevelPartitionIndex(0, logic.getTerm_true());
status = giveToSolver(logic.getTerm_true(), frame.getId());
vec<PTRef> frameFormulas = getTheory().simplifyIndividually(frames[i].formulas, pmanager, i == 0);
firstNotSimplifiedFrame = i + 1;
if (frameFormulas.size() == 0 or std::all_of(frameFormulas.begin(), frameFormulas.end(), [&](PTRef fla) { return fla == logic.getTerm_true(); })) {
continue;
}
for (int j = 0; j < flas.size() && status != s_False; ++j) {
PTRef fla = flas[j];
for (PTRef fla : frameFormulas) {
if (fla == logic.getTerm_true()) { continue; }
assert(pmanager.getPartitionIndex(fla) != -1);
// Optimize the dag for cnfization
Expand All @@ -134,25 +142,27 @@ sstat MainSolver::simplifyFormulas()
}
assert(pmanager.getPartitionIndex(fla) != -1);
pmanager.propagatePartitionMask(fla);
status = giveToSolver(fla, frame.getId());
status = giveToSolver(fla, frames[i].getId());
if (status == s_False) { break; }
}
} else {
PTRef root = frame.root;
if (logic.isFalse(root)) {
giveToSolver(getLogic().getTerm_false(), frame.getId());
PTRef frameFormula = getTheory().simplifyTogether(frames[i].formulas, i == 0);
firstNotSimplifiedFrame = i + 1;
if (logic.isFalse(frameFormula)) {
giveToSolver(getLogic().getTerm_false(), frames[i].getId());
status = s_False;
break;
}
// Optimize the dag for cnfization
if (logic.isBooleanOperator(root)) {
root = rewriteMaxArity(root);
if (logic.isBooleanOperator(frameFormula)) {
frameFormula = rewriteMaxArity(frameFormula);
}
root_instance.setRoot(root);
status = giveToSolver(root, frame.getId());
status = giveToSolver(frameFormula, frames[i].getId());
}
}
if (status == s_False) {
rememberUnsatFrame(frames.getSimplifiedUntil() - 1);
assert(firstNotSimplifiedFrame > 0);
rememberUnsatFrame(firstNotSimplifiedFrame - 1);
}
return status;
}
Expand All @@ -177,48 +187,102 @@ std::unique_ptr<Model> MainSolver::getModel() {
if (status != s_True) { throw OsmtApiException("Model cannot be created if solver is not in SAT state"); }

ModelBuilder modelBuilder {logic};
ts.solver.fillBooleanVars(modelBuilder);
smt_solver->fillBooleanVars(modelBuilder);
thandler->fillTheoryFunctions(modelBuilder);

return modelBuilder.build();
}

lbool MainSolver::getTermValue(PTRef tr) const {
if (logic.getSortRef(tr) != logic.getSort_bool()) { return l_Undef; }
if (not term_mapper->hasLit(tr)) { return l_Undef; }

Lit l = term_mapper->getLit(tr);
auto val = smt_solver->modelValue(l);
assert(val != l_Undef);
return val;
}

std::unique_ptr<InterpolationContext> MainSolver::getInterpolationContext() {
if (status != s_False) { throw OsmtApiException("Interpolation context cannot be created if solver is not in UNSAT state"); }
return std::make_unique<InterpolationContext>(
config, *theory, *term_mapper, getSMTSolver().getProof(), pmanager
);
}

PTRef MainSolver::currentRootInstance() const {
vec<PTRef> assertions;
for (auto i = 0u; i < frames.frameCount(); i++) {
auto const & frameAssertions = frames[i].formulas;
for (PTRef assertion : frameAssertions) {
assertions.push(assertion);
}
}
return logic.mkAnd(std::move(assertions));
}

void MainSolver::printFramesAsQuery() const
{
char* base_name = config.dump_query_name();
if (base_name == NULL)
getTheory().printFramesAsQuery(frames.getFrameReferences(), std::cout);
printFramesAsQuery(std::cout);
else {
char* s_file_name;
int chars_written = asprintf(&s_file_name, "%s-%d.smt2", base_name, check_called);
(void)chars_written;
std::ofstream stream;
stream.open(s_file_name);
getTheory().printFramesAsQuery(frames.getFrameReferences(), stream);
printFramesAsQuery(stream);
stream.close();
free(s_file_name);
}
free(base_name);
}

sstat MainSolver::giveToSolver(PTRef root, MainSolver::FrameId push_id) {

struct ClauseCallBack : public Cnfizer::ClauseCallBack {
std::vector<vec<Lit>> clauses;
void operator()(vec<Lit> && c) override {
clauses.push_back(std::move(c));
}
};
ClauseCallBack callBack;
ts.setClauseCallBack(&callBack);
ts.Cnfizer::cnfize(root, push_id.id);
bool keepPartitionsSeparate = getConfig().produce_inter();
Lit frameLit = push_id.id == 0 ? Lit{} : term_mapper->getOrCreateLit(frameTerms[push_id.id]);
int partitionIndex = keepPartitionsSeparate ? pmanager.getPartitionIndex(root) : -1;
for (auto & clause : callBack.clauses) {
if (push_id.id != 0) {
clause.push(frameLit);
}
opensmt::pair<CRef, CRef> iorefs{CRef_Undef, CRef_Undef};
bool res = smt_solver->addOriginalSMTClause(std::move(clause), iorefs);
if (keepPartitionsSeparate) {
CRef ref = iorefs.first;
if (ref != CRef_Undef) {
ipartitions_t parts = 0;
assert(partitionIndex != -1);
opensmt::setbit(parts, static_cast<unsigned int>(partitionIndex));
pmanager.addClauseClassMask(ref, parts);
}
}
if (not res) { return s_False; }
}
return s_Undef;
}

sstat MainSolver::check()
{
check_called ++;
++check_called;
if (config.timeQueries()) {
printf("; %s query time so far: %f\n", solver_name.c_str(), query_timer.getTime());
opensmt::StopWatch sw(query_timer);
}
if (isLastFrameUnsat()) {
return s_False;
}
initialize();
sstat rval = simplifyFormulas();

if (config.dump_query())
Expand All @@ -245,19 +309,49 @@ sstat MainSolver::solve()
return s_False;
}

// FIXME: Find a better way to deal with Bools in UF
for (PTRef tr : logic.propFormulasAppearingInUF) {
Lit l = term_mapper->getOrCreateLit(tr);
smt_solver->addVar(var(l));
}

vec<FrameId> en_frames;
for (std::size_t i = 0; i < frames.size(); i++) {
const PushFrame& frame = pfstore[frames.getFrameReference(i)];
en_frames.push(frame.getId());
for (std::size_t i = 0; i < frames.frameCount(); ++i) {
en_frames.push(frames[i].getId());
}
status = sstat(solve_(en_frames));
status = solve_(en_frames);

if (status == s_True && config.produce_models())
thandler->computeModel();
smt_solver->clearSearch();
return status;
}

sstat MainSolver::solve_(vec<FrameId> const & enabledFrames) {
assert(frameTerms.size() > 0 and frameTerms[0] == logic.getTerm_true());
vec<Lit> assumps;
// Initialize so that by default frames are disabled
for (PTRef tr : frameTerms) {
assumps.push(term_mapper->getOrCreateLit(tr));
}

// Enable the terms which are listed in enabledFrames
// At this point assumps has the same size as frame_terms and the
// elements are in the same order. We simply invert the
// corresponding literals
uint32_t prevId = UINT32_MAX;
for (FrameId fid : enabledFrames) {
assumps[fid.id] = ~assumps[fid.id];
smt_solver->mapEnabledFrameIdToVar(var(assumps[fid.id]), fid.id, prevId);
}
// Drop the assumption variable for the base frame (it is at the first place)
for (int i = 1; i < assumps.size(); ++i) {
assumps[i-1] = assumps[i];
}
assumps.pop();
return smt_solver->solve(assumps, !config.isIncremental(), config.isIncremental());
}

std::unique_ptr<SimpSMTSolver> MainSolver::createInnerSolver(SMTConfig & config, THandler & thandler) {
if (config.sat_pure_lookahead()) {
return std::make_unique<LookaheadSMTSolver>(config, thandler);
Expand Down
Loading

0 comments on commit 09a8a1a

Please sign in to comment.