Skip to content

Commit

Permalink
Fix nametoordinal (#2221)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 9, 2025
1 parent 495fde3 commit e8bc187
Showing 1 changed file with 57 additions and 48 deletions.
105 changes: 57 additions & 48 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ struct VariableSetting {
StringMap<std::vector<int>> extractions;

std::tuple<std::string, bool, std::vector<int>>
lookup(StringRef name, const Record *pattern, const Init *resultRoot) {
lookup(StringRef name, const Record *pattern, const Init *resultRoot) const {
auto ord = nameToOrdinal.find(name);
if (ord == nameToOrdinal.end())
PrintFatalError(pattern->getLoc(), Twine("unknown named operand '") +
Expand Down Expand Up @@ -1192,14 +1192,16 @@ void handleUse(
const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
const DagInit *tree,
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition);
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition,
const VariableSetting &nameToOrdinal);

void handleUseArgument(
StringRef name, const Init *arg, bool usesPrimal, bool usesShadow,
const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
const DagInit *tree,
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition) {
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition,
const VariableSetting &nameToOrdinal) {

auto arg2 = dyn_cast<DagInit>(arg);

Expand All @@ -1218,7 +1220,8 @@ void handleUseArgument(
handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse,
name.size() ? foundShadowUse2 : foundShadowUse,
name.size() ? foundDiffRet2 : foundDiffRet,
usesPrimal ? precondition : "", tree, varNameToCondition);
usesPrimal ? precondition : "", tree, varNameToCondition,
nameToOrdinal);

if (name.size()) {
if (foundPrimalUse2.size() &&
Expand Down Expand Up @@ -1306,7 +1309,8 @@ void handleUse(
const DagInit *root, const DagInit *resultTree, std::string &foundPrimalUse,
std::string &foundShadowUse, bool &foundDiffRet, std::string precondition,
const DagInit *tree,
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition) {
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition,
const VariableSetting &nameToOrdinal) {
auto opName = resultTree->getOperator()->getAsString();
auto Def = cast<DefInit>(resultTree->getOperator())->getDef();
if (opName == "DiffeRetIndex" || Def->isSubClassOf("DiffeRetIndex")) {
Expand Down Expand Up @@ -1339,7 +1343,9 @@ void handleUse(
if (numArgs == 3) {
if (isa<UnsetInit>(resultTree->getArg(0)) && resultTree->getArgName(0)) {
auto name = resultTree->getArgName(0)->getAsUnquotedString();
conditionStr = ReplaceAll(conditionStr, "imVal", name);
auto [ord, isVec, ext] = nameToOrdinal.lookup(name, nullptr, nullptr);
assert(!isVec);
conditionStr = ReplaceAll(conditionStr, "imVal", ord);
} else
assert("Requires name for arg");
}
Expand All @@ -1362,7 +1368,7 @@ void handleUse(
auto arg = resultTree->getArg(i);
handleUseArgument(name, arg, true, false, root, resultTree,
foundPrimalUse, foundShadowUse, foundDiffRet,
precondition2, tree, varNameToCondition);
precondition2, tree, varNameToCondition, nameToOrdinal);
}

return;
Expand All @@ -1375,16 +1381,57 @@ void handleUse(
auto name = resultTree->getArgNameStr(argEn.index());
handleUseArgument(name, argEn.value(), usesPrimal, usesShadow, root,
resultTree, foundPrimalUse, foundShadowUse, foundDiffRet,
precondition, tree, varNameToCondition);
precondition, tree, varNameToCondition, nameToOrdinal);
}
}

static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic,
StringRef origName) {
VariableSetting nameToOrdinal;
std::function<void(const DagInit *, ArrayRef<unsigned>)> insert =
[&](const DagInit *ptree, ArrayRef<unsigned> prev) {
unsigned i = 0;
for (auto tree : ptree->getArgs()) {
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
next.push_back(i);
if (auto dg = dyn_cast<DagInit>(tree))
insert(dg, next);

if (ptree->getArgNameStr(i).size()) {
std::string op;
if (intrinsic != MLIRDerivatives)
op = (origName + ".getOperand(" + Twine(next[0]) + ")").str();
else
op = (origName + "->getOperand(" + Twine(next[0]) + ")").str();
std::vector<int> extractions;
if (prev.size() > 0) {
for (unsigned i = 1; i < next.size(); i++) {
extractions.push_back(next[i]);
}
}
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
extractions);
}
i++;
}
};

insert(tree, {});

if (tree->getNameStr().size())
nameToOrdinal.insert(tree->getNameStr(),
(Twine("(&") + origName + ")").str(), false, {});
return nameToOrdinal;
}

void printDiffUse(
raw_ostream &os, Twine prefix, const ListInit *argOps, StringRef origName,
ActionType intrinsic, const DagInit *tree,
StringMap<std::tuple<std::string, std::string, bool>> &varNameToCondition) {
os << prefix << " // Rule " << *tree << "\n";

VariableSetting nameToOrdinal = parseVariables(tree, intrinsic, origName);

for (auto argOpEn : enumerate(*argOps)) {
size_t argIdx = argOpEn.index();
if (auto resultRoot = dyn_cast<DagInit>(argOpEn.value())) {
Expand Down Expand Up @@ -1417,7 +1464,8 @@ void printDiffUse(

// hasDiffeRet(resultTree)
handleUse(resultTree, resultTree, foundPrimalUse, foundShadowUse,
foundDiffRet, /*precondition*/ "true", tree, varNameToCondition);
foundDiffRet, /*precondition*/ "true", tree, varNameToCondition,
nameToOrdinal);

os << prefix << " // Arg " << argIdx << " : " << *resultTree << "\n";

Expand Down Expand Up @@ -1587,45 +1635,6 @@ static void emitMLIRReverse(raw_ostream &os, const Record *pattern,
os << " mlir::Value dif = nullptr;\n";
}

static VariableSetting parseVariables(const DagInit *tree, ActionType intrinsic,
StringRef origName) {
VariableSetting nameToOrdinal;
std::function<void(const DagInit *, ArrayRef<unsigned>)> insert =
[&](const DagInit *ptree, ArrayRef<unsigned> prev) {
unsigned i = 0;
for (auto tree : ptree->getArgs()) {
SmallVector<unsigned, 2> next(prev.begin(), prev.end());
next.push_back(i);
if (auto dg = dyn_cast<DagInit>(tree))
insert(dg, next);

if (ptree->getArgNameStr(i).size()) {
std::string op;
if (intrinsic != MLIRDerivatives)
op = (origName + ".getOperand(" + Twine(next[0]) + ")").str();
else
op = (origName + "->getOperand(" + Twine(next[0]) + ")").str();
std::vector<int> extractions;
if (prev.size() > 0) {
for (unsigned i = 1; i < next.size(); i++) {
extractions.push_back(next[i]);
}
}
nameToOrdinal.insert(ptree->getArgNameStr(i), op, false,
extractions);
}
i++;
}
};

insert(tree, {});

if (tree->getNameStr().size())
nameToOrdinal.insert(tree->getNameStr(),
(Twine("(&") + origName + ")").str(), false, {});
return nameToOrdinal;
}

static void emitReverseCommon(raw_ostream &os, const Record *pattern,
const DagInit *tree, ActionType intrinsic,
StringRef origName, const ListInit *argOps) {
Expand Down

0 comments on commit e8bc187

Please sign in to comment.