From 23ca6b6bdb591b993ee35995165da6c3d804aa3a Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Tue, 10 Sep 2024 01:29:29 +0000 Subject: [PATCH] fix bug --- sbv2_core/src/bert.rs | 15 ++++++++++++--- sbv2_core/src/jtalk.rs | 10 +++++----- sbv2_core/src/lib.rs | 1 + sbv2_core/src/main.rs | 26 ++++++++++++++++++++------ sbv2_core/src/utils.rs | 17 +++++++++++++++++ 5 files changed, 55 insertions(+), 14 deletions(-) create mode 100644 sbv2_core/src/utils.rs diff --git a/sbv2_core/src/bert.rs b/sbv2_core/src/bert.rs index 3377bdc..688723a 100644 --- a/sbv2_core/src/bert.rs +++ b/sbv2_core/src/bert.rs @@ -10,7 +10,11 @@ pub fn load_model() -> Result { Ok(session) } -pub fn predict(session: &Session, token_ids: Vec, attention_masks: Vec) -> Result<()> { +pub fn predict( + session: &Session, + token_ids: Vec, + attention_masks: Vec, +) -> Result> { let outputs = session.run( ort::inputs! { "input_ids" => Array2::from_shape_vec((1, token_ids.len()), token_ids).unwrap(), @@ -20,7 +24,12 @@ pub fn predict(session: &Session, token_ids: Vec, attention_masks: Vec let output = outputs.get("output").unwrap(); - println!("{:?}", output); + let content = output.try_extract_tensor::()?.to_owned(); + println!("{:?}", content); - Ok(()) + Ok(Array2::from_shape_vec( + (content.shape()[0], content.shape()[1]), + content.into_raw_vec(), + ) + .unwrap()) } diff --git a/sbv2_core/src/jtalk.rs b/sbv2_core/src/jtalk.rs index f221389..4fe2fb9 100644 --- a/sbv2_core/src/jtalk.rs +++ b/sbv2_core/src/jtalk.rs @@ -143,21 +143,21 @@ impl JTalkProcess { for (token, phoneme) in sep_tokenized.iter().zip(sep_phonemes.iter()) { let phone_len = phoneme.len() as i32; let word_len = token.len() as i32; - word2ph.extend(JTalkProcess::distribute_phone(phone_len, word_len)); + word2ph.append(&mut JTalkProcess::distribute_phone(phone_len, word_len)); } let mut new_phone_tone_list = vec![("_".to_string(), 0)]; new_phone_tone_list.append(&mut phone_tone_list); new_phone_tone_list.push(("_".to_string(), 0)); - let mut word2ph = vec![1]; - word2ph.append(&mut word2ph.clone()); - word2ph.push(1); + let mut new_word2ph = vec![1]; + new_word2ph.extend(word2ph.clone()); + new_word2ph.push(1); let phones: Vec = new_phone_tone_list.iter().map(|(x, _)| x.clone()).collect(); let tones: Vec = new_phone_tone_list.iter().map(|(_, x)| *x).collect(); - Ok((phones, tones, word2ph)) + Ok((phones, tones, new_word2ph)) } fn distribute_phone(n_phone: i32, n_word: i32) -> Vec { diff --git a/sbv2_core/src/lib.rs b/sbv2_core/src/lib.rs index bd05473..dae37ec 100644 --- a/sbv2_core/src/lib.rs +++ b/sbv2_core/src/lib.rs @@ -4,6 +4,7 @@ pub mod jtalk; pub mod mora; pub mod nlp; pub mod norm; +pub mod utils; pub fn add(left: usize, right: usize) -> usize { left + right diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 046c31f..b625160 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -1,4 +1,4 @@ -use sbv2_core::{bert, error, jtalk, nlp, norm}; +use sbv2_core::{bert, error, jtalk, nlp, norm, utils}; fn main() -> error::Result<()> { let text = "こんにちは,世界!"; @@ -7,18 +7,32 @@ fn main() -> error::Result<()> { println!("{}", normalized_text); let jtalk = jtalk::JTalk::new()?; - let (phones, tones, word2ph) = jtalk.g2p(&normalized_text)?; + let (phones, tones, mut word2ph) = jtalk.g2p(&normalized_text)?; let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones); - let tokenizer = jtalk::get_tokenizer()?; - println!("{:?}", tokenizer); + // add black + let phones = utils::intersperse(&phones, 0); + let tones = utils::intersperse(&tones, 0); + let lang_ids = utils::intersperse(&lang_ids, 0); + for i in 0..word2ph.len() { + word2ph[i] = word2ph[i] * 2; + } + word2ph[0] += 1; + let tokenizer = jtalk::get_tokenizer()?; let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?; - println!("{:?}", token_ids); let session = bert::load_model()?; + let bert_content = bert::predict(&session, token_ids, attention_masks)?; - bert::predict(&session, token_ids, attention_masks)?; + println!("{:?}", word2ph); + + assert!( + word2ph.len() == normalized_text.chars().count() + 2, + "{} {}", + word2ph.len(), + normalized_text.chars().count() + ); Ok(()) } diff --git a/sbv2_core/src/utils.rs b/sbv2_core/src/utils.rs new file mode 100644 index 0000000..d36a04c --- /dev/null +++ b/sbv2_core/src/utils.rs @@ -0,0 +1,17 @@ +pub fn intersperse(slice: &[T], sep: T) -> Vec +where + T: Clone, +{ + /* + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + */ + let mut result = vec![sep.clone(); slice.len() * 2 + 1]; + result + .iter_mut() + .step_by(2) + .zip(slice.iter()) + .for_each(|(r, s)| *r = s.clone()); + result +}