Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion crates/fetcher/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ path = "bin/main.rs"

[dependencies]
clap = { workspace = true, features = ["derive"] }
reqwest = { workspace = true, features = ["blocking", "rustls-tls"] }
reqwest = { workspace = true, features = ["rustls-tls", "stream"] }
tokio = { workspace = true, features = ["io-util", "fs"] }
futures-util = "0.3"
pin-project-lite = "0.2"
bytes = "1.0"
uuid = { version = "1.0", features = ["v4"] }
eyre = { workspace = true }
url = "2.5"
flate2 = "1.0"
Expand All @@ -19,3 +24,6 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }

[dev-dependencies]
tokio = { workspace = true, features = ["rt", "macros"] }
7 changes: 5 additions & 2 deletions crates/fetcher/bin/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct Args {
skip_if_valid_checksum: bool,
}

fn main() {
#[tokio::main]
async fn main() {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.init();
Expand Down Expand Up @@ -61,7 +62,9 @@ fn main() {
&args.destination,
&mut progress,
args.checksum,
) {
)
.await
{
tracing::error!("Failed to fetch: {}", e);

// Print the error chain
Expand Down
184 changes: 137 additions & 47 deletions crates/fetcher/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use eyre::{Context, Result};
use flate2::read::GzDecoder;
use futures_util::StreamExt;
use pin_project_lite::pin_project;
use sha2::{Digest, Sha256};
use std::fs::File;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::task::{Context as TaskContext, Poll};
use tar::Archive;
use tokio::io::{AsyncRead, ReadBuf};
use url::Url;
use uuid::Uuid;

#[derive(Debug, Clone, Copy, PartialEq)]
enum ArchiveFormat {
Expand Down Expand Up @@ -90,11 +95,11 @@ impl ProgressTracker for ConsoleProgressTracker {
}
}

pub fn fetch(source: &str, destination: &PathBuf, checksum: Option<String>) -> Result<()> {
fetch_with_progress(source, destination, &mut NoOpProgressTracker, checksum)
pub async fn fetch(source: &str, destination: &PathBuf, checksum: Option<String>) -> Result<()> {
fetch_with_progress(source, destination, &mut NoOpProgressTracker, checksum).await
}

pub fn fetch_with_progress<T: ProgressTracker>(
pub async fn fetch_with_progress<T: ProgressTracker>(
source: &str,
destination: &PathBuf,
progress: &mut T,
Expand All @@ -105,21 +110,18 @@ pub fn fetch_with_progress<T: ProgressTracker>(
Url::parse(source).with_context(|| format!("Failed to parse source as URL: {}", source))?;

match url.scheme() {
"http" | "https" => fetch_http(&url, destination, progress),
"http" | "https" => fetch_http(&url, destination, progress, checksum).await,
scheme => eyre::bail!("Unsupported URL scheme: {}", scheme),
}?;

if let Some(checksum) = checksum {
verify_checksum(destination, &checksum)?
}

Ok(())
}

fn fetch_http<T: ProgressTracker>(
async fn fetch_http<T: ProgressTracker>(
url: &Url,
destination: &PathBuf,
progress: &mut T,
checksum: Option<String>,
) -> Result<()> {
// Detect if the URL points to an archive
let archive_format = ArchiveFormat::detect(url);
Expand All @@ -131,7 +133,8 @@ fn fetch_http<T: ProgressTracker>(
}

// Download the file
let response = reqwest::blocking::get(url.as_str())
let response = reqwest::get(url.as_str())
.await
.with_context(|| format!("Failed to download from: {}", url))?;

if !response.status().is_success() {
Expand All @@ -143,25 +146,42 @@ fn fetch_http<T: ProgressTracker>(
progress.set_total(total);
}

// Create a progress reader wrapper
let mut progress_reader = ProgressReader::new(response, progress);

match archive_format {
// Determine download path: use temp file for tar.gz, direct destination otherwise
let download_path = match archive_format {
ArchiveFormat::TarGz => {
tracing::info!("Detected tar.gz archive, streaming decompression...");
extract_tar_gz(&mut progress_reader, destination)?;
let temp_dir = std::env::temp_dir();
temp_dir.join(format!("fetcher_{}.tar.gz", Uuid::new_v4()))
}
ArchiveFormat::None => {
// Standard file download
let mut file = File::create(destination)
.with_context(|| format!("Failed to create file: {}", destination.display()))?;
ArchiveFormat::None => destination.clone(),
};

std::io::copy(&mut progress_reader, &mut file).context("Failed to write file")?;
}
}
// Create a progress reader wrapper and download to file
let mut progress_reader = AsyncProgressReader::new(response.bytes_stream(), progress);
let mut file = tokio::fs::File::create(&download_path)
.await
.with_context(|| format!("Failed to create file: {}", download_path.display()))?;

tokio::io::copy(&mut progress_reader, &mut file)
.await
.context("Failed to write file")?;

progress_reader.finish();

// Verify checksum if provided
if let Some(expected_checksum) = checksum {
verify_checksum(&download_path, &expected_checksum)?;
}

// Extract if tar.gz
if archive_format == ArchiveFormat::TarGz {
tracing::info!("Extracting tar.gz archive...");
extract_tar_gz_from_file(&download_path, destination)?;
// Clean up temp file
tokio::fs::remove_file(&download_path)
.await
.with_context(|| format!("Failed to remove temp file: {}", download_path.display()))?;
}

Ok(())
}

Expand Down Expand Up @@ -194,19 +214,29 @@ pub fn verify_checksum(file_path: &Path, expected_checksum: &str) -> Result<()>
Ok(())
}

/// A reader wrapper that tracks progress
struct ProgressReader<'a, R: Read, T: ProgressTracker> {
inner: R,
progress: &'a mut T,
downloaded: u64,
// An async reader wrapper that tracks progress
pin_project! {
struct AsyncProgressReader<'a, S, T: ProgressTracker> {
#[pin]
stream: S,
progress: &'a mut T,
downloaded: u64,
buffer: Vec<u8>,
buffer_pos: usize,
}
}

impl<'a, R: Read, T: ProgressTracker> ProgressReader<'a, R, T> {
fn new(inner: R, progress: &'a mut T) -> Self {
impl<'a, S, T: ProgressTracker> AsyncProgressReader<'a, S, T>
where
S: futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
{
fn new(stream: S, progress: &'a mut T) -> Self {
Self {
inner,
stream,
progress,
downloaded: 0,
buffer: Vec::new(),
buffer_pos: 0,
}
}

Expand All @@ -215,18 +245,70 @@ impl<'a, R: Read, T: ProgressTracker> ProgressReader<'a, R, T> {
}
}

impl<'a, R: Read, T: ProgressTracker> Read for ProgressReader<'a, R, T> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let bytes_read = self.inner.read(buf)?;
self.downloaded += bytes_read as u64;
self.progress.update(self.downloaded);
Ok(bytes_read)
impl<'a, S, T: ProgressTracker> AsyncRead for AsyncProgressReader<'a, S, T>
where
S: futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut TaskContext<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let mut this = self.project();

// If we have buffered data, use it first
if *this.buffer_pos < this.buffer.len() {
let remaining = &this.buffer[*this.buffer_pos..];
let to_copy = std::cmp::min(remaining.len(), buf.remaining());
buf.put_slice(&remaining[..to_copy]);
*this.buffer_pos += to_copy;

*this.downloaded += to_copy as u64;
this.progress.update(*this.downloaded);

return Poll::Ready(Ok(()));
}

// Clear the buffer if we've consumed it all
this.buffer.clear();
*this.buffer_pos = 0;

// Poll for next chunk from stream
match this.stream.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(chunk))) => {
let chunk_len = chunk.len();
if chunk_len == 0 {
return Poll::Ready(Ok(()));
}

// Copy what we can directly into the output buffer
let to_copy = std::cmp::min(chunk_len, buf.remaining());
buf.put_slice(&chunk[..to_copy]);

*this.downloaded += to_copy as u64;
this.progress.update(*this.downloaded);

// Store any remaining data in our buffer
if to_copy < chunk_len {
this.buffer.extend_from_slice(&chunk[to_copy..]);
}

Poll::Ready(Ok(()))
}
Poll::Ready(Some(Err(e))) => {
Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
}
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Pending => Poll::Pending,
}
}
}

/// Extract a tar.gz archive from a reader to a destination directory
fn extract_tar_gz<R: Read>(reader: R, destination: &Path) -> Result<()> {
let gz = GzDecoder::new(reader);
/// Extract a tar.gz archive from a file to a destination directory
fn extract_tar_gz_from_file(file_path: &Path, destination: &Path) -> Result<()> {
let file = File::open(file_path)
.with_context(|| format!("Failed to open tar.gz file: {}", file_path.display()))?;
let gz = GzDecoder::new(file);
let mut archive = Archive::new(gz);

// Extract to the destination directory
Expand Down Expand Up @@ -295,8 +377,8 @@ mod tests {
assert_eq!(actual_content, CONTENT_TXT);
}

#[test]
fn test_download_content_txt() {
#[tokio::test]
async fn test_download_content_txt() {
let filename = "content.txt";
let checksum = "3dc7bc0209231cc61cb7d09c2efdfdf7aacb1f0b098db150780e980fa10d6b7a";

Expand All @@ -312,7 +394,8 @@ mod tests {
&destination,
&mut progress,
Some(checksum.to_string()),
);
)
.await;
assert!(
result.is_ok(),
"Failed to download {} or verify checksum: {:?}",
Expand All @@ -326,9 +409,10 @@ mod tests {
let _ = fs::remove_file(&destination);
}

#[test]
fn test_download_content_tar_gz() {
#[tokio::test]
async fn test_download_content_tar_gz() {
let filename = "content.tar.gz";
let checksum = "aa7d1aae79175b06c5529409d65f4794479c9b060381e059a8b6d1510fa2ae48";

let source = get_fixture_path(filename);
let destination = PathBuf::from(format!("/tmp/fetcher_test_{}", filename));
Expand All @@ -337,7 +421,13 @@ mod tests {

let mut progress = TestProgressTracker::new();

let result = fetch_with_progress(&source, &destination, &mut progress, None);
let result = fetch_with_progress(
&source,
&destination,
&mut progress,
Some(checksum.to_string()),
)
.await;
assert!(
result.is_ok(),
"Failed to download {} or verify checksum: {:?}",
Expand Down
Loading