mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-06 14:32:57 +00:00
wip: max loaded models
This commit is contained in:
@@ -69,7 +69,7 @@ async fn synthesize(
|
|||||||
) -> AppResult<impl IntoResponse> {
|
) -> AppResult<impl IntoResponse> {
|
||||||
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
|
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
|
||||||
let buffer = {
|
let buffer = {
|
||||||
let tts_model = state.tts_model.lock().await;
|
let mut tts_model = state.tts_model.lock().await;
|
||||||
tts_model.easy_synthesize(
|
tts_model.easy_synthesize(
|
||||||
&ident,
|
&ident,
|
||||||
&text,
|
&text,
|
||||||
@@ -94,6 +94,9 @@ impl AppState {
|
|||||||
let mut tts_model = TTSModelHolder::new(
|
let mut tts_model = TTSModelHolder::new(
|
||||||
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||||
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||||
|
env::var("HOLDER_MAX_LOADED_MODElS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|x| x.parse().ok()),
|
||||||
)?;
|
)?;
|
||||||
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
|
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
|
||||||
let mut f = fs::read_dir(&models).await?;
|
let mut f = fs::read_dir(&models).await?;
|
||||||
|
|||||||
@@ -23,10 +23,15 @@ pub struct TTSModel {
|
|||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl TTSModel {
|
impl TTSModel {
|
||||||
|
#[pyo3(signature = (bert_model_bytes, tokenizer_bytes, max_loaded_models=None))]
|
||||||
#[new]
|
#[new]
|
||||||
fn new(bert_model_bytes: Vec<u8>, tokenizer_bytes: Vec<u8>) -> anyhow::Result<Self> {
|
fn new(
|
||||||
|
bert_model_bytes: Vec<u8>,
|
||||||
|
tokenizer_bytes: Vec<u8>,
|
||||||
|
max_loaded_models: Option<usize>,
|
||||||
|
) -> anyhow::Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes)?,
|
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes, max_loaded_models)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,10 +43,21 @@ impl TTSModel {
|
|||||||
/// BERTモデルのパス
|
/// BERTモデルのパス
|
||||||
/// tokenizer_path : str
|
/// tokenizer_path : str
|
||||||
/// トークナイザーのパス
|
/// トークナイザーのパス
|
||||||
|
/// max_loaded_models: int | None
|
||||||
|
/// 同時にVRAMに存在するモデルの数
|
||||||
|
#[pyo3(signature = (bert_model_path, tokenizer_path, max_loaded_models=None))]
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result<Self> {
|
fn from_path(
|
||||||
|
bert_model_path: String,
|
||||||
|
tokenizer_path: String,
|
||||||
|
max_loaded_models: Option<usize>,
|
||||||
|
) -> anyhow::Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model: TTSModelHolder::new(fs::read(bert_model_path)?, fs::read(tokenizer_path)?)?,
|
model: TTSModelHolder::new(
|
||||||
|
fs::read(bert_model_path)?,
|
||||||
|
fs::read(tokenizer_path)?,
|
||||||
|
max_loaded_models,
|
||||||
|
)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,7 +137,7 @@ impl TTSModel {
|
|||||||
/// voice_data : bytes
|
/// voice_data : bytes
|
||||||
/// 音声データ
|
/// 音声データ
|
||||||
fn synthesize<'p>(
|
fn synthesize<'p>(
|
||||||
&'p self,
|
&'p mut self,
|
||||||
py: Python<'p>,
|
py: Python<'p>,
|
||||||
text: String,
|
text: String,
|
||||||
ident: String,
|
ident: String,
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ fn main_inner() -> anyhow::Result<()> {
|
|||||||
let mut tts_holder = tts::TTSModelHolder::new(
|
let mut tts_holder = tts::TTSModelHolder::new(
|
||||||
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
||||||
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
||||||
|
env::var("HOLDER_MAX_LOADED_MODElS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|x| x.parse().ok()),
|
||||||
)?;
|
)?;
|
||||||
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
|
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct TTSModel {
|
pub struct TTSModel {
|
||||||
vits2: Session,
|
vits2: Option<Session>,
|
||||||
style_vectors: Array2<f32>,
|
style_vectors: Array2<f32>,
|
||||||
ident: TTSIdent,
|
ident: TTSIdent,
|
||||||
|
bytes: Option<Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// High-level Style-Bert-VITS2's API
|
/// High-level Style-Bert-VITS2's API
|
||||||
@@ -35,6 +36,7 @@ pub struct TTSModelHolder {
|
|||||||
bert: Session,
|
bert: Session,
|
||||||
models: Vec<TTSModel>,
|
models: Vec<TTSModel>,
|
||||||
jtalk: jtalk::JTalk,
|
jtalk: jtalk::JTalk,
|
||||||
|
max_loaded_models: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TTSModelHolder {
|
impl TTSModelHolder {
|
||||||
@@ -43,9 +45,13 @@ impl TTSModelHolder {
|
|||||||
/// # Examples
|
/// # Examples
|
||||||
///
|
///
|
||||||
/// ```rs
|
/// ```rs
|
||||||
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?)?;
|
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?, None)?;
|
||||||
/// ```
|
/// ```
|
||||||
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
|
pub fn new<P: AsRef<[u8]>>(
|
||||||
|
bert_model_bytes: P,
|
||||||
|
tokenizer_bytes: P,
|
||||||
|
max_loaded_models: Option<usize>,
|
||||||
|
) -> Result<Self> {
|
||||||
let bert = model::load_model(bert_model_bytes, true)?;
|
let bert = model::load_model(bert_model_bytes, true)?;
|
||||||
let jtalk = jtalk::JTalk::new()?;
|
let jtalk = jtalk::JTalk::new()?;
|
||||||
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
||||||
@@ -54,6 +60,7 @@ impl TTSModelHolder {
|
|||||||
models: vec![],
|
models: vec![],
|
||||||
jtalk,
|
jtalk,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
max_loaded_models,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,10 +101,25 @@ impl TTSModelHolder {
|
|||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let ident = ident.into();
|
let ident = ident.into();
|
||||||
if self.find_model(ident.clone()).is_err() {
|
if self.find_model(ident.clone()).is_err() {
|
||||||
|
let mut load = true;
|
||||||
|
if let Some(max) = self.max_loaded_models {
|
||||||
|
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
|
||||||
|
load = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
self.models.push(TTSModel {
|
self.models.push(TTSModel {
|
||||||
vits2: model::load_model(vits2_bytes, false)?,
|
vits2: if load {
|
||||||
|
Some(model::load_model(&vits2_bytes, false)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
style_vectors: style::load_style(style_vectors_bytes)?,
|
style_vectors: style::load_style(style_vectors_bytes)?,
|
||||||
ident,
|
ident,
|
||||||
|
bytes: if self.max_loaded_models.is_some() {
|
||||||
|
Some(vits2_bytes.as_ref().to_vec())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -145,6 +167,42 @@ impl TTSModelHolder {
|
|||||||
.find(|m| m.ident == ident)
|
.find(|m| m.ident == ident)
|
||||||
.ok_or(Error::ModelNotFoundError(ident.to_string()))
|
.ok_or(Error::ModelNotFoundError(ident.to_string()))
|
||||||
}
|
}
|
||||||
|
fn find_and_load_model<I: Into<TTSIdent>>(&mut self, ident: I) -> Result<bool> {
|
||||||
|
let ident = ident.into();
|
||||||
|
let (bytes, style_vectors) = {
|
||||||
|
let model = self
|
||||||
|
.models
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.ident == ident)
|
||||||
|
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
|
||||||
|
if model.vits2.is_some() {
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
(model.bytes.clone().unwrap(), model.style_vectors.clone())
|
||||||
|
};
|
||||||
|
self.unload(ident.clone());
|
||||||
|
let s = model::load_model(&bytes, false)?;
|
||||||
|
if let Some(max) = self.max_loaded_models {
|
||||||
|
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
|
||||||
|
self.unload(self.models.first().unwrap().ident.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.models.push(TTSModel {
|
||||||
|
bytes: Some(bytes.to_vec()),
|
||||||
|
vits2: Some(s),
|
||||||
|
style_vectors,
|
||||||
|
ident: ident.clone(),
|
||||||
|
});
|
||||||
|
let model = self
|
||||||
|
.models
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.ident == ident)
|
||||||
|
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
|
||||||
|
if model.vits2.is_some() {
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
Err(Error::ModelNotFoundError(ident.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Get style vector by style id and weight
|
/// Get style vector by style id and weight
|
||||||
///
|
///
|
||||||
@@ -167,12 +225,18 @@ impl TTSModelHolder {
|
|||||||
/// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?;
|
/// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?;
|
||||||
/// ```
|
/// ```
|
||||||
pub fn easy_synthesize<I: Into<TTSIdent> + Copy>(
|
pub fn easy_synthesize<I: Into<TTSIdent> + Copy>(
|
||||||
&self,
|
&mut self,
|
||||||
ident: I,
|
ident: I,
|
||||||
text: &str,
|
text: &str,
|
||||||
style_id: i32,
|
style_id: i32,
|
||||||
options: SynthesizeOptions,
|
options: SynthesizeOptions,
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
|
self.find_and_load_model(ident)?;
|
||||||
|
let vits2 = &self
|
||||||
|
.find_model(ident)?
|
||||||
|
.vits2
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
|
||||||
let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?;
|
let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?;
|
||||||
let audio_array = if options.split_sentences {
|
let audio_array = if options.split_sentences {
|
||||||
let texts: Vec<&str> = text.split('\n').collect();
|
let texts: Vec<&str> = text.split('\n').collect();
|
||||||
@@ -183,7 +247,7 @@ impl TTSModelHolder {
|
|||||||
}
|
}
|
||||||
let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?;
|
let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?;
|
||||||
let audio = model::synthesize(
|
let audio = model::synthesize(
|
||||||
&self.find_model(ident)?.vits2,
|
&vits2,
|
||||||
bert_ori.to_owned(),
|
bert_ori.to_owned(),
|
||||||
phones,
|
phones,
|
||||||
tones,
|
tones,
|
||||||
@@ -204,7 +268,7 @@ impl TTSModelHolder {
|
|||||||
} else {
|
} else {
|
||||||
let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?;
|
let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?;
|
||||||
model::synthesize(
|
model::synthesize(
|
||||||
&self.find_model(ident)?.vits2,
|
&vits2,
|
||||||
bert_ori.to_owned(),
|
bert_ori.to_owned(),
|
||||||
phones,
|
phones,
|
||||||
tones,
|
tones,
|
||||||
@@ -222,8 +286,8 @@ impl TTSModelHolder {
|
|||||||
/// # Note
|
/// # Note
|
||||||
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
|
/// This function is for low-level usage, use `easy_synthesize` for high-level usage.
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn synthesize<I: Into<TTSIdent>>(
|
pub fn synthesize<I: Into<TTSIdent> + Copy>(
|
||||||
&self,
|
&mut self,
|
||||||
ident: I,
|
ident: I,
|
||||||
bert_ori: Array2<f32>,
|
bert_ori: Array2<f32>,
|
||||||
phones: Array1<i64>,
|
phones: Array1<i64>,
|
||||||
@@ -233,8 +297,14 @@ impl TTSModelHolder {
|
|||||||
sdp_ratio: f32,
|
sdp_ratio: f32,
|
||||||
length_scale: f32,
|
length_scale: f32,
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
|
self.find_and_load_model(ident)?;
|
||||||
|
let vits2 = &self
|
||||||
|
.find_model(ident)?
|
||||||
|
.vits2
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
|
||||||
let audio_array = model::synthesize(
|
let audio_array = model::synthesize(
|
||||||
&self.find_model(ident)?.vits2,
|
&vits2,
|
||||||
bert_ori.to_owned(),
|
bert_ori.to_owned(),
|
||||||
phones,
|
phones,
|
||||||
tones,
|
tones,
|
||||||
|
|||||||
Reference in New Issue
Block a user