feat: docker(cpu), 1:n bert:vits2 support

This commit is contained in:
Googlefan
2024-09-11 03:19:08 +00:00
parent 70059fc040
commit 7d191ca37d
10 changed files with 129 additions and 34 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
target/
models/
docker/
.env*
renovate.json
*.py

View File

@@ -1,4 +1,4 @@
BERT_MODEL_PATH=models/debert.onnx BERT_MODEL_PATH=models/deberta.onnx
MAIN_MODEL_PATH=models/model_opt.onnx MODEL_PATH=models/model_tsukuyomi.onnx
STYLE_VECTORS_PATH=models/style_vectors.json STYLE_VECTORS_PATH=models/style_vectors.json
TOKENIZER_PATH=models/tokenizer.json TOKENIZER_PATH=models/tokenizer.json

9
docker/cpu.Dockerfile Normal file
View File

@@ -0,0 +1,9 @@
FROM rust AS builder
WORKDIR /work
COPY . .
RUN cargo build -r --bin sbv2_api
FROM gcr.io/distroless/cc-debian12
WORKDIR /work
COPY --from=builder /work/target/release/sbv2_api /work/main
COPY --from=builder /work/target/release/*.so /work
CMD ["/work/main"]

1
docker/run.sh Normal file
View File

@@ -0,0 +1 @@
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env.sample sbv2

View File

@@ -5,7 +5,7 @@ use axum::{
routing::{get, post}, routing::{get, post},
Json, Router, Json, Router,
}; };
use sbv2_core::tts::TTSModel; use sbv2_core::tts::TTSModelHolder;
use serde::Deserialize; use serde::Deserialize;
use std::env; use std::env;
use std::sync::Arc; use std::sync::Arc;
@@ -18,34 +18,46 @@ use crate::error::AppResult;
#[derive(Deserialize)] #[derive(Deserialize)]
struct SynthesizeRequest { struct SynthesizeRequest {
text: String, text: String,
ident: String,
} }
async fn synthesize( async fn synthesize(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(SynthesizeRequest { text }): Json<SynthesizeRequest>, Json(SynthesizeRequest { text, ident }): Json<SynthesizeRequest>,
) -> AppResult<impl IntoResponse> { ) -> AppResult<impl IntoResponse> {
let buffer = { let buffer = {
let mut tts_model = state.tts_model.lock().await; let mut tts_model = state.tts_model.lock().await;
let tts_model = if let Some(tts_model) = &*tts_model { let tts_model = if let Some(tts_model) = &*tts_model {
tts_model tts_model
} else { } else {
*tts_model = Some(TTSModel::new( let mut tts_holder = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?, &fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("MAIN_MODEL_PATH")?).await?,
&fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?, &fs::read(env::var("TOKENIZER_PATH")?).await?,
)?); )?;
tts_holder.load(
"tsukuyomi",
fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
fs::read(env::var("MODEL_PATH")?).await?,
)?;
*tts_model = Some(tts_holder);
tts_model.as_ref().unwrap() tts_model.as_ref().unwrap()
}; };
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?; let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
let style_vector = tts_model.get_style_vector(0, 1.0)?; let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?;
tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)? tts_model.synthesize(
ident,
bert_ori.to_owned(),
phones,
tones,
lang_ids,
style_vector,
)?
}; };
Ok(([(CONTENT_TYPE, "audio/wav")], buffer)) Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
} }
struct AppState { struct AppState {
tts_model: Arc<Mutex<Option<TTSModel>>>, tts_model: Arc<Mutex<Option<TTSModelHolder>>>,
} }
#[tokio::main] #[tokio::main]

View File

@@ -18,6 +18,8 @@ pub enum Error {
IoError(#[from] std::io::Error), IoError(#[from] std::io::Error),
#[error("hound error: {0}")] #[error("hound error: {0}")]
HoundError(#[from] hound::Error), HoundError(#[from] hound::Error),
#[error("model not found error")]
ModelNotFoundError(String),
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -5,17 +5,21 @@ use sbv2_core::{error, tts};
fn main() -> error::Result<()> { fn main() -> error::Result<()> {
let text = "眠たい"; let text = "眠たい";
let tts_model = tts::TTSModel::new( let mut tts_model = tts::TTSModelHolder::new(
fs::read("models/debert.onnx")?, fs::read("models/debert.onnx")?,
fs::read("models/model_opt.onnx")?, fs::read("models/model_opt.onnx")?,
)?;
tts_model.load(
"tsukuyomi",
fs::read("models/style_vectors.json")?, fs::read("models/style_vectors.json")?,
fs::read("models/tokenizer.json")?, fs::read("models/tokenizer.json")?,
)?; )?;
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?; let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
let style_vector = tts_model.get_style_vector(0, 1.0)?; let style_vector = tts_model.get_style_vector("tsukuyomi", 0, 1.0)?;
let data = tts_model.synthesize( let data = tts_model.synthesize(
"tsukuyomi",
bert_ori.to_owned(), bert_ori.to_owned(),
phones.clone(), phones.clone(),
tones.clone(), tones.clone(),
@@ -26,6 +30,7 @@ fn main() -> error::Result<()> {
let now = Instant::now(); let now = Instant::now();
for _ in 0..10 { for _ in 0..10 {
tts_model.synthesize( tts_model.synthesize(
"tsukuyomi",
bert_ori.to_owned(), bert_ori.to_owned(),
phones.clone(), phones.clone(),
tones.clone(), tones.clone(),

View File

@@ -17,7 +17,7 @@ pub fn load_style<P: AsRef<[u8]>>(path: P) -> Result<Array2<f32>> {
} }
pub fn get_style_vector( pub fn get_style_vector(
style_vectors: Array2<f32>, style_vectors: &Array2<f32>,
style_id: i32, style_id: i32,
weight: f32, weight: f32,
) -> Result<Array1<f32>> { ) -> Result<Array1<f32>> {

View File

@@ -1,37 +1,83 @@
use crate::error::Result; use crate::error::{Error, Result};
use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils}; use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils};
use ndarray::{concatenate, s, Array, Array1, Array2, Axis}; use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
use ort::Session; use ort::Session;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
#[derive(PartialEq, Eq, Clone)]
pub struct TTSIdent(String);
impl std::fmt::Display for TTSIdent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)?;
Ok(())
}
}
impl<S> From<S> for TTSIdent
where
S: AsRef<str>,
{
fn from(value: S) -> Self {
TTSIdent(value.as_ref().to_string())
}
}
pub struct TTSModel { pub struct TTSModel {
tokenizer: Tokenizer,
bert: Session,
vits2: Session, vits2: Session,
style_vectors: Array2<f32>, style_vectors: Array2<f32>,
ident: TTSIdent,
}
pub struct TTSModelHolder {
tokenizer: Tokenizer,
bert: Session,
models: Vec<TTSModel>,
jtalk: jtalk::JTalk, jtalk: jtalk::JTalk,
} }
impl TTSModel { impl TTSModelHolder {
pub fn new<P: AsRef<[u8]>>( pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
bert_model_bytes: P,
main_model_bytes: P,
style_vector_bytes: P,
tokenizer_bytes: P,
) -> Result<Self> {
let bert = model::load_model(bert_model_bytes)?; let bert = model::load_model(bert_model_bytes)?;
let vits2 = model::load_model(main_model_bytes)?;
let style_vectors = style::load_style(style_vector_bytes)?;
let jtalk = jtalk::JTalk::new()?; let jtalk = jtalk::JTalk::new()?;
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?; let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
Ok(TTSModel { Ok(TTSModelHolder {
bert, bert,
vits2, models: vec![],
style_vectors,
jtalk, jtalk,
tokenizer, tokenizer,
}) })
} }
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
style_vectors_bytes: P,
vits2_bytes: P,
) -> Result<()> {
let ident = ident.into();
if self.find_model(ident.clone()).is_err() {
self.models.push(TTSModel {
vits2: model::load_model(vits2_bytes)?,
style_vectors: style::load_style(style_vectors_bytes)?,
ident,
})
}
Ok(())
}
pub fn unload<I: Into<TTSIdent>>(&mut self, ident: I) -> bool {
let ident = ident.into();
if let Some((i, _)) = self
.models
.iter()
.enumerate()
.find(|(_, m)| m.ident == ident)
{
self.models.remove(i);
true
} else {
false
}
}
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
pub fn parse_text( pub fn parse_text(
&self, &self,
@@ -94,13 +140,26 @@ impl TTSModel {
lang_ids.into(), lang_ids.into(),
)) ))
} }
fn find_model<I: Into<TTSIdent>>(&self, ident: I) -> Result<&TTSModel> {
pub fn get_style_vector(&self, style_id: i32, weight: f32) -> Result<Array1<f32>> { let ident = ident.into();
style::get_style_vector(self.style_vectors.clone(), style_id, weight) self.models
.iter()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))
} }
pub fn synthesize( pub fn get_style_vector<I: Into<TTSIdent>>(
&self, &self,
ident: I,
style_id: i32,
weight: f32,
) -> Result<Array1<f32>> {
style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight)
}
pub fn synthesize<I: Into<TTSIdent>>(
&self,
ident: I,
bert_ori: Array2<f32>, bert_ori: Array2<f32>,
phones: Array1<i64>, phones: Array1<i64>,
tones: Array1<i64>, tones: Array1<i64>,
@@ -108,7 +167,7 @@ impl TTSModel {
style_vector: Array1<f32>, style_vector: Array1<f32>,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
let buffer = model::synthesize( let buffer = model::synthesize(
&self.vits2, &self.find_model(ident)?.vits2,
bert_ori.to_owned(), bert_ori.to_owned(),
phones, phones,
tones, tones,

View File

@@ -2,6 +2,7 @@ import requests
res = requests.post('http://localhost:3000/synthesize', json={ res = requests.post('http://localhost:3000/synthesize', json={
"text": "おはようございます", "text": "おはようございます",
"ident": "tsukuyomi"
}) })
with open('output.wav', 'wb') as f: with open('output.wav', 'wb') as f:
f.write(res.content) f.write(res.content)