From 6762a877609aa1f626c49954c80ebdc2e3cd915f Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Tue, 10 Sep 2024 00:06:57 +0000 Subject: [PATCH] new: cleaned_sequence --- sbv2_core/src/{text.rs => jtalk.rs} | 1 - sbv2_core/src/lib.rs | 3 ++- sbv2_core/src/main.rs | 10 +++++----- sbv2_core/src/nlp.rs | 24 ++++++++++++++++++++++++ sbv2_core/src/norm.rs | 19 +++++++++++++++++++ 5 files changed, 50 insertions(+), 7 deletions(-) rename sbv2_core/src/{text.rs => jtalk.rs} (99%) create mode 100644 sbv2_core/src/nlp.rs diff --git a/sbv2_core/src/text.rs b/sbv2_core/src/jtalk.rs similarity index 99% rename from sbv2_core/src/text.rs rename to sbv2_core/src/jtalk.rs index 89823f0..b778d0e 100644 --- a/sbv2_core/src/text.rs +++ b/sbv2_core/src/jtalk.rs @@ -128,7 +128,6 @@ impl JTalkProcess { let mut phone_tone_list = JTalkProcess::align_tones(phone_w_punct, phone_tone_list_wo_punct)?; - println!("{:?}", phone_tone_list); let mut sep_tokenized: Vec> = Vec::new(); for i in 0..seq_text.len() { diff --git a/sbv2_core/src/lib.rs b/sbv2_core/src/lib.rs index b035dd1..bd05473 100644 --- a/sbv2_core/src/lib.rs +++ b/sbv2_core/src/lib.rs @@ -1,8 +1,9 @@ pub mod bert; pub mod error; +pub mod jtalk; pub mod mora; +pub mod nlp; pub mod norm; -pub mod text; pub fn add(left: usize, right: usize) -> usize { left + right diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 17404c8..9a0f7ab 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -1,19 +1,19 @@ -use sbv2_core::{bert, error, text}; +use sbv2_core::{bert, error, jtalk}; fn main() -> error::Result<()> { let text = "こんにちは,世界!"; - let normalized_text = text::normalize_text(text); + let normalized_text = jtalk::normalize_text(text); println!("{}", normalized_text); - let jtalk = text::JTalk::new()?; + let jtalk = jtalk::JTalk::new()?; let (phones, tones, _) = jtalk.g2p(&normalized_text)?; println!("{:?}", tones); - let tokenizer = text::get_tokenizer()?; + let tokenizer = jtalk::get_tokenizer()?; println!("{:?}", tokenizer); - let (token_ids, attention_masks) = text::tokenize(&normalized_text, &tokenizer)?; + let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?; println!("{:?}", token_ids); let session = bert::load_model()?; diff --git a/sbv2_core/src/nlp.rs b/sbv2_core/src/nlp.rs new file mode 100644 index 0000000..e4c24d3 --- /dev/null +++ b/sbv2_core/src/nlp.rs @@ -0,0 +1,24 @@ +use crate::norm::SYMBOLS; +use once_cell::sync::Lazy; +use std::collections::HashMap; + +static SYMBOL_TO_ID: Lazy> = Lazy::new(|| { + let mut map = HashMap::new(); + for (i, symbols) in SYMBOLS.iter().enumerate() { + map.insert(symbols.to_string(), i as i32); + } + map +}); + +pub fn cleaned_text_to_sequence( + cleaned_phones: Vec, + tones: Vec, +) -> (Vec, Vec, Vec) { + let phones: Vec = cleaned_phones + .iter() + .map(|phone| SYMBOL_TO_ID.get(phone).unwrap()) + .collect(); + let tones: Vec = tones.iter().map(|tone| tone + 6).collect(); + let lang_ids: Vec = vec![1; phones.len()]; + (phones, tones, lang_ids) +} diff --git a/sbv2_core/src/norm.rs b/sbv2_core/src/norm.rs index 26a2491..681e18c 100644 --- a/sbv2_core/src/norm.rs +++ b/sbv2_core/src/norm.rs @@ -69,7 +69,26 @@ __PUNCTUATION_CLEANUP_PATTERN = re.compile( ) */ +pub const JP_SYMBOLS: [&str; 42] = [ + "N", "a", "a:", "b", "by", "ch", "d", "dy", "e", "e:", "f", "g", "gy", "h", "hy", "i", "i:", + "j", "k", "ky", "m", "my", "n", "ny", "o", "o:", "p", "py", "q", "r", "ry", "s", "sh", "t", + "ts", "ty", "u", "u:", "w", "y", "z", "zy", +]; + pub static PUNCTUATIONS: [&str; 7] = ["!", "?", "…", ",", ".", "'", "-"]; +pub static PUNCTUATION_SYMBOLS: Lazy> = Lazy::new(|| { + let mut symbols = PUNCTUATIONS.to_vec(); + symbols.append(&mut vec!["SP", "UNK"]); + symbols +}); +const PAD: &str = "_"; +pub static SYMBOLS: Lazy> = Lazy::new(|| { + let mut symbols = JP_SYMBOLS.to_vec(); + symbols.append(&mut JP_SYMBOLS.to_vec()); + symbols.append(&mut PUNCTUATION_SYMBOLS.to_vec()); + symbols +}); + static PUNCTUATION_CLEANUP_PATTERN: Lazy = Lazy::new(|| { let pattern = r"[^\u{3040}-\u{309F}\u{30A0}-\u{30FF}\u{4E00}-\u{9FFF}\u{3400}-\u{4DBF}\u{3005}" .to_owned()