diff --git a/sbv2_bindings/examples/basic.py b/sbv2_bindings/examples/basic.py index 347f59b..8d9b8dd 100644 --- a/sbv2_bindings/examples/basic.py +++ b/sbv2_bindings/examples/basic.py @@ -2,22 +2,18 @@ from sbv2_bindings import TTSModel def main(): - with open("../models/debert.onnx", "rb") as f: - bert = f.read() - with open("../models/tokenizer.json", "rb") as f: - tokenizer = f.read() print("Loading models...") - - model = TTSModel(bert, tokenizer) + model = TTSModel.from_path("../models/debert.onnx", "../models/tokenizer.json") print("Models loaded!") - with open("../models/amitaro.sbv2", "rb") as f: - model.load_sbv2file(f.read()) + model.load_sbv2file_from_path("amitaro", "../models/amitaro.sbv2") print("All setup is done!") style_vector = model.get_style_vector("amitaro", 0, 1.0) with open("output.wav", "wb") as f: - f.write(model.synthesize("こんにちは", "amitaro", style_vector, 0.0, 0.5)) + data = model.synthesize("こんにちは", "amitaro", style_vector, 0.0, 0.5) + print(data) + f.write(data) if __name__ == "__main__": diff --git a/sbv2_bindings/src/sbv2.rs b/sbv2_bindings/src/sbv2.rs index ddece5f..b8d47c5 100644 --- a/sbv2_bindings/src/sbv2.rs +++ b/sbv2_bindings/src/sbv2.rs @@ -1,8 +1,11 @@ use pyo3::prelude::*; +use pyo3::types::PyBytes; use sbv2_core::tts::TTSModelHolder; use crate::style::StyleVector; +use std::fs; + #[pyclass] pub struct TTSModel { pub model: TTSModelHolder, @@ -17,11 +20,26 @@ impl TTSModel { }) } + #[staticmethod] + fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result { + Ok(Self { + model: TTSModelHolder::new( + fs::read(bert_model_path)?, + fs::read(tokenizer_path)?, + )?, + }) + } + fn load_sbv2file(&mut self, ident: String, sbv2file_bytes: Vec) -> anyhow::Result<()> { self.model.load_sbv2file(ident, sbv2file_bytes)?; Ok(()) } + fn load_sbv2file_from_path(&mut self, ident: String, sbv2file_path: String) -> anyhow::Result<()> { + self.model.load_sbv2file(ident, fs::read(sbv2file_path)?)?; + Ok(()) + } + fn get_style_vector( &self, ident: String, @@ -33,16 +51,17 @@ impl TTSModel { )) } - fn synthesize( - &self, + fn synthesize<'p>( + &'p self, + py: Python<'p>, text: String, ident: String, style_vector: StyleVector, sdp_ratio: f32, length_scale: f32, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let (bert_ori, phones, tones, lang_ids) = self.model.parse_text(&text)?; - Ok(self.model.synthesize( + let data = self.model.synthesize( ident, bert_ori, phones, @@ -51,6 +70,7 @@ impl TTSModel { style_vector.get(), sdp_ratio, length_scale, - )?) + )?; + Ok(PyBytes::new_bound(py, &data)) } }