完成!

This commit is contained in:
tuna2134
2024-09-10 06:00:13 +00:00
parent 64bd29d93d
commit e55871c936
8 changed files with 138 additions and 7 deletions

3
.gitignore vendored
View File

@@ -1,2 +1,3 @@
target
models/*.onnx
models/*.onnx
models/*.json

View File

@@ -2,11 +2,11 @@ use crate::error::Result;
use ndarray::Array2;
use ort::{GraphOptimizationLevel, Session};
pub fn load_model() -> Result<Session> {
pub fn load_model(model_file: &str) -> Result<Session> {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(1)?
.commit_from_file("models/debert.onnx")?;
.commit_from_file(model_file)?;
Ok(session)
}

View File

@@ -12,6 +12,10 @@ pub enum Error {
NdArrayError(#[from] ndarray::ShapeError),
#[error("Value error: {0}")]
ValueError(String),
#[error("Serde_json error: {0}")]
SerdeJsonError(#[from] serde_json::Error),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -1,9 +1,11 @@
pub mod bert;
pub mod error;
pub mod jtalk;
pub mod model;
pub mod mora;
pub mod nlp;
pub mod norm;
pub mod style;
pub mod utils;
pub fn add(left: usize, right: usize) -> usize {

View File

@@ -1,4 +1,5 @@
use sbv2_core::{bert, error, jtalk, nlp, norm, utils};
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
use sbv2_core::{bert, error, jtalk, model, nlp, norm, style, utils};
fn main() -> error::Result<()> {
let text = "こんにちは,世界!";
@@ -22,11 +23,9 @@ fn main() -> error::Result<()> {
let tokenizer = jtalk::get_tokenizer()?;
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
let session = bert::load_model()?;
let session = bert::load_model("models/debert.onnx")?;
let bert_content = bert::predict(&session, token_ids, attention_masks)?;
println!("{:?}", word2ph);
assert!(
word2ph.len() == normalized_text.chars().count() + 2,
"{} {}",
@@ -34,5 +33,50 @@ fn main() -> error::Result<()> {
normalized_text.chars().count()
);
let mut phone_level_feature = vec![];
for i in 0..word2ph.len() {
// repeat_feature = np.tile(bert_content[i], (word2ph[i], 1))
let repeat_feature = {
let (reps_rows, reps_cols) = (word2ph[i], 1);
let arr_len = bert_content.slice(s![i, ..]).len();
let mut results: Array2<f32> = Array::zeros((reps_rows as usize, arr_len * reps_cols));
for j in 0..reps_rows {
for k in 0..reps_cols {
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
view.assign(&bert_content.slice(s![i, ..]));
}
}
results
};
phone_level_feature.push(repeat_feature);
}
// ph = np.concatenate(phone_level_feature, axis=0)
// bert_ori = ph.T
let phone_level_feature = concatenate(
Axis(0),
&phone_level_feature
.iter()
.map(|x| x.view())
.collect::<Vec<_>>(),
)?;
let bert_ori = phone_level_feature.t();
println!("{:?}", bert_ori.shape());
// let data: Array2<f32> = Array2::from_shape_vec((bert_ori.shape()[0], bert_ori.shape()[1]), bert_ori.to_vec()).unwrap();
// data
let session = bert::load_model("models/model_opt.onnx")?;
let style_vectors = style::load_style("models/style_vectors.json")?;
let style_vector = style::get_style_vector(style_vectors, 0, 1.0)?;
model::synthesize(
&session,
bert_ori.to_owned(),
Array1::from_vec(phones.iter().map(|x| *x as i64).collect()),
Array1::from_vec(tones.iter().map(|x| *x as i64).collect()),
Array1::from_vec(lang_ids.iter().map(|x| *x as i64).collect()),
style_vector,
)?;
Ok(())
}

29
sbv2_core/src/model.rs Normal file
View File

@@ -0,0 +1,29 @@
use crate::error::Result;
use ndarray::{array, Array1, Array2, Axis};
use ort::Session;
pub fn synthesize(
session: &Session,
bert_ori: Array2<f32>,
x_tst: Array1<i64>,
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
) -> Result<()> {
let bert = bert_ori.insert_axis(Axis(0));
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
let x_tst = x_tst.insert_axis(Axis(0));
let lang_ids = lang_ids.insert_axis(Axis(0));
let tones = tones.insert_axis(Axis(0));
let style_vector = style_vector.insert_axis(Axis(0));
let outputs = session.run(ort::inputs! {
"x_tst" => x_tst,
"x_tst_lengths" => x_tst_lengths,
"sid" => array![0 as i64],
"tones" => tones,
"language" => lang_ids,
"bert" => bert,
"ja_bert" => style_vector,
}?)?;
Ok(())
}

28
sbv2_core/src/style.rs Normal file
View File

@@ -0,0 +1,28 @@
use crate::error::Result;
use ndarray::{s, Array1, Array2};
use serde::Deserialize;
#[derive(Deserialize)]
pub struct Data {
pub shape: [usize; 2],
pub data: Vec<Vec<f32>>,
}
pub fn load_style(path: &str) -> Result<Array2<f32>> {
let data: Data = serde_json::from_str(&std::fs::read_to_string(path)?)?;
Ok(Array2::from_shape_vec(
data.shape,
data.data.iter().flatten().copied().collect(),
)?)
}
pub fn get_style_vector(
style_vectors: Array2<f32>,
style_id: i32,
weight: f32,
) -> Result<Array1<f32>> {
let mean = style_vectors.slice(s![0, ..]).to_owned();
let style_vector = style_vectors.slice(s![style_id as usize, ..]).to_owned();
let diff = (style_vector - &mean) * weight;
Ok(mean + &diff)
}

View File

@@ -1,3 +1,5 @@
use ndarray::{s, Array, Array2};
pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
where
T: Clone,
@@ -15,3 +17,24 @@ where
.for_each(|(r, s)| *r = s.clone());
result
}
/*
fn tile<T: Clone>(arr: &Array2<T>, reps: (usize, usize)) -> Array2<T> {
let (rows, cols) = arr.dim();
let (rep_rows, rep_cols) = reps;
let mut result = Array::zeros((rows * rep_rows, cols * rep_cols));
for i in 0..rep_rows {
for j in 0..rep_cols {
let view = result.slice_mut(s![
i * rows..(i + 1) * rows,
j * cols..(j + 1) * cols
]);
view.assign(arr);
}
}
result
}
*/