Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions pywhispercpp/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@
'default': 0
},
'initial_prompt': {
'type': str,
'description': "Initial prompt, these are prepended to any existing text context from a previous call",
'options': None,
'default': None
},
'type': str,
'description': "Initial prompt, these are prepended to any existing text context from a previous call",
'options': None,
'default': None
},
'prompt_tokens': {
'type': Tuple,
'description': "tokens to provide to the whisper decoder as initial prompt",
Expand Down Expand Up @@ -265,5 +265,17 @@
'description': 'calculate the geometric mean of token probabilities for each segment.',
'options': None,
'default': True
},
'vad': {
'type': bool,
'description': 'Enable VAD',
'options': None,
'default': False
},
'vad_model_path': {
'type': str,
'description': 'Path to VAD model',
'options': None,
'default': None
}
}
153 changes: 149 additions & 4 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ py::function py_logits_filter_callback;
// Thanks to https://github.com/pybind/pybind11/issues/2770
struct whisper_context_wrapper {
whisper_context* ptr;

};


// struct inside params
struct greedy{
int best_of;
Expand Down Expand Up @@ -299,14 +297,18 @@ int whisper_ctx_init_openvino_encoder_wrapper(struct whisper_context_wrapper * c
struct WhisperFullParamsWrapper : public whisper_full_params {
std::string initial_prompt_str;
std::string suppress_regex_str;
std::string vad_model_path_str;
public:
py::function py_progress_callback;
WhisperFullParamsWrapper(const whisper_full_params& params = whisper_full_params())
: whisper_full_params(params),
initial_prompt_str(params.initial_prompt ? params.initial_prompt : ""),
suppress_regex_str(params.suppress_regex ? params.suppress_regex : "") {
suppress_regex_str(params.suppress_regex ? params.suppress_regex : ""),
vad_model_path_str(params.vad_model_path ? params.vad_model_path : "")
{
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str();
// progress callback
progress_callback_user_data = this;
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
Expand All @@ -327,10 +329,12 @@ struct WhisperFullParamsWrapper : public whisper_full_params {
: whisper_full_params(static_cast<whisper_full_params>(other)), // Copy base struct
initial_prompt_str(other.initial_prompt_str),
suppress_regex_str(other.suppress_regex_str),
vad_model_path_str(other.vad_model_path_str),
py_progress_callback(other.py_progress_callback) {
// Reset pointers to new string copies
initial_prompt = initial_prompt_str.empty() ? nullptr : initial_prompt_str.c_str();
suppress_regex = suppress_regex_str.empty() ? nullptr : suppress_regex_str.c_str();
vad_model_path = vad_model_path_str.empty() ? nullptr : vad_model_path_str.c_str();
progress_callback_user_data = this;
progress_callback = [](struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
auto* self = static_cast<WhisperFullParamsWrapper*>(user_data);
Expand All @@ -354,6 +358,10 @@ struct WhisperFullParamsWrapper : public whisper_full_params {
suppress_regex_str = regex;
suppress_regex = suppress_regex_str.c_str();
}
void set_vad_model_path(const std::string& model_path) {
vad_model_path_str = model_path;
vad_model_path = vad_model_path_str.c_str();
}
};
WhisperFullParamsWrapper whisper_full_default_params_wrapper(enum whisper_sampling_strategy strategy) {
return WhisperFullParamsWrapper(whisper_full_default_params(strategy));
Expand Down Expand Up @@ -411,6 +419,99 @@ py::dict get_greedy(whisper_full_params * params){
return d;
}


// Voice Activity Detection (VAD)
struct whisper_vad_context_wrapper {
whisper_vad_context* ptr;
};

struct whisper_vad_context_wrapper whisper_vad_init_from_file_with_params_wrapper(const char * path_model, struct whisper_vad_context_params params){
struct whisper_vad_context * ctx = whisper_vad_init_from_file_with_params(path_model, params);
struct whisper_vad_context_wrapper ctw_w;
ctw_w.ptr = ctx;
return ctw_w;
}

bool whisper_vad_detect_speech_wrapper(
struct whisper_vad_context_wrapper * ctx,
py::array_t<float> samples,
int n_samples){
py::buffer_info buf = samples.request();
float *samples_ptr = static_cast<float *>(buf.ptr);

py::gil_scoped_release release;
return whisper_vad_detect_speech(ctx->ptr, samples_ptr, n_samples);
}

int whisper_vad_n_probs_wrapper(struct whisper_vad_context_wrapper * ctx){
return whisper_vad_n_probs(ctx->ptr);
}

py::array_t<float> whisper_vad_probs_wrapper(struct whisper_vad_context_wrapper * ctx) {
float * probs_ptr = whisper_vad_probs(ctx->ptr);
int n_probs = whisper_vad_n_probs(ctx->ptr);

if (probs_ptr == nullptr || n_probs <= 0) {
return py::array_t<float>(0);
}
return py::array_t<float>(
{n_probs},
{sizeof(float)},
probs_ptr
);
}

struct whisper_vad_segments_wrapper {
struct whisper_vad_segments * ptr;
};

struct whisper_vad_segments_wrapper whisper_vad_segments_from_probs_wrapper(
struct whisper_vad_context_wrapper * vctx_w,
struct whisper_vad_params params
){
struct whisper_vad_segments * wvs = whisper_vad_segments_from_probs(vctx_w->ptr, params);
struct whisper_vad_segments_wrapper wvs_w;
wvs_w.ptr = wvs;
return wvs_w;
}

struct whisper_vad_segments_wrapper whisper_vad_segments_from_samples_wrapper(
struct whisper_vad_context_wrapper * vctx_w,
struct whisper_vad_params params,
py::array_t<float> samples,
int n_samples){

py::buffer_info buf = samples.request();
float *samples_ptr = static_cast<float *>(buf.ptr);

struct whisper_vad_segments * wvs = whisper_vad_segments_from_samples(vctx_w->ptr, params, samples_ptr, n_samples);
struct whisper_vad_segments_wrapper wvs_w;
wvs_w.ptr = wvs;
return wvs_w;
}

int whisper_vad_segments_n_segments_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper){
return whisper_vad_segments_n_segments(segments_wrapper->ptr);
}

float whisper_vad_segments_get_segment_t0_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper, int i_segment) {
return whisper_vad_segments_get_segment_t0(segments_wrapper->ptr, i_segment);
}

float whisper_vad_segments_get_segment_t1_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper, int i_segment) {
return whisper_vad_segments_get_segment_t1(segments_wrapper->ptr, i_segment);
}

void whisper_vad_free_segments_wrapper(struct whisper_vad_segments_wrapper * segments_wrapper){
return whisper_vad_free_segments(segments_wrapper->ptr);
}

void whisper_vad_free_wrapper(struct whisper_vad_context_wrapper * ctx_w){
return whisper_vad_free(ctx_w->ptr);
}

////////////

PYBIND11_MODULE(_pywhispercpp, m) {
m.doc() = R"pbdoc(
Pywhispercpp: Python binding to whisper.cpp
Expand Down Expand Up @@ -665,7 +766,17 @@ PYBIND11_MODULE(_pywhispercpp, m) {
[](WhisperFullParamsWrapper &self, py::dict dict) {self.beam_search.beam_size = dict["beam_size"].cast<int>(); self.beam_search.patience = dict["patience"].cast<float>();})
.def_readwrite("new_segment_callback_user_data", &WhisperFullParamsWrapper::new_segment_callback_user_data)
.def_readwrite("encoder_begin_callback_user_data", &WhisperFullParamsWrapper::encoder_begin_callback_user_data)
.def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data);
.def_readwrite("logits_filter_callback_user_data", &WhisperFullParamsWrapper::logits_filter_callback_user_data)
.def_readwrite("vad", &WhisperFullParamsWrapper::vad)
.def_property("vad_model_path",
[](WhisperFullParamsWrapper &self) {
return py::str(self.vad_model_path ? self.vad_model_path : "");
},
[](WhisperFullParamsWrapper &self, const std::string &vad_model_path) {
self.set_vad_model_path(vad_model_path);
}
)
.def_readwrite("vad_params", &WhisperFullParamsWrapper::vad_params);


py::implicitly_convertible<whisper_full_params, WhisperFullParamsWrapper>();
Expand Down Expand Up @@ -718,6 +829,40 @@ PYBIND11_MODULE(_pywhispercpp, m) {
m.def("assign_logits_filter_callback", &assign_logits_filter_callback, "Assigns a logits_filter_callback, takes <whisper_full_params> instance and a callable function with the same parameters which are defined in the interface",
py::arg("params"), py::arg("callback"));

// VAD
py::class_<whisper_vad_params>(m,"whisper_vad_params")
.def(py::init<>())
.def_readwrite("threshold", &whisper_vad_params::threshold)
.def_readwrite("min_speech_duration_ms", &whisper_vad_params::min_speech_duration_ms)
.def_readwrite("min_silence_duration_ms", &whisper_vad_params::min_silence_duration_ms)
.def_readwrite("max_speech_duration_s", &whisper_vad_params::max_speech_duration_s)
.def_readwrite("speech_pad_ms", &whisper_vad_params::speech_pad_ms)
.def_readwrite("samples_overlap", &whisper_vad_params::samples_overlap);

m.def("whisper_vad_default_params", &whisper_vad_default_params);

py::class_<whisper_vad_context_params>(m,"whisper_vad_context_params")
.def(py::init<>())
.def_readwrite("n_threads", &whisper_vad_context_params::n_threads)
.def_readwrite("use_gpu", &whisper_vad_context_params::use_gpu)
.def_readwrite("gpu_device", &whisper_vad_context_params::gpu_device);

m.def("whisper_vad_default_context_params", &whisper_vad_default_context_params);
m.def("whisper_vad_init_from_file_with_params", &whisper_vad_init_from_file_with_params_wrapper);
m.def("whisper_vad_detect_speech", &whisper_vad_detect_speech_wrapper);
m.def("whisper_vad_n_probs", &whisper_vad_n_probs_wrapper);
m.def("whisper_vad_probs", &whisper_vad_probs_wrapper);
py::class_<whisper_vad_segments_wrapper>(m, "whisper_vad_segments");
m.def("whisper_vad_segments_from_probs", &whisper_vad_segments_from_probs_wrapper);
m.def("whisper_vad_segments_from_samples", &whisper_vad_segments_from_samples_wrapper);
m.def("whisper_vad_segments_n_segments", &whisper_vad_segments_n_segments_wrapper);
m.def("whisper_vad_segments_get_segment_t0", &whisper_vad_segments_get_segment_t0_wrapper);
m.def("whisper_vad_segments_get_segment_t1", &whisper_vad_segments_get_segment_t1_wrapper);
m.def("whisper_vad_free_segments", &whisper_vad_free_segments_wrapper);
m.def("whisper_vad_free", &whisper_vad_free_wrapper);




#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
Expand Down
Loading