coqui_stt/
model.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
#![allow(clippy::missing_safety_doc)]
use crate::{Metadata, Stream};
use std::ffi::CStr;
use std::os::raw::c_uint;

/// A trained Coqui STT model.
pub struct Model(pub(crate) *mut coqui_stt_sys::ModelState);

// these implementations are safe, as ModelState can be passed between threads safely
unsafe impl Send for Model {}
unsafe impl Sync for Model {}

impl Drop for Model {
    #[inline]
    fn drop(&mut self) {
        // SAFETY: this is only called after the model has been disposed of
        unsafe { coqui_stt_sys::STT_FreeModel(self.0) }
    }
}

impl Model {
    /// Create a new model.
    ///
    /// # Errors
    /// Returns an error if the model path is invalid, or for other reasons.
    #[inline]
    pub fn new(model_path: impl Into<String>) -> crate::Result<Self> {
        Self::_new(model_path.into())
    }

    fn _new(model_path: String) -> crate::Result<Self> {
        let mut model_path = model_path.into_bytes();
        model_path.reserve_exact(1);
        model_path.push(b'\0');
        let model_path = CStr::from_bytes_with_nul(model_path.as_ref())?;

        let mut state = std::ptr::null_mut::<coqui_stt_sys::ModelState>();

        // SAFETY: creating a model is only done with a null pointer and a model path,
        // both of which have been checked
        let retval = unsafe {
            coqui_stt_sys::STT_CreateModel(model_path.as_ptr(), std::ptr::addr_of_mut!(state))
        };

        if let Some(e) = crate::Error::from_c_int(retval) {
            return Err(e);
        }

        if state.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(Self(state))
    }

