From dab6c6e7588c7b44036e7e5287b2b811c3fdb152 Mon Sep 17 00:00:00 2001 From: lovemefan Date: Sun, 8 Dec 2024 22:28:05 +0800 Subject: [PATCH] add silero vad model for sense voice asr --- scripts/convert-pt-to-gguf.py | 27 ++- sense-voice/csrc/CMakeLists.txt | 2 + sense-voice/csrc/common.cc | 22 +++ sense-voice/csrc/common.h | 31 +++- sense-voice/csrc/main.cc | 223 +++++++++++++++++++++++- sense-voice/csrc/sense-voice-decoder.cc | 21 --- sense-voice/csrc/sense-voice-encoder.cc | 27 +-- sense-voice/csrc/sense-voice.cc | 51 +++++- sense-voice/csrc/silero-vad.cc | 202 +++++++++++++++++++++ sense-voice/csrc/silero-vad.h | 73 ++++++++ sense-voice/csrc/third-party/ggml | 2 +- 11 files changed, 621 insertions(+), 60 deletions(-) create mode 100644 sense-voice/csrc/silero-vad.cc create mode 100644 sense-voice/csrc/silero-vad.h diff --git a/scripts/convert-pt-to-gguf.py b/scripts/convert-pt-to-gguf.py index addf83b..6874828 100644 --- a/scripts/convert-pt-to-gguf.py +++ b/scripts/convert-pt-to-gguf.py @@ -51,6 +51,7 @@ def __init__( ) self.model_checkpoint = "model.pt" + self.vad_model_checkpoint = 'silero_vad.pt' self.hparams = Model.load_hparams(self.dir_model) self.gguf_writer = gguf.GGUFWriter( fname_out, @@ -105,9 +106,24 @@ def set_vocab(self): raise NotImplementedError def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + print(f"gguf: loading model part '{self.model_checkpoint}'") + print(f"gguf: loading vad model part '{self.vad_model_checkpoint}'") ctx: ContextManager[Any] + ctx = contextlib.nullcontext( + torch.load( + str(self.dir_model / self.vad_model_checkpoint), + map_location="cpu", + mmap=True, + weights_only=True, + ) + ) + + with ctx as model_part: + for name, data in model_part.items(): + yield name, data + ctx = contextlib.nullcontext( torch.load( str(self.dir_model / self.model_checkpoint), @@ -254,7 +270,7 @@ def write_one_tensor(self, data_torch, name): if data_torch.dtype not in (torch.float16, torch.float32): data_torch = data_torch.to(torch.float32) - _data = data_torch.squeeze().numpy() + _data = data_torch.numpy() # use max to avoid n_dim of single tensor become 0 if len(_data.shape) != 0: data = _data @@ -281,6 +297,15 @@ def write_one_tensor(self, data_torch, name): ): data = data.astype(np.float16) + if self.ftype == 0 and name in [ + '_model.stft.forward_basis_buffer.weight', + '_model.encoder.0.reparam_conv.weight', + '_model.encoder.1.reparam_conv.weight', + '_model.encoder.2.reparam_conv.weight', + '_model.encoder.3.reparam_conv.weight', + '_model.decoder.decoder.2.weight' + ]: + data = data.astype(np.float16) print( f"|{name}| n_dims = {n_dims}| {old_dtype} | {data.dtype} | {data.size}|" diff --git a/sense-voice/csrc/CMakeLists.txt b/sense-voice/csrc/CMakeLists.txt index 0265bd6..b39f0b0 100644 --- a/sense-voice/csrc/CMakeLists.txt +++ b/sense-voice/csrc/CMakeLists.txt @@ -11,6 +11,8 @@ set(SOURCE_FILES sense-voice-decoder.cc sense-voice.h sense-voice.cc + silero-vad.h + silero-vad.cc ) add_library(sense-voice-core STATIC ${SOURCE_FILES}) diff --git a/sense-voice/csrc/common.cc b/sense-voice/csrc/common.cc index b64366f..a45814c 100644 --- a/sense-voice/csrc/common.cc +++ b/sense-voice/csrc/common.cc @@ -96,3 +96,25 @@ struct sense_voice_full_params sense_voice_full_default_params(enum sense_voice_ return result; } + +bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph, + int n_threads) { + + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + + auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) { + fn_set_n_threads(backend, n_threads); + } + } + + + bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS; + ggml_backend_sched_reset(sched); + return t; +} \ No newline at end of file diff --git a/sense-voice/csrc/common.h b/sense-voice/csrc/common.h index c3d5c71..97290f1 100644 --- a/sense-voice/csrc/common.h +++ b/sense-voice/csrc/common.h @@ -13,6 +13,7 @@ #include "sense-voice-frontend.h" + #ifdef __GNUC__ #define SENSEVOICE_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) #elif defined(_MSC_VER) @@ -202,6 +203,18 @@ struct sense_voice { }; +struct silero_vad; + +struct silero_vad_model { + struct silero_vad *model; + // context + struct ggml_context *ctx; + + // the model backend data is read-only and can be shared between processors + ggml_backend_buffer_t buffer = nullptr; +}; + + static const std::map> g_lang = { { "auto", { 0, "auto", } }, { "zh", { 3, "chinese", } }, @@ -226,7 +239,7 @@ struct sense_voice_hparams { int n_decoder_attention_heads = 4; int n_decoder_layers = 14; int fsmn_kernel_size = 11; - + int n_vad_encoder_layers = 4; int n_predictor_dim = 512; float predictor_tail_threshold = 0.45; @@ -362,9 +375,21 @@ struct sense_voice_state { std::vector backends; + sense_voice_sched sched_vad; + sense_voice_sched sched_vad_sate; sense_voice_sched sched_encode; sense_voice_sched sched_decode; + // hidden state in vad lstm + ggml_context * vad_ctx = nullptr; + struct ggml_tensor * vad_lstm_hidden_state; + struct ggml_tensor * vad_lstm_context; + ggml_backend_buffer_t vad_lstm_hidden_state_buffer = nullptr; + ggml_backend_buffer_t vad_lstm_context_buffer = nullptr; + + ggml_cgraph *sense_voice_encoder_graph; + ggml_cgraph *sense_voice_decoder_graph; + // result of the encoder struct ggml_tensor *encoder_out = nullptr; @@ -443,6 +468,7 @@ struct sense_voice_full_params { struct sense_voice_model { std::string model_type; sense_voice_hparams hparams; + silero_vad *vad_model; sense_voice *model; // context struct ggml_context *ctx; @@ -472,6 +498,7 @@ struct sense_voice_context { ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16) + silero_vad_model vad_model; sense_voice_model model; sense_voice_vocab vocab; @@ -483,5 +510,7 @@ struct sense_voice_context { }; struct sense_voice_full_params sense_voice_full_default_params(enum sense_voice_decoding_strategy strategy); +bool ggml_graph_compute_helper(ggml_backend_sched_t sched, struct ggml_cgraph * graph, int n_threads); + #endif//SENSEVOICE_CPP_COMMON_H diff --git a/sense-voice/csrc/main.cc b/sense-voice/csrc/main.cc index 7b8ea0a..5241b62 100644 --- a/sense-voice/csrc/main.cc +++ b/sense-voice/csrc/main.cc @@ -3,6 +3,7 @@ // #include "common.h" #include "sense-voice.h" +#include "silero-vad.h" #include #include #include @@ -10,6 +11,12 @@ #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif +#define CHUNK_SIZE 512 +#define CONTEXT_SIZE 576 +#define SENSE_VOICE_VAD_CHUNK_PAD_SIZE 64 +#define VAD_LSTM_STATE_MEMORY_SIZE 2048 +#define VAD_LSTM_STATE_DIM 128 +#define INF 0xffffff // command-line parameters @@ -34,6 +41,14 @@ struct sense_voice_params { float temperature = 0.0f; float temperature_inc = 0.2f; + // vad params + float threshold = 0.5f; + float neg_threshold = 0.35f; + int32_t min_speech_duration_ms = 250; + int32_t max_speech_duration_ms = INF; + int32_t min_silence_duration_ms = 100; + int32_t speech_pad_ms = 30; + bool debug_mode = false; bool translate = false; bool detect_language = false; @@ -153,6 +168,7 @@ static void sense_voice_print_usage(int /*argc*/, char ** argv, const sense_voic fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -lpt N, --min_speech_duration_ms [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); @@ -200,6 +216,14 @@ static bool sense_voice_params_parse(int argc, char ** argv, sense_voice_params else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + // vad parameters + else if (arg == "-vt" || arg == "--threshold") { params.threshold = std::stof(argv[++i]); } + else if (arg == "-vnt" || arg == "--neg_threshold") { params.neg_threshold = std::stof(argv[++i]); } + else if (arg == "--min-speech-duration-ms") { params.min_speech_duration_ms = std::stoi(argv[++i]); } + else if (arg == "--max-speech-duration-ms") { params.max_speech_duration_ms = std::stoi(argv[++i]); } + else if (arg == "--min_silence_duration_ms") { params.min_silence_duration_ms = std::stoi(argv[++i]); } + else if (arg == "--speech_pad_ms") { params.speech_pad_ms = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } @@ -375,13 +399,15 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { void sense_voice_free(struct sense_voice_context * ctx) { if (ctx) { ggml_free(ctx->model.ctx); - + ggml_free(ctx->vad_model.ctx); ggml_backend_buffer_free(ctx->model.buffer); + ggml_backend_buffer_free(ctx->vad_model.buffer); sense_voice_free_state(ctx->state); delete ctx->model.model->encoder; delete ctx->model.model; + delete ctx->vad_model.model; delete ctx; } } @@ -490,13 +516,202 @@ int main(int argc, char ** argv) { wparams.no_timestamps = params.no_timestamps; - if (sense_voice_full_parallel(ctx, wparams, pcmf32, pcmf32.size(), params.n_processors) != 0) { - fprintf(stderr, "%s: failed to process audio\n", argv[0]); - return 10; + + int n_pad = 0; + std::vector chunk(CONTEXT_SIZE + SENSE_VOICE_VAD_CHUNK_PAD_SIZE, 0); + + // run vad and asr + + { + // init state + ctx->state->vad_ctx = ggml_init({VAD_LSTM_STATE_MEMORY_SIZE, nullptr, true}); + ctx->state->vad_lstm_context = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); + ctx->state->vad_lstm_hidden_state = ggml_new_tensor_1d(ctx->state->vad_ctx, GGML_TYPE_F32, VAD_LSTM_STATE_DIM); + + ctx->state->vad_lstm_context_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], + ggml_nbytes(ctx->state->vad_lstm_context) + + ggml_backend_get_alignment(ctx->state->backends[0])); + ctx->state->vad_lstm_hidden_state_buffer = ggml_backend_alloc_buffer(ctx->state->backends[0], + ggml_nbytes(ctx->state->vad_lstm_hidden_state) + + ggml_backend_get_alignment(ctx->state->backends[0])); + auto context_alloc = ggml_tallocr_new(ctx->state->vad_lstm_context_buffer); + ggml_tallocr_alloc(&context_alloc, ctx->state->vad_lstm_context); + + auto state_alloc = ggml_tallocr_new(ctx->state->vad_lstm_hidden_state_buffer); + ggml_tallocr_alloc(&state_alloc, ctx->state->vad_lstm_hidden_state); + + ggml_set_zero(ctx->state->vad_lstm_context); + ggml_set_zero(ctx->state->vad_lstm_hidden_state); + } + + int offset = offset = CHUNK_SIZE - CONTEXT_SIZE; + + auto & sched = ctx->state->sched_vad.sched; + ggml_cgraph *gf = silero_vad_build_graph(*ctx, *ctx->state); + +// ggml_backend_sched_set_eval_callback(sched, ctx->params.cb_eval, &ctx->params.cb_eval_user_data); + + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; } + // var for vad + bool triggered = false; + int32_t temp_end = 0; + int32_t prev_end = 0, next_start = 0; + int32_t current_speech_start = 0, current_speech_end = 0; + int32_t min_speech_samples = sample_rate * params.min_speech_duration_ms / 1000; + int32_t speech_pad_samples = sample_rate * params.speech_pad_ms / 1000; + int32_t max_speech_samples = sample_rate * params.max_speech_duration_ms / 1000 - CHUNK_SIZE - 2 * speech_pad_samples; + int32_t min_silence_samples = sample_rate * params.min_silence_duration_ms / 1000; + int32_t min_silence_samples_at_max_speech = sample_rate * 98 / 1000; + std::vector speech_segment; + for (int i = 0; i < pcmf32.size(); i += CHUNK_SIZE){ + + n_pad = CHUNK_SIZE <= pcmf32.size() - i ? 0 : CHUNK_SIZE + i - pcmf32.size(); + + for (int j = i + offset; j < i + CHUNK_SIZE; j++) { + if (j > 0 and j < i + CONTEXT_SIZE - n_pad){ + chunk[j - i - offset] = pcmf32[j] / 32768; + } else{ + //pad chunk when first chunk in left or data not enough in right + chunk[j - i - offset] = 0; + } + + } + // implements reflection pad + for (int j = CONTEXT_SIZE; j < chunk.size(); j++) { + chunk[j] = chunk[2 * CONTEXT_SIZE - j - 2]; + } + + { + // set the input + { + + struct ggml_tensor *data = ggml_graph_get_tensor(gf, "audio_chunk"); + ggml_backend_tensor_set(data, chunk.data(), 0, ggml_nbytes(data)); + + struct ggml_tensor *in_lstm_context = ggml_graph_get_tensor(gf, "in_lstm_context"); + struct ggml_tensor *in_lstm_hidden_state = ggml_graph_get_tensor(gf, "in_lstm_hidden_state"); + ggml_backend_tensor_copy(ctx->state->vad_lstm_context, in_lstm_context); + ggml_backend_tensor_copy(ctx->state->vad_lstm_hidden_state, in_lstm_hidden_state); + + } + + if (!ggml_graph_compute_helper(sched, gf, params.n_processors)) { + return false; + } + + // save output state + { + struct ggml_tensor *lstm_context = ggml_graph_get_tensor(gf, "out_lstm_context"); + ggml_backend_tensor_copy(lstm_context, ctx->state->vad_lstm_context); + struct ggml_tensor *lstm_hidden_state = ggml_graph_get_tensor(gf, "out_lstm_hidden_state"); + ggml_backend_tensor_copy(lstm_hidden_state, ctx->state->vad_lstm_hidden_state); + + } + + } + + { + float speech_prob = ((float *)ggml_graph_get_tensor(gf, "logit")->data)[0]; + if (speech_prob >= params.threshold && temp_end) { + temp_end = 0; + if(next_start < prev_end) next_start = CHUNK_SIZE * i; + } + + if (speech_prob >= params.threshold && ! triggered){ + triggered = true; + current_speech_start = i; + continue; + } + if (triggered && i - current_speech_start > max_speech_samples) { + if (prev_end){ + current_speech_end = prev_end; + + // find an endpoint in speech + speech_segment.clear(); + speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); + if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 10; + } + current_speech_end = current_speech_start = 0; + if (next_start < prev_end) { + triggered = false; + }else{ + current_speech_start = next_start; + } + // find an endpoint in speech + speech_segment.clear(); + speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); + if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 10; + } + current_speech_end = current_speech_start = 0; + prev_end = next_start = temp_end = 0; + + } else { + current_speech_end = i; + prev_end = next_start = temp_end = 0; + triggered = false; + continue; + + } + } + + if (speech_prob < params.neg_threshold && triggered){ + if (temp_end == 0){ + temp_end = i; + } + + if (i - temp_end > min_silence_samples_at_max_speech) { + prev_end = temp_end; + } + + if (i - temp_end < min_silence_samples) { + continue; + }else{ + current_speech_end = temp_end; + if (current_speech_end - current_speech_start > min_speech_samples) { + // find an endpoint in speech + speech_segment.clear(); + speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + current_speech_end); + printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); + if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 10; + } + current_speech_end = current_speech_start = 0; + } + prev_end = next_start = temp_end = 0; + triggered = false; + continue; + } + } + + } + + } + // last segment speech + if (current_speech_start != 0 && current_speech_end != 0 && pcmf32.size() - current_speech_start > min_speech_samples){ + speech_segment.clear(); + speech_segment.assign(pcmf32.begin() + current_speech_start, pcmf32.begin() + pcmf32.size()); + printf("[%.2f-%.2f] ", current_speech_start / (sample_rate * 1.0), current_speech_end / (sample_rate * 1.0)); + if (sense_voice_full_parallel(ctx, wparams, speech_segment, speech_segment.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 10; + } + } } + SENSE_VOICE_LOG_INFO("\n%s: decoder audio use %f s, rtf is %f. \n\n", + __func__, + (ctx->state->t_encode_us + ctx->state->t_decode_us) / 1e6, + (ctx->state->t_encode_us + ctx->state->t_decode_us) / (1e6 * ctx->state->duration)); } sense_voice_free(ctx); diff --git a/sense-voice/csrc/sense-voice-decoder.cc b/sense-voice/csrc/sense-voice-decoder.cc index 9ceef6a..c5af65b 100644 --- a/sense-voice/csrc/sense-voice-decoder.cc +++ b/sense-voice/csrc/sense-voice-decoder.cc @@ -78,27 +78,6 @@ struct ggml_cgraph *sense_voice_build_graph_ctc_decoder(sense_voice_context &ctx return gf; } -static bool ggml_graph_compute_helper( - ggml_backend_sched_t sched, - struct ggml_cgraph * graph, - int n_threads) { - - for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { - ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - - auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (fn_set_n_threads) { - fn_set_n_threads(backend, n_threads); - } - } - - - bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS; - ggml_backend_sched_reset(sched); - return t; -} bool sense_voice_decode_internal(sense_voice_context &ctx, sense_voice_state &state, diff --git a/sense-voice/csrc/sense-voice-encoder.cc b/sense-voice/csrc/sense-voice-encoder.cc index e3e9b47..c951ac0 100644 --- a/sense-voice/csrc/sense-voice-encoder.cc +++ b/sense-voice/csrc/sense-voice-encoder.cc @@ -70,28 +70,6 @@ struct sense_voice_context_params sense_voice_context_default_params() { } -static bool ggml_graph_compute_helper( - ggml_backend_sched_t sched, - struct ggml_cgraph * graph, - int n_threads) { - - for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { - ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - - auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (fn_set_n_threads) { - fn_set_n_threads(backend, n_threads); - } - } - - - bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS; - ggml_backend_sched_reset(sched); - return t; -} - static struct ggml_tensor *encoder_layer_sanm_forward(const sense_voice_hparams &hparams, sense_voice_context &sctx, ggml_context *ctx0, @@ -197,12 +175,11 @@ static struct ggml_tensor *encoder_layer_sanm_forward(const sense_voice_hparams struct ggml_tensor * a = layer.e_attn_fsmn_w; struct ggml_tensor * b = ggml_cont(ctx0, ggml_transpose(ctx0, V)); - struct ggml_tensor * new_a = ggml_reshape_4d(ctx0, a, a->ne[0], 1, a->ne[1], a->ne[2] * a->ne[3]);; - struct ggml_tensor *im2col = ggml_im2col(ctx0, new_a, ggml_reshape_4d(ctx0, b, b->ne[0], 1, b->ne[1], b->ne[2] * b->ne[3]), 1, 0, padding, 0, 1, 0, false, GGML_TYPE_F32); + struct ggml_tensor *im2col = ggml_im2col(ctx0, a, ggml_reshape_4d(ctx0, b, b->ne[0], 1, b->ne[1], b->ne[2] * b->ne[3]), 1, 0, padding, 0, 1, 0, false, GGML_TYPE_F32); // new_a [n_state, 1, kernel_size], im2col [n_state, length, kernel_size] // result -> [n_state, length, kernel_size] @ [n_state, 1, kernel_size].T = [n_state, length , 1] - struct ggml_tensor * result = ggml_mul_mat(ctx0, new_a, im2col); + struct ggml_tensor * result = ggml_mul_mat(ctx0, a, im2col); fsmn_memory = ggml_reshape_2d(ctx0, result, im2col->ne[1], im2col->ne[2]); } fsmn_memory = ggml_cont(ctx0, ggml_transpose(ctx0, fsmn_memory)); diff --git a/sense-voice/csrc/sense-voice.cc b/sense-voice/csrc/sense-voice.cc index 9a33483..a59d730 100644 --- a/sense-voice/csrc/sense-voice.cc +++ b/sense-voice/csrc/sense-voice.cc @@ -2,13 +2,14 @@ // Created by lovemefan on 2024/7/19. // #include "sense-voice.h" -#include "sense-voice-encoder.h" -#include "sense-voice-decoder.h" -#include "sense-voice-cmvn.h" #include "common.h" +#include "sense-voice-cmvn.h" +#include "sense-voice-decoder.h" +#include "sense-voice-encoder.h" +#include "silero-vad.h" #include -#include #include +#include #define SENSE_VOICE_MAX_NODES 8192 #define SENSE_VOICE_MAX_DECODERS 8 @@ -108,6 +109,7 @@ bool sense_voice_model_load(const char *path_model, sense_voice_context &sctx) { sctx.t_start_ms = t_start_ms; + auto &vad_model = sctx.vad_model; auto &sense_voice = sctx.model; auto &vocab = sctx.vocab; auto &hparams = sense_voice.hparams; @@ -308,12 +310,29 @@ bool sense_voice_model_load(const char *path_model, sense_voice_context &sctx) { { // build model + vad_model.model = new struct silero_vad; + vad_model.model->encoders_layer = std::vector(hparams.n_vad_encoder_layers); + + sense_voice.vad_model = vad_model.model; sense_voice.model = new struct sense_voice; sense_voice.model->encoder = new struct sense_voice_encoder; sense_voice.model->encoder->encoders_layer = std::vector(hparams.n_encoder_layers - 1); sense_voice.model->encoder->tp_encoders_layer = std::vector(hparams.n_tp_encoder_layers); - + // load vad model + { + vad_model.model->stft.forward_basis_buffer = sense_voice.tensors["_model.stft.forward_basis_buffer.weight"]; + for (int i = 0; i < hparams.n_vad_encoder_layers; i++) { + vad_model.model->encoders_layer[i].reparam_conv_w = sense_voice.tensors["_model.encoder." + std::to_string(i) + ".reparam_conv.weight"]; + vad_model.model->encoders_layer[i].reparam_conv_b = sense_voice.tensors["_model.encoder." + std::to_string(i) + ".reparam_conv.bias"]; + } + vad_model.model->decoder.lstm_weight_ih = sense_voice.tensors["_model.decoder.rnn.weight_ih"]; + vad_model.model->decoder.lstm_weight_hh = sense_voice.tensors["_model.decoder.rnn.weight_hh"]; + vad_model.model->decoder.lstm_bias_ih = sense_voice.tensors["_model.decoder.rnn.bias_ih"]; + vad_model.model->decoder.lstm_bias_hh = sense_voice.tensors["_model.decoder.rnn.bias_hh"]; + vad_model.model->decoder.decoder_conv_w = sense_voice.tensors["_model.decoder.decoder.2.weight"]; + vad_model.model->decoder.decoder_conv_b = sense_voice.tensors["_model.decoder.decoder.2.bias"]; + } // load encoder weights, multi layers of EncoderLayerSANM { @@ -496,7 +515,10 @@ void sense_voice_free_state(struct sense_voice_state * state) { { ggml_free(state->feature.ctx); + ggml_free(state->vad_ctx); ggml_backend_buffer_free(state->feature.buffer); + ggml_backend_buffer_free(state->vad_lstm_hidden_state_buffer); + ggml_backend_buffer_free(state->vad_lstm_context_buffer); state->feature.n_len_org = 0; state->feature.ctx = nullptr; state->feature.tensor = nullptr; @@ -578,6 +600,21 @@ struct sense_voice_state *sense_voice_init_state(sense_voice_context *ctx) { state->logits_id.reserve(ctx->model.hparams.n_vocab); + // vad allocator + { + bool ok = sense_voice_sched_graph_init( + state->sched_vad, state->backends, + [&]() { return silero_vad_build_graph(*ctx, *state); }); + + if (!ok) { + SENSE_VOICE_LOG_ERROR("%s: failed to init vad model allocator\n", __func__); + sense_voice_free_state(state); + return nullptr; + } + + SENSE_VOICE_LOG_INFO("%s: compute buffer (encoder) = %7.2f MB\n", __func__, + sense_voice_sched_size(state->sched_vad) / 1e6); + } // encoder allocator { @@ -690,7 +727,7 @@ int sense_voice_pcm_to_feature_with_state(struct sense_voice_context * ctx, // state->feature.tensor = ggml_transpose(state->feature.ctx, state->feature.tensor); } - SENSE_VOICE_LOG_INFO("%s: calculate fbank and cmvn takes %.3f ms\n", __func__, + SENSE_VOICE_LOG_DEBUG("%s: calculate fbank and cmvn takes %.3f ms\n", __func__, state->t_feature_us / 1000.0); return 0; } @@ -749,7 +786,7 @@ int sense_voice_full_with_state( return -6; } - SENSE_VOICE_LOG_INFO("\n%s: decoder audio use %f s, rtf is %f. \n\n", + SENSE_VOICE_LOG_DEBUG("\n%s: decoder audio use %f s, rtf is %f. \n\n", __func__, (state->t_encode_us + state->t_decode_us) / 1e6, (state->t_encode_us + state->t_decode_us) / (1e6 * state->duration)); diff --git a/sense-voice/csrc/silero-vad.cc b/sense-voice/csrc/silero-vad.cc new file mode 100644 index 0000000..f6d76ca --- /dev/null +++ b/sense-voice/csrc/silero-vad.cc @@ -0,0 +1,202 @@ +// +// Created by lovemefan on 2024/11/24. +// + +#include "silero-vad.h" +#define SENSE_VOICE_VAD_MAX_NODES 1024 +#define VAD_CHUNK_SIZE 640 +/* + \begin{array}{ll} + i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ + f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ + g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ + o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ + c' = f * c + i * g \\ + h' = o * \tanh(c') \\ + \end{array} + + * */ + + +ggml_cgraph *silero_vad_build_graph( + sense_voice_context &ctx, sense_voice_state &state){ + + const auto &model = ctx.vad_model.model; + const auto &hparams = ctx.model.hparams; + + struct ggml_init_params params = { + /*.mem_size =*/state.sched_vad.meta.size(), + /*.mem_buffer =*/state.sched_vad.meta.data(), + /*.no_alloc =*/true, + }; + + struct ggml_context *ctx0 = ggml_init(params); + + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, SENSE_VOICE_VAD_MAX_NODES, false); + + ggml_tensor *chunk = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, VAD_CHUNK_SIZE); + // chunk size must be 576 before pad + ggml_set_name(chunk, "audio_chunk"); + ggml_set_input(chunk); + + + ggml_tensor *cur; + // stft + { + cur = ggml_conv_1d(ctx0, model->stft.forward_basis_buffer, chunk, 128, 0, 1); + // chunk operation by ggml view, equals torch.chunk(x, 2) in pytorch + struct ggml_tensor * real_part = ggml_view_2d(ctx0, cur, cur->ne[0], cur->ne[1] / 2, cur->nb[1], 0); + ggml_set_name(real_part, "real_part"); + struct ggml_tensor * image_part = ggml_view_2d(ctx0, cur, cur->ne[0], cur->ne[1] / 2, cur->nb[1], cur->nb[0] * cur->ne[0] * cur->ne[1] / 2); + ggml_set_name(image_part, "image_part"); + // magnitude, equals torch.sqrt(real_part ** 2 + imag_part ** 2) + cur = ggml_sqrt(ctx0, + ggml_add(ctx0, + ggml_mul(ctx0, real_part, real_part), + ggml_mul(ctx0, image_part, image_part) + ) + ); + ggml_set_name(cur, "magnitude"); + + } + + // encoder + { + { + cur = ggml_conv_1d(ctx0, model->encoders_layer[0].reparam_conv_w, cur, 1, 1, 1); + cur = ggml_add(ctx0, cur, ggml_transpose(ctx0, model->encoders_layer[0].reparam_conv_b)); + cur = ggml_relu(ctx0, cur); + + cur = ggml_conv_1d(ctx0, model->encoders_layer[1].reparam_conv_w, cur, 2, 1, 1); + cur = ggml_add(ctx0, cur, ggml_transpose(ctx0, model->encoders_layer[1].reparam_conv_b)); + cur = ggml_relu(ctx0, cur); + + cur = ggml_conv_1d(ctx0, model->encoders_layer[2].reparam_conv_w, cur, 2, 1, 1); + cur = ggml_add(ctx0, cur, ggml_transpose(ctx0, model->encoders_layer[2].reparam_conv_b)); + cur = ggml_relu(ctx0, cur); + + cur = ggml_conv_1d(ctx0, model->encoders_layer[3].reparam_conv_w, cur, 1, 1, 1); + cur = ggml_add(ctx0, cur, ggml_transpose(ctx0, model->encoders_layer[3].reparam_conv_b)); + cur = ggml_relu(ctx0, cur); + } + + } + + //decoder + { + + struct ggml_tensor* in_lstm_hidden_state = ggml_new_tensor_1d(ctx0, cur->type, cur->ne[1]); + struct ggml_tensor* in_lstm_context = ggml_new_tensor_1d(ctx0, cur->type, cur->ne[1]); + + struct ggml_tensor* out_lstm_hidden_state; + struct ggml_tensor* out_lstm_context; + + ggml_set_name(in_lstm_context, "in_lstm_context"); + ggml_set_name(in_lstm_hidden_state, "in_lstm_hidden_state"); + + + // lstm cell + // ref: https://github.com/pytorch/pytorch/blob/1a93b96815b5c87c92e060a6dca51be93d712d09/aten/src/ATen/native/RNN.cpp#L298-L304 + // gates = x @ self.weight_ih.T + self.bias_ih + hx[0] @ self.weight_hh.T + self.bias_hh + // chunked_gates = gates.chunk(4, dim=-1) + // ingate = torch.sigmoid(chunked_gates[0]) + // forgetgate = torch.sigmoid(chunked_gates[1]) + // cellgate = torch.tanh(chunked_gates[2]) + // outgate = torch.sigmoid(chunked_gates[3]) + // cy = forgetgate * hx[1] + ingate * cellgate + // hy = outgate * torch.tanh(cy) + + struct ggml_tensor *gates = ggml_add( + ctx0, + ggml_add(ctx0, ggml_mul_mat(ctx0, + model->decoder.lstm_weight_ih, + ggml_transpose(ctx0, cur)), + model->decoder.lstm_bias_ih), + + ggml_add(ctx0, ggml_mul_mat(ctx0, + model->decoder.lstm_weight_hh, + in_lstm_hidden_state), + model->decoder.lstm_bias_hh)); + ggml_set_name(gates, "gates"); + + struct ggml_tensor * input_gates = ggml_sigmoid(ctx0, ggml_view_2d(ctx0, gates, gates->ne[0] / 4, gates->ne[1] , gates->nb[1], 0)); + struct ggml_tensor * forget_gates = ggml_sigmoid(ctx0, ggml_view_2d(ctx0, gates, gates->ne[0] / 4, gates->ne[1], gates->nb[1], gates->nb[0] / 4 * gates->ne[0])); + struct ggml_tensor * cell_gate = ggml_tanh(ctx0, ggml_view_2d(ctx0, gates, gates->ne[0] / 4, gates->ne[1], gates->nb[1], 2 * gates->nb[0] / 4 * gates->ne[0])); + struct ggml_tensor * out_gates = ggml_sigmoid(ctx0, ggml_view_2d(ctx0, gates, gates->ne[0] / 4, gates->ne[1], gates->nb[1], 3 * gates->nb[0] / 4 * gates->ne[0])); + + ggml_set_name(input_gates, "input_gates"); + ggml_set_name(forget_gates, "forget_gates"); + ggml_set_name(cell_gate, "cell_gates"); + ggml_set_name(out_gates, "out_gates"); + + out_lstm_context = ggml_add(ctx0, + ggml_mul(ctx0, forget_gates, in_lstm_context), + ggml_mul(ctx0, input_gates, cell_gate) + ); + + ggml_set_name(out_lstm_context, "out_lstm_context"); + ggml_set_output(out_lstm_context); + out_lstm_hidden_state = ggml_mul(ctx0, out_gates, ggml_tanh(ctx0, out_lstm_context)); + ggml_set_name(out_lstm_hidden_state, "out_lstm_hidden_state"); + ggml_set_output(out_lstm_hidden_state); + + cur = ggml_relu(ctx0, out_lstm_hidden_state); + cur = ggml_conv_1d(ctx0, model->decoder.decoder_conv_w, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), 1, 0, 1); + cur = ggml_add(ctx0, cur, ggml_transpose(ctx0, model->decoder.decoder_conv_b)); + ggml_set_name(cur, "decoder_out"); + cur = ggml_sigmoid(ctx0, cur); + ggml_set_name(cur, "logit"); + + } + + ggml_set_output(cur); + ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); + return gf; +} + +bool silero_vad_encode_internal(sense_voice_context &ctx, + sense_voice_state &state, + std::vector chunk, + const int n_threads, + const double &speech_prob){ + { + auto & sched = ctx.state->sched_vad.sched; + ggml_cgraph *gf = silero_vad_build_graph(ctx, state); + + // ggml_backend_sched_set_eval_callback(sched, ctx->params.cb_eval, &ctx->params.cb_eval_user_data); + + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + // set the input + { + + struct ggml_tensor *data = ggml_graph_get_tensor(gf, "audio_chunk"); + ggml_backend_tensor_set(data, chunk.data(), 0, ggml_nbytes(data)); + + struct ggml_tensor *in_lstm_context = ggml_graph_get_tensor(gf, "in_lstm_context"); + struct ggml_tensor *in_lstm_hidden_state = ggml_graph_get_tensor(gf, "in_lstm_hidden_state"); + + ggml_backend_tensor_copy(state.vad_lstm_context, in_lstm_context); + ggml_backend_tensor_copy(state.vad_lstm_hidden_state, in_lstm_hidden_state); + + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + // save output state + { + struct ggml_tensor *lstm_context = ggml_graph_get_tensor(gf, "out_lstm_context"); + ggml_backend_tensor_copy(lstm_context, state.vad_lstm_context); + struct ggml_tensor *lstm_hidden_state = ggml_graph_get_tensor(gf, "out_lstm_hidden_state"); + ggml_backend_tensor_copy(lstm_hidden_state, state.vad_lstm_hidden_state); + + } + + } +} \ No newline at end of file diff --git a/sense-voice/csrc/silero-vad.h b/sense-voice/csrc/silero-vad.h new file mode 100644 index 0000000..2883610 --- /dev/null +++ b/sense-voice/csrc/silero-vad.h @@ -0,0 +1,73 @@ +// +// Created by lovemefan on 2024/11/24. +// + +#ifndef SENSEVOICE_CPP_SILERO_VAD_H +#define SENSEVOICE_CPP_SILERO_VAD_H + + +#include +#include "common.h" + + +struct silero_vad_stft { + struct ggml_tensor *forward_basis_buffer; +}; + +struct silero_vad_encoder_layer { + // conv1d + struct ggml_tensor *reparam_conv_w; + struct ggml_tensor *reparam_conv_b; +}; + + +struct silero_vad_decoder { + + // lstm cell + struct ggml_tensor *lstm_weight_ih; + struct ggml_tensor *lstm_bias_ih; + struct ggml_tensor *lstm_weight_hh; + struct ggml_tensor *lstm_bias_hh; + + // conv1d + struct ggml_tensor * decoder_conv_w; + struct ggml_tensor * decoder_conv_b; + +}; + +struct silero_vad { + silero_vad_stft stft; + std::vector encoders_layer; + silero_vad_decoder decoder; +}; + +// Progress callback +typedef void (*silero_vad_progress_callback)(struct silero_vad_context *ctx, + struct silero_vad_state *state, + int progress, void *user_data); + + + +// Various functions for loading a ggml silero_vad model. +// Allocate (almost) all memory needed for the model. +// Return NULL on failure + +SENSEVOICE_API struct silero_vad_context_params; + + +SENSEVOICE_API struct ggml_cgraph *silero_vad_build_graph( + sense_voice_context &ctx, sense_voice_state &state); + + +SENSEVOICE_API bool silero_vad_encode_internal(sense_voice_context &ctx, + sense_voice_state &state, + std::vector chunk, + const int n_threads, + const double &speech_prob); + +SENSEVOICE_API double silero_vad_with_state(sense_voice_context &ctx, + sense_voice_state &state, + std::vector &pcmf32, + int n_processors); + +#endif//SENSEVOICE_CPP_SILERO_VAD_H diff --git a/sense-voice/csrc/third-party/ggml b/sense-voice/csrc/third-party/ggml index 6fcbd60..a5960e8 160000 --- a/sense-voice/csrc/third-party/ggml +++ b/sense-voice/csrc/third-party/ggml @@ -1 +1 @@ -Subproject commit 6fcbd60bc72ac3f7ad43f78c87e535f2e6206f58 +Subproject commit a5960e80d3e65ce6ff18f90315ab96f63cf9c4cc