mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
Merge branch 'main' of https://github.com/tuna2134/sbv2-api
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
BERT_MODEL_PATH=models/debert.onnx
|
||||
MAIN_MODEL_PATH=models/model_opt.onnx
|
||||
STYLE_VECTORS_PATH=models/style_vectors.json
|
||||
STYLE_VECTORS_PATH=models/style_vectors.json
|
||||
TOKENIZER_PATH=models/tokenizer.json
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,6 +1,6 @@
|
||||
target
|
||||
models/*.onnx
|
||||
models/*.json
|
||||
venv
|
||||
venv/
|
||||
.env
|
||||
output.wav
|
||||
@@ -1,6 +1,6 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [ "sbv2_api","sbv2_core"]
|
||||
members = ["sbv2_api", "sbv2_core"]
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1.0.86"
|
||||
|
||||
@@ -10,3 +10,7 @@ dotenvy = "0.15.7"
|
||||
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }
|
||||
serde = { version = "1.0.210", features = ["derive"] }
|
||||
tokio = { version = "1.40.0", features = ["full"] }
|
||||
|
||||
[features]
|
||||
cuda = ["sbv2_core/cuda"]
|
||||
cuda_tf32 = ["sbv2_core/cuda_tf32"]
|
||||
@@ -9,6 +9,7 @@ use sbv2_core::tts::TTSModel;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::fs;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
mod error;
|
||||
@@ -29,9 +30,10 @@ async fn synthesize(
|
||||
tts_model
|
||||
} else {
|
||||
*tts_model = Some(TTSModel::new(
|
||||
&env::var("BERT_MODEL_PATH")?,
|
||||
&env::var("MAIN_MODEL_PATH")?,
|
||||
&env::var("STYLE_VECTORS_PATH")?,
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||
&fs::read(env::var("MAIN_MODEL_PATH")?).await?,
|
||||
&fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||
)?);
|
||||
tts_model.as_ref().unwrap()
|
||||
};
|
||||
|
||||
@@ -16,3 +16,7 @@ serde = { version = "1.0.210", features = ["derive"] }
|
||||
serde_json = "1.0.128"
|
||||
thiserror = "1.0.63"
|
||||
tokenizers = "0.20.0"
|
||||
|
||||
[features]
|
||||
cuda = ["ort/cuda"]
|
||||
cuda_tf32 = []
|
||||
@@ -4,13 +4,13 @@ use crate::norm::{replace_punctuation, PUNCTUATIONS};
|
||||
use jpreprocess::*;
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
type JPreprocessType = JPreprocess<DefaultFetcher>;
|
||||
|
||||
fn get_jtalk() -> Result<JPreprocessType> {
|
||||
fn initialize_jtalk() -> Result<JPreprocessType> {
|
||||
let config = JPreprocessConfig {
|
||||
dictionary: SystemDictionaryConfig::Bundled(kind::JPreprocessDictionaryKind::NaistJdic),
|
||||
user_dictionary: None,
|
||||
@@ -50,7 +50,7 @@ pub struct JTalk {
|
||||
|
||||
impl JTalk {
|
||||
pub fn new() -> Result<Self> {
|
||||
let jpreprocess = Arc::new(get_jtalk()?);
|
||||
let jpreprocess = Arc::new(initialize_jtalk()?);
|
||||
Ok(Self { jpreprocess })
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ impl JTalk {
|
||||
static KATAKANA_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"[\u30A0-\u30FF]+").unwrap());
|
||||
static MORA_PATTERN: Lazy<Vec<String>> = Lazy::new(|| {
|
||||
let mut sorted_keys: Vec<String> = MORA_KATA_TO_MORA_PHONEMES.keys().cloned().collect();
|
||||
sorted_keys.sort_by(|a, b| b.len().cmp(&a.len()));
|
||||
sorted_keys.sort_by_key(|b| Reverse(b.len()));
|
||||
sorted_keys
|
||||
});
|
||||
static LONG_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\w)(ー*)").unwrap());
|
||||
@@ -128,8 +128,8 @@ impl JTalkProcess {
|
||||
JTalkProcess::align_tones(phone_w_punct, phone_tone_list_wo_punct)?;
|
||||
|
||||
let mut sep_tokenized: Vec<Vec<String>> = Vec::new();
|
||||
for i in 0..seq_text.len() {
|
||||
let text = seq_text[i].clone();
|
||||
for seq_text_item in &seq_text {
|
||||
let text = seq_text_item.clone();
|
||||
if !PUNCTUATIONS.contains(&text.as_str()) {
|
||||
sep_tokenized.push(text.chars().map(|x| x.to_string()).collect());
|
||||
} else {
|
||||
@@ -390,22 +390,3 @@ impl JTalkProcess {
|
||||
Ok(phones)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_tokenizer() -> Result<Tokenizer> {
|
||||
let tokenizer = Tokenizer::from_file("tokenizer.json")?;
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
|
||||
let mut token_ids = vec![1];
|
||||
let mut attention_masks = vec![1];
|
||||
for content in text.chars() {
|
||||
let token = tokenizer.encode(content.to_string(), false)?;
|
||||
let ids = token.get_ids();
|
||||
token_ids.extend(ids.iter().map(|&x| x as i64));
|
||||
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
|
||||
}
|
||||
token_ids.push(2);
|
||||
attention_masks.push(1);
|
||||
Ok((token_ids, attention_masks))
|
||||
}
|
||||
|
||||
@@ -6,20 +6,6 @@ pub mod mora;
|
||||
pub mod nlp;
|
||||
pub mod norm;
|
||||
pub mod style;
|
||||
pub mod tokenizer;
|
||||
pub mod tts;
|
||||
pub mod utils;
|
||||
|
||||
pub fn add(left: usize, right: usize) -> usize {
|
||||
left + right
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +1,38 @@
|
||||
use std::{fs, time::Instant};
|
||||
|
||||
use sbv2_core::{error, tts};
|
||||
|
||||
fn main() -> error::Result<()> {
|
||||
let text = "眠たい";
|
||||
|
||||
let tts_model = tts::TTSModel::new(
|
||||
"models/debert.onnx",
|
||||
"models/model_opt.onnx",
|
||||
"models/style_vectors.json",
|
||||
fs::read("models/debert.onnx")?,
|
||||
fs::read("models/model_opt.onnx")?,
|
||||
fs::read("models/style_vectors.json")?,
|
||||
fs::read("models/tokenizer.json")?,
|
||||
)?;
|
||||
|
||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
|
||||
|
||||
let style_vector = tts_model.get_style_vector(0, 1.0)?;
|
||||
let data = tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?;
|
||||
|
||||
let data = tts_model.synthesize(
|
||||
bert_ori.to_owned(),
|
||||
phones.clone(),
|
||||
tones.clone(),
|
||||
lang_ids.clone(),
|
||||
style_vector.clone(),
|
||||
)?;
|
||||
std::fs::write("output.wav", data)?;
|
||||
|
||||
let now = Instant::now();
|
||||
for _ in 0..10 {
|
||||
tts_model.synthesize(
|
||||
bert_ori.to_owned(),
|
||||
phones.clone(),
|
||||
tones.clone(),
|
||||
lang_ids.clone(),
|
||||
style_vector.clone(),
|
||||
)?;
|
||||
}
|
||||
println!("Time taken: {}", now.elapsed().as_millis());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,14 +4,27 @@ use ndarray::{array, s, Array1, Array2, Axis};
|
||||
use ort::{GraphOptimizationLevel, Session};
|
||||
use std::io::Cursor;
|
||||
|
||||
pub fn load_model(model_file: &str) -> Result<Session> {
|
||||
let session = Session::builder()?
|
||||
#[allow(clippy::vec_init_then_push)]
|
||||
pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
||||
let mut exp = Vec::new();
|
||||
#[cfg(feature = "cuda")]
|
||||
{
|
||||
let mut cuda = ort::CUDAExecutionProvider::default()
|
||||
.with_conv_algorithm_search(ort::CUDAExecutionProviderCuDNNConvAlgoSearch::Default);
|
||||
#[cfg(feature = "cuda_tf32")]
|
||||
{
|
||||
cuda = cuda.with_tf32(true);
|
||||
}
|
||||
exp.push(cuda.build());
|
||||
}
|
||||
exp.push(ort::CPUExecutionProvider::default().build());
|
||||
Ok(Session::builder()?
|
||||
.with_execution_providers(exp)?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(num_cpus::get_physical())?
|
||||
.with_parallel_execution(true)?
|
||||
.with_inter_threads(num_cpus::get_physical())?
|
||||
.commit_from_file(model_file)?;
|
||||
Ok(session)
|
||||
.commit_from_memory(model_file.as_ref())?)
|
||||
}
|
||||
|
||||
pub fn synthesize(
|
||||
|
||||
@@ -8,8 +8,8 @@ pub struct Data {
|
||||
pub data: Vec<Vec<f32>>,
|
||||
}
|
||||
|
||||
pub fn load_style(path: &str) -> Result<Array2<f32>> {
|
||||
let data: Data = serde_json::from_str(&std::fs::read_to_string(path)?)?;
|
||||
pub fn load_style<P: AsRef<[u8]>>(path: P) -> Result<Array2<f32>> {
|
||||
let data: Data = serde_json::from_slice(path.as_ref())?;
|
||||
Ok(Array2::from_shape_vec(
|
||||
data.shape,
|
||||
data.data.iter().flatten().copied().collect(),
|
||||
|
||||
21
sbv2_core/src/tokenizer.rs
Normal file
21
sbv2_core/src/tokenizer.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
use crate::error::Result;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub fn get_tokenizer<P: AsRef<[u8]>>(p: P) -> Result<Tokenizer> {
|
||||
let tokenizer = Tokenizer::from_bytes(p)?;
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
|
||||
let mut token_ids = vec![1];
|
||||
let mut attention_masks = vec![1];
|
||||
for content in text.chars() {
|
||||
let token = tokenizer.encode(content.to_string(), false)?;
|
||||
let ids = token.get_ids();
|
||||
token_ids.extend(ids.iter().map(|&x| x as i64));
|
||||
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
|
||||
}
|
||||
token_ids.push(2);
|
||||
attention_masks.push(1);
|
||||
Ok((token_ids, attention_masks))
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::error::Result;
|
||||
use crate::{bert, jtalk, model, nlp, norm, style, utils};
|
||||
use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils};
|
||||
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
||||
use ort::Session;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
pub struct TTSModel {
|
||||
tokenizer: Tokenizer,
|
||||
bert: Session,
|
||||
vits2: Session,
|
||||
style_vectors: Array2<f32>,
|
||||
@@ -11,23 +13,26 @@ pub struct TTSModel {
|
||||
}
|
||||
|
||||
impl TTSModel {
|
||||
pub fn new(
|
||||
bert_model_path: &str,
|
||||
main_model_path: &str,
|
||||
style_vector_path: &str,
|
||||
pub fn new<P: AsRef<[u8]>>(
|
||||
bert_model_bytes: P,
|
||||
main_model_bytes: P,
|
||||
style_vector_bytes: P,
|
||||
tokenizer_bytes: P,
|
||||
) -> Result<Self> {
|
||||
let bert = model::load_model(bert_model_path)?;
|
||||
let vits2 = model::load_model(main_model_path)?;
|
||||
let style_vectors = style::load_style(style_vector_path)?;
|
||||
let bert = model::load_model(bert_model_bytes)?;
|
||||
let vits2 = model::load_model(main_model_bytes)?;
|
||||
let style_vectors = style::load_style(style_vector_bytes)?;
|
||||
let jtalk = jtalk::JTalk::new()?;
|
||||
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
||||
Ok(TTSModel {
|
||||
bert,
|
||||
vits2,
|
||||
style_vectors,
|
||||
jtalk,
|
||||
tokenizer,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn parse_text(
|
||||
&self,
|
||||
text: &str,
|
||||
@@ -40,13 +45,11 @@ impl TTSModel {
|
||||
let phones = utils::intersperse(&phones, 0);
|
||||
let tones = utils::intersperse(&tones, 0);
|
||||
let lang_ids = utils::intersperse(&lang_ids, 0);
|
||||
for i in 0..word2ph.len() {
|
||||
word2ph[i] *= 2;
|
||||
for item in &mut word2ph {
|
||||
*item *= 2;
|
||||
}
|
||||
word2ph[0] += 1;
|
||||
|
||||
let tokenizer = jtalk::get_tokenizer()?;
|
||||
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
|
||||
let (token_ids, attention_masks) = tokenizer::tokenize(&normalized_text, &self.tokenizer)?;
|
||||
|
||||
let bert_content = bert::predict(&self.bert, token_ids, attention_masks)?;
|
||||
|
||||
@@ -58,9 +61,9 @@ impl TTSModel {
|
||||
);
|
||||
|
||||
let mut phone_level_feature = vec![];
|
||||
for i in 0..word2ph.len() {
|
||||
for (i, reps) in word2ph.iter().enumerate() {
|
||||
let repeat_feature = {
|
||||
let (reps_rows, reps_cols) = (word2ph[i], 1);
|
||||
let (reps_rows, reps_cols) = (*reps, 1);
|
||||
let arr_len = bert_content.slice(s![i, ..]).len();
|
||||
|
||||
let mut results: Array2<f32> =
|
||||
|
||||
1
test.py
1
test.py
@@ -1,6 +1,5 @@
|
||||
import requests
|
||||
|
||||
|
||||
res = requests.post('http://localhost:3000/synthesize', json={
|
||||
"text": "おはようございます",
|
||||
})
|
||||
|
||||
22116
tokenizer.json
22116
tokenizer.json
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user