    /// Create a new model from a memory buffer.
    ///
    /// # Errors
    /// Returns an error if the model is invalid, or for other reasons.
    #[inline]
    #[cfg(not(target_os = "windows"))]
    pub fn new_from_buffer<'a>(buffer: impl AsRef<&'a [u8]>) -> crate::Result<Self> {
        Self::_new_from_buffer(buffer.as_ref())
    }

    #[inline]
    #[cfg(not(target_os = "windows"))]
    fn _new_from_buffer(buffer: &[u8]) -> crate::Result<Self> {
        let mut state = std::ptr::null_mut::<coqui_stt_sys::ModelState>();

        // SAFETY: creating a model is only done with a null pointer and a model buffer
        // both of which have been checked
        let retval = unsafe {
            coqui_stt_sys::STT_CreateModelFromBuffer(
                buffer.as_ptr().cast::<i8>(),
                buffer.len() as c_uint,
                std::ptr::addr_of_mut!(state),
            )
        };

        if let Some(e) = crate::Error::from_c_int(retval) {
            return Err(e);
        }

        if state.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(Self(state))
    }

    /// Take this model, and return the inner model state.
    ///
    /// This is useful if the safe API does not provide something you need.
    ///
    /// # Safety
    /// Once this is called, the memory management of the model is no longer handled for you.
    ///
    /// You must not forget to call `STT_FreeModel` once you are done
    /// with the pointer to dispose of the model properly.
    #[inline]
    #[must_use]
    pub unsafe fn into_inner(self) -> *mut coqui_stt_sys::ModelState {
        let manual_drop = std::mem::ManuallyDrop::new(self);

        manual_drop.0
    }

    /// Create a new model from an existing model state.
    ///
    /// # Safety
    /// You must ensure `state` is a valid model state.
    #[inline]
    pub const unsafe fn from_model_state(state: *mut coqui_stt_sys::ModelState) -> Self {
        Self(state)
    }

    /// Enable an external scorer for this model.
    ///
    /// # Errors
    /// Returns an error if the `scorer_path`/file pointed to is invalid in some way.
    #[inline]
    pub fn enable_external_scorer(&mut self, scorer_path: impl Into<String>) -> crate::Result<()> {
        self._enable_external_scorer(scorer_path.into())
    }

    #[inline]
    fn _enable_external_scorer(&mut self, scorer_path: String) -> crate::Result<()> {
        let mut scorer_path = scorer_path.into_bytes();
        scorer_path.reserve_exact(1);
        scorer_path.push(b'\0');
        let scorer_path = CStr::from_bytes_with_nul(scorer_path.as_ref())?;
        handle_error!(coqui_stt_sys::STT_EnableExternalScorer(
            self.0,
            scorer_path.as_ptr()
        ))
    }

    /// Enable an external scorer for this model, loaded from a buffer in memory.
    ///
    /// # Errors
    /// Returns an error if the scorer in memory is invalid in some way.
    #[inline]
    #[cfg(not(target_os = "windows"))]
    pub fn enable_external_scorer_from_buffer(
        &mut self,
        buffer: impl AsRef<[u8]>,
    ) -> crate::Result<()> {
        self._enable_external_scorer_from_buffer(buffer.as_ref())
    }

    #[inline]
    #[cfg(not(target_os = "windows"))]
    fn _enable_external_scorer_from_buffer(&mut self, buffer: &[u8]) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_EnableExternalScorerFromBuffer(
            self.0,
            buffer.as_ptr().cast::<i8>(),
            buffer.len() as c_uint
        ))
    }

    /// Disable an external scorer that was previously set up with
    /// [`enable_external_scorer`](crate::Model::enable_external_scorer).
    ///
    /// # Errors
    /// Returns an error if an error happened while disabling the scorer.
    #[inline]
    pub fn disable_external_scorer(&mut self) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_DisableExternalScorer(self.0))
    }

    /// Add a hot-word and its boost.
    ///
    /// Words that don’t occur in the scorer (e.g. proper nouns),
    /// or strings that contain spaces won't be taken into account.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn add_hot_word(&mut self, word: impl Into<String>, boost: f32) -> crate::Result<()> {
        self._add_hot_word(word.into(), boost)
    }

    #[inline]
    fn _add_hot_word(&mut self, word: String, boost: f32) -> crate::Result<()> {
        let mut word = word.into_bytes();
        word.reserve_exact(1);
        word.push(b'\0');
        let word = CStr::from_bytes_with_nul(word.as_ref())?;
        handle_error!(coqui_stt_sys::STT_AddHotWord(self.0, word.as_ptr(), boost))
    }

    /// Remove entry for a hot-word from the hot-words map.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    ///
    /// Additionally, if the input word contains a NUL character anywhere in it, returns an error.
    #[inline]
    pub fn erase_hot_word(&mut self, word: impl Into<String>) -> crate::Result<()> {
        self._erase_hot_word(word.into())
    }

    #[inline]
    fn _erase_hot_word(&mut self, word: String) -> crate::Result<()> {
        let mut word = word.into_bytes();
        word.reserve_exact(1);
        word.push(b'\0');
        let word = CStr::from_bytes_with_nul(word.as_ref())?;
        handle_error!(coqui_stt_sys::STT_EraseHotWord(self.0, word.as_ptr()))
    }

    /// Removes all elements from the hot-words map.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn clear_hot_words(&mut self) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_ClearHotWords(self.0))
    }

    /// Set hyperparameters alpha and beta of the external scorer.
    ///
    /// `alpha` is the alpha hyperparameter of the decoder. Language model weight.
    ///
    /// `beta` is the beta hyperparameter of the decoder. Word insertion weight.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn set_scorer_alpha_beta(&mut self, alpha: f32, beta: f32) -> crate::Result<()> {
        handle_error!(coqui_stt_sys::STT_SetScorerAlphaBeta(self.0, alpha, beta))
    }

    /// Return the sample rate expected by a model in Hz.
    #[inline]
    #[must_use]
    pub fn get_sample_rate(&self) -> i32 {
        unsafe { coqui_stt_sys::STT_GetModelSampleRate(self.0 as *const _) }
    }

    /// Use the Coqui STT model to convert speech to text.
    ///
    /// `buffer` should be a 16-bit, mono, raw audio signal
    /// at the appropriate sample rate, matching what the model was trained on.
    /// The required sample rate can be obtained from [`get_sample_rate`](crate::Model::get_sample_rate).
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    ///
    /// Additionally, if the returned string is not valid UTF-8, this function returns an error.
    #[allow(clippy::missing_inline_in_public_items)]
    pub fn speech_to_text(&mut self, buffer: &[i16]) -> crate::Result<String> {
        let ptr = unsafe {
            coqui_stt_sys::STT_SpeechToText(self.0, buffer.as_ptr(), buffer.len() as c_uint)
        };

        if ptr.is_null() {
            return Err(crate::Error::Unknown);
        }

        // SAFETY: STT_SpeechToText will always return a valid CStr
        let cstr = unsafe { CStr::from_ptr(ptr) };
        let mut unchecked_str = Vec::new();
        unchecked_str.extend_from_slice(cstr.to_bytes());

        // SAFETY: the pointer the string points to is not used anywhere after this call
        unsafe { coqui_stt_sys::STT_FreeString(ptr) }

        Ok(String::from_utf8(unchecked_str)?)
    }

    /// Use the Coqui STT model to convert speech to text and output results including metadata.
    ///
    /// `buffer` should be a 16-bit, mono, raw audio signal
    /// at the appropriate sample rate, matching what the model was trained on.
    /// The required sample rate can be obtained from [`get_sample_rate`](crate::Model::get_sample_rate).
    ///
    /// `num_results` is the maximum number of possible transcriptions to return.
    /// Note that it is not guaranteed this many will be returned at minimum,
    /// but there will never be more than this number at maximum.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[inline]
    pub fn speech_to_text_with_metadata(
        &mut self,
        buffer: &[i16],
        num_results: u32,
    ) -> crate::Result<Metadata> {
        let ptr = unsafe {
            coqui_stt_sys::STT_SpeechToTextWithMetadata(
                self.0,
                buffer.as_ptr(),
                buffer.len() as c_uint,
                num_results,
            )
        };

        if ptr.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(crate::Metadata::new(ptr))
    }

    /// Convert this model into one used for streaming inference states.
    ///
    /// Note that this requires exclusive access to the model,
    /// so it is not possible to use the same model for multiple streams concurrently.
    ///
    /// # Errors
    /// Passes through any errors from the C library. See enum [`Error`](crate::Error).
    #[allow(clippy::missing_inline_in_public_items)]
    pub fn as_streaming(&mut self) -> crate::Result<Stream> {
        let mut state = std::ptr::null_mut();

        let retval = unsafe { coqui_stt_sys::STT_CreateStream(self.0, &mut state) };

        if let Some(e) = crate::Error::from_c_int(retval) {
            return Err(e);
        }

        if state.is_null() {
            return Err(crate::Error::Unknown);
        }

        Ok(Stream {
            model: self,
            state,
            already_freed: false,
        })
    }
}