Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Check for element-wise equality at construction and simplify zero/one…
Browse files Browse the repository at this point in the history
… checks later (#3785)
  • Loading branch information
rkimballn1 authored and diyessi committed Oct 17, 2019
1 parent 7f5ad24 commit 938c2a6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
13 changes: 2 additions & 11 deletions src/ngraph/graph_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,8 @@ bool ngraph::is_equal_to_const_value(std::string const_value, const Output<Node>
{
if (auto rc = dynamic_pointer_cast<ngraph::op::Constant>(reduce_constant.get_node_shared_ptr()))
{
auto cshape = rc->get_shape();
size_t n = shape_size(cshape);
// way to construct a constant of a given type, shape, value
std::vector<std::string> vector_zero{n, const_value};
auto constant_val_op =
std::make_shared<ngraph::op::Constant>(rc->get_element_type(), cshape, vector_zero);

// way to compare elements to const_value
size_t n_bytes = n * rc->get_element_type().size();
NGRAPH_DEBUG << "Comparing " << n_bytes << " bytes";
return !memcmp(constant_val_op->get_data_ptr(), rc->get_data_ptr(), n_bytes);
return (rc->get_all_data_elements_bitwise_identical() &&
rc->convert_value_to_string(0) == const_value);
}
else
{
Expand Down
8 changes: 8 additions & 0 deletions src/ngraph/op/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace ngraph
write_values(values);
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}

/// \brief Constructs a tensor constant
Expand Down Expand Up @@ -128,6 +129,7 @@ namespace ngraph
}
}
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}

/// \brief Constructs a tensor constant with the same initialization value copied across
Expand All @@ -146,6 +148,7 @@ namespace ngraph
host_alignment()));
std::memcpy(m_data->get_ptr(), data, size);
constructor_validate_and_infer_types();
m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
}

virtual ~Constant() override;
Expand Down Expand Up @@ -246,6 +249,10 @@ namespace ngraph

bool is_constant() const override { return true; }
bool are_all_data_elements_bitwise_identical() const;
bool get_all_data_elements_bitwise_identical() const
{
return m_all_elements_bitwise_identical;
}
std::string convert_value_to_string(size_t index) const;

protected:
Expand Down Expand Up @@ -343,6 +350,7 @@ namespace ngraph
element::Type m_element_type;
Shape m_shape{};
std::unique_ptr<runtime::AlignedBuffer> m_data;
bool m_all_elements_bitwise_identical;
Constant(const Constant&) = delete;
Constant operator=(const Constant&) = delete;
};
Expand Down

0 comments on commit 938c2a6

Please sign in to comment.