This commit is contained in:
tuna2134
2024-10-18 13:32:35 +00:00
parent c312fb0ce4
commit c4005808bd
2 changed files with 10 additions and 18 deletions

View File

@@ -1,5 +1,5 @@
use crate::error::Result;
use ndarray::Array2;
use ndarray::{Array2, Ix3};
use ort::Session;
pub fn predict(
@@ -14,10 +14,10 @@ pub fn predict(
}?
)?;
let output = outputs.get("output").unwrap();
let output = outputs["output"]
.try_extract_tensor::<f32>()?
.into_dimensionality::<Ix3>()?
.to_owned();
let content = output.try_extract_tensor::<f32>()?.to_owned();
let (data, _) = content.clone().into_raw_vec_and_offset();
Ok(Array2::from_shape_vec((content.shape()[0], content.shape()[1]), data).unwrap())
Ok(output)
}

View File

@@ -1,5 +1,5 @@
use crate::error::Result;
use ndarray::{array, Array1, Array2, Array3, Axis};
use ndarray::{array, Array1, Array2, Array3, Axis, Ix3};
use ort::{GraphOptimizationLevel, Session};
#[allow(clippy::vec_init_then_push, unused_variables)]
@@ -76,18 +76,10 @@ pub fn synthesize(
"length_scale" => array![length_scale],
}?)?;
let audio_array = outputs
.get("output")
.unwrap()
let audio_array = outputs["output"]
.try_extract_tensor::<f32>()?
.into_dimensionality::<Ix3>()?
.to_owned();
Ok(Array3::from_shape_vec(
(
audio_array.shape()[0],
audio_array.shape()[1],
audio_array.shape()[2],
),
audio_array.into_raw_vec_and_offset().0,
)?)
Ok(audio_array)
}