-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update Arm Compute Library Execution Provider (#22032)
### Description This PR makes the following updates to the Arm Compute Library execution provider: - Target Arm Compute Library 24.07 - Add support for the following operators: - Conv (FP16) - NhwcConv - QLinearConv - MatMul - FusedMatMul - MatMulIntegerToFloat - Optimize memory usage and performance - Expose the enable_fast_math setting - Use the main runtime thread pool ### Motivation and Context These updates improve performance and memory usage, and enable use of a more recent version of Arm Compute Library. @microsoft-github-policy-service agree company="Arm Ltd" --------- Signed-off-by: Michael Tyler <[email protected]>
- Loading branch information
1 parent
22437b5
commit 904b850
Showing
34 changed files
with
1,396 additions
and
426 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
# Licensed under the MIT License. | ||
|
||
# Minimum CMake required | ||
|
@@ -132,11 +133,6 @@ option(onnxruntime_USE_DML "Build with DirectML support" OFF) | |
option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF) | ||
option(onnxruntime_USE_WINML "Build with WinML support" OFF) | ||
option(onnxruntime_USE_ACL "Build with ACL support" OFF) | ||
option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) | ||
option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) | ||
option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) | ||
option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF) | ||
option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF) | ||
option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF) | ||
option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON) | ||
option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON) | ||
|
@@ -1207,44 +1203,22 @@ function(onnxruntime_add_include_to_target dst_target) | |
endfunction() | ||
|
||
# ACL | ||
if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308) | ||
if (onnxruntime_USE_ACL) | ||
set(onnxruntime_USE_ACL ON) | ||
if (onnxruntime_USE_ACL_1902) | ||
add_definitions(-DACL_1902=1) | ||
else() | ||
if (onnxruntime_USE_ACL_1908) | ||
add_definitions(-DACL_1908=1) | ||
else() | ||
if (onnxruntime_USE_ACL_2002) | ||
add_definitions(-DACL_2002=1) | ||
else() | ||
if (onnxruntime_USE_ACL_2308) | ||
add_definitions(-DACL_2308=1) | ||
else() | ||
add_definitions(-DACL_1905=1) | ||
endif() | ||
endif() | ||
endif() | ||
endif() | ||
|
||
if (NOT ${onnxruntime_ACL_LIBS} STREQUAL "") | ||
add_library(arm_compute SHARED IMPORTED) | ||
set_target_properties(arm_compute PROPERTIES | ||
IMPORTED_NO_SONAME 1 | ||
IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") | ||
|
||
add_library(arm_compute_core SHARED IMPORTED) | ||
set_target_properties(arm_compute_core PROPERTIES | ||
IMPORTED_NO_SONAME 1 | ||
IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") | ||
|
||
add_library(arm_compute_graph SHARED IMPORTED) | ||
set_target_properties(arm_compute_graph PROPERTIES | ||
IMPORTED_NO_SONAME 1 | ||
IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_graph.so") | ||
endif() | ||
|
||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_core arm_compute_graph) | ||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_graph) | ||
|
||
endif() | ||
|
||
|
@@ -1263,11 +1237,6 @@ if (onnxruntime_USE_ARMNN) | |
IMPORTED_NO_SONAME 1 | ||
IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") | ||
|
||
add_library(arm_compute_core SHARED IMPORTED) | ||
set_target_properties(arm_compute_core PROPERTIES | ||
IMPORTED_NO_SONAME 1 | ||
IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") | ||
|
||
add_library(arm_compute_graph SHARED IMPORTED) | ||
set_target_properties(arm_compute_graph PROPERTIES | ||
IMPORTED_NO_SONAME 1 | ||
|
@@ -1281,7 +1250,7 @@ if (onnxruntime_USE_ARMNN) | |
IMPORTED_LOCATION "${onnxruntime_ARMNN_LIBS}/libarmnn.so") | ||
endif() | ||
|
||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_core arm_compute_graph) | ||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_graph) | ||
endif() | ||
|
||
if (onnxruntime_USE_DNNL) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
// Licensed under the MIT License. | ||
|
||
#include "onnxruntime_c_api.h" | ||
|
@@ -10,7 +11,8 @@ extern "C" { | |
/** | ||
* \param use_arena zero: false. non-zero: true. | ||
*/ | ||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, int use_arena) | ||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, | ||
bool enable_fast_math) | ||
ORT_ALL_ARGS_NONNULL; | ||
|
||
#ifdef __cplusplus | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
/* | ||
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. | ||
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
* Licensed under the MIT License. | ||
*/ | ||
package ai.onnxruntime; | ||
|
@@ -1181,12 +1182,12 @@ public void addDirectML(int deviceId) throws OrtException { | |
/** | ||
* Adds the ARM Compute Library as an execution backend. | ||
* | ||
* @param useArena If true use the arena memory allocator. | ||
* @param enableFastMath Enable fast math mode in ACL. | ||
* @throws OrtException If there was an error in native code. | ||
*/ | ||
public void addACL(boolean useArena) throws OrtException { | ||
public void addACL(boolean enableFastMath) throws OrtException { | ||
checkClosed(); | ||
addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); | ||
addACL(OnnxRuntime.ortApiHandle, nativeHandle, enableFastMath); | ||
} | ||
|
||
/** | ||
|
@@ -1354,7 +1355,8 @@ private native void addTvm(long apiHandle, long nativeHandle, String settings) | |
private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) | ||
throws OrtException; | ||
|
||
private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException; | ||
private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) | ||
throws OrtException; | ||
|
||
private native void addArmNN(long apiHandle, long nativeHandle, int useArena) | ||
throws OrtException; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
/* | ||
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. | ||
* SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
* Licensed under the MIT License. | ||
*/ | ||
#include <jni.h> | ||
|
@@ -644,12 +645,13 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDir | |
* Signature: (JJI)V | ||
*/ | ||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addACL | ||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint useArena) { | ||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jboolean enableFastMath) { | ||
(void)jobj; | ||
#ifdef USE_ACL | ||
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle,useArena)); | ||
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle, | ||
OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle, enableFastMath)); | ||
#else | ||
(void)apiHandle;(void)handle;(void)useArena; // Parameters used when ACL is defined. | ||
(void)apiHandle;(void)handle;(void)enableFastMath; // Parameters used when ACL is defined. | ||
throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with ACL support."); | ||
#endif | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
// Licensed under the MIT License. | ||
|
||
#include "core/optimizer/graph_transformer_utils.h" | ||
|
@@ -196,6 +197,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers( | |
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; | ||
#ifndef DISABLE_CONTRIB_OPS | ||
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider}; | ||
const InlinedHashSet<std::string_view> cpu_acl_eps = {onnxruntime::kCpuExecutionProvider, | ||
onnxruntime::kAclExecutionProvider}; | ||
#endif | ||
const InlinedHashSet<std::string_view> dml_ep = {onnxruntime::kDmlExecutionProvider}; | ||
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>(); | ||
|
@@ -285,6 +288,11 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers( | |
onnxruntime::kCudaExecutionProvider, | ||
onnxruntime::kRocmExecutionProvider, | ||
onnxruntime::kDmlExecutionProvider}; | ||
const InlinedHashSet<std::string_view> cpu_acl_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, | ||
onnxruntime::kAclExecutionProvider, | ||
onnxruntime::kCudaExecutionProvider, | ||
onnxruntime::kRocmExecutionProvider, | ||
onnxruntime::kDmlExecutionProvider}; | ||
const InlinedHashSet<std::string_view> cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, | ||
onnxruntime::kRocmExecutionProvider, | ||
onnxruntime::kAclExecutionProvider, | ||
|
@@ -296,8 +304,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers( | |
onnxruntime::kAclExecutionProvider, | ||
onnxruntime::kArmNNExecutionProvider, | ||
onnxruntime::kJsExecutionProvider}; | ||
const InlinedHashSet<std::string_view> cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, | ||
onnxruntime::kDmlExecutionProvider}; | ||
const InlinedHashSet<std::string_view> cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider, | ||
onnxruntime::kDmlExecutionProvider, | ||
onnxruntime::kAclExecutionProvider}; | ||
const int64_t qdq_matmulnbits_accuracy_level = | ||
ParseStringWithClassicLocale<int64_t>( | ||
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, | ||
|
@@ -323,26 +332,26 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers( | |
} | ||
|
||
transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep)); | ||
transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_dml_eps)); | ||
transformers.emplace_back(std::make_unique<DynamicQuantizeMatMulFusion>(cpu_ep)); | ||
transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_dml_acl_eps)); | ||
transformers.emplace_back(std::make_unique<DynamicQuantizeMatMulFusion>(cpu_acl_eps)); | ||
|
||
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_rocm_acl_armnn_js_eps)); | ||
|
||
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps, level)); | ||
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps, level)); | ||
transformers.emplace_back(std::make_unique<GeluFusion>(cpu_acl_cuda_dml_rocm_eps, level)); | ||
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_acl_cuda_dml_rocm_eps, level)); | ||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_acl_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_acl_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<GatherSliceToSplitFusion>(cpu_cuda_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps)); | ||
|
||
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_acl_cuda_dml_rocm_eps)); | ||
|
||
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_acl_cuda_dml_rocm_eps)); | ||
|
||
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_acl_cuda_dml_rocm_eps)); | ||
|
||
// GeluApproximation has side effects which may change results. It needs to be manually enabled, | ||
// or alternatively the model can be updated offline using a model conversion script | ||
|
@@ -367,7 +376,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers( | |
transformers.emplace_back(std::make_unique<SceLossGradBiasFusion>(cpu_cuda_rocm_eps)); | ||
#endif | ||
|
||
transformers.emplace_back(std::make_unique<MatMulScaleFusion>(cpu_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<MatMulScaleFusion>(cpu_acl_cuda_dml_rocm_eps)); | ||
transformers.emplace_back(std::make_unique<MatMulActivationFusion>(dml_ep)); | ||
|
||
#ifdef MLAS_TARGET_AMD64_IX86 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
// Licensed under the MIT License. | ||
|
||
#include <deque> | ||
|
@@ -183,7 +184,8 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, | |
modified = false; | ||
for (std::unique_ptr<api::NodeRef>& node : api_graph->Nodes()) { | ||
// If the node is not supported in the CPU EP, skip it | ||
if (node->GetExecutionProviderType() != kCpuExecutionProvider) { | ||
const auto ep = node->GetExecutionProviderType(); | ||
if ((ep != kCpuExecutionProvider) && (ep != kAclExecutionProvider)) { | ||
continue; | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <[email protected]> | ||
// Licensed under the MIT License. | ||
|
||
#include <memory> | ||
|
@@ -381,9 +382,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( | |
CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, | ||
intra_op_thread_pool, p_buffered_tensors), | ||
apply_context, | ||
// this transformer is compatible with CPU, DML and CUDA EP. | ||
// this transformer is compatible with CPU, DML, ACL and CUDA EP. | ||
// There is further EP control on the rule level. | ||
{kCpuExecutionProvider, kDmlExecutionProvider, kCudaExecutionProvider}} { | ||
{kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider}} { | ||
} | ||
|
||
} // namespace onnxruntime |
Oops, something went wrong.