Skip to content

Commit

Permalink
Merge pull request #30 from lovemefan/develop
Browse files Browse the repository at this point in the history
use backend registry
  • Loading branch information
lovemefan authored Nov 24, 2024
2 parents e9fbbd7 + 40d0d46 commit 9b7cb22
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 250 deletions.
15 changes: 7 additions & 8 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,16 @@ jobs:

strategy:
matrix:
build: [Release]
arch: [x64]
cublas: [ON]
sdl2: [ON]
cuda-toolkit: [12.2.0, 11.8.0]
build: [ Release ]
arch: [ x64 ]
cublas: [ ON ]
sdl2: [ ON ]
cuda-toolkit: [ 12.2.0, 11.8.0 ]
include:
- arch: x64
s2arc: x64
- sdl2: ON
s2ver: 2.28.5

steps:
- name: Clone
Expand Down Expand Up @@ -450,9 +452,6 @@ jobs:
-Include cudart64_*,cublas64_*,cublasLt64_*
-Destination build/bin/${{ matrix.build }}
- name: Copy SDL2.dll
if: matrix.sdl2 == 'ON'
run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }}
- name: Upload binaries
if: matrix.sdl2 == 'ON'
Expand Down
2 changes: 1 addition & 1 deletion README-EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This project is based on the [ggml](https://github.com/ggerganov/ggml) framework

1. Based on ggml, it does not rely on other third-party libraries and is committed to edge deployment.
2. Feature extraction references the [kaldi-native-fbank](https://github.com/csukuangfj/kaldi-native-fbank) library, supporting multi-threaded feature extraction.
3. Flash attention decoding can be used (The speed has not improved 🤔 weird,need help).
3. Support Flash attention decoding
4. Support Q3, Q4, Q5, Q6, Q8 quantization.

### 1.1 Future Plans
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

1. 基于ggml,不依赖其他第三方库, 致力于端侧部署
2. 特征提取参考[kaldi-native-fbank](https://github.com/csukuangfj/kaldi-native-fbank)库,支持多线程特征提取。
3. 可以使用flash attention解码(速度没有明显提升🤔不知道为啥)
3. 支持flash attention解码
4. 支持Q3, Q4, Q5, Q6, Q8量化

### 1.1 未来计划
Expand Down
7 changes: 1 addition & 6 deletions sense-voice/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
#include <string>
#include <map>
#include <set>
#include <ggml.h>
#include <ggml-alloc.h>

#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif
#include <ggml-cpu.h>

#include "sense-voice-frontend.h"

Expand Down
3 changes: 0 additions & 3 deletions sense-voice/csrc/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,12 @@ const char * sense_voice_print_system_info(void) {
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
s += "COREML = " + std::to_string(sense_voice_has_coreml()) + " | ";
s += "OPENVINO = " + std::to_string(sense_voice_has_openvino());

Expand Down
40 changes: 7 additions & 33 deletions sense-voice/csrc/sense-voice-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,6 @@
//

#include "sense-voice-decoder.h"
#include <ggml.h>
#include "ggml-alloc.h"
#include "ggml-backend.h"
#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif

#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif

#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

#ifdef GGML_USE_SYCL
#include "ggml-sycl.h"
#endif

#ifdef GGML_USE_BLAS
#include "ggml-blas.h"
#endif

#ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h"
#endif

#define SENSEVOICE_DECODER_MAX_NODES 8

Expand Down Expand Up @@ -111,16 +85,16 @@ static bool ggml_graph_compute_helper(

for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, n_threads);
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;

auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (fn_set_n_threads) {
fn_set_n_threads(backend, n_threads);
}
#ifdef GGML_USE_BLAS
if (ggml_backend_is_blas(backend)) {
ggml_backend_blas_set_n_threads(backend, n_threads);
}
#endif
}


bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
ggml_backend_sched_reset(sched);
return t;
Expand Down
86 changes: 7 additions & 79 deletions sense-voice/csrc/sense-voice-encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,7 @@
//

#include "sense-voice-encoder.h"

#include <cmath>

#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif

#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

#ifdef GGML_USE_SYCL
#include "ggml-sycl.h"
#endif

#ifdef GGML_USE_BLAS
#include "ggml-blas.h"
#endif

#ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h"
#endif


#include <cassert>
#include <map>
#include <string>
Expand Down Expand Up @@ -93,53 +70,6 @@ struct sense_voice_context_params sense_voice_context_default_params() {
return result;
}

static ggml_backend_t sense_voice_backend_init(
const sense_voice_context_params &params) {
ggml_backend_t backend_gpu = nullptr;

// initialize the backends
#ifdef GGML_USE_CUDA
if (params.use_gpu) {
SENSE_VOICE_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init(params.gpu_device);
if (!backend_gpu) {
SENSE_VOICE_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif

#ifdef GGML_USE_METAL
if (params.use_gpu) {
SENSEVOICE_LOG_INFO("%s: using Metal backend\n", __func__);
backend_gpu = ggml_backend_metal_init();
if (!backend_gpu) {
SENSEVOICE_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
} else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) {
SENSEVOICE_LOG_ERROR(
"%s: Metal GPU does not support family 7 - falling back to CPU\n",
__func__);
ggml_backend_free(backend_gpu);
backend_gpu = nullptr;
}
}
#endif

#ifdef GGML_USE_SYCL
if (params.use_gpu) {
SENSE_VOICE_LOG_INFO("%s: using SYCL backend\n", __func__);
backend_gpu = ggml_backend_sycl_init(params.gpu_device);
if (!backend_gpu) {
SENSE_VOICE_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
}
}
#endif

if (backend_gpu) {
return backend_gpu;
}
return ggml_backend_cpu_init();
}


static bool ggml_graph_compute_helper(
ggml_backend_sched_t sched,
Expand All @@ -148,16 +78,16 @@ static bool ggml_graph_compute_helper(

for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, n_threads);
}
#ifdef GGML_USE_BLAS
if (ggml_backend_is_blas(backend)) {
ggml_backend_blas_set_n_threads(backend, n_threads);
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;

auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (fn_set_n_threads) {
fn_set_n_threads(backend, n_threads);
}
#endif
}


bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
ggml_backend_sched_reset(sched);
return t;
Expand Down Expand Up @@ -448,8 +378,6 @@ bool sense_voice_encode_internal(sense_voice_context &ctx,
const int n_threads) {
const int64_t t_start_us = ggml_time_us();

const auto &model = ctx.model;


// encoder
{
Expand Down
Loading

0 comments on commit 9b7cb22

Please sign in to comment.