From 2a6562741468f5b4271d1810e742a5f2e5866d69 Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Thu, 12 Sep 2024 12:12:21 +0000 Subject: [PATCH] initial commit --- Cargo.lock | 3 ++- Cargo.toml | 1 + sbv2_bindings/Cargo.toml | 1 + sbv2_bindings/examples/basic.py | 21 +++++++++++++++++ sbv2_bindings/src/lib.rs | 1 + sbv2_bindings/src/sbv2.rs | 41 +++++++++++++++++++++++++++++++-- sbv2_bindings/src/style.rs | 16 +++++++++++++ sbv2_core/Cargo.toml | 2 +- 8 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 sbv2_bindings/examples/basic.py create mode 100644 sbv2_bindings/src/style.rs diff --git a/Cargo.lock b/Cargo.lock index a61dc9b..efc6247 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1834,13 +1834,14 @@ name = "sbv2_bindings" version = "0.1.0" dependencies = [ "anyhow", + "ndarray", "pyo3", "sbv2_core", ] [[package]] name = "sbv2_core" -version = "0.1.1" +version = "0.1.2" dependencies = [ "anyhow", "dotenvy", diff --git a/Cargo.toml b/Cargo.toml index a9a7433..822de4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,3 +6,4 @@ members = ["sbv2_api", "sbv2_core", "sbv2_bindings"] anyhow = "1.0.86" dotenvy = "0.15.7" env_logger = "0.11.5" +ndarray = "0.16.1" \ No newline at end of file diff --git a/sbv2_bindings/Cargo.toml b/sbv2_bindings/Cargo.toml index a4968d8..f372191 100644 --- a/sbv2_bindings/Cargo.toml +++ b/sbv2_bindings/Cargo.toml @@ -10,5 +10,6 @@ crate-type = ["cdylib"] [dependencies] anyhow.workspace = true +ndarray.workspace = true pyo3 = { version = "0.22.0", features = ["anyhow"] } sbv2_core = { version = "0.1.0", path = "../sbv2_core" } diff --git a/sbv2_bindings/examples/basic.py b/sbv2_bindings/examples/basic.py new file mode 100644 index 0000000..f8fcc9e --- /dev/null +++ b/sbv2_bindings/examples/basic.py @@ -0,0 +1,21 @@ +from sbv2_bindings import TTSModel + + +def main(): + with open("../models/debert.onnx", "rb") as f: + bert = f.read() + with open("../models/tokenizer.json", "rb") as f: + tokenizer = f.read() + + model = TTSModel(bert, tokenizer) + + with open("../models/amitaro.sbv2", "rb") as f: + model.load_sbv2file(f.read()) + + style_vector = model.get_style_vector("amitaro", 0, 1.0) + with open("output.wav", "wb") as f: + f.write(model.synthesize("こんにちは", "amitaro", style_vector, 0.0, 0.5)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sbv2_bindings/src/lib.rs b/sbv2_bindings/src/lib.rs index 5516fe2..cd0bd9a 100644 --- a/sbv2_bindings/src/lib.rs +++ b/sbv2_bindings/src/lib.rs @@ -1,5 +1,6 @@ use pyo3::prelude::*; mod sbv2; +pub mod style; /// Formats the sum of two numbers as string. #[pyfunction] diff --git a/sbv2_bindings/src/sbv2.rs b/sbv2_bindings/src/sbv2.rs index 1715b5d..ddece5f 100644 --- a/sbv2_bindings/src/sbv2.rs +++ b/sbv2_bindings/src/sbv2.rs @@ -1,6 +1,8 @@ use pyo3::prelude::*; use sbv2_core::tts::TTSModelHolder; +use crate::style::StyleVector; + #[pyclass] pub struct TTSModel { pub model: TTSModelHolder, @@ -15,5 +17,40 @@ impl TTSModel { }) } - fn load() -} \ No newline at end of file + fn load_sbv2file(&mut self, ident: String, sbv2file_bytes: Vec) -> anyhow::Result<()> { + self.model.load_sbv2file(ident, sbv2file_bytes)?; + Ok(()) + } + + fn get_style_vector( + &self, + ident: String, + style_id: i32, + weight: f32, + ) -> anyhow::Result { + Ok(StyleVector::new( + self.model.get_style_vector(ident, style_id, weight)?, + )) + } + + fn synthesize( + &self, + text: String, + ident: String, + style_vector: StyleVector, + sdp_ratio: f32, + length_scale: f32, + ) -> anyhow::Result> { + let (bert_ori, phones, tones, lang_ids) = self.model.parse_text(&text)?; + Ok(self.model.synthesize( + ident, + bert_ori, + phones, + tones, + lang_ids, + style_vector.get(), + sdp_ratio, + length_scale, + )?) + } +} diff --git a/sbv2_bindings/src/style.rs b/sbv2_bindings/src/style.rs new file mode 100644 index 0000000..56a8256 --- /dev/null +++ b/sbv2_bindings/src/style.rs @@ -0,0 +1,16 @@ +use ndarray::Array1; +use pyo3::prelude::*; + +#[pyclass] +#[derive(Clone)] +pub struct StyleVector(Array1); + +impl StyleVector { + pub fn new(data: Array1) -> Self { + StyleVector(data) + } + + pub fn get(&self) -> Array1 { + self.0.clone() + } +} diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 40efa96..f361d84 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -14,7 +14,7 @@ dotenvy.workspace = true env_logger.workspace = true hound = "3.5.1" jpreprocess = { version = "0.10.0", features = ["naist-jdic"] } -ndarray = "0.16.1" +ndarray.workspace = true num_cpus = "1.16.0" once_cell = "1.19.0" ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.6" }