mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-05-14 12:50:40 +00:00
fix types
This commit is contained in:
@@ -2,22 +2,18 @@ from sbv2_bindings import TTSModel
|
||||
|
||||
|
||||
def main():
|
||||
with open("../models/debert.onnx", "rb") as f:
|
||||
bert = f.read()
|
||||
with open("../models/tokenizer.json", "rb") as f:
|
||||
tokenizer = f.read()
|
||||
print("Loading models...")
|
||||
|
||||
model = TTSModel(bert, tokenizer)
|
||||
model = TTSModel.from_path("../models/debert.onnx", "../models/tokenizer.json")
|
||||
print("Models loaded!")
|
||||
|
||||
with open("../models/amitaro.sbv2", "rb") as f:
|
||||
model.load_sbv2file(f.read())
|
||||
model.load_sbv2file_from_path("amitaro", "../models/amitaro.sbv2")
|
||||
print("All setup is done!")
|
||||
|
||||
style_vector = model.get_style_vector("amitaro", 0, 1.0)
|
||||
with open("output.wav", "wb") as f:
|
||||
f.write(model.synthesize("こんにちは", "amitaro", style_vector, 0.0, 0.5))
|
||||
data = model.synthesize("こんにちは", "amitaro", style_vector, 0.0, 0.5)
|
||||
print(data)
|
||||
f.write(data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyBytes;
|
||||
use sbv2_core::tts::TTSModelHolder;
|
||||
|
||||
use crate::style::StyleVector;
|
||||
|
||||
use std::fs;
|
||||
|
||||
#[pyclass]
|
||||
pub struct TTSModel {
|
||||
pub model: TTSModelHolder,
|
||||
@@ -17,11 +20,26 @@ impl TTSModel {
|
||||
})
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
model: TTSModelHolder::new(
|
||||
fs::read(bert_model_path)?,
|
||||
fs::read(tokenizer_path)?,
|
||||
)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn load_sbv2file(&mut self, ident: String, sbv2file_bytes: Vec<u8>) -> anyhow::Result<()> {
|
||||
self.model.load_sbv2file(ident, sbv2file_bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_sbv2file_from_path(&mut self, ident: String, sbv2file_path: String) -> anyhow::Result<()> {
|
||||
self.model.load_sbv2file(ident, fs::read(sbv2file_path)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_style_vector(
|
||||
&self,
|
||||
ident: String,
|
||||
@@ -33,16 +51,17 @@ impl TTSModel {
|
||||
))
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
fn synthesize<'p>(
|
||||
&'p self,
|
||||
py: Python<'p>,
|
||||
text: String,
|
||||
ident: String,
|
||||
style_vector: StyleVector,
|
||||
sdp_ratio: f32,
|
||||
length_scale: f32,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
) -> anyhow::Result<Bound<PyBytes>> {
|
||||
let (bert_ori, phones, tones, lang_ids) = self.model.parse_text(&text)?;
|
||||
Ok(self.model.synthesize(
|
||||
let data = self.model.synthesize(
|
||||
ident,
|
||||
bert_ori,
|
||||
phones,
|
||||
@@ -51,6 +70,7 @@ impl TTSModel {
|
||||
style_vector.get(),
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
)?)
|
||||
)?;
|
||||
Ok(PyBytes::new_bound(py, &data))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user