diff --git a/pywhispercpp/constants.py b/pywhispercpp/constants.py index aa170ce..f56a3e9 100644 --- a/pywhispercpp/constants.py +++ b/pywhispercpp/constants.py @@ -171,11 +171,11 @@ 'default': 0 }, 'initial_prompt': { - 'type': str, - 'description': "Initial prompt, these are prepended to any existing text context from a previous call", - 'options': None, - 'default': None - }, + 'type': str, + 'description': "Initial prompt, these are prepended to any existing text context from a previous call", + 'options': None, + 'default': None + }, 'prompt_tokens': { 'type': Tuple, 'description': "tokens to provide to the whisper decoder as initial prompt", @@ -265,5 +265,17 @@ 'description': 'calculate the geometric mean of token probabilities for each segment.', 'options': None, 'default': True + }, + 'vad': { + 'type': bool, + 'description': 'Enable VAD', + 'options': None, + 'default': False + }, + 'vad_model_path': { + 'type': str, + 'description': 'Path to VAD model', + 'options': None, + 'default': None } } diff --git a/src/main.cpp b/src/main.cpp index 6f36bc4..d5c8a0a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -38,10 +38,8 @@ py::function py_logits_filter_callback; // Thanks to https://github.com/pybind/pybind11/issues/2770 struct whisper_context_wrapper { whisper_context* ptr; - }; - // struct inside params struct greedy{ int best_of; @@ -299,14 +297,18 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c struct WhisperFullParamsWrapper : public whisper_full_params { std::string initial_prompt_str; std::string suppress_regex_str; + std::string vad_model_path_str; public: py::function py_progress_callback; WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params()) : whisper_full_params(params), initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""), - suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") { + suppress_regex_str(params.suppress_regex ? params.suppress_regex : ""), + vad_model_path_str(params.vad_model_path ? params.vad_model_path : "") + { initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); + vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str(); // progress callback progress_callback_user_data = this; progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { @@ -327,10 +329,12 @@ struct WhisperFullParamsWrapper : public whisper_full_params { : whisper_full_params(static_cast(other)), // Copy base struct initial_prompt_str(other.initial_prompt_str), suppress_regex_str(other.suppress_regex_str), + vad_model_path_str(other.vad_model_path_str), py_progress_callback(other.py_progress_callback) { // Reset pointers to new string copies initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str(); suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str(); + vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str(); progress_callback_user_data = this; progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) { auto* self = static_cast(user_data); @@ -354,6 +358,10 @@ struct WhisperFullParamsWrapper : public whisper_full_params { suppress_regex_str = regex; suppress_regex = suppress_regex_str.c_str(); } + void set_vad_model_path(const std::string& model_path) { + vad_model_path_str = model_path; + vad_model_path = vad_model_path_str.c_str(); + } }; WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) { return WhisperFullParamsWrapper(whisper_full_default_params(strategy)); @@ -411,6 +419,99 @@ py::dict get_greedy(whisper_full_params * params){ return d; } + +// Voice Activity Detection (VAD) +struct whisper_vad_context_wrapper { + whisper_vad_context* ptr; +}; + +struct whisper_vad_context_wrapper whisper_vad_init_from_file_with_params_wrapper(const char * path_model, struct whisper_vad_context_params params){ + struct whisper_vad_context * ctx = whisper_vad_init_from_file_with_params(path_model, params); + struct whisper_vad_context_wrapper ctw_w; + ctw_w.ptr = ctx; + return ctw_w; +} + +bool whisper_vad_detect_speech_wrapper( + struct whisper_vad_context_wrapper * ctx, + py::array_t samples, + int n_samples){ + py::buffer_info buf = samples.request(); + float *samples_ptr = static_cast(buf.ptr); + + py::gil_scoped_release release; + return whisper_vad_detect_speech(ctx->ptr, samples_ptr, n_samples); +} + +int whisper_vad_n_probs_wrapper(struct whisper_vad_context_wrapper * ctx){ + return whisper_vad_n_probs(ctx->ptr); +} + +py::array_t whisper_vad_probs_wrapper(struct whisper_vad_context_wrapper * ctx) { + float * probs_ptr = whisper_vad_probs(ctx->ptr); + int n_probs = whisper_vad_n_probs(ctx->ptr); + + if (probs_ptr == nullptr || n_probs <= 0) { + return py::array_t(0); + } + return py::array_t( + {n_probs}, + {sizeof(float)}, + probs_ptr + ); +} + +struct whisper_vad_segments_wrapper { + struct whisper_vad_segments * ptr; +}; + +struct whisper_vad_segments_wrapper whisper_vad_segments_from_probs_wrapper( + struct whisper_vad_context_wrapper * vctx_w, + struct whisper_vad_params params + ){ + struct whisper_vad_segments * wvs = whisper_vad_segments_from_probs(vctx_w->ptr, params); + struct whisper_vad_segments_wrapper wvs_w; + wvs_w.ptr = wvs; + return wvs_w; +} + +struct whisper_vad_segments_wrapper whisper_vad_segments_from_samples_wrapper( + struct whisper_vad_context_wrapper * vctx_w, + struct whisper_vad_params params, + py::array_t samples, + int n_samples){ + + py::buffer_info buf = samples.request(); + float *samples_ptr = static_cast(buf.ptr); + + struct whisper_vad_segments * wvs = whisper_vad_segments_from_samples(vctx_w->ptr, params, samples_ptr, n_samples); + struct whisper_vad_segments_wrapper wvs_w; + wvs_w.ptr = wvs; + return wvs_w; +} + +int whisper_vad_segments_n_segments_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper){ + return whisper_vad_segments_n_segments(segments_wrapper->ptr); +} + +float whisper_vad_segments_get_segment_t0_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper, int i_segment) { + return whisper_vad_segments_get_segment_t0(segments_wrapper->ptr, i_segment); +} + +float whisper_vad_segments_get_segment_t1_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper, int i_segment) { + return whisper_vad_segments_get_segment_t1(segments_wrapper->ptr, i_segment); +} + +void whisper_vad_free_segments_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper){ + return whisper_vad_free_segments(segments_wrapper->ptr); +} + +void whisper_vad_free_wrapper(struct whisper_vad_context_wrapper * ctx_w){ + return whisper_vad_free(ctx_w->ptr); +} + +//////////// + PYBIND11_MODULE(_pywhispercpp, m) { m.doc() = R"pbdoc( Pywhispercpp: Python binding to whisper.cpp @@ -665,7 +766,17 @@ PYBIND11_MODULE(_pywhispercpp, m) { [](WhisperFullParamsWrapper &self, py::dict dict) {self.beam_search.beam_size = dict["beam_size"].cast(); self.beam_search.patience = dict["patience"].cast();}) .def_readwrite("new_segment_callback_user_data", &WhisperFullParamsWrapper::new_segment_callback_user_data) .def_readwrite("encoder_begin_callback_user_data", &WhisperFullParamsWrapper::encoder_begin_callback_user_data) - .def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data); + .def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data) + .def_readwrite("vad", &WhisperFullParamsWrapper::vad) + .def_property("vad_model_path", + [](WhisperFullParamsWrapper &self) { + return py::str(self.vad_model_path ? self.vad_model_path : ""); + }, + [](WhisperFullParamsWrapper &self, const std::string &vad_model_path) { + self.set_vad_model_path(vad_model_path); + } + ) + .def_readwrite("vad_params", &WhisperFullParamsWrapper::vad_params); py::implicitly_convertible(); @@ -718,6 +829,40 @@ PYBIND11_MODULE(_pywhispercpp, m) { m.def("assign_logits_filter_callback", &assign_logits_filter_callback, "Assigns a logits_filter_callback, takes instance and a callable function with the same parameters which are defined in the interface", py::arg("params"), py::arg("callback")); + // VAD + py::class_(m,"whisper_vad_params") + .def(py::init<>()) + .def_readwrite("threshold", &whisper_vad_params::threshold) + .def_readwrite("min_speech_duration_ms", &whisper_vad_params::min_speech_duration_ms) + .def_readwrite("min_silence_duration_ms", &whisper_vad_params::min_silence_duration_ms) + .def_readwrite("max_speech_duration_s", &whisper_vad_params::max_speech_duration_s) + .def_readwrite("speech_pad_ms", &whisper_vad_params::speech_pad_ms) + .def_readwrite("samples_overlap", &whisper_vad_params::samples_overlap); + + m.def("whisper_vad_default_params", &whisper_vad_default_params); + + py::class_(m,"whisper_vad_context_params") + .def(py::init<>()) + .def_readwrite("n_threads", &whisper_vad_context_params::n_threads) + .def_readwrite("use_gpu", &whisper_vad_context_params::use_gpu) + .def_readwrite("gpu_device", &whisper_vad_context_params::gpu_device); + + m.def("whisper_vad_default_context_params", &whisper_vad_default_context_params); + m.def("whisper_vad_init_from_file_with_params", &whisper_vad_init_from_file_with_params_wrapper); + m.def("whisper_vad_detect_speech", &whisper_vad_detect_speech_wrapper); + m.def("whisper_vad_n_probs", &whisper_vad_n_probs_wrapper); + m.def("whisper_vad_probs", &whisper_vad_probs_wrapper); + py::class_(m, "whisper_vad_segments"); + m.def("whisper_vad_segments_from_probs", &whisper_vad_segments_from_probs_wrapper); + m.def("whisper_vad_segments_from_samples", &whisper_vad_segments_from_samples_wrapper); + m.def("whisper_vad_segments_n_segments", &whisper_vad_segments_n_segments_wrapper); + m.def("whisper_vad_segments_get_segment_t0", &whisper_vad_segments_get_segment_t0_wrapper); + m.def("whisper_vad_segments_get_segment_t1", &whisper_vad_segments_get_segment_t1_wrapper); + m.def("whisper_vad_free_segments", &whisper_vad_free_segments_wrapper); + m.def("whisper_vad_free", &whisper_vad_free_wrapper); + + + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);