diff --git a/Cargo.lock b/Cargo.lock index 66532ec..4d67acd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,9 +101,9 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "cassowary" diff --git a/src/http/server.rs b/src/http/server.rs index e2aba99..66b5657 100644 --- a/src/http/server.rs +++ b/src/http/server.rs @@ -17,6 +17,7 @@ pub struct Server { github: Arc, client: Client, path: Arc, + ssh_key: Arc>, } impl Server { @@ -28,6 +29,7 @@ impl Server { status: Arc>, github: Arc, path: Arc, + ssh_key: Arc>, ) -> reqwest::Result { let policy = Policy::custom(move |attempt| { if attempt.previous().len() > Self::REDIRECTS { @@ -79,6 +81,7 @@ impl Server { github, client: client_builder.build()?, path, + ssh_key, }) } @@ -90,11 +93,12 @@ impl Server { let github = self.github.clone(); let client = self.client.clone(); let path = self.path.clone(); + let ssh_key = self.ssh_key.clone(); // Spawn a new task to handle the connection. tokio::spawn(async move { let stream = TokioIo::new(stream); - let service = Service::new(addr.ip(), status, github, client, path); + let service = Service::new(addr.ip(), status, github, client, path, ssh_key); Builder::new().serve_connection(stream, service).await }); } diff --git a/src/http/service.rs b/src/http/service.rs index 028dad4..304525f 100644 --- a/src/http/service.rs +++ b/src/http/service.rs @@ -27,6 +27,7 @@ pub struct Service { github: Arc, client: Client, path: Arc, + ssh_key: Arc>, } impl Service { @@ -36,6 +37,7 @@ impl Service { github: Arc, client: Client, path: Arc, + ssh_key: Arc>, ) -> Self { Self { remote, @@ -43,6 +45,7 @@ impl Service { github, client, path, + ssh_key, } } @@ -70,6 +73,36 @@ impl Service { _ => client.get(url), } } + + /// Fetch a user-provided public SSH key + fn fetch_ssh( + ssh_key: Option<&String>, + path: &str, + base_path: &str, + ) -> Option>> { + // Check for SSH key endpoint + if path != format!("{base_path}/ssh-key") { + return None; + } + + // Return key if it exists + match ssh_key { + Some(key) => { + let bytes = Bytes::from(key.clone()); + Some( + Response::builder() + .status(Code::OK) + .header("content-type", "text/plain") + .header("content-length", bytes.len()) + .body(BoxBody::new(StreamBody::new(Box::pin(stream::once( + async move { Ok(Frame::data(bytes)) }, + ))))) + .ok()?, + ) + } + None => Some(EMPTY.reply(Code::NOT_FOUND, None, None)), + } + } } impl hyper::service::Service> for Service { @@ -82,6 +115,7 @@ impl hyper::service::Service> for Service { let client = self.client.clone(); let github = self.github.clone(); let path = self.path.clone(); + let ssh_key = self.ssh_key.clone(); let remote = self.remote; Box::pin(async move { @@ -146,6 +180,12 @@ impl hyper::service::Service> for Service { // The GET request is used to fetch the assigned asset. Method::GET => { + if let Some(response) = + Self::fetch_ssh(ssh_key.as_ref().as_ref(), req.uri().path(), &path) + { + return Ok(response); + } + match status.clone().assign(remote).await { // No asset assigned, return poweroff EFI binary. None => return Ok(POWEROFF_EFI.reply(None, Type::Efi, None)), diff --git a/src/main.rs b/src/main.rs index 2f7db99..e8b3f17 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,6 +44,10 @@ struct Args { /// Path to offer services on #[arg(short = 'p', long, default_value = concat!("/", std::env!("CARGO_PKG_NAME")))] path: String, + + /// Path to SSH public key file + #[arg(long)] + ssh_key_file: Option, } /// Guard that ensures term settings are restored upon program exit @@ -64,6 +68,13 @@ async fn main() -> Result<()> { // Parse arguments let args = Args::parse(); + // Read SSH public key file if provided + let ssh_key = match args.ssh_key_file { + Some(path) => Some(std::fs::read_to_string(path)?), + None => None, + }; + let ssh_key = Arc::new(ssh_key); + // Ensure we're authenticated with GitHub let github = Arc::new(args.github.login().await?); @@ -83,7 +94,7 @@ async fn main() -> Result<()> { status.lock().await.render()?; // Create the HTTP server - let server = Server::new(listener, status.clone(), github, path.clone())?; + let server = Server::new(listener, status.clone(), github, path.clone(), ssh_key)?; // Create TXT records let name = std::env!("CARGO_PKG_NAME");