Skip to content

Commit

Permalink
ActivityAnalysis: consider atomicrmw (#2205)
Browse files Browse the repository at this point in the history
* ActivityAnalysis: consider atomicrmw

* fix

* fix
  • Loading branch information
wsmoses authored Dec 23, 2024
1 parent 79ac406 commit 1ca3ceb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
22 changes: 21 additions & 1 deletion enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,26 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
ReEvaluateValueIfInactiveValue[II->getOperand(0)].insert(TmpOrig);
}
}
} else if (auto RMW = dyn_cast<AtomicRMWInst>(TmpOrig)) {
if (directions == UP) {
if (isConstantValue(TR, RMW->getPointerOperand())) {
InsertConstantValue(TR, Val);
return true;
}
} else {
if (UpHypothesis->isConstantValue(TR, RMW->getPointerOperand())) {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
}
}
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(Val);
if (TmpOrig != Val) {
ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(
TmpOrig);
}
}
} else if (auto op = dyn_cast<CallInst>(TmpOrig)) {
if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") ||
op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex,
Expand Down Expand Up @@ -1940,7 +1960,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
isRefSet(AARes)) {
if (EnzymePrintActivity)
llvm::errs() << "potential active load: " << *I << "\n";
if (isa<LoadInst>(I) || isNVLoad(I)) {
if (isa<LoadInst>(I) || isNVLoad(I) || isa<AtomicRMWInst>(I)) {
// If the ref'ing value is a load check if the loaded value is
// active
if (!Hypothesis->isConstantValue(TR, I)) {
Expand Down
10 changes: 10 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3822,6 +3822,9 @@ bool GradientUtils::legalRecompute(const Value *val,
}
}

if (isa<AtomicRMWInst>(val))
return false;

if (auto phi = dyn_cast<PHINode>(val)) {
if (auto uiv = hasUninverted(val)) {
if (auto dli = dyn_cast_or_null<LoadInst>(uiv)) {
Expand All @@ -3835,6 +3838,13 @@ bool GradientUtils::legalRecompute(const Value *val,
}
}

auto found = fictiousPHIs.find(const_cast<llvm::PHINode *>(phi));
if (found != fictiousPHIs.end()) {
auto orig = found->second;
if (isa<AtomicRMWInst>(orig))
return false;
}

if (phi->getNumIncomingValues() == 0) {
llvm::errs() << *oldFunc << "\n";
llvm::errs() << *newFunc << "\n";
Expand Down

0 comments on commit 1ca3ceb

Please sign in to comment.