This commit is contained in:
tuna2134
2024-09-12 11:52:56 +00:00
parent f5951413b3
commit d5e729aa81
3 changed files with 6 additions and 7 deletions

View File

@@ -10,6 +10,5 @@ crate-type = ["cdylib"]
[dependencies]
anyhow.workspace = true
numpy = "0.21.0"
pyo3 = { version = "0.22.0", features = ["anyhow"] }
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }

View File

@@ -1,20 +1,19 @@
use pyo3::prelude::*;
use sbv2_core::tts::TTSModel as BaseTTSModel;
use numpy::{convert::IntoPyArray
use sbv2_core::tts::TTSModelHolder;
#[pyclass]
pub struct TTSModel {
pub model: BaseTTSModel,
pub model: TTSModelHolder,
}
#[pymethods]
impl TTSModel {
#[new]
fn new(bert_model_path: &str, main_model_path: &str, style_vectors_path: &str) -> anyhow::Result<Self> {
fn new(bert_model_bytes: Vec<u8>, tokenizer_bytes: Vec<u8>) -> anyhow::Result<Self> {
Ok(Self {
model: BaseTTSModel::new(bert_model_path, main_model_path, style_vectors_path)?,
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes)?,
})
}
fn get_style_vector
fn load()
}

View File

@@ -6,6 +6,7 @@ edition = "2021"
license = "MIT"
readme = "../README.md"
repository = "https://github.com/tuna2134/sbv2-api"
documentation = "https://docs.rs/sbv2_core"
[dependencies]
anyhow.workspace = true