mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-09 16:02:56 +00:00
fix bug
This commit is contained in:
@@ -10,7 +10,11 @@ pub fn load_model() -> Result<Session> {
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub fn predict(session: &Session, token_ids: Vec<i64>, attention_masks: Vec<i64>) -> Result<()> {
|
||||
pub fn predict(
|
||||
session: &Session,
|
||||
token_ids: Vec<i64>,
|
||||
attention_masks: Vec<i64>,
|
||||
) -> Result<Array2<f32>> {
|
||||
let outputs = session.run(
|
||||
ort::inputs! {
|
||||
"input_ids" => Array2::from_shape_vec((1, token_ids.len()), token_ids).unwrap(),
|
||||
@@ -20,7 +24,12 @@ pub fn predict(session: &Session, token_ids: Vec<i64>, attention_masks: Vec<i64>
|
||||
|
||||
let output = outputs.get("output").unwrap();
|
||||
|
||||
println!("{:?}", output);
|
||||
let content = output.try_extract_tensor::<f32>()?.to_owned();
|
||||
println!("{:?}", content);
|
||||
|
||||
Ok(())
|
||||
Ok(Array2::from_shape_vec(
|
||||
(content.shape()[0], content.shape()[1]),
|
||||
content.into_raw_vec(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
@@ -143,21 +143,21 @@ impl JTalkProcess {
|
||||
for (token, phoneme) in sep_tokenized.iter().zip(sep_phonemes.iter()) {
|
||||
let phone_len = phoneme.len() as i32;
|
||||
let word_len = token.len() as i32;
|
||||
word2ph.extend(JTalkProcess::distribute_phone(phone_len, word_len));
|
||||
word2ph.append(&mut JTalkProcess::distribute_phone(phone_len, word_len));
|
||||
}
|
||||
|
||||
let mut new_phone_tone_list = vec![("_".to_string(), 0)];
|
||||
new_phone_tone_list.append(&mut phone_tone_list);
|
||||
new_phone_tone_list.push(("_".to_string(), 0));
|
||||
|
||||
let mut word2ph = vec![1];
|
||||
word2ph.append(&mut word2ph.clone());
|
||||
word2ph.push(1);
|
||||
let mut new_word2ph = vec![1];
|
||||
new_word2ph.extend(word2ph.clone());
|
||||
new_word2ph.push(1);
|
||||
|
||||
let phones: Vec<String> = new_phone_tone_list.iter().map(|(x, _)| x.clone()).collect();
|
||||
let tones: Vec<i32> = new_phone_tone_list.iter().map(|(_, x)| *x).collect();
|
||||
|
||||
Ok((phones, tones, word2ph))
|
||||
Ok((phones, tones, new_word2ph))
|
||||
}
|
||||
|
||||
fn distribute_phone(n_phone: i32, n_word: i32) -> Vec<i32> {
|
||||
|
||||
@@ -4,6 +4,7 @@ pub mod jtalk;
|
||||
pub mod mora;
|
||||
pub mod nlp;
|
||||
pub mod norm;
|
||||
pub mod utils;
|
||||
|
||||
pub fn add(left: usize, right: usize) -> usize {
|
||||
left + right
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use sbv2_core::{bert, error, jtalk, nlp, norm};
|
||||
use sbv2_core::{bert, error, jtalk, nlp, norm, utils};
|
||||
|
||||
fn main() -> error::Result<()> {
|
||||
let text = "こんにちは,世界!";
|
||||
@@ -7,18 +7,32 @@ fn main() -> error::Result<()> {
|
||||
println!("{}", normalized_text);
|
||||
|
||||
let jtalk = jtalk::JTalk::new()?;
|
||||
let (phones, tones, word2ph) = jtalk.g2p(&normalized_text)?;
|
||||
let (phones, tones, mut word2ph) = jtalk.g2p(&normalized_text)?;
|
||||
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
|
||||
|
||||
let tokenizer = jtalk::get_tokenizer()?;
|
||||
println!("{:?}", tokenizer);
|
||||
// add black
|
||||
let phones = utils::intersperse(&phones, 0);
|
||||
let tones = utils::intersperse(&tones, 0);
|
||||
let lang_ids = utils::intersperse(&lang_ids, 0);
|
||||
for i in 0..word2ph.len() {
|
||||
word2ph[i] = word2ph[i] * 2;
|
||||
}
|
||||
word2ph[0] += 1;
|
||||
|
||||
let tokenizer = jtalk::get_tokenizer()?;
|
||||
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
|
||||
println!("{:?}", token_ids);
|
||||
|
||||
let session = bert::load_model()?;
|
||||
let bert_content = bert::predict(&session, token_ids, attention_masks)?;
|
||||
|
||||
bert::predict(&session, token_ids, attention_masks)?;
|
||||
println!("{:?}", word2ph);
|
||||
|
||||
assert!(
|
||||
word2ph.len() == normalized_text.chars().count() + 2,
|
||||
"{} {}",
|
||||
word2ph.len(),
|
||||
normalized_text.chars().count()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
17
sbv2_core/src/utils.rs
Normal file
17
sbv2_core/src/utils.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/*
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
*/
|
||||
let mut result = vec![sep.clone(); slice.len() * 2 + 1];
|
||||
result
|
||||
.iter_mut()
|
||||
.step_by(2)
|
||||
.zip(slice.iter())
|
||||
.for_each(|(r, s)| *r = s.clone());
|
||||
result
|
||||
}
|
||||
Reference in New Issue
Block a user