libraryか

This commit is contained in:
tuna2134
2024-09-10 08:34:59 +00:00
parent c5d24203fb
commit 72eb1f2aa8
7 changed files with 147 additions and 85 deletions

Binary file not shown.

View File

@@ -1,14 +1,6 @@
use crate::error::Result;
use ndarray::Array2;
use ort::{GraphOptimizationLevel, Session};
pub fn load_model(model_file: &str) -> Result<Session> {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(1)?
.commit_from_file(model_file)?;
Ok(session)
}
use ort::Session;
pub fn predict(
session: &Session,

View File

@@ -6,6 +6,7 @@ pub mod mora;
pub mod nlp;
pub mod norm;
pub mod style;
pub mod tts;
pub mod utils;
pub fn add(left: usize, right: usize) -> usize {

View File

@@ -1,78 +1,18 @@
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
use sbv2_core::{bert, error, jtalk, model, nlp, norm, style, utils};
use sbv2_core::{error, tts};
fn main() -> error::Result<()> {
let text = "隣の客はよくかき食う客だ";
let text = "おはようございます。";
let normalized_text = norm::normalize_text(text);
let jtalk = jtalk::JTalk::new()?;
let (phones, tones, mut word2ph) = jtalk.g2p(&normalized_text)?;
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
// add black
let phones = utils::intersperse(&phones, 0);
let tones = utils::intersperse(&tones, 0);
let lang_ids = utils::intersperse(&lang_ids, 0);
for i in 0..word2ph.len() {
word2ph[i] *= 2;
}
word2ph[0] += 1;
let tokenizer = jtalk::get_tokenizer()?;
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
let session = bert::load_model("models/debert.onnx")?;
let bert_content = bert::predict(&session, token_ids, attention_masks)?;
assert!(
word2ph.len() == normalized_text.chars().count() + 2,
"{} {}",
word2ph.len(),
normalized_text.chars().count()
);
let mut phone_level_feature = vec![];
for i in 0..word2ph.len() {
// repeat_feature = np.tile(bert_content[i], (word2ph[i], 1))
let repeat_feature = {
let (reps_rows, reps_cols) = (word2ph[i], 1);
let arr_len = bert_content.slice(s![i, ..]).len();
let mut results: Array2<f32> = Array::zeros((reps_rows as usize, arr_len * reps_cols));
for j in 0..reps_rows {
for k in 0..reps_cols {
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
view.assign(&bert_content.slice(s![i, ..]));
}
}
results
};
phone_level_feature.push(repeat_feature);
}
// ph = np.concatenate(phone_level_feature, axis=0)
// bert_ori = ph.T
let phone_level_feature = concatenate(
Axis(0),
&phone_level_feature
.iter()
.map(|x| x.view())
.collect::<Vec<_>>(),
let tts_model = tts::TTSModel::new(
"models/debert.onnx",
"models/model_opt.onnx",
"models/style_vectors.json",
)?;
let bert_ori = phone_level_feature.t();
let session = bert::load_model("models/model_opt.onnx")?;
let style_vectors = style::load_style("models/style_vectors.json")?;
let style_vector = style::get_style_vector(style_vectors, 0, 1.0)?;
model::synthesize(
&session,
bert_ori.to_owned(),
Array1::from_vec(phones.iter().map(|x| *x as i64).collect()),
Array1::from_vec(tones.iter().map(|x| *x as i64).collect()),
Array1::from_vec(lang_ids.iter().map(|x| *x as i64).collect()),
style_vector,
)?;
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
let style_vector = tts_model.get_style_vector(0, 1.0)?;
tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?;
Ok(())
}

View File

@@ -1,7 +1,15 @@
use crate::error::Result;
use hound::{SampleFormat, WavSpec, WavWriter};
use ndarray::{array, Array1, Array2, Axis};
use ort::Session;
use ort::{GraphOptimizationLevel, Session};
pub fn load_model(model_file: &str) -> Result<Session> {
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(1)?
.commit_from_file(model_file)?;
Ok(session)
}
fn write_wav(file_path: &str, audio: &[f32], sample_rate: u32) -> Result<()> {
let spec = WavSpec {

View File

@@ -13,12 +13,12 @@ static SYMBOL_TO_ID: Lazy<HashMap<String, i32>> = Lazy::new(|| {
pub fn cleaned_text_to_sequence(
cleaned_phones: Vec<String>,
tones: Vec<i32>,
) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
let phones: Vec<i32> = cleaned_phones
) -> (Vec<i64>, Vec<i64>, Vec<i64>) {
let phones: Vec<i64> = cleaned_phones
.iter()
.map(|phone| *SYMBOL_TO_ID.get(phone).unwrap())
.map(|phone| *SYMBOL_TO_ID.get(phone).unwrap() as i64)
.collect();
let tones: Vec<i32> = tones.iter().map(|tone| *tone + 6).collect();
let lang_ids: Vec<i32> = vec![1; phones.len()];
let tones: Vec<i64> = tones.iter().map(|tone| (*tone + 6) as i64).collect();
let lang_ids: Vec<i64> = vec![1; phones.len()];
(phones, tones, lang_ids)
}

121
sbv2_core/src/tts.rs Normal file
View File

@@ -0,0 +1,121 @@
use crate::error::Result;
use crate::{bert, jtalk, model, nlp, norm, style, utils};
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
use ort::Session;
pub struct TTSModel {
bert: Session,
vits2: Session,
style_vectors: Array2<f32>,
jtalk: jtalk::JTalk,
}
impl TTSModel {
pub fn new(
bert_model_path: &str,
main_model_path: &str,
style_vector_path: &str,
) -> Result<Self> {
let bert = model::load_model(bert_model_path)?;
let vits2 = model::load_model(main_model_path)?;
let style_vectors = style::load_style(style_vector_path)?;
let jtalk = jtalk::JTalk::new()?;
Ok(TTSModel {
bert,
vits2,
style_vectors,
jtalk,
})
}
pub fn parse_text(
&self,
text: &str,
) -> Result<(Array2<f32>, Array1<i64>, Array1<i64>, Array1<i64>)> {
let normalized_text = norm::normalize_text(text);
let (phones, tones, mut word2ph) = self.jtalk.g2p(&normalized_text)?;
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
let phones = utils::intersperse(&phones, 0);
let tones = utils::intersperse(&tones, 0);
let lang_ids = utils::intersperse(&lang_ids, 0);
for i in 0..word2ph.len() {
word2ph[i] *= 2;
}
word2ph[0] += 1;
let tokenizer = jtalk::get_tokenizer()?;
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
let bert_content = bert::predict(&self.bert, token_ids, attention_masks)?;
assert!(
word2ph.len() == normalized_text.chars().count() + 2,
"{} {}",
word2ph.len(),
normalized_text.chars().count()
);
let mut phone_level_feature = vec![];
for i in 0..word2ph.len() {
let repeat_feature = {
let (reps_rows, reps_cols) = (word2ph[i], 1);
let arr_len = bert_content.slice(s![i, ..]).len();
let mut results: Array2<f32> =
Array::zeros((reps_rows as usize, arr_len * reps_cols));
for j in 0..reps_rows {
for k in 0..reps_cols {
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
view.assign(&bert_content.slice(s![i, ..]));
}
}
results
};
phone_level_feature.push(repeat_feature);
}
let phone_level_feature = concatenate(
Axis(0),
&phone_level_feature
.iter()
.map(|x| x.view())
.collect::<Vec<_>>(),
)?;
let bert_ori = phone_level_feature.t();
Ok((
bert_ori.to_owned(),
phones.into(),
tones.into(),
lang_ids.into(),
))
}
pub fn get_style_vector(&self, style_id: i32, weight: f32) -> Result<Array1<f32>> {
Ok(style::get_style_vector(
self.style_vectors.clone(),
style_id,
weight,
)?)
}
pub fn synthesize(
&self,
bert_ori: Array2<f32>,
phones: Array1<i64>,
tones: Array1<i64>,
lang_ids: Array1<i64>,
style_vector: Array1<f32>,
) -> Result<()> {
model::synthesize(
&self.vits2,
bert_ori.to_owned(),
phones,
tones,
lang_ids,
style_vector,
)?;
Ok(())
}
}