mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-11 00:42:57 +00:00
support directml
This commit is contained in:
@@ -21,4 +21,5 @@ tokenizers = "0.20.0"
|
||||
[features]
|
||||
cuda = ["ort/cuda"]
|
||||
cuda_tf32 = []
|
||||
dynamic = ["ort/load-dynamic"]
|
||||
dynamic = ["ort/load-dynamic"]
|
||||
directml = ["ort/directml"]
|
||||
@@ -17,6 +17,10 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
||||
}
|
||||
exp.push(cuda.build());
|
||||
}
|
||||
#[cfg(feature = "directml")]
|
||||
{
|
||||
exp.push(ort::DirectMLExecutionProvider::default().build());
|
||||
}
|
||||
exp.push(ort::CPUExecutionProvider::default().build());
|
||||
Ok(Session::builder()?
|
||||
.with_execution_providers(exp)?
|
||||
@@ -26,6 +30,7 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
||||
.with_inter_threads(num_cpus::get_physical())?
|
||||
.commit_from_memory(model_file.as_ref())?)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn synthesize(
|
||||
session: &Session,
|
||||
|
||||
Reference in New Issue
Block a user