diff --git a/crates/sbv2_core/src/tts.rs b/crates/sbv2_core/src/tts.rs index eb6b3b7..07345c6 100644 --- a/crates/sbv2_core/src/tts.rs +++ b/crates/sbv2_core/src/tts.rs @@ -240,39 +240,43 @@ impl TTSModelHolder { } fn find_and_load_model>(&mut self, ident: I) -> Result { let ident = ident.into(); - let (bytes, style_vectors) = { - let model = self - .models - .iter() - .find(|m| m.ident == ident) - .ok_or(Error::ModelNotFoundError(ident.to_string()))?; - if model.vits2.is_some() { - return Ok(true); - } - (model.bytes.clone().unwrap(), model.style_vectors.clone()) - }; - self.unload(ident.clone()); - let s = model::load_model(&bytes, false)?; - if let Some(max) = self.max_loaded_models { - if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max { - self.unload(self.models.first().unwrap().ident.clone()); - } - } - self.models.push(TTSModel { - bytes: Some(bytes.to_vec()), - vits2: Some(s), - style_vectors, - ident: ident.clone(), - }); - let model = self + // Locate target model entry + let target_index = self .models .iter() - .find(|m| m.ident == ident) + .position(|m| m.ident == ident) .ok_or(Error::ModelNotFoundError(ident.to_string()))?; - if model.vits2.is_some() { + + // Already loaded + if self.models[target_index].vits2.is_some() { return Ok(true); } - Err(Error::ModelNotFoundError(ident.to_string())) + + // Get bytes to build a Session + let bytes = self.models[target_index] + .bytes + .clone() + .ok_or(Error::ModelNotFoundError(ident.to_string()))?; + + // Enforce max loaded models by evicting a different loaded model's session, not removing the entry + if let Some(max) = self.max_loaded_models { + let loaded_count = self.models.iter().filter(|m| m.vits2.is_some()).count(); + if loaded_count >= max { + if let Some(evict_index) = self + .models + .iter() + .position(|m| m.vits2.is_some() && m.ident != ident) + { + // Drop only the session to free memory; keep bytes/style for future reload + self.models[evict_index].vits2 = None; + } + } + } + + // Build and set session in-place for the target model + let s = model::load_model(&bytes, false)?; + self.models[target_index].vits2 = Some(s); + Ok(true) } /// Get style vector by style id and weight