Skip to content

Commit

Permalink
fix(schedule_heuristic): refactored heursitic to use new literals for…
Browse files Browse the repository at this point in the history
… decision points
  • Loading branch information
danbryce committed Aug 30, 2017
1 parent cff941f commit e8ca181
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 37 deletions.
3 changes: 2 additions & 1 deletion src/opensmt/egraph/EgraphStore.C
Original file line number Diff line number Diff line change
Expand Up @@ -665,14 +665,15 @@ Enode * Egraph::mkFun( const char * name, Enode * args )
//
ostringstream ss;
ss << name;
std::cout << "[" << name << "]" << std::endl;
for ( Enode * l = args ; !l->isEnil( ) ; l = l->getCdr( ) )
{
ss << " ";
l->getCar( )->getLastSort( )->print( ss, false );
}

Enode * e = lookupSymbol( ss.str( ).c_str( ) );
if ( e == nullptr ) opensmt_error2( "undeclared function symbol ", ss.str( ).c_str( ) );
if ( e == nullptr ) opensmt_error2( "undeclared function symbol [", ss.str( ).c_str( ) );

Enode * ret = cons( e, args );

Expand Down
83 changes: 48 additions & 35 deletions src/opensmt/heuristics/schedule_heuristic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ namespace dreal {
m_config = &c;
m_is_initialized = true;

m_depth = 10;
num_choices_per_happening = 20; //num actions * 2
m_depth = 13;
num_choices_per_happening = 36; //num actions * 2

for(int i = 0; i < m_depth+1; i++){
at_time_enodes.push_back(new vector<Enode*>(num_choices_per_happening, NULL));
Expand Down Expand Up @@ -124,7 +124,7 @@ void schedule_heuristic::inform(Enode * e) {
return;
m_atoms.insert(e);

// DREAL_LOG_INFO << "schedule_heuristic::inform(): " << e << endl;
DREAL_LOG_INFO << "schedule_heuristic::inform(): " << e << endl;
// if (!e->isTAtom() && !e->isNot()) {
// unordered_set<Enode *> const & vars = e->get_vars();
// //unordered_set<Enode *> const & consts = e->get_constants();
Expand Down Expand Up @@ -210,44 +210,55 @@ void schedule_heuristic::inform(Enode * e) {
// } else
if (e->isEq()) {
unordered_set<Enode *> const & vars = e->get_vars();
unordered_set<Enode *> const & consts = e->get_constants();
// unordered_set<Enode *> const & consts = e->get_constants();
Enode* time_point = NULL;
Enode* happening = NULL;
for (auto const & v : vars) {
stringstream ss;
ss << v;
string var = ss.str();
if (var.find("time") != string::npos) {
int time = atoi(var.substr(var.find_last_of("_")+1).c_str());
for (auto const & c : consts) {
stringstream css;
css << c;
int cs = atoi(css.str().c_str());
if (cs == 0) { // only care about assinging time if wait is possible
DREAL_LOG_INFO << "time time = " << time << endl;
time_enodes[time] = e;
}
}
} else if (var.find("duract") == 0 && var.find("_at") != string::npos){
DREAL_LOG_INFO << "var = " << var;
for (auto const & c : consts) {
stringstream css;
css << c;
int time = atoi(css.str().c_str());
DREAL_LOG_INFO << "time = " << time << endl;
auto at = at_id.find(var);
if(at == at_id.end()){
at_id[var] = num_acts++;
at = at_id.find(var);
}

DREAL_LOG_INFO << "index = " << (*at).second;
(*at_time_enodes[time])[(*at).second] = e;
at_enodes.insert(e);
DREAL_LOG_INFO << "Got = " << (*at_time_enodes[time])[(*at).second];

}
if (var.find("duract") == 0 ){
time_point = v;
} else if (var.find("happening") == 0 ){
happening = v;
}
}

if (time_point != NULL && happening != NULL){
stringstream ss;
ss << happening;
string happening_str = ss.str();

stringstream ss1;
ss1 << time_point;
string time_point_str = ss1.str();


int time_pos = happening_str.find_last_of("_")+1;
int time = atoi(happening_str.substr(time_pos).c_str());

DREAL_LOG_INFO << "time = " << time << endl;

string name = time_point_str;

DREAL_LOG_INFO << "name = " << name << endl;


auto at = at_id.find(name);
if(at == at_id.end()){
at_id[name] = num_acts++;
at = at_id.find(name);
}

DREAL_LOG_INFO << "index = " << (*at).second;
(*at_time_enodes[time])[(*at).second] = e;
at_enodes.insert(e);
DREAL_LOG_INFO << "Got = " << (*at_time_enodes[time])[(*at).second];

}
}
//DREAL_LOG_INFO << "schedule_heuristic::inform(): " << e << endl;

}

int schedule_heuristic::getChoiceIndex(Enode *e) {
Expand Down Expand Up @@ -417,11 +428,12 @@ void schedule_heuristic::inform(Enode * e) {
DREAL_LOG_DEBUG << "Suggesting: " << decision->first << " " << decision->second->back();
for(int time = 0; time <= m_depth; time++){
Enode* decision_enode = (*at_time_enodes[time])[decision->first];
if(!decision_enode) continue;
DREAL_LOG_DEBUG << "Suggesting: " << decision_enode << " = " << (time == decision->second->back());
if (time == decision->second->back()){
m_suggestions.push_back(new std::pair<Enode *, bool>(decision_enode, true));
} else {
m_suggestions.push_back(new std::pair<Enode *, bool>(decision_enode, false));
//m_suggestions.push_back(new std::pair<Enode *, bool>(decision_enode, false));
}
}
}
Expand Down Expand Up @@ -493,6 +505,7 @@ std::pair<Enode*, bool>* schedule_heuristic::on_stack(Enode* act) {
std::vector<int>* decisions = new vector<int>();
for (int i = m_depth-1; i >= 0; i--) {
Enode* act_at_step = (*at_time_enodes[i])[act];
if(!act_at_step) continue;
pair<Enode*, bool>* on = on_stack(act_at_step);
if (on == NULL) { //no decisions for act on stack
decisions->push_back(i);
Expand Down
4 changes: 3 additions & 1 deletion src/opensmt/heuristics/schedule_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ class schedule_heuristic : public heuristic {
bool expand_path(bool first_expansion);
void displayDecisions();

std::map<std::string, int> at_id; //not used
std::map<std::string, int> at_id;
std::map<Enode*, std::string> at_names;
std::vector<std::vector<Enode*>*> at_time_enodes;

std::set<Enode*> at_enodes;
std::vector<int>* get_possible_decisions(int act);
std::pair<Enode*, bool>* on_stack(Enode* act);
Expand Down
6 changes: 6 additions & 0 deletions src/util/mcts_expander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ int get_step(std::string var, bool & at_start) {
int last_underscore = var.find_last_of("_");
// DREAL_LOG_INFO << first_underscore << " " << last_underscore;
int step = -1;

try{

if (first_underscore == last_underscore) { // is a time_x variable
step = std::stoi(var.substr(first_underscore + 1));
at_start = false;
Expand All @@ -271,6 +274,9 @@ int get_step(std::string var, bool & at_start) {
else
at_start = true;
}
} catch (std::exception e) {
return 0;
}
return step;
}

Expand Down

0 comments on commit e8ca181

Please sign in to comment.