From 21b768373a234ae17505d84069b345787802d690 Mon Sep 17 00:00:00 2001 From: Ferran Borreguero Date: Sun, 11 Jan 2026 17:00:45 +0800 Subject: [PATCH] Make fetcher async --- Cargo.lock | 21 +++++ crates/fetcher/Cargo.toml | 10 +- crates/fetcher/bin/main.rs | 7 +- crates/fetcher/src/lib.rs | 184 +++++++++++++++++++++++++++---------- 4 files changed, 172 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a719a67..0e9fd42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1772,17 +1772,22 @@ dependencies = [ name = "fetcher" version = "0.1.0" dependencies = [ + "bytes", "clap", "eyre", "flate2", + "futures-util", + "pin-project-lite", "reqwest", "serde", "serde_json", "sha2", "tar", + "tokio", "tracing", "tracing-subscriber", "url", + "uuid", ] [[package]] @@ -3497,12 +3502,14 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", ] @@ -4716,6 +4723,7 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ + "getrandom 0.3.4", "js-sys", "wasm-bindgen", ] @@ -4829,6 +4837,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmtimer" version = "0.4.3" diff --git a/crates/fetcher/Cargo.toml b/crates/fetcher/Cargo.toml index aa21943..0f1bc2c 100644 --- a/crates/fetcher/Cargo.toml +++ b/crates/fetcher/Cargo.toml @@ -9,7 +9,12 @@ path = "bin/main.rs" [dependencies] clap = { workspace = true, features = ["derive"] } -reqwest = { workspace = true, features = ["blocking", "rustls-tls"] } +reqwest = { workspace = true, features = ["rustls-tls", "stream"] } +tokio = { workspace = true, features = ["io-util", "fs"] } +futures-util = "0.3" +pin-project-lite = "0.2" +bytes = "1.0" +uuid = { version = "1.0", features = ["v4"] } eyre = { workspace = true } url = "2.5" flate2 = "1.0" @@ -19,3 +24,6 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/crates/fetcher/bin/main.rs b/crates/fetcher/bin/main.rs index 8f78f50..c888561 100644 --- a/crates/fetcher/bin/main.rs +++ b/crates/fetcher/bin/main.rs @@ -22,7 +22,8 @@ struct Args { skip_if_valid_checksum: bool, } -fn main() { +#[tokio::main] +async fn main() { tracing_subscriber::fmt() .with_max_level(tracing::Level::INFO) .init(); @@ -61,7 +62,9 @@ fn main() { &args.destination, &mut progress, args.checksum, - ) { + ) + .await + { tracing::error!("Failed to fetch: {}", e); // Print the error chain diff --git a/crates/fetcher/src/lib.rs b/crates/fetcher/src/lib.rs index 76d739f..c2cbe14 100644 --- a/crates/fetcher/src/lib.rs +++ b/crates/fetcher/src/lib.rs @@ -1,11 +1,16 @@ use eyre::{Context, Result}; use flate2::read::GzDecoder; +use futures_util::StreamExt; +use pin_project_lite::pin_project; use sha2::{Digest, Sha256}; use std::fs::File; -use std::io::Read; use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context as TaskContext, Poll}; use tar::Archive; +use tokio::io::{AsyncRead, ReadBuf}; use url::Url; +use uuid::Uuid; #[derive(Debug, Clone, Copy, PartialEq)] enum ArchiveFormat { @@ -90,11 +95,11 @@ impl ProgressTracker for ConsoleProgressTracker { } } -pub fn fetch(source: &str, destination: &PathBuf, checksum: Option) -> Result<()> { - fetch_with_progress(source, destination, &mut NoOpProgressTracker, checksum) +pub async fn fetch(source: &str, destination: &PathBuf, checksum: Option) -> Result<()> { + fetch_with_progress(source, destination, &mut NoOpProgressTracker, checksum).await } -pub fn fetch_with_progress( +pub async fn fetch_with_progress( source: &str, destination: &PathBuf, progress: &mut T, @@ -105,21 +110,18 @@ pub fn fetch_with_progress( Url::parse(source).with_context(|| format!("Failed to parse source as URL: {}", source))?; match url.scheme() { - "http" | "https" => fetch_http(&url, destination, progress), + "http" | "https" => fetch_http(&url, destination, progress, checksum).await, scheme => eyre::bail!("Unsupported URL scheme: {}", scheme), }?; - if let Some(checksum) = checksum { - verify_checksum(destination, &checksum)? - } - Ok(()) } -fn fetch_http( +async fn fetch_http( url: &Url, destination: &PathBuf, progress: &mut T, + checksum: Option, ) -> Result<()> { // Detect if the URL points to an archive let archive_format = ArchiveFormat::detect(url); @@ -131,7 +133,8 @@ fn fetch_http( } // Download the file - let response = reqwest::blocking::get(url.as_str()) + let response = reqwest::get(url.as_str()) + .await .with_context(|| format!("Failed to download from: {}", url))?; if !response.status().is_success() { @@ -143,25 +146,42 @@ fn fetch_http( progress.set_total(total); } - // Create a progress reader wrapper - let mut progress_reader = ProgressReader::new(response, progress); - - match archive_format { + // Determine download path: use temp file for tar.gz, direct destination otherwise + let download_path = match archive_format { ArchiveFormat::TarGz => { - tracing::info!("Detected tar.gz archive, streaming decompression..."); - extract_tar_gz(&mut progress_reader, destination)?; + let temp_dir = std::env::temp_dir(); + temp_dir.join(format!("fetcher_{}.tar.gz", Uuid::new_v4())) } - ArchiveFormat::None => { - // Standard file download - let mut file = File::create(destination) - .with_context(|| format!("Failed to create file: {}", destination.display()))?; + ArchiveFormat::None => destination.clone(), + }; - std::io::copy(&mut progress_reader, &mut file).context("Failed to write file")?; - } - } + // Create a progress reader wrapper and download to file + let mut progress_reader = AsyncProgressReader::new(response.bytes_stream(), progress); + let mut file = tokio::fs::File::create(&download_path) + .await + .with_context(|| format!("Failed to create file: {}", download_path.display()))?; + + tokio::io::copy(&mut progress_reader, &mut file) + .await + .context("Failed to write file")?; progress_reader.finish(); + // Verify checksum if provided + if let Some(expected_checksum) = checksum { + verify_checksum(&download_path, &expected_checksum)?; + } + + // Extract if tar.gz + if archive_format == ArchiveFormat::TarGz { + tracing::info!("Extracting tar.gz archive..."); + extract_tar_gz_from_file(&download_path, destination)?; + // Clean up temp file + tokio::fs::remove_file(&download_path) + .await + .with_context(|| format!("Failed to remove temp file: {}", download_path.display()))?; + } + Ok(()) } @@ -194,19 +214,29 @@ pub fn verify_checksum(file_path: &Path, expected_checksum: &str) -> Result<()> Ok(()) } -/// A reader wrapper that tracks progress -struct ProgressReader<'a, R: Read, T: ProgressTracker> { - inner: R, - progress: &'a mut T, - downloaded: u64, +// An async reader wrapper that tracks progress +pin_project! { + struct AsyncProgressReader<'a, S, T: ProgressTracker> { + #[pin] + stream: S, + progress: &'a mut T, + downloaded: u64, + buffer: Vec, + buffer_pos: usize, + } } -impl<'a, R: Read, T: ProgressTracker> ProgressReader<'a, R, T> { - fn new(inner: R, progress: &'a mut T) -> Self { +impl<'a, S, T: ProgressTracker> AsyncProgressReader<'a, S, T> +where + S: futures_util::Stream> + Unpin, +{ + fn new(stream: S, progress: &'a mut T) -> Self { Self { - inner, + stream, progress, downloaded: 0, + buffer: Vec::new(), + buffer_pos: 0, } } @@ -215,18 +245,70 @@ impl<'a, R: Read, T: ProgressTracker> ProgressReader<'a, R, T> { } } -impl<'a, R: Read, T: ProgressTracker> Read for ProgressReader<'a, R, T> { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let bytes_read = self.inner.read(buf)?; - self.downloaded += bytes_read as u64; - self.progress.update(self.downloaded); - Ok(bytes_read) +impl<'a, S, T: ProgressTracker> AsyncRead for AsyncProgressReader<'a, S, T> +where + S: futures_util::Stream> + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut this = self.project(); + + // If we have buffered data, use it first + if *this.buffer_pos < this.buffer.len() { + let remaining = &this.buffer[*this.buffer_pos..]; + let to_copy = std::cmp::min(remaining.len(), buf.remaining()); + buf.put_slice(&remaining[..to_copy]); + *this.buffer_pos += to_copy; + + *this.downloaded += to_copy as u64; + this.progress.update(*this.downloaded); + + return Poll::Ready(Ok(())); + } + + // Clear the buffer if we've consumed it all + this.buffer.clear(); + *this.buffer_pos = 0; + + // Poll for next chunk from stream + match this.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(chunk))) => { + let chunk_len = chunk.len(); + if chunk_len == 0 { + return Poll::Ready(Ok(())); + } + + // Copy what we can directly into the output buffer + let to_copy = std::cmp::min(chunk_len, buf.remaining()); + buf.put_slice(&chunk[..to_copy]); + + *this.downloaded += to_copy as u64; + this.progress.update(*this.downloaded); + + // Store any remaining data in our buffer + if to_copy < chunk_len { + this.buffer.extend_from_slice(&chunk[to_copy..]); + } + + Poll::Ready(Ok(())) + } + Poll::Ready(Some(Err(e))) => { + Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))) + } + Poll::Ready(None) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } } } -/// Extract a tar.gz archive from a reader to a destination directory -fn extract_tar_gz(reader: R, destination: &Path) -> Result<()> { - let gz = GzDecoder::new(reader); +/// Extract a tar.gz archive from a file to a destination directory +fn extract_tar_gz_from_file(file_path: &Path, destination: &Path) -> Result<()> { + let file = File::open(file_path) + .with_context(|| format!("Failed to open tar.gz file: {}", file_path.display()))?; + let gz = GzDecoder::new(file); let mut archive = Archive::new(gz); // Extract to the destination directory @@ -295,8 +377,8 @@ mod tests { assert_eq!(actual_content, CONTENT_TXT); } - #[test] - fn test_download_content_txt() { + #[tokio::test] + async fn test_download_content_txt() { let filename = "content.txt"; let checksum = "3dc7bc0209231cc61cb7d09c2efdfdf7aacb1f0b098db150780e980fa10d6b7a"; @@ -312,7 +394,8 @@ mod tests { &destination, &mut progress, Some(checksum.to_string()), - ); + ) + .await; assert!( result.is_ok(), "Failed to download {} or verify checksum: {:?}", @@ -326,9 +409,10 @@ mod tests { let _ = fs::remove_file(&destination); } - #[test] - fn test_download_content_tar_gz() { + #[tokio::test] + async fn test_download_content_tar_gz() { let filename = "content.tar.gz"; + let checksum = "aa7d1aae79175b06c5529409d65f4794479c9b060381e059a8b6d1510fa2ae48"; let source = get_fixture_path(filename); let destination = PathBuf::from(format!("/tmp/fetcher_test_{}", filename)); @@ -337,7 +421,13 @@ mod tests { let mut progress = TestProgressTracker::new(); - let result = fetch_with_progress(&source, &destination, &mut progress, None); + let result = fetch_with_progress( + &source, + &destination, + &mut progress, + Some(checksum.to_string()), + ) + .await; assert!( result.is_ok(), "Failed to download {} or verify checksum: {:?}",