Skip to content

Commit

Permalink
Add derivative for LLVM:ExpOp (#2220)
Browse files Browse the repository at this point in the history
Co-authored-by: BuildKite <[email protected]>
  • Loading branch information
tyb0807 and BuildKite authored Jan 8, 2025
1 parent 96b8efc commit ac3be7e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 0 deletions.
4 changes: 4 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def TypeOf : Operation</*primal*/0, /*shadow*/0> {

class ComplexInst<string m> : Inst<m, "complex">;
class ArithInst<string m> : Inst<m, "arith">;
class LlvmInst<string m> : Inst<m, "LLVM">;
class MathInst<string m> : Inst<m, "math">;

def AddF : ArithInst<"AddFOp">;
Expand All @@ -133,6 +134,9 @@ def RemF : ArithInst<"RemFOp">;
def CheckedMulF : ArithInst<"MulFOp">;
def CheckedDivF : ArithInst<"DivFOp">;

def LlvmCheckedMulF : LlvmInst<"FMulOp">;
def LlvmExpF : LlvmInst<"ExpOp">;

def CosF : MathInst<"CosOp">;
def SinF : MathInst<"SinOp">;
def ExpF : MathInst<"ExpOp">;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@ def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>;
def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>;

def : AllocationOp<"LLVM", "AllocaOp">;

def : MLIRDerivative<"LLVM", "ExpOp", (Op $x),
[
(LlvmCheckedMulF (DiffeRet), (LlvmExpF $x))
]
>;
17 changes: 17 additions & 0 deletions enzyme/test/MLIR/ForwardMode/llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ module {
%r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}

func.func @exp(%x: f32) -> f32 {
%0 = llvm.intr.exp(%x) : (f32) -> f32
return %0 : f32
}

func.func @dexp(%x: f32, %dx: f32) -> f32 {
%r = enzyme.fwddiff @exp(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f32, f32) -> f32
return %r : f32
}
}

// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 {
Expand All @@ -29,3 +39,10 @@ module {
// CHECK-NEXT: %[[i7:.+]] = llvm.load %[[i1]] : !llvm.ptr -> f64
// CHECK-NEXT: return %[[i6]] : f64
// CHECK-NEXT: }

// CHECK: func.func private @fwddiffeexp(%[[arg0:.+]]: f32, %[[arg1:.+]]: f32) -> f32 {
// CHECK-NEXT: %[[der:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32
// CHECK-NEXT: %[[res:.+]] = llvm.fmul %[[arg1]], %[[der]] : f32
// CHECK-NEXT: %[[exp:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32
// CHECK-NEXT: return %[[res]] : f32
// CHECK-NEXT: }

0 comments on commit ac3be7e

Please sign in to comment.