This commit is contained in:
tuna2134
2024-11-20 02:14:59 +00:00
parent 9c9119a107
commit 2eda2fe9ca

View File

@@ -1,14 +1,14 @@
use crate::error::{Error, Result};
use crate::{jtalk, model, style, tokenizer, tts_util};
use ndarray::{concatenate, Array1, Array2, Array3, Axis};
use ort::Session;
use tokenizers::Tokenizer;
#[cfg(feature = "aivmx")]
use base64::prelude::{Engine as _, BASE64_STANDARD};
#[cfg(feature = "aivmx")]
use std::io::Cursor;
#[cfg(feature = "aivmx")]
use ndarray::ShapeBuilder;
use ndarray::{concatenate, Array1, Array2, Array3, Axis};
use ort::Session;
#[cfg(feature = "aivmx")]
use std::io::Cursor;
use tokenizers::Tokenizer;
#[derive(PartialEq, Eq, Clone)]
pub struct TTSIdent(String);
@@ -79,7 +79,7 @@ impl TTSModelHolder {
pub fn load_aivmx<I: Into<TTSIdent>, P: AsRef<[u8]>>(
&mut self,
ident: I,
aivmx_bytes: P
aivmx_bytes: P,
) -> Result<()> {
let ident = ident.into();
if self.find_model(ident.clone()).is_err() {
@@ -89,10 +89,11 @@ impl TTSModelHolder {
load = false;
}
}
let model = model::load_model(aivmx_bytes, false)?;
let model = model::load_model(&aivmx_bytes, false)?;
let metadata = model.metadata()?;
if let Some(aivm_style_vectors) = metadata.custom("aivm_style_vectors")? {
let style_vectors = Cursor::new(&BASE64_STANDARD.decode(aivm_style_vectors)?);
let aivm_style_vectors = BASE64_STANDARD.decode(aivm_style_vectors)?;
let style_vectors = Cursor::new(&aivm_style_vectors);
let reader = npyz::NpyFile::new(style_vectors)?;
let style_vectors = {
let shape = reader.shape().to_vec();
@@ -100,17 +101,14 @@ impl TTSModelHolder {
let data = reader.into_vec::<f32>()?;
let shape = match shape[..] {
[i1, i2] => [i1 as usize, i2 as usize],
_ => panic!("expected 2D array"),
_ => panic!("expected 2D array"),
};
let true_shape = shape.set_f(order == npyz::Order::Fortran);
ndarray::Array2::from_shape_vec(true_shape, data)?
};
drop(metadata);
self.models.push(TTSModel {
vits2: if load {
Some(model)
} else {
None
},
vits2: if load { Some(model) } else { None },
bytes: if self.max_loaded_models.is_some() {
Some(aivmx_bytes.as_ref().to_vec())
} else {