diff --git a/.env.sample b/.env.sample index 300cd96..2d42907 100644 --- a/.env.sample +++ b/.env.sample @@ -1,4 +1,6 @@ BERT_MODEL_PATH=models/deberta.onnx MODEL_PATH=models/model_tsukuyomi.onnx +MODELS_PATH=models STYLE_VECTORS_PATH=models/style_vectors.json -TOKENIZER_PATH=models/tokenizer.json \ No newline at end of file +TOKENIZER_PATH=models/tokenizer.json +ADDR=localhost:3000 \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 596d85a..33a8baa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,55 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.6.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" + +[[package]] +name = "anstyle-parse" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + [[package]] name = "anyhow" version = "1.0.87" @@ -188,6 +237,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "colorchoice" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" + [[package]] name = "console" version = "0.15.8" @@ -457,6 +512,29 @@ dependencies = [ "encoding_rs", ] +[[package]] +name = "env_filter" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "errno" version = "0.3.9" @@ -653,6 +731,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "hyper" version = "1.4.1" @@ -725,6 +809,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.11.0" @@ -1642,6 +1732,8 @@ dependencies = [ "anyhow", "axum", "dotenvy", + "env_logger", + "log", "sbv2_core", "serde", "tokio", @@ -1652,6 +1744,7 @@ name = "sbv2_core" version = "0.1.0" dependencies = [ "anyhow", + "dotenvy", "hound", "jpreprocess", "ndarray", @@ -2101,6 +2194,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "version_check" version = "0.9.5" diff --git a/Cargo.toml b/Cargo.toml index e9b0da0..7c30b9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,4 @@ members = ["sbv2_api", "sbv2_core"] [workspace.dependencies] anyhow = "1.0.86" +dotenvy = "0.15.7" \ No newline at end of file diff --git a/convert/convert_model.py b/convert/convert_model.py index 9ec97b7..a8097db 100644 --- a/convert/convert_model.py +++ b/convert/convert_model.py @@ -90,8 +90,18 @@ model = get_net_g( ) -def forward(*args): - return model.infer(*args) +def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio): + return model.infer( + x, + x_len, + sid, + tone, + lang, + bert, + style, + sdp_ratio=sdp_ratio, + length_scale=length_scale, + ) model.forward = forward @@ -106,6 +116,8 @@ torch.onnx.export( lang_ids, bert, style_vec_tensor, + torch.tensor(1.0), + torch.tensor(0.0), ), f"../models/model_{out_name}.onnx", verbose=True, @@ -124,6 +136,8 @@ torch.onnx.export( "language", "bert", "style_vec", + "length_scale", + "sdp_ratio", ], output_names=["output"], ) diff --git a/docker/run.sh b/docker/run.sh index 40f384a..7b32775 100644 --- a/docker/run.sh +++ b/docker/run.sh @@ -1 +1 @@ -docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env.sample sbv2 \ No newline at end of file +docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env sbv2 \ No newline at end of file diff --git a/sbv2_api/Cargo.toml b/sbv2_api/Cargo.toml index f14fd27..b1ac73a 100644 --- a/sbv2_api/Cargo.toml +++ b/sbv2_api/Cargo.toml @@ -6,7 +6,9 @@ edition = "2021" [dependencies] anyhow.workspace = true axum = "0.7.5" -dotenvy = "0.15.7" +dotenvy.workspace = true +env_logger = "0.11.5" +log = "0.4.22" sbv2_core = { version = "0.1.0", path = "../sbv2_core" } serde = { version = "1.0.210", features = ["derive"] } tokio = { version = "1.40.0", features = ["full"] } @@ -14,4 +16,4 @@ tokio = { version = "1.40.0", features = ["full"] } [features] cuda = ["sbv2_core/cuda"] cuda_tf32 = ["sbv2_core/cuda_tf32"] -dynamic = ["sbv2_core/dynamic"] \ No newline at end of file +dynamic = ["sbv2_core/dynamic"] diff --git a/sbv2_api/src/main.rs b/sbv2_api/src/main.rs index 0602acf..e956e93 100644 --- a/sbv2_api/src/main.rs +++ b/sbv2_api/src/main.rs @@ -15,33 +15,39 @@ use tokio::sync::Mutex; mod error; use crate::error::AppResult; +async fn models(State(state): State) -> AppResult { + Ok(Json(state.tts_model.lock().await.models())) +} + +fn sdp_default() -> f32 { + 0.0 +} + +fn length_default() -> f32 { + 1.0 +} #[derive(Deserialize)] struct SynthesizeRequest { text: String, ident: String, + #[serde(default = "sdp_default")] + sdp_ratio: f32, + #[serde(default = "length_default")] + length_scale: f32, } async fn synthesize( - State(state): State>, - Json(SynthesizeRequest { text, ident }): Json, + State(state): State, + Json(SynthesizeRequest { + text, + ident, + sdp_ratio, + length_scale, + }): Json, ) -> AppResult { + log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}"); let buffer = { - let mut tts_model = state.tts_model.lock().await; - let tts_model = if let Some(tts_model) = &*tts_model { - tts_model - } else { - let mut tts_holder = TTSModelHolder::new( - &fs::read(env::var("BERT_MODEL_PATH")?).await?, - &fs::read(env::var("TOKENIZER_PATH")?).await?, - )?; - tts_holder.load( - "tsukuyomi", - fs::read(env::var("STYLE_VECTORS_PATH")?).await?, - fs::read(env::var("MODEL_PATH")?).await?, - )?; - *tts_model = Some(tts_holder); - tts_model.as_ref().unwrap() - }; + let tts_model = state.tts_model.lock().await; let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?; let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?; tts_model.synthesize( @@ -51,26 +57,78 @@ async fn synthesize( tones, lang_ids, style_vector, + sdp_ratio, + length_scale, )? }; Ok(([(CONTENT_TYPE, "audio/wav")], buffer)) } +#[derive(Clone)] struct AppState { - tts_model: Arc>>, + tts_model: Arc>, +} + +impl AppState { + pub async fn new() -> anyhow::Result { + let mut tts_model = TTSModelHolder::new( + &fs::read(env::var("BERT_MODEL_PATH")?).await?, + &fs::read(env::var("TOKENIZER_PATH")?).await?, + )?; + let models = env::var("MODELS_PATH").unwrap_or("models".to_string()); + let mut f = fs::read_dir(&models).await?; + let mut entries = vec![]; + while let Ok(Some(e)) = f.next_entry().await { + let name = e.file_name().to_string_lossy().to_string(); + if name.ends_with(".onnx") && name.starts_with("model_") { + let name_len = name.len(); + let name = name.chars(); + entries.push( + name.collect::>()[6..name_len - 5] + .iter() + .collect::(), + ); + } + } + for entry in entries { + log::info!("Try loading: {entry}"); + let style_vectors_bytes = + match fs::read(format!("{models}/style_vectors_{entry}.json")).await { + Ok(b) => b, + Err(e) => { + log::warn!("Error loading style_vectors_bytes from file {entry}: {e}"); + continue; + } + }; + let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await { + Ok(b) => b, + Err(e) => { + log::warn!("Error loading vits2_bytes from file {entry}: {e}"); + continue; + } + }; + if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) { + log::warn!("Error loading {entry}: {e}"); + }; + } + Ok(Self { + tts_model: Arc::new(Mutex::new(tts_model)), + }) + } } #[tokio::main] async fn main() -> anyhow::Result<()> { dotenvy::dotenv().ok(); + env_logger::init(); let app = Router::new() .route("/", get(|| async { "Hello, World!" })) .route("/synthesize", post(synthesize)) - .with_state(Arc::new(AppState { - tts_model: Arc::new(Mutex::new(None)), - })); - - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?; + .route("/models", get(models)) + .with_state(AppState::new().await?); + let addr = env::var("ADDR").unwrap_or("0.0.0.0:3000".to_string()); + let listener = tokio::net::TcpListener::bind(&addr).await?; + log::info!("Listening on {addr}"); axum::serve(listener, app).await?; Ok(()) diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 817748c..37203d9 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] anyhow.workspace = true +dotenvy.workspace = true hound = "3.5.1" jpreprocess = { version = "0.10.0", features = ["naist-jdic"] } ndarray = "0.16.1" diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 8db83e9..5d27c65 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -1,41 +1,47 @@ use std::{fs, time::Instant}; -use sbv2_core::{error, tts}; +use sbv2_core::tts; +use std::env; -fn main() -> error::Result<()> { +fn main() -> anyhow::Result<()> { + dotenvy::dotenv().ok(); let text = "眠たい"; - - let mut tts_model = tts::TTSModelHolder::new( - fs::read("models/debert.onnx")?, - fs::read("models/model_opt.onnx")?, + let ident = "aaa"; + let mut tts_holder = tts::TTSModelHolder::new( + &fs::read(env::var("BERT_MODEL_PATH")?)?, + &fs::read(env::var("TOKENIZER_PATH")?)?, )?; - tts_model.load( - "tsukuyomi", - fs::read("models/style_vectors.json")?, - fs::read("models/tokenizer.json")?, + tts_holder.load( + ident, + fs::read(env::var("STYLE_VECTORS_PATH")?)?, + fs::read(env::var("MODEL_PATH")?)?, )?; - let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?; + let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?; - let style_vector = tts_model.get_style_vector("tsukuyomi", 0, 1.0)?; - let data = tts_model.synthesize( - "tsukuyomi", + let style_vector = tts_holder.get_style_vector(ident, 0, 1.0)?; + let data = tts_holder.synthesize( + ident, bert_ori.to_owned(), phones.clone(), tones.clone(), lang_ids.clone(), style_vector.clone(), + 0.0, + 0.5, )?; std::fs::write("output.wav", data)?; let now = Instant::now(); for _ in 0..10 { - tts_model.synthesize( - "tsukuyomi", + tts_holder.synthesize( + ident, bert_ori.to_owned(), phones.clone(), tones.clone(), lang_ids.clone(), style_vector.clone(), + 0.0, + 1.0, )?; } println!("Time taken: {}", now.elapsed().as_millis()); diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index a01e710..1598e9b 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -26,7 +26,7 @@ pub fn load_model>(model_file: P) -> Result { .with_inter_threads(num_cpus::get_physical())? .commit_from_memory(model_file.as_ref())?) } - +#[allow(clippy::too_many_arguments)] pub fn synthesize( session: &Session, bert_ori: Array2, @@ -34,6 +34,8 @@ pub fn synthesize( tones: Array1, lang_ids: Array1, style_vector: Array1, + sdp_ratio: f32, + length_scale: f32, ) -> Result> { let bert = bert_ori.insert_axis(Axis(0)); let x_tst_lengths: Array1 = array![x_tst.shape()[0] as i64]; @@ -49,6 +51,8 @@ pub fn synthesize( "language" => lang_ids, "bert" => bert, "style_vec" => style_vector, + "sdp_ratio" => array![sdp_ratio], + "length_scale" => array![length_scale], }?)?; let audio_array = outputs diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs index 28e7d93..dad322f 100644 --- a/sbv2_core/src/tts.rs +++ b/sbv2_core/src/tts.rs @@ -48,6 +48,9 @@ impl TTSModelHolder { tokenizer, }) } + pub fn models(&self) -> Vec { + self.models.iter().map(|m| m.ident.to_string()).collect() + } pub fn load, P: AsRef<[u8]>>( &mut self, ident: I, @@ -156,7 +159,7 @@ impl TTSModelHolder { ) -> Result> { style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight) } - + #[allow(clippy::too_many_arguments)] pub fn synthesize>( &self, ident: I, @@ -165,6 +168,8 @@ impl TTSModelHolder { tones: Array1, lang_ids: Array1, style_vector: Array1, + sdp_ratio: f32, + length_scale: f32, ) -> Result> { let buffer = model::synthesize( &self.find_model(ident)?.vits2, @@ -173,6 +178,8 @@ impl TTSModelHolder { tones, lang_ids, style_vector, + sdp_ratio, + length_scale, )?; Ok(buffer) } diff --git a/test.py b/test.py index 907c2a4..571b861 100644 --- a/test.py +++ b/test.py @@ -1,7 +1,7 @@ import requests res = requests.post( - "http://localhost:3000/synthesize", + "http://localhost:3001/synthesize", json={"text": "おはようございます", "ident": "tsukuyomi"}, ) with open("output.wav", "wb") as f: