This commit is contained in:
tuna2134
2024-09-10 01:29:29 +00:00
parent 9e8c07857c
commit 23ca6b6bdb
5 changed files with 55 additions and 14 deletions

View File

@@ -10,7 +10,11 @@ pub fn load_model() -> Result<Session> {
Ok(session)
}
pub fn predict(session: &Session, token_ids: Vec<i64>, attention_masks: Vec<i64>) -> Result<()> {
pub fn predict(
session: &Session,
token_ids: Vec<i64>,
attention_masks: Vec<i64>,
) -> Result<Array2<f32>> {
let outputs = session.run(
ort::inputs! {
"input_ids" => Array2::from_shape_vec((1, token_ids.len()), token_ids).unwrap(),
@@ -20,7 +24,12 @@ pub fn predict(session: &Session, token_ids: Vec<i64>, attention_masks: Vec<i64>
let output = outputs.get("output").unwrap();
println!("{:?}", output);
let content = output.try_extract_tensor::<f32>()?.to_owned();
println!("{:?}", content);
Ok(())
Ok(Array2::from_shape_vec(
(content.shape()[0], content.shape()[1]),
content.into_raw_vec(),
)
.unwrap())
}

View File

@@ -143,21 +143,21 @@ impl JTalkProcess {
for (token, phoneme) in sep_tokenized.iter().zip(sep_phonemes.iter()) {
let phone_len = phoneme.len() as i32;
let word_len = token.len() as i32;
word2ph.extend(JTalkProcess::distribute_phone(phone_len, word_len));
word2ph.append(&mut JTalkProcess::distribute_phone(phone_len, word_len));
}
let mut new_phone_tone_list = vec![("_".to_string(), 0)];
new_phone_tone_list.append(&mut phone_tone_list);
new_phone_tone_list.push(("_".to_string(), 0));
let mut word2ph = vec![1];
word2ph.append(&mut word2ph.clone());
word2ph.push(1);
let mut new_word2ph = vec![1];
new_word2ph.extend(word2ph.clone());
new_word2ph.push(1);
let phones: Vec<String> = new_phone_tone_list.iter().map(|(x, _)| x.clone()).collect();
let tones: Vec<i32> = new_phone_tone_list.iter().map(|(_, x)| *x).collect();
Ok((phones, tones, word2ph))
Ok((phones, tones, new_word2ph))
}
fn distribute_phone(n_phone: i32, n_word: i32) -> Vec<i32> {

View File

@@ -4,6 +4,7 @@ pub mod jtalk;
pub mod mora;
pub mod nlp;
pub mod norm;
pub mod utils;
pub fn add(left: usize, right: usize) -> usize {
left + right

View File

@@ -1,4 +1,4 @@
use sbv2_core::{bert, error, jtalk, nlp, norm};
use sbv2_core::{bert, error, jtalk, nlp, norm, utils};
fn main() -> error::Result<()> {
let text = "こんにちは,世界!";
@@ -7,18 +7,32 @@ fn main() -> error::Result<()> {
println!("{}", normalized_text);
let jtalk = jtalk::JTalk::new()?;
let (phones, tones, word2ph) = jtalk.g2p(&normalized_text)?;
let (phones, tones, mut word2ph) = jtalk.g2p(&normalized_text)?;
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
let tokenizer = jtalk::get_tokenizer()?;
println!("{:?}", tokenizer);
// add black
let phones = utils::intersperse(&phones, 0);
let tones = utils::intersperse(&tones, 0);
let lang_ids = utils::intersperse(&lang_ids, 0);
for i in 0..word2ph.len() {
word2ph[i] = word2ph[i] * 2;
}
word2ph[0] += 1;
let tokenizer = jtalk::get_tokenizer()?;
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
println!("{:?}", token_ids);
let session = bert::load_model()?;
let bert_content = bert::predict(&session, token_ids, attention_masks)?;
bert::predict(&session, token_ids, attention_masks)?;
println!("{:?}", word2ph);
assert!(
word2ph.len() == normalized_text.chars().count() + 2,
"{} {}",
word2ph.len(),
normalized_text.chars().count()
);
Ok(())
}

17
sbv2_core/src/utils.rs Normal file
View File

@@ -0,0 +1,17 @@
pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
where
T: Clone,
{
/*
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
*/
let mut result = vec![sep.clone(); slice.len() * 2 + 1];
result
.iter_mut()
.step_by(2)
.zip(slice.iter())
.for_each(|(r, s)| *r = s.clone());
result
}