mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-05 22:12:57 +00:00
libraryか
This commit is contained in:
BIN
output.wav
BIN
output.wav
Binary file not shown.
@@ -1,14 +1,6 @@
|
||||
use crate::error::Result;
|
||||
use ndarray::Array2;
|
||||
use ort::{GraphOptimizationLevel, Session};
|
||||
|
||||
pub fn load_model(model_file: &str) -> Result<Session> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_file(model_file)?;
|
||||
Ok(session)
|
||||
}
|
||||
use ort::Session;
|
||||
|
||||
pub fn predict(
|
||||
session: &Session,
|
||||
|
||||
@@ -6,6 +6,7 @@ pub mod mora;
|
||||
pub mod nlp;
|
||||
pub mod norm;
|
||||
pub mod style;
|
||||
pub mod tts;
|
||||
pub mod utils;
|
||||
|
||||
pub fn add(left: usize, right: usize) -> usize {
|
||||
|
||||
@@ -1,78 +1,18 @@
|
||||
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
||||
use sbv2_core::{bert, error, jtalk, model, nlp, norm, style, utils};
|
||||
use sbv2_core::{error, tts};
|
||||
|
||||
fn main() -> error::Result<()> {
|
||||
let text = "隣の客はよくかき食う客だ";
|
||||
let text = "おはようございます。";
|
||||
|
||||
let normalized_text = norm::normalize_text(text);
|
||||
|
||||
let jtalk = jtalk::JTalk::new()?;
|
||||
let (phones, tones, mut word2ph) = jtalk.g2p(&normalized_text)?;
|
||||
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
|
||||
|
||||
// add black
|
||||
let phones = utils::intersperse(&phones, 0);
|
||||
let tones = utils::intersperse(&tones, 0);
|
||||
let lang_ids = utils::intersperse(&lang_ids, 0);
|
||||
for i in 0..word2ph.len() {
|
||||
word2ph[i] *= 2;
|
||||
}
|
||||
word2ph[0] += 1;
|
||||
|
||||
let tokenizer = jtalk::get_tokenizer()?;
|
||||
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
|
||||
|
||||
let session = bert::load_model("models/debert.onnx")?;
|
||||
let bert_content = bert::predict(&session, token_ids, attention_masks)?;
|
||||
|
||||
assert!(
|
||||
word2ph.len() == normalized_text.chars().count() + 2,
|
||||
"{} {}",
|
||||
word2ph.len(),
|
||||
normalized_text.chars().count()
|
||||
);
|
||||
|
||||
let mut phone_level_feature = vec![];
|
||||
for i in 0..word2ph.len() {
|
||||
// repeat_feature = np.tile(bert_content[i], (word2ph[i], 1))
|
||||
let repeat_feature = {
|
||||
let (reps_rows, reps_cols) = (word2ph[i], 1);
|
||||
let arr_len = bert_content.slice(s![i, ..]).len();
|
||||
|
||||
let mut results: Array2<f32> = Array::zeros((reps_rows as usize, arr_len * reps_cols));
|
||||
|
||||
for j in 0..reps_rows {
|
||||
for k in 0..reps_cols {
|
||||
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
|
||||
view.assign(&bert_content.slice(s![i, ..]));
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
phone_level_feature.push(repeat_feature);
|
||||
}
|
||||
// ph = np.concatenate(phone_level_feature, axis=0)
|
||||
// bert_ori = ph.T
|
||||
let phone_level_feature = concatenate(
|
||||
Axis(0),
|
||||
&phone_level_feature
|
||||
.iter()
|
||||
.map(|x| x.view())
|
||||
.collect::<Vec<_>>(),
|
||||
let tts_model = tts::TTSModel::new(
|
||||
"models/debert.onnx",
|
||||
"models/model_opt.onnx",
|
||||
"models/style_vectors.json",
|
||||
)?;
|
||||
let bert_ori = phone_level_feature.t();
|
||||
|
||||
let session = bert::load_model("models/model_opt.onnx")?;
|
||||
let style_vectors = style::load_style("models/style_vectors.json")?;
|
||||
let style_vector = style::get_style_vector(style_vectors, 0, 1.0)?;
|
||||
model::synthesize(
|
||||
&session,
|
||||
bert_ori.to_owned(),
|
||||
Array1::from_vec(phones.iter().map(|x| *x as i64).collect()),
|
||||
Array1::from_vec(tones.iter().map(|x| *x as i64).collect()),
|
||||
Array1::from_vec(lang_ids.iter().map(|x| *x as i64).collect()),
|
||||
style_vector,
|
||||
)?;
|
||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
|
||||
|
||||
let style_vector = tts_model.get_style_vector(0, 1.0)?;
|
||||
tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
use crate::error::Result;
|
||||
use hound::{SampleFormat, WavSpec, WavWriter};
|
||||
use ndarray::{array, Array1, Array2, Axis};
|
||||
use ort::Session;
|
||||
use ort::{GraphOptimizationLevel, Session};
|
||||
|
||||
pub fn load_model(model_file: &str) -> Result<Session> {
|
||||
let session = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(1)?
|
||||
.commit_from_file(model_file)?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
fn write_wav(file_path: &str, audio: &[f32], sample_rate: u32) -> Result<()> {
|
||||
let spec = WavSpec {
|
||||
|
||||
@@ -13,12 +13,12 @@ static SYMBOL_TO_ID: Lazy<HashMap<String, i32>> = Lazy::new(|| {
|
||||
pub fn cleaned_text_to_sequence(
|
||||
cleaned_phones: Vec<String>,
|
||||
tones: Vec<i32>,
|
||||
) -> (Vec<i32>, Vec<i32>, Vec<i32>) {
|
||||
let phones: Vec<i32> = cleaned_phones
|
||||
) -> (Vec<i64>, Vec<i64>, Vec<i64>) {
|
||||
let phones: Vec<i64> = cleaned_phones
|
||||
.iter()
|
||||
.map(|phone| *SYMBOL_TO_ID.get(phone).unwrap())
|
||||
.map(|phone| *SYMBOL_TO_ID.get(phone).unwrap() as i64)
|
||||
.collect();
|
||||
let tones: Vec<i32> = tones.iter().map(|tone| *tone + 6).collect();
|
||||
let lang_ids: Vec<i32> = vec![1; phones.len()];
|
||||
let tones: Vec<i64> = tones.iter().map(|tone| (*tone + 6) as i64).collect();
|
||||
let lang_ids: Vec<i64> = vec![1; phones.len()];
|
||||
(phones, tones, lang_ids)
|
||||
}
|
||||
|
||||
121
sbv2_core/src/tts.rs
Normal file
121
sbv2_core/src/tts.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use crate::error::Result;
|
||||
use crate::{bert, jtalk, model, nlp, norm, style, utils};
|
||||
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
||||
use ort::Session;
|
||||
|
||||
pub struct TTSModel {
|
||||
bert: Session,
|
||||
vits2: Session,
|
||||
style_vectors: Array2<f32>,
|
||||
jtalk: jtalk::JTalk,
|
||||
}
|
||||
|
||||
impl TTSModel {
|
||||
pub fn new(
|
||||
bert_model_path: &str,
|
||||
main_model_path: &str,
|
||||
style_vector_path: &str,
|
||||
) -> Result<Self> {
|
||||
let bert = model::load_model(bert_model_path)?;
|
||||
let vits2 = model::load_model(main_model_path)?;
|
||||
let style_vectors = style::load_style(style_vector_path)?;
|
||||
let jtalk = jtalk::JTalk::new()?;
|
||||
Ok(TTSModel {
|
||||
bert,
|
||||
vits2,
|
||||
style_vectors,
|
||||
jtalk,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn parse_text(
|
||||
&self,
|
||||
text: &str,
|
||||
) -> Result<(Array2<f32>, Array1<i64>, Array1<i64>, Array1<i64>)> {
|
||||
let normalized_text = norm::normalize_text(text);
|
||||
|
||||
let (phones, tones, mut word2ph) = self.jtalk.g2p(&normalized_text)?;
|
||||
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
|
||||
|
||||
let phones = utils::intersperse(&phones, 0);
|
||||
let tones = utils::intersperse(&tones, 0);
|
||||
let lang_ids = utils::intersperse(&lang_ids, 0);
|
||||
for i in 0..word2ph.len() {
|
||||
word2ph[i] *= 2;
|
||||
}
|
||||
word2ph[0] += 1;
|
||||
|
||||
let tokenizer = jtalk::get_tokenizer()?;
|
||||
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
|
||||
|
||||
let bert_content = bert::predict(&self.bert, token_ids, attention_masks)?;
|
||||
|
||||
assert!(
|
||||
word2ph.len() == normalized_text.chars().count() + 2,
|
||||
"{} {}",
|
||||
word2ph.len(),
|
||||
normalized_text.chars().count()
|
||||
);
|
||||
|
||||
let mut phone_level_feature = vec![];
|
||||
for i in 0..word2ph.len() {
|
||||
let repeat_feature = {
|
||||
let (reps_rows, reps_cols) = (word2ph[i], 1);
|
||||
let arr_len = bert_content.slice(s![i, ..]).len();
|
||||
|
||||
let mut results: Array2<f32> =
|
||||
Array::zeros((reps_rows as usize, arr_len * reps_cols));
|
||||
|
||||
for j in 0..reps_rows {
|
||||
for k in 0..reps_cols {
|
||||
let mut view = results.slice_mut(s![j, k * arr_len..(k + 1) * arr_len]);
|
||||
view.assign(&bert_content.slice(s![i, ..]));
|
||||
}
|
||||
}
|
||||
results
|
||||
};
|
||||
phone_level_feature.push(repeat_feature);
|
||||
}
|
||||
let phone_level_feature = concatenate(
|
||||
Axis(0),
|
||||
&phone_level_feature
|
||||
.iter()
|
||||
.map(|x| x.view())
|
||||
.collect::<Vec<_>>(),
|
||||
)?;
|
||||
let bert_ori = phone_level_feature.t();
|
||||
Ok((
|
||||
bert_ori.to_owned(),
|
||||
phones.into(),
|
||||
tones.into(),
|
||||
lang_ids.into(),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn get_style_vector(&self, style_id: i32, weight: f32) -> Result<Array1<f32>> {
|
||||
Ok(style::get_style_vector(
|
||||
self.style_vectors.clone(),
|
||||
style_id,
|
||||
weight,
|
||||
)?)
|
||||
}
|
||||
|
||||
pub fn synthesize(
|
||||
&self,
|
||||
bert_ori: Array2<f32>,
|
||||
phones: Array1<i64>,
|
||||
tones: Array1<i64>,
|
||||
lang_ids: Array1<i64>,
|
||||
style_vector: Array1<f32>,
|
||||
) -> Result<()> {
|
||||
model::synthesize(
|
||||
&self.vits2,
|
||||
bert_ori.to_owned(),
|
||||
phones,
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user