diff --git a/Cargo.lock b/Cargo.lock index cf397ea..839c8a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1425,9 +1425,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.26" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "macro_rules_attribute" @@ -2318,6 +2318,7 @@ version = "0.2.0-alpha6" dependencies = [ "anyhow", "axum", + "log", "sbv2_core", "serde", "tokio", diff --git a/crates/sbv2_core/src/tts.rs b/crates/sbv2_core/src/tts.rs index 0c45da5..4ec9fae 100644 --- a/crates/sbv2_core/src/tts.rs +++ b/crates/sbv2_core/src/tts.rs @@ -205,6 +205,23 @@ impl TTSModelHolder { ) -> Result<(Array2, Array1, Array1, Array1)> { crate::tts_util::parse_text_blocking( text, + None, + &self.jtalk, + &self.tokenizer, + |token_ids, attention_masks| { + crate::bert::predict(&mut self.bert, token_ids, attention_masks) + }, + ) + } + + pub fn parse_text_neo( + &mut self, + text: String, + given_tones: Option>, + ) -> Result<(Array2, Array1, Array1, Array1)> { + crate::tts_util::parse_text_blocking( + &text, + given_tones, &self.jtalk, &self.tokenizer, |token_ids, attention_masks| { @@ -347,6 +364,78 @@ impl TTSModelHolder { }; tts_util::array_to_vec(audio_array) } + + pub fn easy_synthesize_neo + Copy>( + &mut self, + ident: I, + text: &str, + given_tones: Option>, + style_id: i32, + speaker_id: i64, + options: SynthesizeOptions, + ) -> Result> { + self.find_and_load_model(ident)?; + let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?; + let audio_array = if options.split_sentences { + let texts: Vec<&str> = text.split('\n').collect(); + let mut audios = vec![]; + for (i, t) in texts.iter().enumerate() { + if t.is_empty() { + continue; + } + let (bert_ori, phones, tones, lang_ids) = self.parse_text_neo(t, given_tones)?; + + let vits2 = self + .find_model(ident)? + .vits2 + .as_mut() + .ok_or(Error::ModelNotFoundError(ident.into().to_string()))?; + let audio = model::synthesize( + vits2, + bert_ori.to_owned(), + phones, + Array1::from_vec(vec![speaker_id]), + tones, + lang_ids, + style_vector.clone(), + options.sdp_ratio, + options.length_scale, + 0.677, + 0.8, + )?; + audios.push(audio.clone()); + if i != texts.len() - 1 { + audios.push(Array3::zeros((1, 1, 22050))); + } + } + concatenate( + Axis(2), + &audios.iter().map(|x| x.view()).collect::>(), + )? + } else { + let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?; + + let vits2 = self + .find_model(ident)? + .vits2 + .as_mut() + .ok_or(Error::ModelNotFoundError(ident.into().to_string()))?; + model::synthesize( + vits2, + bert_ori.to_owned(), + phones, + Array1::from_vec(vec![speaker_id]), + tones, + lang_ids, + style_vector, + options.sdp_ratio, + options.length_scale, + 0.677, + 0.8, + )? + }; + tts_util::array_to_vec(audio_array) + } } /// Synthesize options diff --git a/crates/sbv2_core/src/tts_util.rs b/crates/sbv2_core/src/tts_util.rs index 17eb9d7..3fc4f77 100644 --- a/crates/sbv2_core/src/tts_util.rs +++ b/crates/sbv2_core/src/tts_util.rs @@ -34,7 +34,6 @@ pub async fn parse_text( let (normalized_text, process) = preprocess_parse_text(text, jtalk)?; let (phones, tones, mut word2ph) = process.g2p()?; let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones); - let phones = utils::intersperse(&phones, 0); let tones = utils::intersperse(&tones, 0); let lang_ids = utils::intersperse(&lang_ids, 0); @@ -99,6 +98,7 @@ pub async fn parse_text( #[allow(clippy::type_complexity)] pub fn parse_text_blocking( text: &str, + given_tones: Option>, jtalk: &jtalk::JTalk, tokenizer: &Tokenizer, bert_predict: impl FnOnce(Vec, Vec) -> Result>, @@ -107,7 +107,10 @@ pub fn parse_text_blocking( let normalized_text = norm::normalize_text(&text); let process = jtalk.process_text(&normalized_text)?; - let (phones, tones, mut word2ph) = process.g2p()?; + let (phones, mut tones, mut word2ph) = process.g2p()?; + if let Some(given_tones) = given_tones { + tones = given_tones; + } let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones); let phones = utils::intersperse(&phones, 0); diff --git a/crates/sbv2_editor/Cargo.toml b/crates/sbv2_editor/Cargo.toml index 734109d..3d7ffab 100644 --- a/crates/sbv2_editor/Cargo.toml +++ b/crates/sbv2_editor/Cargo.toml @@ -11,6 +11,7 @@ documentation.workspace = true [dependencies] anyhow.workspace = true axum = "0.8.1" +log = "0.4.27" sbv2_core = { version = "0.2.0-alpha6", path = "../sbv2_core" } serde = { version = "1.0.219", features = ["derive"] } tokio = { version = "1.44.1", features = ["full"] } diff --git a/crates/sbv2_editor/src/main.rs b/crates/sbv2_editor/src/main.rs index ac77598..2268210 100644 --- a/crates/sbv2_editor/src/main.rs +++ b/crates/sbv2_editor/src/main.rs @@ -1,7 +1,11 @@ -use axum::{extract::Query, response::IntoResponse, routing::get, Json, Router}; -use sbv2_core::{jtalk::JTalk, tts_util::preprocess_parse_text}; +use axum::extract::State; +use axum::{extract::Query, response::IntoResponse, routing::{get, post}, Json, Router}; +use sbv2_core::{jtalk::JTalk, tts::TTSModelHolder, tts_util::preprocess_parse_text}; use serde::{Deserialize, Serialize}; -use tokio::net::TcpListener; +use tokio::{fs, net::TcpListener, sync::Mutex}; + +use std::env; +use std::sync::Arc; use error::AppResult; @@ -12,32 +16,162 @@ struct RequestCreateAudioQuery { text: String, } -#[derive(Serialize)] -struct ResponseCreateAudioQuery { +#[derive(Serialize, Deserialize)] +struct AudioQuery { kana: String, tone: i32, } +#[derive(Serialize)] +struct ResponseCreateAudioQuery { + audio_query: Vec, + text: String, +} + async fn create_audio_query( Query(request): Query, ) -> AppResult { - let (_, process) = preprocess_parse_text(&request.text, &JTalk::new()?)?; + let (text, process) = preprocess_parse_text(&request.text, &JTalk::new()?)?; let kana_tone_list = process.g2kana_tone()?; - let response = kana_tone_list + let audio_query = kana_tone_list .iter() - .map(|(kana, tone)| ResponseCreateAudioQuery { + .map(|(kana, tone)| AudioQuery { kana: kana.clone(), tone: *tone, }) .collect::>(); - Ok(Json(response)) + Ok(Json(ResponseCreateAudioQuery { audio_query, text })) +} + +#[derive(Deserialize)] +pub struct RequestSynthesis { + text: String, + speaker_id: i32, + sdp_ratio: f32, + length_scale: f32, + style_id: i32, + audio_query: Vec, +} + +async fn synthesis( + State(state): State, + Json(request): Json, +) -> AppResult { + let mut tones: Vec = request.audio_query.iter().map(|query| query.tone).collect(); + tones.insert(0, 0); + tones.push(0); + let buffer = { + let mut tts_model = state.tts_model.lock().await; + tts_model.easy_synthesize_neo( + &ident, + &text, + Some(tones), + request.style_id, + request.speaker_id, + SynthesizeOptions { + sdp_ratio: request.sdp_ratio, + length_scale: request.length_scale, + ..Default::default() + }, + )? + }; + Ok(([(CONTENT_TYPE, "audio/wav")], buffer)) +} + +#[derive(Clone)] +struct AppState { + tts_model: Arc>, +} + +impl AppState { + pub async fn new() -> anyhow::Result { + let mut tts_model = TTSModelHolder::new( + &fs::read(env::var("BERT_MODEL_PATH")?).await?, + &fs::read(env::var("TOKENIZER_PATH")?).await?, + env::var("HOLDER_MAX_LOADED_MODElS") + .ok() + .and_then(|x| x.parse().ok()), + )?; + let models = env::var("MODELS_PATH").unwrap_or("models".to_string()); + let mut f = fs::read_dir(&models).await?; + let mut entries = vec![]; + while let Ok(Some(e)) = f.next_entry().await { + let name = e.file_name().to_string_lossy().to_string(); + if name.ends_with(".onnx") && name.starts_with("model_") { + let name_len = name.len(); + let name = name.chars(); + entries.push( + name.collect::>()[6..name_len - 5] + .iter() + .collect::(), + ); + } else if name.ends_with(".sbv2") { + let entry = &name[..name.len() - 5]; + log::info!("Try loading: {entry}"); + let sbv2_bytes = match fs::read(format!("{models}/{entry}.sbv2")).await { + Ok(b) => b, + Err(e) => { + log::warn!("Error loading sbv2_bytes from file {entry}: {e}"); + continue; + } + }; + if let Err(e) = tts_model.load_sbv2file(entry, sbv2_bytes) { + log::warn!("Error loading {entry}: {e}"); + }; + log::info!("Loaded: {entry}"); + } else if name.ends_with(".aivmx") { + let entry = &name[..name.len() - 6]; + log::info!("Try loading: {entry}"); + let aivmx_bytes = match fs::read(format!("{models}/{entry}.aivmx")).await { + Ok(b) => b, + Err(e) => { + log::warn!("Error loading aivmx bytes from file {entry}: {e}"); + continue; + } + }; + if let Err(e) = tts_model.load_aivmx(entry, aivmx_bytes) { + log::error!("Error loading {entry}: {e}"); + } + log::info!("Loaded: {entry}"); + } + } + for entry in entries { + log::info!("Try loading: {entry}"); + let style_vectors_bytes = + match fs::read(format!("{models}/style_vectors_{entry}.json")).await { + Ok(b) => b, + Err(e) => { + log::warn!("Error loading style_vectors_bytes from file {entry}: {e}"); + continue; + } + }; + let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await { + Ok(b) => b, + Err(e) => { + log::warn!("Error loading vits2_bytes from file {entry}: {e}"); + continue; + } + }; + if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) { + log::warn!("Error loading {entry}: {e}"); + }; + log::info!("Loaded: {entry}"); + } + Ok(Self { + tts_model: Arc::new(Mutex::new(tts_model)), + }) + } } #[tokio::main] async fn main() -> anyhow::Result<()> { + dotenvy::dotenv_override().ok(); + env_logger::init(); let app = Router::new() .route("/", get(|| async { "Hello, world!" })) - .route("/audio_query", get(create_audio_query)); + .route("/audio_query", get(create_audio_query)) + .route("/synthesis", post(synthesis)) + .with_state(AppState::new().await?); let listener = TcpListener::bind("0.0.0.0:8080").await?; axum::serve(listener, app).await?; Ok(())