diff --git a/sbv2_core/src/bert.rs b/sbv2_core/src/bert.rs index cf0bdcf..0710bf6 100644 --- a/sbv2_core/src/bert.rs +++ b/sbv2_core/src/bert.rs @@ -1,5 +1,5 @@ use crate::error::Result; -use ndarray::{Array2, Ix3}; +use ndarray::{Array2, Ix2}; use ort::Session; pub fn predict( @@ -16,7 +16,7 @@ pub fn predict( let output = outputs["output"] .try_extract_tensor::()? - .into_dimensionality::()? + .into_dimensionality::()? .to_owned(); Ok(output)