From e495561320276ac6382d3cb0867c701f8ecf1ab5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ewa=20Tusie=C5=84?= Date: Mon, 22 Jul 2019 17:14:52 +0200 Subject: [PATCH] [Py] Added elu operator to Python API. (#3236) * Added elu operator to Python API. * Added missing file. * Specified elu function description. * Expand docstring * [Py] Added test with scalar for elu operator. * Bugfix * [Py] Changed input type in elu test. * Update test_ops_binary.py * [Py] Syntax bugfix. * [Py] Added elu operator to list in documentation. --- .../python_api/_autosummary/ngraph.ops.rst | 1 + python/ngraph/__init__.py | 1 + python/ngraph/impl/op/__init__.py | 1 + python/ngraph/ops.py | 22 +++++- python/pyngraph/ops/elu.cpp | 30 ++++++++ python/pyngraph/ops/elu.hpp | 23 +++++++ python/pyngraph/ops/fused/elu.cpp | 30 ++++++++ python/pyngraph/ops/regmodule_pyngraph_op.cpp | 1 + python/pyngraph/ops/regmodule_pyngraph_op.hpp | 1 + python/setup.py | 1 + python/test/ngraph/test_ops_fused.py | 69 +++++++++++++++++++ src/ngraph/op/fused/elu.cpp | 2 +- 12 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 python/pyngraph/ops/elu.cpp create mode 100644 python/pyngraph/ops/elu.hpp create mode 100644 python/pyngraph/ops/fused/elu.cpp create mode 100644 python/test/ngraph/test_ops_fused.py diff --git a/doc/sphinx/source/python_api/_autosummary/ngraph.ops.rst b/doc/sphinx/source/python_api/_autosummary/ngraph.ops.rst index f4f81a8ed98..8fccd40baa5 100644 --- a/doc/sphinx/source/python_api/_autosummary/ngraph.ops.rst +++ b/doc/sphinx/source/python_api/_autosummary/ngraph.ops.rst @@ -32,6 +32,7 @@ ngraph.ops cosh divide dot + elu equal exp floor diff --git a/python/ngraph/__init__.py b/python/ngraph/__init__.py index 9d696250861..71b41c6e427 100644 --- a/python/ngraph/__init__.py +++ b/python/ngraph/__init__.py @@ -45,6 +45,7 @@ from ngraph.ops import cosh from ngraph.ops import divide from ngraph.ops import dot +from ngraph.ops import elu from ngraph.ops import equal from ngraph.ops import exp from ngraph.ops import floor diff --git a/python/ngraph/impl/op/__init__.py b/python/ngraph/impl/op/__init__.py index 762a0dc1b6e..17a6f69f492 100644 --- a/python/ngraph/impl/op/__init__.py +++ b/python/ngraph/impl/op/__init__.py @@ -69,6 +69,7 @@ from _pyngraph.op import Cosh from _pyngraph.op import Divide from _pyngraph.op import Dot +from _pyngraph.op import Elu from _pyngraph.op import Equal from _pyngraph.op import Exp from _pyngraph.op import Floor diff --git a/python/ngraph/ops.py b/python/ngraph/ops.py index 1a53709b026..e543dbe36a2 100644 --- a/python/ngraph/ops.py +++ b/python/ngraph/ops.py @@ -22,7 +22,7 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \ - Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \ + Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \ GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \ @@ -35,7 +35,7 @@ from ngraph.utils.input_validation import assert_list_of_ints from ngraph.utils.reduction import get_reduction_axes from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \ - NodeInput, ScalarData + NodeInput, ScalarData, as_node from ngraph.utils.types import get_element_type @@ -60,6 +60,24 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType, return make_constant_node(value, dtype) +@nameable_op +def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node + """Perform Exponential Linear Unit operation element-wise on data from input node. + + Computes exponential linear: alpha * (exp(data) - 1) if < 0, data otherwise. + + For more information refer to: + `Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) + `_ + + :param data: Input tensor. One of: input node, array or scalar. + :param alpha: Multiplier for negative values. One of: input node or scalar value. + :param name: Optional output node name. + :return: The new node performing an ELU operation on its input data element-wise. + """ + return Elu(as_node(data), as_node(alpha)) + + # Unary ops @unary_op def absolute(node, name=None): # type: (NodeInput, str) -> Node diff --git a/python/pyngraph/ops/elu.cpp b/python/pyngraph/ops/elu.cpp new file mode 100644 index 00000000000..f5fe11c40bc --- /dev/null +++ b/python/pyngraph/ops/elu.cpp @@ -0,0 +1,30 @@ +//***************************************************************************** +// Copyright 2017-2019 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "ngraph/op/fused/elu.hpp" +#include "pyngraph/ops/elu.hpp" + +namespace py = pybind11; + +void regclass_pyngraph_op_Elu(py::module m) +{ + py::class_, ngraph::op::Op> elu(m, "Elu"); + elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu"; + elu.def(py::init&, const std::shared_ptr&>()); +} diff --git a/python/pyngraph/ops/elu.hpp b/python/pyngraph/ops/elu.hpp new file mode 100644 index 00000000000..2e93b81d0ae --- /dev/null +++ b/python/pyngraph/ops/elu.hpp @@ -0,0 +1,23 @@ +//***************************************************************************** +// Copyright 2017-2019 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_pyngraph_op_Elu(py::module m); diff --git a/python/pyngraph/ops/fused/elu.cpp b/python/pyngraph/ops/fused/elu.cpp new file mode 100644 index 00000000000..c8bf21b5f8d --- /dev/null +++ b/python/pyngraph/ops/fused/elu.cpp @@ -0,0 +1,30 @@ +//***************************************************************************** +// Copyright 2017-2019 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "ngraph/op/fused/elu.hpp" +#include "pyngraph/ops/fused/elu.hpp" + +namespace py = pybind11; + +void regclass_pyngraph_op_Elu(py::module m) +{ + py::class_, ngraph::op::Op> elu(m, "Elu"); + elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu"; + elu.def(py::init&, const std::shared_ptr&>()); +} diff --git a/python/pyngraph/ops/regmodule_pyngraph_op.cpp b/python/pyngraph/ops/regmodule_pyngraph_op.cpp index 346bc415dae..2ba9fb1ac2c 100644 --- a/python/pyngraph/ops/regmodule_pyngraph_op.cpp +++ b/python/pyngraph/ops/regmodule_pyngraph_op.cpp @@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op) regclass_pyngraph_op_Cosh(m_op); regclass_pyngraph_op_Divide(m_op); regclass_pyngraph_op_Dot(m_op); + regclass_pyngraph_op_Elu(m_op); regclass_pyngraph_op_Equal(m_op); regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Floor(m_op); diff --git a/python/pyngraph/ops/regmodule_pyngraph_op.hpp b/python/pyngraph/ops/regmodule_pyngraph_op.hpp index 1b22ed762e3..8ec77426d5c 100644 --- a/python/pyngraph/ops/regmodule_pyngraph_op.hpp +++ b/python/pyngraph/ops/regmodule_pyngraph_op.hpp @@ -39,6 +39,7 @@ #include "pyngraph/ops/cosh.hpp" #include "pyngraph/ops/divide.hpp" #include "pyngraph/ops/dot.hpp" +#include "pyngraph/ops/elu.hpp" #include "pyngraph/ops/equal.hpp" #include "pyngraph/ops/exp.hpp" #include "pyngraph/ops/floor.hpp" diff --git a/python/setup.py b/python/setup.py index da76f218487..32dec747374 100644 --- a/python/setup.py +++ b/python/setup.py @@ -179,6 +179,7 @@ def cpp_flag(compiler): 'pyngraph/ops/ceiling.cpp', 'pyngraph/ops/divide.cpp', 'pyngraph/ops/dot.cpp', + 'pyngraph/ops/elu.cpp', 'pyngraph/ops/equal.cpp', 'pyngraph/ops/exp.cpp', 'pyngraph/ops/floor.cpp', diff --git a/python/test/ngraph/test_ops_fused.py b/python/test/ngraph/test_ops_fused.py new file mode 100644 index 00000000000..ca0d1b3c4c9 --- /dev/null +++ b/python/test/ngraph/test_ops_fused.py @@ -0,0 +1,69 @@ +# ****************************************************************************** +# Copyright 2017-2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ****************************************************************************** +import numpy as np + +import ngraph as ng +from test.ngraph.util import get_runtime + + +def test_elu_operator(): + runtime = get_runtime() + + data_shape = [2, 2] + alpha_shape = [2] + parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32) + parameter_alpha = ng.parameter(alpha_shape, name='Alpha', dtype=np.float32) + + model = ng.elu(parameter_data, parameter_alpha) + computation = runtime.computation(model, parameter_data, parameter_alpha) + + value_data = np.array([[-5, 1], [-2, 3]], dtype=np.float32) + value_alpha = np.array([3, 3], dtype=np.float32) + + result = computation(value_data, value_alpha) + expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32) + assert np.allclose(result, expected) + + +def test_elu_operator_with_scalar_and_array(): + runtime = get_runtime() + + data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32) + alpha_value = np.float32(3) + + model = ng.elu(data_value, alpha_value) + computation = runtime.computation(model) + + result = computation() + expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32) + assert np.allclose(result, expected) + + +def test_elu_operator_with_scalar(): + runtime = get_runtime() + + data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32) + alpha_value = np.float32(3) + + data_shape = [2, 2] + parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32) + + model = ng.elu(parameter_data, alpha_value) + computation = runtime.computation(model, parameter_data) + + result = computation(data_value) + expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32) + assert np.allclose(result, expected) diff --git a/src/ngraph/op/fused/elu.cpp b/src/ngraph/op/fused/elu.cpp index 1629d2d5628..380d320f034 100644 --- a/src/ngraph/op/fused/elu.cpp +++ b/src/ngraph/op/fused/elu.cpp @@ -39,7 +39,7 @@ NodeVector op::Elu::decompose_op() const auto data = get_argument(0); auto alpha_node = get_argument(1); - alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape()); + alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data->get_shape()); shared_ptr zero_node = builder::make_constant(data->get_element_type(), data->get_shape(), 0);