diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 1a0ac9ac744..d9b77735753 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1166,9 +1166,27 @@ void handleUse( bool usesShadow = Def->getValueAsBit("usesShadow"); bool usesCustom = Def->getValueAsBit("usesCustom"); - // We don't handle any custom primal/shadow - (void)usesCustom; - assert(!usesCustom); + // This only concerns instances of StaticSelect for now + if (usesCustom) { + auto numArgs = resultTree->getNumArgs(); + + for (int i = numArgs == 3; i < numArgs; ++i) { + std::string foundPrimalUse2 = ""; + std::string foundShadowUse2 = ""; + + bool foundDiffRet2 = false; + + auto name = resultTree->getArgNameStr(i); + auto arg = resultTree->getArg(i); + auto arg2 = dyn_cast(arg); + handleUse(root, arg2, name.size() ? foundPrimalUse2 : foundPrimalUse, + name.size() ? foundShadowUse2 : foundShadowUse, + name.size() ? foundDiffRet2 : foundDiffRet, + usesPrimal ? precondition : "", tree, varNameToCondition); + } + + return; + } for (auto argEn : llvm::enumerate(resultTree->getArgs())) { auto name = resultTree->getArgNameStr(argEn.index());