Skip to content

Commit

Permalink
Fix blas pointer info (#2036)
Browse files Browse the repository at this point in the history
* Change to pointer

* blas fix

* fix multi

* TRAA

* fix store checker

* ix

* no jlinstsimpl
  • Loading branch information
wsmoses authored Aug 11, 2024
1 parent 0367300 commit fb0fa7b
Show file tree
Hide file tree
Showing 17 changed files with 243 additions and 218 deletions.
8 changes: 5 additions & 3 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,6 @@ bool attributeKnownFunctions(llvm::Function &F) {
}
}

changed |= attributeTablegen(F);

if (F.getName() ==
"_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_createERmm") {
changed = true;
Expand Down Expand Up @@ -377,6 +375,7 @@ bool attributeKnownFunctions(llvm::Function &F) {
AttributeList::FunctionIndex,
Attribute::get(F.getContext(), "enzyme_no_escaping_allocation"));
}
changed |= attributeTablegen(F);
return changed;
}

Expand Down Expand Up @@ -2963,9 +2962,12 @@ class EnzymeBase {
bool run(Module &M) {
Logic.clear();

for (Function &F : make_early_inc_range(M)) {
attributeKnownFunctions(F);
}

bool changed = false;
for (Function &F : M) {
attributeKnownFunctions(F);
if (F.empty())
continue;
for (BasicBlock &BB : F) {
Expand Down
24 changes: 18 additions & 6 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ struct CacheAnalysis {
n == "jl_get_ptls_states")
return false;
}
if (auto objli = dyn_cast<LoadInst>(obj)) {
auto obj2 = getBaseObject(objli->getOperand(0));
if (auto obj_op = dyn_cast<CallInst>(obj2)) {
auto n = getFuncNameFromCall(obj_op);
if (n == "julia.get_pgcstack" || n == "julia.ptls_states" ||
n == "jl_get_ptls_states")
return false;
}
}

// Openmp bound and local thread id are unchanging
// definitionally cacheable.
Expand Down Expand Up @@ -344,7 +353,7 @@ struct CacheAnalysis {
}
}

if (!overwritesToMemoryReadBy(AA, TLI, SE, OrigLI, OrigDT, &li,
if (!overwritesToMemoryReadBy(&TR, AA, TLI, SE, OrigLI, OrigDT, &li,
inst2)) {
return false;
}
Expand All @@ -367,7 +376,7 @@ struct CacheAnalysis {
return false;
}

if (!writesToMemoryReadBy(AA, TLI, &li, mid)) {
if (!writesToMemoryReadBy(&TR, AA, TLI, &li, mid)) {
return false;
}

Expand Down Expand Up @@ -464,6 +473,9 @@ struct CacheAnalysis {
if (funcName == "julia.write_barrier_binding")
return {};

if (funcName == "julia.safepoint")
return {};

if (funcName == "enzyme_zerotype")
return {};

Expand Down Expand Up @@ -1016,7 +1028,7 @@ void calculateUnusedValuesInFunction(
}

if (writesToMemoryReadBy(
*gutils->OrigAA, TLI,
&gutils->TR, *gutils->OrigAA, TLI,
/*maybeReader*/ const_cast<MemTransferInst *>(mti),
/*maybeWriter*/ I)) {
foundStore = true;
Expand Down Expand Up @@ -1176,7 +1188,7 @@ void calculateUnusedStoresInFunction(

// if (I == &MTI) return;
if (writesToMemoryReadBy(
*gutils->OrigAA, TLI,
&gutils->TR, *gutils->OrigAA, TLI,
/*maybeReader*/ const_cast<MemTransferInst *>(mti),
/*maybeWriter*/ I)) {
foundStore = true;
Expand Down Expand Up @@ -1576,7 +1588,7 @@ bool legalCombinedForwardReverse(
auto consider = [&](Instruction *user) {
if (!user->mayReadFromMemory())
return false;
if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI,
if (writesToMemoryReadBy(&gutils->TR, *gutils->OrigAA, gutils->TLI,
/*maybeReader*/ user,
/*maybeWriter*/ inst)) {

Expand Down Expand Up @@ -1609,7 +1621,7 @@ bool legalCombinedForwardReverse(
if (!post->mayWriteToMemory())
return false;

if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI,
if (writesToMemoryReadBy(&gutils->TR, *gutils->OrigAA, gutils->TLI,
/*maybeReader*/ inst,
/*maybeWriter*/ post)) {
if (EnzymePrintPerf) {
Expand Down
30 changes: 15 additions & 15 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3977,7 +3977,7 @@ bool GradientUtils::legalRecompute(const Value *val,
const_cast<Instruction *>(orig), [&](Instruction *I) -> bool {
if (I->mayWriteToMemory() &&
writesToMemoryReadBy(
*OrigAA, TLI,
&TR, *OrigAA, TLI,
/*maybeReader*/ const_cast<Instruction *>(orig),
/*maybeWriter*/ I)) {
failed = true;
Expand Down Expand Up @@ -4008,7 +4008,7 @@ bool GradientUtils::legalRecompute(const Value *val,
const_cast<Instruction *>(orig), [&](Instruction *I) -> bool {
if (I->mayWriteToMemory() &&
writesToMemoryReadBy(
*OrigAA, TLI,
&TR, *OrigAA, TLI,
/*maybeReader*/ const_cast<Instruction *>(orig),
/*maybeWriter*/ I)) {
failed = true;
Expand Down Expand Up @@ -6651,7 +6651,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
allInstructionsBetween(
*OrigLI, orig2, origInst, [&](Instruction *I) -> bool {
if (I->mayWriteToMemory() &&
writesToMemoryReadBy(*OrigAA, TLI,
writesToMemoryReadBy(&TR, *OrigAA, TLI,
/*maybeReader*/ origInst,
/*maybeWriter*/ I)) {
failed = true;
Expand Down Expand Up @@ -6729,7 +6729,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
*OrigLI, SI, origInst, [&](Instruction *potentialAlias) {
if (!potentialAlias->mayWriteToMemory())
return false;
if (!writesToMemoryReadBy(*OrigAA, TLI, origInst,
if (!writesToMemoryReadBy(&TR, *OrigAA, TLI, origInst,
potentialAlias))
return false;

Expand All @@ -6747,8 +6747,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
if (mid == SI)
return false;

if (!writesToMemoryReadBy(*OrigAA, TLI, origInst,
mid)) {
if (!writesToMemoryReadBy(&TR, *OrigAA, TLI,
origInst, mid)) {
return false;
}
lastStore = false;
Expand Down Expand Up @@ -6947,7 +6947,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
*OrigLI, &*origTerm, origInst,
[&](Instruction *I) -> bool {
if (I->mayWriteToMemory() &&
writesToMemoryReadBy(*OrigAA, TLI,
writesToMemoryReadBy(&TR, *OrigAA, TLI,
/*maybeReader*/ tmpload,
/*maybeWriter*/ I)) {
failed = true;
Expand All @@ -6973,7 +6973,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
*OrigLI, &*origTerm, origInst,
[&](Instruction *I) -> bool {
if (I->mayWriteToMemory() &&
writesToMemoryReadBy(*OrigAA, TLI,
writesToMemoryReadBy(&TR, *OrigAA, TLI,
/*maybeReader*/ tmpload,
/*maybeWriter*/ I)) {
failed = true;
Expand Down Expand Up @@ -7178,7 +7178,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
allInstructionsBetween(
*OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool {
if (I->mayWriteToMemory() &&
writesToMemoryReadBy(*OrigAA, TLI,
writesToMemoryReadBy(&TR, *OrigAA, TLI,
/*maybeReader*/ origInst,
/*maybeWriter*/ I)) {
failed = true;
Expand Down Expand Up @@ -7207,7 +7207,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
allInstructionsBetween(
*OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool {
if (I->mayWriteToMemory() &&
writesToMemoryReadBy(*OrigAA, TLI,
writesToMemoryReadBy(&TR, *OrigAA, TLI,
/*maybeReader*/ origInst,
/*maybeWriter*/ I)) {
failed = true;
Expand Down Expand Up @@ -8916,8 +8916,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) {
SmallVector<Instruction *, 2> results;
mayExecuteAfter(results, LI, storingOps, outer);
for (auto res : results) {
if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI,
res, outer)) {
if (overwritesToMemoryReadBy(&TR, *OrigAA, TLI, SE, *OrigLI, *OrigDT,
LI, res, outer)) {
EmitWarning("NotPromotable", *LI,
" Could not promote shadow allocation ", *V,
" due to pointer load ", *LI,
Expand Down Expand Up @@ -8971,8 +8971,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) {
SmallVector<Instruction *, 2> results;
mayExecuteAfter(results, LI, storingOps, outer);
for (auto res : results) {
if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI, res,
outer)) {
if (overwritesToMemoryReadBy(&TR, *OrigAA, TLI, SE, *OrigLI, *OrigDT, LI,
res, outer)) {
EmitWarning("NotPromotable", *LI, " Could not promote allocation ", *V,
" due to load ", *LI,
" which does not postdominates store ", *res);
Expand All @@ -8987,7 +8987,7 @@ void GradientUtils::computeForwardingProperties(Instruction *V) {
SmallVector<Instruction *, 2> results;
mayExecuteAfter(results, LI.loadCall, storingOps, outer);
for (auto res : results) {
if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT,
if (overwritesToMemoryReadBy(&TR, *OrigAA, TLI, SE, *OrigLI, *OrigDT,
LI.loadCall, res, outer)) {
EmitWarning("NotPromotable", *LI.loadCall,
" Could not promote allocation ", *V,
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/JLInstSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ bool jlInstSimplify(llvm::Function &F, TargetLibraryInfo &TLI,
return /*earlyBreak*/ false;

for (auto LI : {llhs, lrhs})
if (writesToMemoryReadBy(AA, TLI,
if (writesToMemoryReadBy(nullptr, AA, TLI,
/*maybeReader*/ LI,
/*maybeWriter*/ I)) {
overwritten = true;
Expand Down
52 changes: 47 additions & 5 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,27 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal,
call->setDebugLoc(loc);
}

Type *BlasInfo::fpType(LLVMContext &ctx) const {
if (floatType == "d" || floatType == "D") {
return Type::getDoubleTy(ctx);
} else if (floatType == "s" || floatType == "S") {
return Type::getFloatTy(ctx);
} else if (floatType == "c" || floatType == "C") {
return VectorType::get(Type::getFloatTy(ctx), 2, false);
} else if (floatType == "z" || floatType == "Z") {
return VectorType::get(Type::getDoubleTy(ctx), 2, false);
} else {
assert(false && "Unreachable");
}
}

IntegerType *BlasInfo::intType(LLVMContext &ctx) const {
if (is64)
return IntegerType::get(ctx, 64);
else
return IntegerType::get(ctx, 32);
}

/// Create function for type that is equivalent to memcpy but adds to
/// destination rather than a direct copy; dst, src, numelems
Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType,
Expand Down Expand Up @@ -2149,14 +2170,14 @@ bool overwritesToMemoryReadByLoop(
return true;
}

bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
ScalarEvolution &SE, llvm::LoopInfo &LI,
llvm::DominatorTree &DT,
bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
llvm::TargetLibraryInfo &TLI, ScalarEvolution &SE,
llvm::LoopInfo &LI, llvm::DominatorTree &DT,
llvm::Instruction *maybeReader,
llvm::Instruction *maybeWriter,
llvm::Loop *scope) {
using namespace llvm;
if (!writesToMemoryReadBy(AA, TLI, maybeReader, maybeWriter))
if (!writesToMemoryReadBy(TR, AA, TLI, maybeReader, maybeWriter))
return false;
const SCEV *LoadBegin = SE.getCouldNotCompute();
const SCEV *LoadEnd = SE.getCouldNotCompute();
Expand Down Expand Up @@ -2251,7 +2272,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}

/// Return whether maybeReader can read from memory written to by maybeWriter
bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
llvm::TargetLibraryInfo &TLI,
llvm::Instruction *maybeReader,
llvm::Instruction *maybeWriter) {
assert(maybeReader->getParent()->getParent() ==
Expand Down Expand Up @@ -2466,6 +2488,26 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
assert(maybeReader->mayReadFromMemory());

if (auto li = dyn_cast<LoadInst>(maybeReader)) {
if (TR) {
auto TT = TR->query(li)[{-1}];
if (TT != BaseType::Unknown && TT != BaseType::Anything) {
if (auto si = dyn_cast<StoreInst>(maybeWriter)) {
auto TT2 = TR->query(si->getValueOperand())[{-1}];
if (TT2 != BaseType::Unknown && TT2 != BaseType::Anything) {
if (TT != TT2)
return false;
}
auto &dl = li->getParent()->getParent()->getParent()->getDataLayout();
auto len =
(dl.getTypeSizeInBits(si->getValueOperand()->getType()) + 7) / 8;
TT2 = TR->query(si->getPointerOperand()).Lookup(len, dl)[{-1}];
if (TT2 != BaseType::Unknown && TT2 != BaseType::Anything) {
if (TT != TT2)
return false;
}
}
}
}
return isModSet(AA.getModRefInfo(maybeWriter, MemoryLocation::get(li)));
}
if (auto rmw = dyn_cast<AtomicRMWInst>(maybeReader)) {
Expand Down
11 changes: 9 additions & 2 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@

#include "TypeAnalysis/ConcreteType.h"

class TypeResults;

namespace llvm {
class ScalarEvolution;
}
Expand Down Expand Up @@ -674,6 +676,9 @@ struct BlasInfo {
std::string suffix;
std::string function;
bool is64;

llvm::Type *fpType(llvm::LLVMContext &ctx) const;
llvm::IntegerType *intType(llvm::LLVMContext &ctx) const;
};

#if LLVM_VERSION_MAJOR >= 16
Expand Down Expand Up @@ -968,7 +973,8 @@ void mayExecuteAfter(llvm::SmallVectorImpl<llvm::Instruction *> &results,
const llvm::Loop *region);

/// Return whether maybeReader can read from memory written to by maybeWriter
bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
bool writesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
llvm::TargetLibraryInfo &TLI,
llvm::Instruction *maybeReader,
llvm::Instruction *maybeWriter);

Expand All @@ -983,7 +989,8 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
// load A[i-1]
// store A[i] = ...
// }
bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
bool overwritesToMemoryReadBy(const TypeResults *TR, llvm::AAResults &AA,
llvm::TargetLibraryInfo &TLI,
llvm::ScalarEvolution &SE, llvm::LoopInfo &LI,
llvm::DominatorTree &DT,
llvm::Instruction *maybeReader,
Expand Down
Loading

0 comments on commit fb0fa7b

Please sign in to comment.