From e55871c936dc22a6cea7537214f3898a1f203ec4 Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Tue, 10 Sep 2024 06:00:13 +0000 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=EF=BC=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 ++- sbv2_core/src/bert.rs | 4 ++-- sbv2_core/src/error.rs | 4 ++++ sbv2_core/src/lib.rs | 2 ++ sbv2_core/src/main.rs | 52 ++++++++++++++++++++++++++++++++++++++---- sbv2_core/src/model.rs | 29 +++++++++++++++++++++++ sbv2_core/src/style.rs | 28 +++++++++++++++++++++++ sbv2_core/src/utils.rs | 23 +++++++++++++++++++ 8 files changed, 138 insertions(+), 7 deletions(-) create mode 100644 sbv2_core/src/model.rs create mode 100644 sbv2_core/src/style.rs diff --git a/.gitignore b/.gitignore index c4ceed7..a601f1b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target -models/*.onnx \ No newline at end of file +models/*.onnx +models/*.json \ No newline at end of file diff --git a/sbv2_core/src/bert.rs b/sbv2_core/src/bert.rs index 688723a..021d2bb 100644 --- a/sbv2_core/src/bert.rs +++ b/sbv2_core/src/bert.rs @@ -2,11 +2,11 @@ use crate::error::Result; use ndarray::Array2; use ort::{GraphOptimizationLevel, Session}; -pub fn load_model() -> Result { +pub fn load_model(model_file: &str) -> Result { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(1)? - .commit_from_file("models/debert.onnx")?; + .commit_from_file(model_file)?; Ok(session) } diff --git a/sbv2_core/src/error.rs b/sbv2_core/src/error.rs index 53281cc..0443941 100644 --- a/sbv2_core/src/error.rs +++ b/sbv2_core/src/error.rs @@ -12,6 +12,10 @@ pub enum Error { NdArrayError(#[from] ndarray::ShapeError), #[error("Value error: {0}")] ValueError(String), + #[error("Serde_json error: {0}")] + SerdeJsonError(#[from] serde_json::Error), + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), } pub type Result = std::result::Result; diff --git a/sbv2_core/src/lib.rs b/sbv2_core/src/lib.rs index dae37ec..b058762 100644 --- a/sbv2_core/src/lib.rs +++ b/sbv2_core/src/lib.rs @@ -1,9 +1,11 @@ pub mod bert; pub mod error; pub mod jtalk; +pub mod model; pub mod mora; pub mod nlp; pub mod norm; +pub mod style; pub mod utils; pub fn add(left: usize, right: usize) -> usize { diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 74d3543..ada215b 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -1,4 +1,5 @@ -use sbv2_core::{bert, error, jtalk, nlp, norm, utils}; +use ndarray::{concatenate, s, Array, Array1, Array2, Axis}; +use sbv2_core::{bert, error, jtalk, model, nlp, norm, style, utils}; fn main() -> error::Result<()> { let text = "こんにちは,世界!"; @@ -22,11 +23,9 @@ fn main() -> error::Result<()> { let tokenizer = jtalk::get_tokenizer()?; let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?; - let session = bert::load_model()?; + let session = bert::load_model("models/debert.onnx")?; let bert_content = bert::predict(&session, token_ids, attention_masks)?; - println!("{:?}", word2ph); - assert!( word2ph.len() == normalized_text.chars().count() + 2, "{} {}", @@ -34,5 +33,50 @@ fn main() -> error::Result<()> { normalized_text.chars().count() ); + let mut phone_level_feature = vec![]; + for i in 0..word2ph.len() { + // repeat_feature = np.tile(bert_content[i], (word2ph[i], 1)) + let repeat_feature = { + let (reps_rows, reps_cols) = (word2ph[i], 1); + let arr_len = bert_content.slice(s![i, ..]).len(); + + let mut results: Array2 = Array::zeros((reps_rows as usize, arr_len * reps_cols)); + + for j in 0..reps_rows { + for k in 0..reps_cols { + let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]); + view.assign(&bert_content.slice(s![i, ..])); + } + } + results + }; + phone_level_feature.push(repeat_feature); + } + // ph = np.concatenate(phone_level_feature, axis=0) + // bert_ori = ph.T + let phone_level_feature = concatenate( + Axis(0), + &phone_level_feature + .iter() + .map(|x| x.view()) + .collect::>(), + )?; + let bert_ori = phone_level_feature.t(); + println!("{:?}", bert_ori.shape()); + // let data: Array2 = Array2::from_shape_vec((bert_ori.shape()[0], bert_ori.shape()[1]), bert_ori.to_vec()).unwrap(); + // data + + let session = bert::load_model("models/model_opt.onnx")?; + let style_vectors = style::load_style("models/style_vectors.json")?; + let style_vector = style::get_style_vector(style_vectors, 0, 1.0)?; + model::synthesize( + &session, + bert_ori.to_owned(), + Array1::from_vec(phones.iter().map(|x| *x as i64).collect()), + Array1::from_vec(tones.iter().map(|x| *x as i64).collect()), + Array1::from_vec(lang_ids.iter().map(|x| *x as i64).collect()), + style_vector, + )?; + Ok(()) } diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs new file mode 100644 index 0000000..bb04127 --- /dev/null +++ b/sbv2_core/src/model.rs @@ -0,0 +1,29 @@ +use crate::error::Result; +use ndarray::{array, Array1, Array2, Axis}; +use ort::Session; + +pub fn synthesize( + session: &Session, + bert_ori: Array2, + x_tst: Array1, + tones: Array1, + lang_ids: Array1, + style_vector: Array1, +) -> Result<()> { + let bert = bert_ori.insert_axis(Axis(0)); + let x_tst_lengths: Array1 = array![x_tst.shape()[0] as i64]; + let x_tst = x_tst.insert_axis(Axis(0)); + let lang_ids = lang_ids.insert_axis(Axis(0)); + let tones = tones.insert_axis(Axis(0)); + let style_vector = style_vector.insert_axis(Axis(0)); + let outputs = session.run(ort::inputs! { + "x_tst" => x_tst, + "x_tst_lengths" => x_tst_lengths, + "sid" => array![0 as i64], + "tones" => tones, + "language" => lang_ids, + "bert" => bert, + "ja_bert" => style_vector, + }?)?; + Ok(()) +} diff --git a/sbv2_core/src/style.rs b/sbv2_core/src/style.rs new file mode 100644 index 0000000..07b87dc --- /dev/null +++ b/sbv2_core/src/style.rs @@ -0,0 +1,28 @@ +use crate::error::Result; +use ndarray::{s, Array1, Array2}; +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct Data { + pub shape: [usize; 2], + pub data: Vec>, +} + +pub fn load_style(path: &str) -> Result> { + let data: Data = serde_json::from_str(&std::fs::read_to_string(path)?)?; + Ok(Array2::from_shape_vec( + data.shape, + data.data.iter().flatten().copied().collect(), + )?) +} + +pub fn get_style_vector( + style_vectors: Array2, + style_id: i32, + weight: f32, +) -> Result> { + let mean = style_vectors.slice(s![0, ..]).to_owned(); + let style_vector = style_vectors.slice(s![style_id as usize, ..]).to_owned(); + let diff = (style_vector - &mean) * weight; + Ok(mean + &diff) +} diff --git a/sbv2_core/src/utils.rs b/sbv2_core/src/utils.rs index d36a04c..05f6c4c 100644 --- a/sbv2_core/src/utils.rs +++ b/sbv2_core/src/utils.rs @@ -1,3 +1,5 @@ +use ndarray::{s, Array, Array2}; + pub fn intersperse(slice: &[T], sep: T) -> Vec where T: Clone, @@ -15,3 +17,24 @@ where .for_each(|(r, s)| *r = s.clone()); result } + +/* +fn tile(arr: &Array2, reps: (usize, usize)) -> Array2 { + let (rows, cols) = arr.dim(); + let (rep_rows, rep_cols) = reps; + + let mut result = Array::zeros((rows * rep_rows, cols * rep_cols)); + + for i in 0..rep_rows { + for j in 0..rep_cols { + let view = result.slice_mut(s![ + i * rows..(i + 1) * rows, + j * cols..(j + 1) * cols + ]); + view.assign(arr); + } + } + + result +} +*/