Skip to content

Commit

Permalink
Add Gradient for Atan (#23172)
Browse files Browse the repository at this point in the history
  • Loading branch information
cocotdf authored Jan 9, 2025
1 parent d0c7438 commit 16a246d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 0 deletions.
12 changes: 12 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2227,5 +2227,17 @@ IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) {
SrcNodeAttributes())};
}

IMPLEMENT_GRADIENT_BUILDER(GetAtanGradient) {
// dl/dx = dl/dy * (1/(1+x^2))
NodeDef one_const_node = OneConstantNode(IElemType(0));
ArgDef one = one_const_node.output_args[0];
std::vector<NodeDef> result;
result.push_back(one_const_node);
result.push_back(NodeDef("Mul", {I(0), I(0)}, {IA("Square_I0")}));
result.push_back(NodeDef("Add", {IA("Square_I0"), one}, {IA("One_Plus_Square_I0")}));
result.push_back(NodeDef("Div", {GO(0), IA("One_Plus_Square_I0")}, {GI(0)}));
return result;
}

} // namespace training
} // namespace onnxruntime
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ DECLARE_GRADIENT_BUILDER(GetReciprocalGradient)
DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient)
DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient)
DECLARE_GRADIENT_BUILDER(GetResizeGradient)
DECLARE_GRADIENT_BUILDER(GetAtanGradient)

DECLARE_GRADIENT_BUILDER(GetExternalGradient)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient);
REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient);
REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient);
REGISTER_GRADIENT_BUILDER("Atan", GetAtanGradient);

REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient);
};
Expand Down
2 changes: 2 additions & 0 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3352,6 +3352,8 @@ TEST(GradientCheckerTest, ResizeGrad) {

#endif // USE_CUDA

TEST(GradientCheckerTest, AtanGrad) { UnaryOpGradientTest("Atan"); }

} // namespace test
} // namespace onnxruntime

Expand Down

0 comments on commit 16a246d

Please sign in to comment.