fix types

This commit is contained in:
tuna2134
2024-09-12 12:39:22 +00:00
parent 13ceebac9d
commit 0f586ec4b3
2 changed files with 30 additions and 14 deletions

View File

@@ -2,22 +2,18 @@ from sbv2_bindings import TTSModel
def main():
with open("../models/debert.onnx", "rb") as f:
bert = f.read()
with open("../models/tokenizer.json", "rb") as f:
tokenizer = f.read()
print("Loading models...")
model = TTSModel(bert, tokenizer)
model = TTSModel.from_path("../models/debert.onnx", "../models/tokenizer.json")
print("Models loaded!")
with open("../models/amitaro.sbv2", "rb") as f:
model.load_sbv2file(f.read())
model.load_sbv2file_from_path("amitaro", "../models/amitaro.sbv2")
print("All setup is done!")
style_vector = model.get_style_vector("amitaro", 0, 1.0)
with open("output.wav", "wb") as f:
f.write(model.synthesize("こんにちは", "amitaro", style_vector, 0.0, 0.5))
data = model.synthesize("こんにちは", "amitaro", style_vector, 0.0, 0.5)
print(data)
f.write(data)
if __name__ == "__main__":

View File

@@ -1,8 +1,11 @@
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use sbv2_core::tts::TTSModelHolder;
use crate::style::StyleVector;
use std::fs;
#[pyclass]
pub struct TTSModel {
pub model: TTSModelHolder,
@@ -17,11 +20,26 @@ impl TTSModel {
})
}
#[staticmethod]
fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result<Self> {
Ok(Self {
model: TTSModelHolder::new(
fs::read(bert_model_path)?,
fs::read(tokenizer_path)?,
)?,
})
}
fn load_sbv2file(&mut self, ident: String, sbv2file_bytes: Vec<u8>) -> anyhow::Result<()> {
self.model.load_sbv2file(ident, sbv2file_bytes)?;
Ok(())
}
fn load_sbv2file_from_path(&mut self, ident: String, sbv2file_path: String) -> anyhow::Result<()> {
self.model.load_sbv2file(ident, fs::read(sbv2file_path)?)?;
Ok(())
}
fn get_style_vector(
&self,
ident: String,
@@ -33,16 +51,17 @@ impl TTSModel {
))
}
fn synthesize(
&self,
fn synthesize<'p>(
&'p self,
py: Python<'p>,
text: String,
ident: String,
style_vector: StyleVector,
sdp_ratio: f32,
length_scale: f32,
) -> anyhow::Result<Vec<u8>> {
) -> anyhow::Result<Bound<PyBytes>> {
let (bert_ori, phones, tones, lang_ids) = self.model.parse_text(&text)?;
Ok(self.model.synthesize(
let data = self.model.synthesize(
ident,
bert_ori,
phones,
@@ -51,6 +70,7 @@ impl TTSModel {
style_vector.get(),
sdp_ratio,
length_scale,
)?)
)?;
Ok(PyBytes::new_bound(py, &data))
}
}