mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
feat: synthesis
This commit is contained in:
@@ -205,6 +205,23 @@ impl TTSModelHolder {
|
||||
) -> Result<(Array2<f32>, Array1<i64>, Array1<i64>, Array1<i64>)> {
|
||||
crate::tts_util::parse_text_blocking(
|
||||
text,
|
||||
None,
|
||||
&self.jtalk,
|
||||
&self.tokenizer,
|
||||
|token_ids, attention_masks| {
|
||||
crate::bert::predict(&mut self.bert, token_ids, attention_masks)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn parse_text_neo(
|
||||
&mut self,
|
||||
text: String,
|
||||
given_tones: Option<Vec<i32>>,
|
||||
) -> Result<(Array2<f32>, Array1<i64>, Array1<i64>, Array1<i64>)> {
|
||||
crate::tts_util::parse_text_blocking(
|
||||
&text,
|
||||
given_tones,
|
||||
&self.jtalk,
|
||||
&self.tokenizer,
|
||||
|token_ids, attention_masks| {
|
||||
@@ -347,6 +364,78 @@ impl TTSModelHolder {
|
||||
};
|
||||
tts_util::array_to_vec(audio_array)
|
||||
}
|
||||
|
||||
pub fn easy_synthesize_neo<I: Into<TTSIdent> + Copy>(
|
||||
&mut self,
|
||||
ident: I,
|
||||
text: &str,
|
||||
given_tones: Option<Vec<i32>>,
|
||||
style_id: i32,
|
||||
speaker_id: i64,
|
||||
options: SynthesizeOptions,
|
||||
) -> Result<Vec<u8>> {
|
||||
self.find_and_load_model(ident)?;
|
||||
let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?;
|
||||
let audio_array = if options.split_sentences {
|
||||
let texts: Vec<&str> = text.split('\n').collect();
|
||||
let mut audios = vec![];
|
||||
for (i, t) in texts.iter().enumerate() {
|
||||
if t.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let (bert_ori, phones, tones, lang_ids) = self.parse_text_neo(t, given_tones)?;
|
||||
|
||||
let vits2 = self
|
||||
.find_model(ident)?
|
||||
.vits2
|
||||
.as_mut()
|
||||
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
|
||||
let audio = model::synthesize(
|
||||
vits2,
|
||||
bert_ori.to_owned(),
|
||||
phones,
|
||||
Array1::from_vec(vec![speaker_id]),
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector.clone(),
|
||||
options.sdp_ratio,
|
||||
options.length_scale,
|
||||
0.677,
|
||||
0.8,
|
||||
)?;
|
||||
audios.push(audio.clone());
|
||||
if i != texts.len() - 1 {
|
||||
audios.push(Array3::zeros((1, 1, 22050)));
|
||||
}
|
||||
}
|
||||
concatenate(
|
||||
Axis(2),
|
||||
&audios.iter().map(|x| x.view()).collect::<Vec<_>>(),
|
||||
)?
|
||||
} else {
|
||||
let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?;
|
||||
|
||||
let vits2 = self
|
||||
.find_model(ident)?
|
||||
.vits2
|
||||
.as_mut()
|
||||
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
|
||||
model::synthesize(
|
||||
vits2,
|
||||
bert_ori.to_owned(),
|
||||
phones,
|
||||
Array1::from_vec(vec![speaker_id]),
|
||||
tones,
|
||||
lang_ids,
|
||||
style_vector,
|
||||
options.sdp_ratio,
|
||||
options.length_scale,
|
||||
0.677,
|
||||
0.8,
|
||||
)?
|
||||
};
|
||||
tts_util::array_to_vec(audio_array)
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthesize options
|
||||
|
||||
@@ -34,7 +34,6 @@ pub async fn parse_text(
|
||||
let (normalized_text, process) = preprocess_parse_text(text, jtalk)?;
|
||||
let (phones, tones, mut word2ph) = process.g2p()?;
|
||||
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
|
||||
|
||||
let phones = utils::intersperse(&phones, 0);
|
||||
let tones = utils::intersperse(&tones, 0);
|
||||
let lang_ids = utils::intersperse(&lang_ids, 0);
|
||||
@@ -99,6 +98,7 @@ pub async fn parse_text(
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn parse_text_blocking(
|
||||
text: &str,
|
||||
given_tones: Option<Vec<i32>>,
|
||||
jtalk: &jtalk::JTalk,
|
||||
tokenizer: &Tokenizer,
|
||||
bert_predict: impl FnOnce(Vec<i64>, Vec<i64>) -> Result<ndarray::Array2<f32>>,
|
||||
@@ -107,7 +107,10 @@ pub fn parse_text_blocking(
|
||||
let normalized_text = norm::normalize_text(&text);
|
||||
|
||||
let process = jtalk.process_text(&normalized_text)?;
|
||||
let (phones, tones, mut word2ph) = process.g2p()?;
|
||||
let (phones, mut tones, mut word2ph) = process.g2p()?;
|
||||
if let Some(given_tones) = given_tones {
|
||||
tones = given_tones;
|
||||
}
|
||||
let (phones, tones, lang_ids) = nlp::cleaned_text_to_sequence(phones, tones);
|
||||
|
||||
let phones = utils::intersperse(&phones, 0);
|
||||
|
||||
@@ -11,6 +11,7 @@ documentation.workspace = true
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum = "0.8.1"
|
||||
log = "0.4.27"
|
||||
sbv2_core = { version = "0.2.0-alpha6", path = "../sbv2_core" }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
tokio = { version = "1.44.1", features = ["full"] }
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
use axum::{extract::Query, response::IntoResponse, routing::get, Json, Router};
|
||||
use sbv2_core::{jtalk::JTalk, tts_util::preprocess_parse_text};
|
||||
use axum::extract::State;
|
||||
use axum::{extract::Query, response::IntoResponse, routing::{get, post}, Json, Router};
|
||||
use sbv2_core::{jtalk::JTalk, tts::TTSModelHolder, tts_util::preprocess_parse_text};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::{fs, net::TcpListener, sync::Mutex};
|
||||
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
|
||||
use error::AppResult;
|
||||
|
||||
@@ -12,32 +16,162 @@ struct RequestCreateAudioQuery {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ResponseCreateAudioQuery {
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AudioQuery {
|
||||
kana: String,
|
||||
tone: i32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ResponseCreateAudioQuery {
|
||||
audio_query: Vec<AudioQuery>,
|
||||
text: String,
|
||||
}
|
||||
|
||||
async fn create_audio_query(
|
||||
Query(request): Query<RequestCreateAudioQuery>,
|
||||
) -> AppResult<impl IntoResponse> {
|
||||
let (_, process) = preprocess_parse_text(&request.text, &JTalk::new()?)?;
|
||||
let (text, process) = preprocess_parse_text(&request.text, &JTalk::new()?)?;
|
||||
let kana_tone_list = process.g2kana_tone()?;
|
||||
let response = kana_tone_list
|
||||
let audio_query = kana_tone_list
|
||||
.iter()
|
||||
.map(|(kana, tone)| ResponseCreateAudioQuery {
|
||||
.map(|(kana, tone)| AudioQuery {
|
||||
kana: kana.clone(),
|
||||
tone: *tone,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
Ok(Json(response))
|
||||
Ok(Json(ResponseCreateAudioQuery { audio_query, text }))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct RequestSynthesis {
|
||||
text: String,
|
||||
speaker_id: i32,
|
||||
sdp_ratio: f32,
|
||||
length_scale: f32,
|
||||
style_id: i32,
|
||||
audio_query: Vec<AudioQuery>,
|
||||
}
|
||||
|
||||
async fn synthesis(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<RequestSynthesis>,
|
||||
) -> AppResult<impl IntoResponse> {
|
||||
let mut tones: Vec<i32> = request.audio_query.iter().map(|query| query.tone).collect();
|
||||
tones.insert(0, 0);
|
||||
tones.push(0);
|
||||
let buffer = {
|
||||
let mut tts_model = state.tts_model.lock().await;
|
||||
tts_model.easy_synthesize_neo(
|
||||
&ident,
|
||||
&text,
|
||||
Some(tones),
|
||||
request.style_id,
|
||||
request.speaker_id,
|
||||
SynthesizeOptions {
|
||||
sdp_ratio: request.sdp_ratio,
|
||||
length_scale: request.length_scale,
|
||||
..Default::default()
|
||||
},
|
||||
)?
|
||||
};
|
||||
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
tts_model: Arc<Mutex<TTSModelHolder>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new() -> anyhow::Result<Self> {
|
||||
let mut tts_model = TTSModelHolder::new(
|
||||
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||
env::var("HOLDER_MAX_LOADED_MODElS")
|
||||
.ok()
|
||||
.and_then(|x| x.parse().ok()),
|
||||
)?;
|
||||
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
|
||||
let mut f = fs::read_dir(&models).await?;
|
||||
let mut entries = vec![];
|
||||
while let Ok(Some(e)) = f.next_entry().await {
|
||||
let name = e.file_name().to_string_lossy().to_string();
|
||||
if name.ends_with(".onnx") && name.starts_with("model_") {
|
||||
let name_len = name.len();
|
||||
let name = name.chars();
|
||||
entries.push(
|
||||
name.collect::<Vec<_>>()[6..name_len - 5]
|
||||
.iter()
|
||||
.collect::<String>(),
|
||||
);
|
||||
} else if name.ends_with(".sbv2") {
|
||||
let entry = &name[..name.len() - 5];
|
||||
log::info!("Try loading: {entry}");
|
||||
let sbv2_bytes = match fs::read(format!("{models}/{entry}.sbv2")).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
log::warn!("Error loading sbv2_bytes from file {entry}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Err(e) = tts_model.load_sbv2file(entry, sbv2_bytes) {
|
||||
log::warn!("Error loading {entry}: {e}");
|
||||
};
|
||||
log::info!("Loaded: {entry}");
|
||||
} else if name.ends_with(".aivmx") {
|
||||
let entry = &name[..name.len() - 6];
|
||||
log::info!("Try loading: {entry}");
|
||||
let aivmx_bytes = match fs::read(format!("{models}/{entry}.aivmx")).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
log::warn!("Error loading aivmx bytes from file {entry}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Err(e) = tts_model.load_aivmx(entry, aivmx_bytes) {
|
||||
log::error!("Error loading {entry}: {e}");
|
||||
}
|
||||
log::info!("Loaded: {entry}");
|
||||
}
|
||||
}
|
||||
for entry in entries {
|
||||
log::info!("Try loading: {entry}");
|
||||
let style_vectors_bytes =
|
||||
match fs::read(format!("{models}/style_vectors_{entry}.json")).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
log::warn!("Error loading style_vectors_bytes from file {entry}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
log::warn!("Error loading vits2_bytes from file {entry}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
|
||||
log::warn!("Error loading {entry}: {e}");
|
||||
};
|
||||
log::info!("Loaded: {entry}");
|
||||
}
|
||||
Ok(Self {
|
||||
tts_model: Arc::new(Mutex::new(tts_model)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
dotenvy::dotenv_override().ok();
|
||||
env_logger::init();
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "Hello, world!" }))
|
||||
.route("/audio_query", get(create_audio_query));
|
||||
.route("/audio_query", get(create_audio_query))
|
||||
.route("/synthesis", post(synthesis))
|
||||
.with_state(AppState::new().await?);
|
||||
let listener = TcpListener::bind("0.0.0.0:8080").await?;
|
||||
axum::serve(listener, app).await?;
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user