mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-25 00:29:57 +00:00
Compare commits
30 Commits
v0.2.0-alp
...
v0.2.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e0c8591cd | ||
|
|
997b562682 | ||
|
|
fbd62315d0 | ||
|
|
060af0c187 | ||
|
|
b76738f467 | ||
|
|
8598167114 | ||
|
|
001f61bb6a | ||
|
|
9b9962ed29 | ||
|
|
b414d22a3b | ||
|
|
248363ae4a | ||
|
|
c4b61a36db | ||
|
|
35d16d88a8 | ||
|
|
fe48d6a034 | ||
|
|
bca4b2053f | ||
|
|
3330242cd8 | ||
|
|
f10f71f29b | ||
|
|
7bd39b7182 | ||
|
|
2d557fb0ee | ||
|
|
14d631eeaa | ||
|
|
380daf479c | ||
|
|
cb814a9952 | ||
|
|
795caf626c | ||
|
|
fb32357f31 | ||
|
|
e4010b3b83 | ||
|
|
17244a9ede | ||
|
|
61b04fd3d7 | ||
|
|
4e57a22a40 | ||
|
|
8e10057882 | ||
|
|
0222b9a189 | ||
|
|
5e96d5aef7 |
5
.github/workflows/CI.yml
vendored
5
.github/workflows/CI.yml
vendored
@@ -140,9 +140,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
tag: [cpu, cuda]
|
||||
platform:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
@@ -163,4 +160,4 @@ jobs:
|
||||
tags: |
|
||||
ghcr.io/${{ github.repository }}:${{ matrix.tag }}
|
||||
file: docker/${{ matrix.tag }}.Dockerfile
|
||||
platforms: ${{ matrix.platform }}
|
||||
platforms: linux/amd64, linux/arm64
|
||||
|
||||
64
Cargo.lock
generated
64
Cargo.lock
generated
@@ -77,9 +77,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.92"
|
||||
version = "1.0.93"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13"
|
||||
checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775"
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
@@ -1608,9 +1608,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.22.5"
|
||||
version = "0.22.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d922163ba1f79c04bc49073ba7b32fd5a8d3b76a87c955921234b8e77333c51"
|
||||
checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cfg-if",
|
||||
@@ -1627,9 +1627,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-build-config"
|
||||
version = "0.22.5"
|
||||
version = "0.22.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc38c5feeb496c8321091edf3d63e9a6829eab4b863b4a6a65f26f3e9cc6b179"
|
||||
checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"target-lexicon",
|
||||
@@ -1637,9 +1637,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-ffi"
|
||||
version = "0.22.5"
|
||||
version = "0.22.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94845622d88ae274d2729fcefc850e63d7a3ddff5e3ce11bd88486db9f1d357d"
|
||||
checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pyo3-build-config",
|
||||
@@ -1647,9 +1647,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros"
|
||||
version = "0.22.5"
|
||||
version = "0.22.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e655aad15e09b94ffdb3ce3d217acf652e26bbc37697ef012f5e5e348c716e5e"
|
||||
checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
@@ -1659,9 +1659,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros-backend"
|
||||
version = "0.22.5"
|
||||
version = "0.22.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae1e3f09eecd94618f60a455a23def79f79eba4dc561a97324bf9ac8c6df30ce"
|
||||
checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
@@ -1930,18 +1930,18 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.214"
|
||||
version = "1.0.215"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5"
|
||||
checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.214"
|
||||
version = "1.0.215"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766"
|
||||
checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2095,9 +2095,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.86"
|
||||
version = "2.0.87"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c"
|
||||
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2118,9 +2118,9 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.42"
|
||||
version = "0.4.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ff6c40d3aedb5e06b57c6f669ad17ab063dd1e63d977c6a88e7f4dfa4f04020"
|
||||
checksum = "c65998313f8e17d0d553d28f91a0df93e4dbbbf770279c7bc21ca0f09ea1a1f6"
|
||||
dependencies = [
|
||||
"filetime",
|
||||
"libc",
|
||||
@@ -2135,18 +2135,18 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.66"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede"
|
||||
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.66"
|
||||
version = "1.0.69"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5"
|
||||
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2170,9 +2170,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.20.1"
|
||||
version = "0.20.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b172ffa9a2e5c31bbddc940cd5725d933ced983a9333bbebc4c7eda3bbce1557"
|
||||
checksum = "67b67c92f6d705e2a1d106fb0b28c696f9074901a9c656ee5d9f5de204c39bf7"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"derive_builder",
|
||||
@@ -2203,9 +2203,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.41.0"
|
||||
version = "1.41.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb"
|
||||
checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"bytes",
|
||||
@@ -2379,9 +2379,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "utoipa"
|
||||
version = "5.1.3"
|
||||
version = "5.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d9ba0ade4e2f024cd1842dfbaf9dbc540639fc082299acf7649d71bd14eaca3"
|
||||
checksum = "514a48569e4e21c86d0b84b5612b5e73c0b2cf09db63260134ba426d4e8ea714"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"serde",
|
||||
@@ -2391,9 +2391,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "utoipa-gen"
|
||||
version = "5.1.3"
|
||||
version = "5.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4cf390d6503c9c9eac988447c38ba934a707b0b768b14511a493b4fc0e8ecb00"
|
||||
checksum = "5629efe65599d0ccd5d493688cbf6e03aa7c1da07fe59ff97cf5977ed0637f66"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
||||
@@ -40,6 +40,14 @@ fn length_default() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn style_id_default() -> i32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn speaker_id_default() -> i64 {
|
||||
0
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
struct SynthesizeRequest {
|
||||
text: String,
|
||||
@@ -48,6 +56,10 @@ struct SynthesizeRequest {
|
||||
sdp_ratio: f32,
|
||||
#[serde(default = "length_default")]
|
||||
length_scale: f32,
|
||||
#[serde(default = "style_id_default")]
|
||||
style_id: i32,
|
||||
#[serde(default = "speaker_id_default")]
|
||||
speaker_id: i64,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
@@ -65,15 +77,18 @@ async fn synthesize(
|
||||
ident,
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
style_id,
|
||||
speaker_id,
|
||||
}): Json<SynthesizeRequest>,
|
||||
) -> AppResult<impl IntoResponse> {
|
||||
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
|
||||
let buffer = {
|
||||
let tts_model = state.tts_model.lock().await;
|
||||
let mut tts_model = state.tts_model.lock().await;
|
||||
tts_model.easy_synthesize(
|
||||
&ident,
|
||||
&text,
|
||||
0,
|
||||
style_id,
|
||||
speaker_id,
|
||||
SynthesizeOptions {
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
@@ -94,6 +109,9 @@ impl AppState {
|
||||
let mut tts_model = TTSModelHolder::new(
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||
env::var("HOLDER_MAX_LOADED_MODElS")
|
||||
.ok()
|
||||
.and_then(|x| x.parse().ok()),
|
||||
)?;
|
||||
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
|
||||
let mut f = fs::read_dir(&models).await?;
|
||||
|
||||
@@ -23,10 +23,15 @@ pub struct TTSModel {
|
||||
|
||||
#[pymethods]
|
||||
impl TTSModel {
|
||||
#[pyo3(signature = (bert_model_bytes, tokenizer_bytes, max_loaded_models=None))]
|
||||
#[new]
|
||||
fn new(bert_model_bytes: Vec<u8>, tokenizer_bytes: Vec<u8>) -> anyhow::Result<Self> {
|
||||
fn new(
|
||||
bert_model_bytes: Vec<u8>,
|
||||
tokenizer_bytes: Vec<u8>,
|
||||
max_loaded_models: Option<usize>,
|
||||
) -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes)?,
|
||||
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes, max_loaded_models)?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -38,10 +43,21 @@ impl TTSModel {
|
||||
/// BERTモデルのパス
|
||||
/// tokenizer_path : str
|
||||
/// トークナイザーのパス
|
||||
/// max_loaded_models: int | None
|
||||
/// 同時にVRAMに存在するモデルの数
|
||||
#[pyo3(signature = (bert_model_path, tokenizer_path, max_loaded_models=None))]
|
||||
#[staticmethod]
|
||||
fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result<Self> {
|
||||
fn from_path(
|
||||
bert_model_path: String,
|
||||
tokenizer_path: String,
|
||||
max_loaded_models: Option<usize>,
|
||||
) -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
model: TTSModelHolder::new(fs::read(bert_model_path)?, fs::read(tokenizer_path)?)?,
|
||||
model: TTSModelHolder::new(
|
||||
fs::read(bert_model_path)?,
|
||||
fs::read(tokenizer_path)?,
|
||||
max_loaded_models,
|
||||
)?,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -121,11 +137,12 @@ impl TTSModel {
|
||||
/// voice_data : bytes
|
||||
/// 音声データ
|
||||
fn synthesize<'p>(
|
||||
&'p self,
|
||||
&'p mut self,
|
||||
py: Python<'p>,
|
||||
text: String,
|
||||
ident: String,
|
||||
style_id: i32,
|
||||
speaker_id: i64,
|
||||
sdp_ratio: f32,
|
||||
length_scale: f32,
|
||||
) -> anyhow::Result<Bound<PyBytes>> {
|
||||
@@ -133,6 +150,7 @@ impl TTSModel {
|
||||
ident.as_str(),
|
||||
&text,
|
||||
style_id,
|
||||
speaker_id,
|
||||
SynthesizeOptions {
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
|
||||
@@ -11,10 +11,14 @@ fn main_inner() -> anyhow::Result<()> {
|
||||
let mut tts_holder = tts::TTSModelHolder::new(
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
||||
env::var("HOLDER_MAX_LOADED_MODElS")
|
||||
.ok()
|
||||
.and_then(|x| x.parse().ok()),
|
||||
)?;
|
||||
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
|
||||
|
||||
let audio = tts_holder.easy_synthesize(ident, &text, 0, tts::SynthesizeOptions::default())?;
|
||||
let audio =
|
||||
tts_holder.easy_synthesize(ident, &text, 0, 0, tts::SynthesizeOptions::default())?;
|
||||
fs::write("output.wav", audio)?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -52,6 +52,7 @@ pub fn synthesize(
|
||||
session: &Session,
|
||||
bert_ori: Array2<f32>,
|
||||
x_tst: Array1<i64>,
|
||||
sid: Array1<i64>,
|
||||
tones: Array1<i64>,
|
||||
lang_ids: Array1<i64>,
|
||||
style_vector: Array1<f32>,
|
||||
@@ -67,7 +68,7 @@ pub fn synthesize(
|
||||
let outputs = session.run(ort::inputs! {
|
||||
"x_tst" => x_tst,
|
||||
"x_tst_lengths" => x_tst_lengths,
|
||||
"sid" => array![0_i64],
|
||||
"sid" => sid,
|
||||
"tones" => tones,
|
||||
"language" => lang_ids,
|
||||
"bert" => bert,
|
||||
|
||||
@@ -24,9 +24,10 @@ where
|
||||
}
|
||||
|
||||
pub struct TTSModel {
|
||||
vits2: Session,
|
||||
vits2: Option<Session>,
|
||||
style_vectors: Array2<f32>,
|
||||
ident: TTSIdent,
|
||||
bytes: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// High-level Style-Bert-VITS2's API
|
||||
@@ -35,6 +36,7 @@ pub struct TTSModelHolder {
|
||||
bert: Session,
|
||||
models: Vec<TTSModel>,
|
||||
jtalk: jtalk::JTalk,
|
||||
max_loaded_models: Option<usize>,
|
||||
}
|
||||
|
||||
impl TTSModelHolder {
|
||||
@@ -43,9 +45,13 @@ impl TTSModelHolder {
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rs
|
||||
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?)?;
|
||||
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?, None)?;
|
||||
/// ```
|
||||
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
|
||||
pub fn new<P: AsRef<[u8]>>(
|
||||
bert_model_bytes: P,
|
||||
tokenizer_bytes: P,
|
||||
max_loaded_models: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let bert = model::load_model(bert_model_bytes, true)?;
|
||||
let jtalk = jtalk::JTalk::new()?;
|
||||
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
||||
@@ -54,6 +60,7 @@ impl TTSModelHolder {
|
||||
models: vec![],
|
||||
jtalk,
|
||||
tokenizer,
|
||||
max_loaded_models,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -94,10 +101,25 @@ impl TTSModelHolder {
|
||||
) -> Result<()> {
|
||||
let ident = ident.into();
|
||||
if self.find_model(ident.clone()).is_err() {
|
||||
let mut load = true;
|
||||
if let Some(max) = self.max_loaded_models {
|
||||
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
|
||||
load = false;
|
||||
}
|
||||
}
|
||||
self.models.push(TTSModel {
|
||||
vits2: model::load_model(vits2_bytes, false)?,
|
||||
vits2: if load {
|
||||
Some(model::load_model(&vits2_bytes, false)?)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
style_vectors: style::load_style(style_vectors_bytes)?,
|
||||
ident,
|
||||
bytes: if self.max_loaded_models.is_some() {
|
||||
Some(vits2_bytes.as_ref().to_vec())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
}
|
||||
Ok(())
|
||||
@@ -145,6 +167,42 @@ impl TTSModelHolder {
|
||||
.find(|m| m.ident == ident)
|
||||
.ok_or(Error::ModelNotFoundError(ident.to_string()))
|
||||
}
|
||||
fn find_and_load_model<I: Into<TTSIdent>>(&mut self, ident: I) -> Result<bool> {
|
||||
let ident = ident.into();
|
||||
let (bytes, style_vectors) = {
|
||||
let model = self
|
||||
.models
|
||||
.iter()
|
||||
.find(|m| m.ident == ident)
|
||||
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
|
||||
if model.vits2.is_some() {
|
||||
return Ok(true);
|
||||
}
|
||||
(model.bytes.clone().unwrap(), model.style_vectors.clone())
|
||||
};
|
||||
self.unload(ident.clone());
|
||||
let s = model::load_model(&bytes, false)?;
|
||||
if let Some(max) = self.max_loaded_models {
|
||||
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
|
||||
self.unload(self.models.first().unwrap().ident.clone());
|
||||
}
|
||||
}
|
||||
self.models.push(TTSModel {
|
||||
bytes: Some(bytes.to_vec()),
|
||||
vits2: Some(s),
|
||||
style_vectors,
|
||||
ident: ident.clone(),
|
||||
});
|
||||
let model = self
|
||||
.models
|
||||
.iter()
|
||||
.find(|m| m.ident == ident)
|
||||
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
|
||||
if model.vits2.is_some() {
|
||||
return Ok(true);
|
||||
}
|
||||
Err(Error::ModelNotFoundError(ident.to_string()))
|
||||
}
|
||||
|
||||
/// Get style vector by style id and weight
|
||||
///
|
||||
@@ -167,12 +225,19 @@ impl TTSModelHolder {
|
||||
/// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?;
|
||||
/// ```
|
||||
pub fn easy_synthesize<I: Into<TTSIdent> + Copy>(
|
||||
&self,
|
||||
&mut self,
|
||||
ident: I,
|
||||
text: &str,
|
||||
style_id: i32,
|
||||
speaker_id: i64,
|
||||
options: SynthesizeOptions,
|
||||
) -> Result<Vec<u8>> {
|
||||
self.find_and_load_model(ident)?;
|
||||
let vits2 = &self
|
||||
.find_model(ident)?
|
||||
.vits2
|
||||
.as_ref()
|
||||
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
|
||||
let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?;
|
||||
let audio_array = if options.split_sentences {
|
||||
let texts: Vec<&str> = text.split('\n').collect();
|
||||
@@ -183,9 +248,10 @@ impl TTSModelHolder {
|
||||
}
|
||||
let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?;
|
||||
let audio = model::synthesize(
|
||||
&self.find_model(ident)?.vits2,
|
||||
vits2,
|
||||
bert_ori.to_owned(),
|
||||
phones,
|
||||
Array1::from_vec(vec![speaker_id]),
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector.clone(),
|
||||
@@ -204,9 +270,10 @@ impl TTSModelHolder {
|
||||
} else {
|
||||
let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?;
|
||||
model::synthesize(
|
||||
&self.find_model(ident)?.vits2,
|
||||
vits2,
|
||||
bert_ori.to_owned(),
|
||||
phones,
|
||||
Array1::from_vec(vec![speaker_id]),
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector,
|
||||
@@ -216,35 +283,6 @@ impl TTSModelHolder {
|
||||
};
|
||||
tts_util::array_to_vec(audio_array)
|
||||
}
|
||||
|
||||
/// Synthesize text to audio
|
||||
///
|
||||
/// # Note
|
||||
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn synthesize<I: Into<TTSIdent>>(
|
||||
&self,
|
||||
ident: I,
|
||||
bert_ori: Array2<f32>,
|
||||
phones: Array1<i64>,
|
||||
tones: Array1<i64>,
|
||||
lang_ids: Array1<i64>,
|
||||
style_vector: Array1<f32>,
|
||||
sdp_ratio: f32,
|
||||
length_scale: f32,
|
||||
) -> Result<Vec<u8>> {
|
||||
let audio_array = model::synthesize(
|
||||
&self.find_model(ident)?.vits2,
|
||||
bert_ori.to_owned(),
|
||||
phones,
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector,
|
||||
sdp_ratio,
|
||||
length_scale,
|
||||
)?;
|
||||
tts_util::array_to_vec(audio_array)
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthesize options
|
||||
|
||||
Reference in New Issue
Block a user