From 250f82966621d99ee8b7a511a0fb48c499790167 Mon Sep 17 00:00:00 2001 From: Mikko Rautiainen Date: Mon, 27 May 2024 14:52:15 +0300 Subject: [PATCH] simple CPU-only inference --- src/lib.rs | 4 ++-- src/main.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6620143..3f18b08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,7 +100,7 @@ pub fn error_correction( cluster_path: &str, threads: usize, window_size: u32, - devices: Vec, + devices: Vec, batch_size: usize, aln_mode: AlnMode, ) where @@ -160,7 +160,7 @@ pub fn error_correction( s.spawn(move || { inference_worker( model_path, - tch::Device::Cuda(device), + if device == "cpu" { tch::Device::Cpu } else { tch::Device::Cuda(device.parse::().unwrap()) }, infer_recv, cons_sender, ) diff --git a/src/main.rs b/src/main.rs index 90e24de..d8de5ee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -87,9 +87,9 @@ struct InferenceArgs { short = 'd', value_delimiter = ',', default_value = "0", - help = "List of cuda devices in format d0,d1... (e.g 0,1,3) (default 0)" + help = "List of cuda devices in format d0,d1... (e.g 0,1,3), or cpu to use CPU (default 0)" )] - devices: Vec, + devices: Vec, #[arg( short = 'b',