Skip to content

Commit

Permalink
Fix Build Error (#23299)
Browse files Browse the repository at this point in the history
Fix build error.
  • Loading branch information
centwang authored Jan 9, 2025
1 parent 4134cd9 commit 3b1a900
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4057,6 +4057,8 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
}

TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
const auto& logger = DefaultLoggingManager().DefaultLogger();

// Relu is redundant.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
Expand Down Expand Up @@ -4097,7 +4099,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
domain_to_version[kOnnxDomain] = 18;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
Expand Down Expand Up @@ -4125,7 +4127,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
// Check if SelectorManager get a conv qdq group selection as expected
{
QDQ::SelectorManager selector_mgr;
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer);
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger);
ASSERT_FALSE(result.empty());
const auto& qdq_group = result.at(0);
ASSERT_EQ(std::vector<NodeIndex>({0, 1, 2}), qdq_group.dq_nodes);
Expand All @@ -4141,7 +4143,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;

std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger);

// We should get a single QDQ Node unit in the result
ASSERT_EQ(1, node_unit_holder.size());
Expand Down Expand Up @@ -4203,7 +4205,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
domain_to_version[kOnnxDomain] = 18;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
Expand All @@ -4226,7 +4228,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
// Check if SelectorManager get a conv qdq group selection as expected
{
QDQ::SelectorManager selector_mgr;
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer);
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger);
ASSERT_TRUE(result.empty());
}

Expand All @@ -4235,7 +4237,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
{
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger);
ASSERT_EQ(6, node_unit_holder.size());
ASSERT_EQ(6, node_unit_map.size());
}
Expand All @@ -4244,6 +4246,8 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Conv_Relu) {
}

TEST(QDQTransformerTests, QDQ_Selector_Test_Add_Clip) {
const auto& logger = DefaultLoggingManager().DefaultLogger();

// Clip is redundant.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
Expand Down Expand Up @@ -4278,7 +4282,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Add_Clip) {
domain_to_version[kOnnxDomain] = 18;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
Expand All @@ -4305,7 +4309,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Add_Clip) {
// Check if SelectorManager get a add qdq group selection as expected
{
QDQ::SelectorManager selector_mgr;
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer);
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger);
ASSERT_FALSE(result.empty());
const auto& qdq_group = result.at(0);
ASSERT_EQ(std::vector<NodeIndex>({0, 1}), qdq_group.dq_nodes);
Expand All @@ -4321,7 +4325,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Add_Clip) {
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;

std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger);

// We should get a single QDQ Node unit in the result
ASSERT_EQ(1, node_unit_holder.size());
Expand Down Expand Up @@ -4375,7 +4379,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Add_Clip) {
domain_to_version[kOnnxDomain] = 18;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
Expand All @@ -4395,7 +4399,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Add_Clip) {

{
QDQ::SelectorManager selector_mgr;
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer);
const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger);
ASSERT_TRUE(result.empty());
}

Expand All @@ -4405,7 +4409,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test_Add_Clip) {
// Get all the NodeUnits in the graph_viewer
std::vector<std::unique_ptr<NodeUnit>> node_unit_holder;
std::unordered_map<const Node*, const NodeUnit*> node_unit_map;
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer);
std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger);
ASSERT_EQ(5, node_unit_holder.size());
ASSERT_EQ(5, node_unit_map.size());
}
Expand Down

0 comments on commit 3b1a900

Please sign in to comment.