diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 8ba2c6170b96c..57ae8c354abb7 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -87,7 +87,6 @@ Status CreateNodeArgs(const std::vector& names, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, QnnModelLookupTable& qnn_models) { ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); NodeAttrHelper node_helper(main_context_node); @@ -97,7 +96,6 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), main_context_node.Name(), - logger, qnn_models); } @@ -147,7 +145,6 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), main_context_node.Name(), - logger, qnn_models); } @@ -158,7 +155,7 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) { ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, - logger, qnn_models); + qnn_models); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 4ff7618b486e2..f308a7456d46c 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -49,7 +49,6 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - const logging::Logger& logger, QnnModelLookupTable& qnn_models); Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index dde70fdcbdaa6..db5c2c5cb32ba 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -608,7 +608,6 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - const logging::Logger& logger, QnnModelLookupTable& qnn_models) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || @@ -665,12 +664,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t if (1 == graph_count) { // in case the EPContext node is generated from script // the graph name from the context binary may not match the EPContext node name - auto qnn_model = std::make_unique(logger, this); + auto qnn_model = std::make_unique(this); ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[0], context)); qnn_models.emplace(node_name, std::move(qnn_model)); } else { for (uint32_t i = 0; i < graph_count; ++i) { - auto qnn_model = std::make_unique(logger, this); + auto qnn_model = std::make_unique(this); ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[i], context)); qnn_models.emplace(graphs_info[i].graphInfoV1.graphName, std::move(qnn_model)); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index d1a3b46a8fc55..b80f1374fcdc7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -91,7 +91,6 @@ class QnnBackendManager { Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - const logging::Logger& logger, std::unordered_map>& qnn_models); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index a09b1daa81726..f322456e0c8f0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -17,7 +17,7 @@ namespace onnxruntime { namespace qnn { -bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper) { +bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger) { bool rt = true; graph_info_ = std::make_unique(model_wrapper.GetQnnGraph(), @@ -25,7 +25,7 @@ bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper) { std::move(model_wrapper.GetGraphInputTensorWrappers()), std::move(model_wrapper.GetGraphOutputTensorWrappers())); if (graph_info_ == nullptr) { - LOGS(logger_, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo."; + LOGS(logger, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo."; return false; } @@ -33,16 +33,19 @@ bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper) { } Status QnnModel::SetGraphInputOutputInfo(const GraphViewer& graph_viewer, - const onnxruntime::Node& fused_node) { + const onnxruntime::Node& fused_node, + const logging::Logger& logger) { auto graph_initializers = graph_viewer.GetAllInitializedTensors(); for (auto graph_ini : graph_initializers) { initializer_inputs_.emplace(graph_ini.first); } auto input_defs = fused_node.InputDefs(); - ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, model_input_index_map_, true)); + ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(input_defs, input_names_, inputs_info_, + model_input_index_map_, logger, true)); auto output_defs = fused_node.OutputDefs(); - ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, model_output_index_map_)); + ORT_RETURN_IF_ERROR(ParseGraphInputOrOutput(output_defs, output_names_, outputs_info_, + model_output_index_map_, logger)); return Status::OK(); } @@ -51,6 +54,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainer& input_output_names, std::unordered_map& input_output_info_table, std::unordered_map& input_output_index_map, + const logging::Logger& logger, bool is_input) { for (size_t i = 0, end = input_output_defs.size(), index = 0; i < end; ++i) { const auto& name = input_output_defs[i]->Name(); @@ -60,7 +64,7 @@ Status QnnModel::ParseGraphInputOrOutput(ConstPointerContainerShape(); // consider use qnn_model_wrapper.GetOnnxShape ORT_RETURN_IF(shape_proto == nullptr, "shape_proto cannot be null for output: ", name); @@ -91,8 +95,9 @@ const NodeUnit& QnnModel::GetNodeUnit(const Node* node, Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const logging::Logger& logger, const QnnGraph_Config_t** graph_configs) { - LOGS(logger_, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name(); + LOGS(logger, VERBOSE) << "ComposeGraph Graph name: " << graph_viewer.Name(); // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is // valid throughout the lifetime of the ModelBuilder @@ -102,9 +107,9 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); - ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node)); + ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); - QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger_, + QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger, qnn_backend_manager_->GetQnnInterface(), qnn_backend_manager_->GetQnnBackendHandle(), model_input_index_map_, @@ -121,65 +126,65 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, qnn_node_groups.reserve(node_unit_holder.size()); ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map, - node_unit_holder.size(), logger_)); + node_unit_holder.size(), logger)); for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { - Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger_); + Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger); if (!status.IsOK()) { - LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " - << status.ErrorMessage() << std::endl; + LOGS(logger, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " + << status.ErrorMessage() << std::endl; return status; } } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); - rt = GetGraphInfoFromModel(qnn_model_wrapper); + rt = GetGraphInfoFromModel(qnn_model_wrapper, logger); if (!rt) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetGraphInfoFromModel failed."); } - LOGS(logger_, VERBOSE) << "GetGraphInfoFromModel completed."; + LOGS(logger, VERBOSE) << "GetGraphInfoFromModel completed."; return Status::OK(); } -Status QnnModel::FinalizeGraphs() { - LOGS(logger_, VERBOSE) << "FinalizeGraphs started."; +Status QnnModel::FinalizeGraphs(const logging::Logger& logger) { + LOGS(logger, VERBOSE) << "FinalizeGraphs started."; Qnn_ErrorHandle_t status = qnn_backend_manager_->GetQnnInterface().graphFinalize(graph_info_->Graph(), qnn_backend_manager_->GetQnnProfileHandle(), nullptr); if (QNN_GRAPH_NO_ERROR != status) { - LOGS(logger_, ERROR) << "Failed to finalize QNN graph. Error code: " << status; + LOGS(logger, ERROR) << "Failed to finalize QNN graph. Error code: " << status; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to finalize QNN graph."); } ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); - LOGS(logger_, VERBOSE) << "FinalizeGraphs completed."; + LOGS(logger, VERBOSE) << "FinalizeGraphs completed."; return Status::OK(); } -Status QnnModel::SetupQnnInputOutput() { - LOGS(logger_, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name(); +Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { + LOGS(logger, VERBOSE) << "Setting up QNN input/output for graph: " << graph_info_->Name(); auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors()); if (Status::OK() != result) { - LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); + LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!"); } result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false); if (Status::OK() != result) { - LOGS(logger_, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); + LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name(); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!"); } return Status::OK(); } -Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { - LOGS(logger_, VERBOSE) << "QnnModel::ExecuteGraphs"; +Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger) { + LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs"; const size_t num_inputs = context.GetInputCount(); const size_t num_outputs = context.GetOutputCount(); ORT_RETURN_IF_NOT(qnn_input_infos_.size() <= num_inputs, "Inconsistent input sizes"); @@ -198,12 +203,12 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { qnn_inputs.reserve(qnn_input_infos_.size()); for (const auto& qnn_input_info : qnn_input_infos_) { - LOGS(logger_, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() - << " index = " << qnn_input_info.ort_index; + LOGS(logger, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() + << " index = " << qnn_input_info.ort_index; auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); auto ort_tensor_size = TensorDataSize(ort_input_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size."); @@ -217,13 +222,13 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { for (auto& qnn_output_info : qnn_output_infos_) { const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); - LOGS(logger_, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; + LOGS(logger, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; const auto& ort_output_info = GetOutputInfo(model_output_name); const std::vector& output_shape = ort_output_info->shape_; auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); auto ort_tensor_size = TensorDataSize(ort_output_tensor); - LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size + << "Ort tensor size: " << ort_tensor_size; ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, "ORT Tensor data size does not match QNN tensor data size"); @@ -232,7 +237,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { const_cast(ort_output_tensor.GetTensorData()), qnn_output_info.tensor_byte_size); } - LOGS(logger_, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); + LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; @@ -257,7 +262,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { if (QNN_COMMON_ERROR_SYSTEM_COMMUNICATION == execute_status) { auto error_message = "NPU crashed. SSR detected. Caused QNN graph execute error. Error code: "; - LOGS(logger_, ERROR) << error_message << execute_status; + LOGS(logger, ERROR) << error_message << execute_status; return ORT_MAKE_STATUS(ONNXRUNTIME, ENGINE_ERROR, error_message, execute_status); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 1416d9ba92671..83cf8f9f08fb0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -25,10 +25,8 @@ struct QnnTensorInfo { class QnnModel { public: - QnnModel(const logging::Logger& logger, - QnnBackendManager* qnn_backend_manager) - : logger_(logger), - qnn_backend_manager_(qnn_backend_manager) { + QnnModel(QnnBackendManager* qnn_backend_manager) + : qnn_backend_manager_(qnn_backend_manager) { qnn_backend_type_ = qnn_backend_manager_->GetQnnBackendType(); } @@ -37,13 +35,14 @@ class QnnModel { Status ComposeGraph(const GraphViewer& graph_viewer, const onnxruntime::Node& fused_node, + const logging::Logger& logger, const QnnGraph_Config_t** graph_configs = nullptr); - Status FinalizeGraphs(); + Status FinalizeGraphs(const logging::Logger& logger); - Status SetupQnnInputOutput(); + Status SetupQnnInputOutput(const logging::Logger& logger); - Status ExecuteGraph(const Ort::KernelContext& context); + Status ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger); const OnnxTensorInfo* GetOutputInfo(const std::string& name) const { auto it = outputs_info_.find(name); @@ -55,11 +54,13 @@ class QnnModel { } Status SetGraphInputOutputInfo(const GraphViewer& graph_viewer, - const onnxruntime::Node& fused_node); + const onnxruntime::Node& fused_node, + const logging::Logger& logger); Status ParseGraphInputOrOutput(ConstPointerContainer>& input_output_defs, std::vector& input_output_names, std::unordered_map& input_output_info_table, std::unordered_map& input_output_index, + const logging::Logger& logger, bool is_input = false); const std::unordered_set& GetInitializerInputs() const { return initializer_inputs_; } @@ -107,7 +108,7 @@ class QnnModel { private: const NodeUnit& GetNodeUnit(const Node* node, const std::unordered_map& node_unit_map) const; - bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper); + bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger); Status GetQnnTensorDataLength(const std::vector& dims, Qnn_DataType_t data_type, @@ -125,7 +126,6 @@ class QnnModel { } private: - const logging::Logger& logger_; std::unique_ptr graph_info_; QnnBackendManager* qnn_backend_manager_ = nullptr; // , initializer inputs are excluded, keep the input index here diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index f2991df3b1b8e..698ceaea7c3b7 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -789,10 +789,10 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod ORT_UNUSED_PARAMETER(state); }; - compute_info.compute_func = [](FunctionState state, const OrtApi*, OrtKernelContext* context) { + compute_info.compute_func = [&logger](FunctionState state, const OrtApi*, OrtKernelContext* context) { Ort::KernelContext ctx(context); qnn::QnnModel* model = reinterpret_cast(state); - Status result = model->ExecuteGraph(ctx); + Status result = model->ExecuteGraph(ctx, logger); return result; }; @@ -843,16 +843,15 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector qnn_model = std::make_unique(logger, - qnn_backend_manager_.get()); + std::unique_ptr qnn_model = std::make_unique(qnn_backend_manager_.get()); qnn::QnnConfigsBuilder graph_configs_builder(QNN_GRAPH_CONFIG_INIT, QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); InitQnnGraphConfigs(graph_configs_builder); - ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnConfigs())); - ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs()); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, logger, graph_configs_builder.GetQnnConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs(logger)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name(); qnn_models_.emplace(fused_node.Name(), std::move(qnn_model)); @@ -894,8 +893,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::string key = ep_context_node->Name(); auto qnn_model_shared = SharedContext::GetInstance().GetSharedQnnModel(key); ORT_RETURN_IF(nullptr == qnn_model_shared, "Graph: " + key + " not found from shared EP contexts."); - ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput()); + ORT_RETURN_IF_ERROR(qnn_model_shared->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); + ORT_RETURN_IF_ERROR(qnn_model_shared->SetupQnnInputOutput(logger)); qnn_models_shared_.emplace(graph_meta_id, qnn_model_shared); use_shared_model_ = true; ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); @@ -929,8 +928,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::string key = ep_context_node->Name(); ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); auto qnn_model = std::move(qnn_models[key]); - ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput(logger)); // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] // the name here must be same with context->node_name in compute_info diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index d293a0d9c96c1..a3f0ed55b83f2 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -976,9 +976,14 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { UpdateEpContextModel(ctx_model_paths_to_update, last_qnn_ctx_binary_file_name, DefaultLoggingManager().DefaultLogger()); - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - so.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so1; + so1.SetLogId("so1"); + so1.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so1.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so2; + so2.SetLogId("so2"); + so2.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so2.AppendExecutionProvider("QNN", provider_options); EXPECT_TRUE(2 == ctx_model_paths.size()); #ifdef _WIN32 @@ -988,8 +993,8 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { std::string ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); std::string ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); #endif - Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so); - Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so); + Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so1); + Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so2); std::vector input_names; std::vector output_names;