diff --git a/output.wav b/output.wav index 91bad2e..7883ed9 100644 Binary files a/output.wav and b/output.wav differ diff --git a/sbv2_core/src/bert.rs b/sbv2_core/src/bert.rs index 50762bf..3bd8396 100644 --- a/sbv2_core/src/bert.rs +++ b/sbv2_core/src/bert.rs @@ -1,14 +1,6 @@ use crate::error::Result; use ndarray::Array2; -use ort::{GraphOptimizationLevel, Session}; - -pub fn load_model(model_file: &str) -> Result { - let session = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(1)? - .commit_from_file(model_file)?; - Ok(session) -} +use ort::Session; pub fn predict( session: &Session, diff --git a/sbv2_core/src/lib.rs b/sbv2_core/src/lib.rs index b058762..0482063 100644 --- a/sbv2_core/src/lib.rs +++ b/sbv2_core/src/lib.rs @@ -6,6 +6,7 @@ pub mod mora; pub mod nlp; pub mod norm; pub mod style; +pub mod tts; pub mod utils; pub fn add(left: usize, right: usize) -> usize { diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index b2d9a0f..63ad56b 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -1,78 +1,18 @@ -use ndarray::{concatenate, s, Array, Array1, Array2, Axis}; -use sbv2_core::{bert, error, jtalk, model, nlp, norm, style, utils}; +use sbv2_core::{error, tts}; fn main() -> error::Result<()> { - let text = "隣の客はよくかき食う客だ"; + let text = "おはようございます。"; - let normalized_text = norm::normalize_text(text); - - let jtalk = jtalk::JTalk::new()?; - let (phones, tones, mut word2ph) = jtalk.g2p(&normalized_text)?; - let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones); - - // add black - let phones = utils::intersperse(&phones, 0); - let tones = utils::intersperse(&tones, 0); - let lang_ids = utils::intersperse(&lang_ids, 0); - for i in 0..word2ph.len() { - word2ph[i] *= 2; - } - word2ph[0] += 1; - - let tokenizer = jtalk::get_tokenizer()?; - let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?; - - let session = bert::load_model("models/debert.onnx")?; - let bert_content = bert::predict(&session, token_ids, attention_masks)?; - - assert!( - word2ph.len() == normalized_text.chars().count() + 2, - "{} {}", - word2ph.len(), - normalized_text.chars().count() - ); - - let mut phone_level_feature = vec![]; - for i in 0..word2ph.len() { - // repeat_feature = np.tile(bert_content[i], (word2ph[i], 1)) - let repeat_feature = { - let (reps_rows, reps_cols) = (word2ph[i], 1); - let arr_len = bert_content.slice(s![i, ..]).len(); - - let mut results: Array2 = Array::zeros((reps_rows as usize, arr_len * reps_cols)); - - for j in 0..reps_rows { - for k in 0..reps_cols { - let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]); - view.assign(&bert_content.slice(s![i, ..])); - } - } - results - }; - phone_level_feature.push(repeat_feature); - } - // ph = np.concatenate(phone_level_feature, axis=0) - // bert_ori = ph.T - let phone_level_feature = concatenate( - Axis(0), - &phone_level_feature - .iter() - .map(|x| x.view()) - .collect::>(), + let tts_model = tts::TTSModel::new( + "models/debert.onnx", + "models/model_opt.onnx", + "models/style_vectors.json", )?; - let bert_ori = phone_level_feature.t(); - let session = bert::load_model("models/model_opt.onnx")?; - let style_vectors = style::load_style("models/style_vectors.json")?; - let style_vector = style::get_style_vector(style_vectors, 0, 1.0)?; - model::synthesize( - &session, - bert_ori.to_owned(), - Array1::from_vec(phones.iter().map(|x| *x as i64).collect()), - Array1::from_vec(tones.iter().map(|x| *x as i64).collect()), - Array1::from_vec(lang_ids.iter().map(|x| *x as i64).collect()), - style_vector, - )?; + let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?; + + let style_vector = tts_model.get_style_vector(0, 1.0)?; + tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?; Ok(()) } diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 50da5b5..d5c0eb2 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -1,7 +1,15 @@ use crate::error::Result; use hound::{SampleFormat, WavSpec, WavWriter}; use ndarray::{array, Array1, Array2, Axis}; -use ort::Session; +use ort::{GraphOptimizationLevel, Session}; + +pub fn load_model(model_file: &str) -> Result { + let session = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(1)? + .commit_from_file(model_file)?; + Ok(session) +} fn write_wav(file_path: &str, audio: &[f32], sample_rate: u32) -> Result<()> { let spec = WavSpec { diff --git a/sbv2_core/src/nlp.rs b/sbv2_core/src/nlp.rs index 253edb6..541cfc1 100644 --- a/sbv2_core/src/nlp.rs +++ b/sbv2_core/src/nlp.rs @@ -13,12 +13,12 @@ static SYMBOL_TO_ID: Lazy> = Lazy::new(|| { pub fn cleaned_text_to_sequence( cleaned_phones: Vec, tones: Vec, -) -> (Vec, Vec, Vec) { - let phones: Vec = cleaned_phones +) -> (Vec, Vec, Vec) { + let phones: Vec = cleaned_phones .iter() - .map(|phone| *SYMBOL_TO_ID.get(phone).unwrap()) + .map(|phone| *SYMBOL_TO_ID.get(phone).unwrap() as i64) .collect(); - let tones: Vec = tones.iter().map(|tone| *tone + 6).collect(); - let lang_ids: Vec = vec![1; phones.len()]; + let tones: Vec = tones.iter().map(|tone| (*tone + 6) as i64).collect(); + let lang_ids: Vec = vec![1; phones.len()]; (phones, tones, lang_ids) } diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs new file mode 100644 index 0000000..26b7434 --- /dev/null +++ b/sbv2_core/src/tts.rs @@ -0,0 +1,121 @@ +use crate::error::Result; +use crate::{bert, jtalk, model, nlp, norm, style, utils}; +use ndarray::{concatenate, s, Array, Array1, Array2, Axis}; +use ort::Session; + +pub struct TTSModel { + bert: Session, + vits2: Session, + style_vectors: Array2, + jtalk: jtalk::JTalk, +} + +impl TTSModel { + pub fn new( + bert_model_path: &str, + main_model_path: &str, + style_vector_path: &str, + ) -> Result { + let bert = model::load_model(bert_model_path)?; + let vits2 = model::load_model(main_model_path)?; + let style_vectors = style::load_style(style_vector_path)?; + let jtalk = jtalk::JTalk::new()?; + Ok(TTSModel { + bert, + vits2, + style_vectors, + jtalk, + }) + } + + pub fn parse_text( + &self, + text: &str, + ) -> Result<(Array2, Array1, Array1, Array1)> { + let normalized_text = norm::normalize_text(text); + + let (phones, tones, mut word2ph) = self.jtalk.g2p(&normalized_text)?; + let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones); + + let phones = utils::intersperse(&phones, 0); + let tones = utils::intersperse(&tones, 0); + let lang_ids = utils::intersperse(&lang_ids, 0); + for i in 0..word2ph.len() { + word2ph[i] *= 2; + } + word2ph[0] += 1; + + let tokenizer = jtalk::get_tokenizer()?; + let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?; + + let bert_content = bert::predict(&self.bert, token_ids, attention_masks)?; + + assert!( + word2ph.len() == normalized_text.chars().count() + 2, + "{} {}", + word2ph.len(), + normalized_text.chars().count() + ); + + let mut phone_level_feature = vec![]; + for i in 0..word2ph.len() { + let repeat_feature = { + let (reps_rows, reps_cols) = (word2ph[i], 1); + let arr_len = bert_content.slice(s![i, ..]).len(); + + let mut results: Array2 = + Array::zeros((reps_rows as usize, arr_len * reps_cols)); + + for j in 0..reps_rows { + for k in 0..reps_cols { + let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]); + view.assign(&bert_content.slice(s![i, ..])); + } + } + results + }; + phone_level_feature.push(repeat_feature); + } + let phone_level_feature = concatenate( + Axis(0), + &phone_level_feature + .iter() + .map(|x| x.view()) + .collect::>(), + )?; + let bert_ori = phone_level_feature.t(); + Ok(( + bert_ori.to_owned(), + phones.into(), + tones.into(), + lang_ids.into(), + )) + } + + pub fn get_style_vector(&self, style_id: i32, weight: f32) -> Result> { + Ok(style::get_style_vector( + self.style_vectors.clone(), + style_id, + weight, + )?) + } + + pub fn synthesize( + &self, + bert_ori: Array2, + phones: Array1, + tones: Array1, + lang_ids: Array1, + style_vector: Array1, + ) -> Result<()> { + model::synthesize( + &self.vits2, + bert_ori.to_owned(), + phones, + tones, + lang_ids, + style_vector, + )?; + Ok(()) + } +}