support directml

This commit is contained in:
tuna2134
2024-09-11 10:10:09 +00:00
parent ab7fc0b4da
commit bb7913e0fe
2 changed files with 7 additions and 1 deletions

View File

@@ -21,4 +21,5 @@ tokenizers = "0.20.0"
[features]
cuda = ["ort/cuda"]
cuda_tf32 = []
dynamic = ["ort/load-dynamic"]
dynamic = ["ort/load-dynamic"]
directml = ["ort/directml"]

View File

@@ -17,6 +17,10 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
}
exp.push(cuda.build());
}
#[cfg(feature = "directml")]
{
exp.push(ort::DirectMLExecutionProvider::default().build());
}
exp.push(ort::CPUExecutionProvider::default().build());
Ok(Session::builder()?
.with_execution_providers(exp)?
@@ -26,6 +30,7 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
.with_inter_threads(num_cpus::get_physical())?
.commit_from_memory(model_file.as_ref())?)
}
#[allow(clippy::too_many_arguments)]
pub fn synthesize(
session: &Session,