diff --git a/src/lib.rs b/src/lib.rs index 6620143..3f18b08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,7 +100,7 @@ pub fn error_correction( cluster_path: &str, threads: usize, window_size: u32, - devices: Vec, + devices: Vec, batch_size: usize, aln_mode: AlnMode, ) where @@ -160,7 +160,7 @@ pub fn error_correction( s.spawn(move || { inference_worker( model_path, - tch::Device::Cuda(device), + if device == "cpu" { tch::Device::Cpu } else { tch::Device::Cuda(device.parse::().unwrap()) }, infer_recv, cons_sender, ) diff --git a/src/main.rs b/src/main.rs index 90e24de..d8de5ee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -87,9 +87,9 @@ struct InferenceArgs { short = 'd', value_delimiter = ',', default_value = "0", - help = "List of cuda devices in format d0,d1... (e.g 0,1,3) (default 0)" + help = "List of cuda devices in format d0,d1... (e.g 0,1,3), or cpu to use CPU (default 0)" )] - devices: Vec, + devices: Vec, #[arg( short = 'b',