diff --git a/sbv2_core/src/bert.rs b/sbv2_core/src/bert.rs index 3bd8396..cf0bdcf 100644 --- a/sbv2_core/src/bert.rs +++ b/sbv2_core/src/bert.rs @@ -1,5 +1,5 @@ use crate::error::Result; -use ndarray::Array2; +use ndarray::{Array2, Ix3}; use ort::Session; pub fn predict( @@ -14,10 +14,10 @@ pub fn predict( }? )?; - let output = outputs.get("output").unwrap(); + let output = outputs["output"] + .try_extract_tensor::()? + .into_dimensionality::()? + .to_owned(); - let content = output.try_extract_tensor::()?.to_owned(); - let (data, _) = content.clone().into_raw_vec_and_offset(); - - Ok(Array2::from_shape_vec((content.shape()[0], content.shape()[1]), data).unwrap()) + Ok(output) } diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 485c3f7..9f2a221 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -1,5 +1,5 @@ use crate::error::Result; -use ndarray::{array, Array1, Array2, Array3, Axis}; +use ndarray::{array, Array1, Array2, Array3, Axis, Ix3}; use ort::{GraphOptimizationLevel, Session}; #[allow(clippy::vec_init_then_push, unused_variables)] @@ -76,18 +76,10 @@ pub fn synthesize( "length_scale" => array![length_scale], }?)?; - let audio_array = outputs - .get("output") - .unwrap() + let audio_array = outputs["output"] .try_extract_tensor::()? + .into_dimensionality::()? .to_owned(); - Ok(Array3::from_shape_vec( - ( - audio_array.shape()[0], - audio_array.shape()[1], - audio_array.shape()[2], - ), - audio_array.into_raw_vec_and_offset().0, - )?) + Ok(audio_array) }