support aivmx

This commit is contained in:
tuna2134
2024-11-20 01:42:04 +00:00
parent f90904a337
commit 9c9119a107
4 changed files with 150 additions and 1 deletions

View File

@@ -21,6 +21,9 @@ pub enum Error {
HoundError(#[from] hound::Error),
#[error("model not found error")]
ModelNotFoundError(String),
#[cfg(feature = "base64")]
#[error("base64 error")]
Base64Error(#[from] base64::DecodeError),
#[error("other")]
OtherError(String),
}

View File

@@ -3,6 +3,12 @@ use crate::{jtalk, model, style, tokenizer, tts_util};
use ndarray::{concatenate, Array1, Array2, Array3, Axis};
use ort::Session;
use tokenizers::Tokenizer;
#[cfg(feature = "aivmx")]
use base64::prelude::{Engine as _, BASE64_STANDARD};
#[cfg(feature = "aivmx")]
use std::io::Cursor;
#[cfg(feature = "aivmx")]
use ndarray::ShapeBuilder;
#[derive(PartialEq, Eq, Clone)]
pub struct TTSIdent(String);
@@ -69,6 +75,55 @@ impl TTSModelHolder {
self.models.iter().map(|m| m.ident.to_string()).collect()
}
#[cfg(feature = "aivmx")]
pub fn load_aivmx<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
aivmx_bytes: P
) -> Result<()> {
let ident = ident.into();
if self.find_model(ident.clone()).is_err() {
let mut load = true;
if let Some(max) = self.max_loaded_models {
if self.models.iter().filter(|x| x.vits2.is_some()).count() >= max {
load = false;
}
}
let model = model::load_model(aivmx_bytes, false)?;
let metadata = model.metadata()?;
if let Some(aivm_style_vectors) = metadata.custom("aivm_style_vectors")? {
let style_vectors = Cursor::new(&BASE64_STANDARD.decode(aivm_style_vectors)?);
let reader = npyz::NpyFile::new(style_vectors)?;
let style_vectors = {
let shape = reader.shape().to_vec();
let order = reader.order();
let data = reader.into_vec::<f32>()?;
let shape = match shape[..] {
[i1, i2] => [i1 as usize, i2 as usize],
_ => panic!("expected 2D array"),
};
let true_shape = shape.set_f(order == npyz::Order::Fortran);
ndarray::Array2::from_shape_vec(true_shape, data)?
};
self.models.push(TTSModel {
vits2: if load {
Some(model)
} else {
None
},
bytes: if self.max_loaded_models.is_some() {
Some(aivmx_bytes.as_ref().to_vec())
} else {
None
},
ident,
style_vectors,
})
}
}
Ok(())
}
/// Load a .sbv2 file binary
///
/// # Examples