wip: max loaded models

This commit is contained in:
Googlefan
2024-11-06 10:43:41 +00:00
parent 380daf479c
commit 14d631eeaa
4 changed files with 108 additions and 16 deletions

View File

@@ -69,7 +69,7 @@ async fn synthesize(
) -> AppResult<impl IntoResponse> { ) -> AppResult<impl IntoResponse> {
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}"); log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
let buffer = { let buffer = {
let tts_model = state.tts_model.lock().await; let mut tts_model = state.tts_model.lock().await;
tts_model.easy_synthesize( tts_model.easy_synthesize(
&ident, &ident,
&text, &text,
@@ -94,6 +94,9 @@ impl AppState {
let mut tts_model = TTSModelHolder::new( let mut tts_model = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?, &fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?, &fs::read(env::var("TOKENIZER_PATH")?).await?,
env::var("HOLDER_MAX_LOADED_MODElS")
.ok()
.and_then(|x| x.parse().ok()),
)?; )?;
let models = env::var("MODELS_PATH").unwrap_or("models".to_string()); let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
let mut f = fs::read_dir(&models).await?; let mut f = fs::read_dir(&models).await?;

View File

@@ -23,10 +23,15 @@ pub struct TTSModel {
#[pymethods] #[pymethods]
impl TTSModel { impl TTSModel {
#[pyo3(signature = (bert_model_bytes, tokenizer_bytes, max_loaded_models=None))]
#[new] #[new]
fn new(bert_model_bytes: Vec<u8>, tokenizer_bytes: Vec<u8>) -> anyhow::Result<Self> { fn new(
bert_model_bytes: Vec<u8>,
tokenizer_bytes: Vec<u8>,
max_loaded_models: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self { Ok(Self {
model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes)?, model: TTSModelHolder::new(bert_model_bytes, tokenizer_bytes, max_loaded_models)?,
}) })
} }
@@ -38,10 +43,21 @@ impl TTSModel {
/// BERTモデルのパス /// BERTモデルのパス
/// tokenizer_path : str /// tokenizer_path : str
/// トークナイザーのパス /// トークナイザーのパス
/// max_loaded_models: int | None
/// 同時にVRAMに存在するモデルの数
#[pyo3(signature = (bert_model_path, tokenizer_path, max_loaded_models=None))]
#[staticmethod] #[staticmethod]
fn from_path(bert_model_path: String, tokenizer_path: String) -> anyhow::Result<Self> { fn from_path(
bert_model_path: String,
tokenizer_path: String,
max_loaded_models: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self { Ok(Self {
model: TTSModelHolder::new(fs::read(bert_model_path)?, fs::read(tokenizer_path)?)?, model: TTSModelHolder::new(
fs::read(bert_model_path)?,
fs::read(tokenizer_path)?,
max_loaded_models,
)?,
}) })
} }
@@ -121,7 +137,7 @@ impl TTSModel {
/// voice_data : bytes /// voice_data : bytes
/// 音声データ /// 音声データ
fn synthesize<'p>( fn synthesize<'p>(
&'p self, &'p mut self,
py: Python<'p>, py: Python<'p>,
text: String, text: String,
ident: String, ident: String,

View File

@@ -11,6 +11,9 @@ fn main_inner() -> anyhow::Result<()> {
let mut tts_holder = tts::TTSModelHolder::new( let mut tts_holder = tts::TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?)?, &fs::read(env::var("BERT_MODEL_PATH")?)?,
&fs::read(env::var("TOKENIZER_PATH")?)?, &fs::read(env::var("TOKENIZER_PATH")?)?,
env::var("HOLDER_MAX_LOADED_MODElS")
.ok()
.and_then(|x| x.parse().ok()),
)?; )?;
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?; tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;

View File

@@ -24,9 +24,10 @@ where
} }
pub struct TTSModel { pub struct TTSModel {
vits2: Session, vits2: Option<Session>,
style_vectors: Array2<f32>, style_vectors: Array2<f32>,
ident: TTSIdent, ident: TTSIdent,
bytes: Option<Vec<u8>>,
} }
/// High-level Style-Bert-VITS2's API /// High-level Style-Bert-VITS2's API
@@ -35,6 +36,7 @@ pub struct TTSModelHolder {
bert: Session, bert: Session,
models: Vec<TTSModel>, models: Vec<TTSModel>,
jtalk: jtalk::JTalk, jtalk: jtalk::JTalk,
max_loaded_models: Option<usize>,
} }
impl TTSModelHolder { impl TTSModelHolder {
@@ -43,9 +45,13 @@ impl TTSModelHolder {
/// # Examples /// # Examples
/// ///
/// ```rs /// ```rs
/// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?)?; /// let mut tts_holder = TTSModelHolder::new(std::fs::read("deberta.onnx")?, std::fs::read("tokenizer.json")?, None)?;
/// ``` /// ```
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> { pub fn new<P: AsRef<[u8]>>(
bert_model_bytes: P,
tokenizer_bytes: P,
max_loaded_models: Option<usize>,
) -> Result<Self> {
let bert = model::load_model(bert_model_bytes, true)?; let bert = model::load_model(bert_model_bytes, true)?;
let jtalk = jtalk::JTalk::new()?; let jtalk = jtalk::JTalk::new()?;
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?; let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
@@ -54,6 +60,7 @@ impl TTSModelHolder {
models: vec![], models: vec![],
jtalk, jtalk,
tokenizer, tokenizer,
max_loaded_models,
}) })
} }
@@ -94,10 +101,25 @@ impl TTSModelHolder {
) -> Result<()> { ) -> Result<()> {
let ident = ident.into(); let ident = ident.into();
if self.find_model(ident.clone()).is_err() { if self.find_model(ident.clone()).is_err() {
let mut load = true;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
load = false;
}
}
self.models.push(TTSModel { self.models.push(TTSModel {
vits2: model::load_model(vits2_bytes, false)?, vits2: if load {
Some(model::load_model(&vits2_bytes, false)?)
} else {
None
},
style_vectors: style::load_style(style_vectors_bytes)?, style_vectors: style::load_style(style_vectors_bytes)?,
ident, ident,
bytes: if self.max_loaded_models.is_some() {
Some(vits2_bytes.as_ref().to_vec())
} else {
None
},
}) })
} }
Ok(()) Ok(())
@@ -145,6 +167,42 @@ impl TTSModelHolder {
.find(|m| m.ident == ident) .find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string())) .ok_or(Error::ModelNotFoundError(ident.to_string()))
} }
fn find_and_load_model<I: Into<TTSIdent>>(&mut self, ident: I) -> Result<bool> {
let ident = ident.into();
let (bytes, style_vectors) = {
let model = self
.models
.iter()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
if model.vits2.is_some() {
return Ok(true);
}
(model.bytes.clone().unwrap(), model.style_vectors.clone())
};
self.unload(ident.clone());
let s = model::load_model(&bytes, false)?;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
self.unload(self.models.first().unwrap().ident.clone());
}
}
self.models.push(TTSModel {
bytes: Some(bytes.to_vec()),
vits2: Some(s),
style_vectors,
ident: ident.clone(),
});
let model = self
.models
.iter()
.find(|m| m.ident == ident)
.ok_or(Error::ModelNotFoundError(ident.to_string()))?;
if model.vits2.is_some() {
return Ok(true);
}
Err(Error::ModelNotFoundError(ident.to_string()))
}
/// Get style vector by style id and weight /// Get style vector by style id and weight
/// ///
@@ -167,12 +225,18 @@ impl TTSModelHolder {
/// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?; /// let audio = tts_holder.easy_synthesize("tsukuyomi", "こんにちは", 0, SynthesizeOptions::default())?;
/// ``` /// ```
pub fn easy_synthesize<I: Into<TTSIdent> + Copy>( pub fn easy_synthesize<I: Into<TTSIdent> + Copy>(
&self, &mut self,
ident: I, ident: I,
text: &str, text: &str,
style_id: i32, style_id: i32,
options: SynthesizeOptions, options: SynthesizeOptions,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
self.find_and_load_model(ident)?;
let vits2 = &self
.find_model(ident)?
.vits2
.as_ref()
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?; let style_vector = self.get_style_vector(ident, style_id, options.style_weight)?;
let audio_array = if options.split_sentences { let audio_array = if options.split_sentences {
let texts: Vec<&str> = text.split('\n').collect(); let texts: Vec<&str> = text.split('\n').collect();
@@ -183,7 +247,7 @@ impl TTSModelHolder {
} }
let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?; let (bert_ori, phones, tones, lang_ids) = self.parse_text(t)?;
let audio = model::synthesize( let audio = model::synthesize(
&self.find_model(ident)?.vits2, &vits2,
bert_ori.to_owned(), bert_ori.to_owned(),
phones, phones,
tones, tones,
@@ -204,7 +268,7 @@ impl TTSModelHolder {
} else { } else {
let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?; let (bert_ori, phones, tones, lang_ids) = self.parse_text(text)?;
model::synthesize( model::synthesize(
&self.find_model(ident)?.vits2, &vits2,
bert_ori.to_owned(), bert_ori.to_owned(),
phones, phones,
tones, tones,
@@ -222,8 +286,8 @@ impl TTSModelHolder {
/// # Note /// # Note
/// This function is for low-level usage, use `easy_synthesize` for high-level usage. /// This function is for low-level usage, use `easy_synthesize` for high-level usage.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn synthesize<I: Into<TTSIdent>>( pub fn synthesize<I: Into<TTSIdent> + Copy>(
&self, &mut self,
ident: I, ident: I,
bert_ori: Array2<f32>, bert_ori: Array2<f32>,
phones: Array1<i64>, phones: Array1<i64>,
@@ -233,8 +297,14 @@ impl TTSModelHolder {
sdp_ratio: f32, sdp_ratio: f32,
length_scale: f32, length_scale: f32,
) -> Result<Vec<u8>> { ) -> Result<Vec<u8>> {
self.find_and_load_model(ident)?;
let vits2 = &self
.find_model(ident)?
.vits2
.as_ref()
.ok_or(Error::ModelNotFoundError(ident.into().to_string()))?;
let audio_array = model::synthesize( let audio_array = model::synthesize(
&self.find_model(ident)?.vits2, &vits2,
bert_ori.to_owned(), bert_ori.to_owned(),
phones, phones,
tones, tones,