mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
完成!
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
target
|
||||
models/*.onnx
|
||||
models/*.json
|
||||
@@ -2,11 +2,11 @@ use crate::error::Result;
|
||||
use ndarray::Array2;
|
||||
use ort::{GraphOptimizationLevel, Session};
|
||||
|
||||
pub fn load_model() -> Result<Session> {
|
||||
pub fn load_model(model_file: &str) -> Result<Session> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_file("models/debert.onnx")?;
|
||||
.commit_from_file(model_file)?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,10 @@ pub enum Error {
|
||||
NdArrayError(#[from] ndarray::ShapeError),
|
||||
#[error("Value error: {0}")]
|
||||
ValueError(String),
|
||||
#[error("Serde_json error: {0}")]
|
||||
SerdeJsonError(#[from] serde_json::Error),
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
pub mod bert;
|
||||
pub mod error;
|
||||
pub mod jtalk;
|
||||
pub mod model;
|
||||
pub mod mora;
|
||||
pub mod nlp;
|
||||
pub mod norm;
|
||||
pub mod style;
|
||||
pub mod utils;
|
||||
|
||||
pub fn add(left: usize, right: usize) -> usize {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use sbv2_core::{bert, error, jtalk, nlp, norm, utils};
|
||||
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
||||
use sbv2_core::{bert, error, jtalk, model, nlp, norm, style, utils};
|
||||
|
||||
fn main() -> error::Result<()> {
|
||||
let text = "こんにちは,世界!";
|
||||
@@ -22,11 +23,9 @@ fn main() -> error::Result<()> {
|
||||
let tokenizer = jtalk::get_tokenizer()?;
|
||||
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
|
||||
|
||||
let session = bert::load_model()?;
|
||||
let session = bert::load_model("models/debert.onnx")?;
|
||||
let bert_content = bert::predict(&session, token_ids, attention_masks)?;
|
||||
|
||||
println!("{:?}", word2ph);
|
||||
|
||||
assert!(
|
||||
word2ph.len() == normalized_text.chars().count() + 2,
|
||||
"{} {}",
|
||||
@@ -34,5 +33,50 @@ fn main() -> error::Result<()> {
|
||||
normalized_text.chars().count()
|
||||
);
|
||||
|
||||
let mut phone_level_feature = vec![];
|
||||
for i in 0..word2ph.len() {
|
||||
// repeat_feature = np.tile(bert_content[i], (word2ph[i], 1))
|
||||
let repeat_feature = {
|
||||
let (reps_rows, reps_cols) = (word2ph[i], 1);
|
||||
let arr_len = bert_content.slice(s![i, ..]).len();
|
||||
|
||||
let mut results: Array2<f32> = Array::zeros((reps_rows as usize, arr_len * reps_cols));
|
||||
|
||||
for j in 0..reps_rows {
|
||||
for k in 0..reps_cols {
|
||||
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
|
||||
view.assign(&bert_content.slice(s![i, ..]));
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
phone_level_feature.push(repeat_feature);
|
||||
}
|
||||
// ph = np.concatenate(phone_level_feature, axis=0)
|
||||
// bert_ori = ph.T
|
||||
let phone_level_feature = concatenate(
|
||||
Axis(0),
|
||||
&phone_level_feature
|
||||
.iter()
|
||||
.map(|x| x.view())
|
||||
.collect::<Vec<_>>(),
|
||||
)?;
|
||||
let bert_ori = phone_level_feature.t();
|
||||
println!("{:?}", bert_ori.shape());
|
||||
// let data: Array2<f32> = Array2::from_shape_vec((bert_ori.shape()[0], bert_ori.shape()[1]), bert_ori.to_vec()).unwrap();
|
||||
// data
|
||||
|
||||
let session = bert::load_model("models/model_opt.onnx")?;
|
||||
let style_vectors = style::load_style("models/style_vectors.json")?;
|
||||
let style_vector = style::get_style_vector(style_vectors, 0, 1.0)?;
|
||||
model::synthesize(
|
||||
&session,
|
||||
bert_ori.to_owned(),
|
||||
Array1::from_vec(phones.iter().map(|x| *x as i64).collect()),
|
||||
Array1::from_vec(tones.iter().map(|x| *x as i64).collect()),
|
||||
Array1::from_vec(lang_ids.iter().map(|x| *x as i64).collect()),
|
||||
style_vector,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
29
sbv2_core/src/model.rs
Normal file
29
sbv2_core/src/model.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::error::Result;
|
||||
use ndarray::{array, Array1, Array2, Axis};
|
||||
use ort::Session;
|
||||
|
||||
pub fn synthesize(
|
||||
session: &Session,
|
||||
bert_ori: Array2<f32>,
|
||||
x_tst: Array1<i64>,
|
||||
tones: Array1<i64>,
|
||||
lang_ids: Array1<i64>,
|
||||
style_vector: Array1<f32>,
|
||||
) -> Result<()> {
|
||||
let bert = bert_ori.insert_axis(Axis(0));
|
||||
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
|
||||
let x_tst = x_tst.insert_axis(Axis(0));
|
||||
let lang_ids = lang_ids.insert_axis(Axis(0));
|
||||
let tones = tones.insert_axis(Axis(0));
|
||||
let style_vector = style_vector.insert_axis(Axis(0));
|
||||
let outputs = session.run(ort::inputs! {
|
||||
"x_tst" => x_tst,
|
||||
"x_tst_lengths" => x_tst_lengths,
|
||||
"sid" => array![0 as i64],
|
||||
"tones" => tones,
|
||||
"language" => lang_ids,
|
||||
"bert" => bert,
|
||||
"ja_bert" => style_vector,
|
||||
}?)?;
|
||||
Ok(())
|
||||
}
|
||||
28
sbv2_core/src/style.rs
Normal file
28
sbv2_core/src/style.rs
Normal file
@@ -0,0 +1,28 @@
|
||||
use crate::error::Result;
|
||||
use ndarray::{s, Array1, Array2};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Data {
|
||||
pub shape: [usize; 2],
|
||||
pub data: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
pub fn load_style(path: &str) -> Result<Array2<f32>> {
|
||||
let data: Data = serde_json::from_str(&std::fs::read_to_string(path)?)?;
|
||||
Ok(Array2::from_shape_vec(
|
||||
data.shape,
|
||||
data.data.iter().flatten().copied().collect(),
|
||||
)?)
|
||||
}
|
||||
|
||||
pub fn get_style_vector(
|
||||
style_vectors: Array2<f32>,
|
||||
style_id: i32,
|
||||
weight: f32,
|
||||
) -> Result<Array1<f32>> {
|
||||
let mean = style_vectors.slice(s![0, ..]).to_owned();
|
||||
let style_vector = style_vectors.slice(s![style_id as usize, ..]).to_owned();
|
||||
let diff = (style_vector - &mean) * weight;
|
||||
Ok(mean + &diff)
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
use ndarray::{s, Array, Array2};
|
||||
|
||||
pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
|
||||
where
|
||||
T: Clone,
|
||||
@@ -15,3 +17,24 @@ where
|
||||
.for_each(|(r, s)| *r = s.clone());
|
||||
result
|
||||
}
|
||||
|
||||
/*
|
||||
fn tile<T: Clone>(arr: &Array2<T>, reps: (usize, usize)) -> Array2<T> {
|
||||
let (rows, cols) = arr.dim();
|
||||
let (rep_rows, rep_cols) = reps;
|
||||
|
||||
let mut result = Array::zeros((rows * rep_rows, cols * rep_cols));
|
||||
|
||||
for i in 0..rep_rows {
|
||||
for j in 0..rep_cols {
|
||||
let view = result.slice_mut(s![
|
||||
i * rows..(i + 1) * rows,
|
||||
j * cols..(j + 1) * cols
|
||||
]);
|
||||
view.assign(arr);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user