Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct Server {
github: Arc<GitHub>,
client: Client,
path: Arc<String>,
ssh_key: Arc<Option<String>>,
}

impl Server {
Expand All @@ -28,6 +29,7 @@ impl Server {
status: Arc<Mutex<Status>>,
github: Arc<GitHub>,
path: Arc<String>,
ssh_key: Arc<Option<String>>,
) -> reqwest::Result<Self> {
let policy = Policy::custom(move |attempt| {
if attempt.previous().len() > Self::REDIRECTS {
Expand Down Expand Up @@ -79,6 +81,7 @@ impl Server {
github,
client: client_builder.build()?,
path,
ssh_key,
})
}

Expand All @@ -90,11 +93,12 @@ impl Server {
let github = self.github.clone();
let client = self.client.clone();
let path = self.path.clone();
let ssh_key = self.ssh_key.clone();

// Spawn a new task to handle the connection.
tokio::spawn(async move {
let stream = TokioIo::new(stream);
let service = Service::new(addr.ip(), status, github, client, path);
let service = Service::new(addr.ip(), status, github, client, path, ssh_key);
Builder::new().serve_connection(stream, service).await
});
}
Expand Down
40 changes: 40 additions & 0 deletions src/http/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct Service {
github: Arc<GitHub>,
client: Client,
path: Arc<String>,
ssh_key: Arc<Option<String>>,
}

impl Service {
Expand All @@ -36,13 +37,15 @@ impl Service {
github: Arc<GitHub>,
client: Client,
path: Arc<String>,
ssh_key: Arc<Option<String>>,
) -> Self {
Self {
remote,
status,
github,
client,
path,
ssh_key,
}
}

Expand Down Expand Up @@ -70,6 +73,36 @@ impl Service {
_ => client.get(url),
}
}

/// Fetch a user-provided public SSH key
fn fetch_ssh(
ssh_key: Option<&String>,
path: &str,
base_path: &str,
) -> Option<Response<BoxBody<Bytes, Infallible>>> {
// Check for SSH key endpoint
if path != format!("{base_path}/ssh-key") {
return None;
}

// Return key if it exists
match ssh_key {
Some(key) => {
let bytes = Bytes::from(key.clone());
Some(
Response::builder()
.status(Code::OK)
.header("content-type", "text/plain")
.header("content-length", bytes.len())
.body(BoxBody::new(StreamBody::new(Box::pin(stream::once(
async move { Ok(Frame::data(bytes)) },
)))))
.ok()?,
)
}
None => Some(EMPTY.reply(Code::NOT_FOUND, None, None)),
}
}
}

impl hyper::service::Service<Request<Incoming>> for Service {
Expand All @@ -82,6 +115,7 @@ impl hyper::service::Service<Request<Incoming>> for Service {
let client = self.client.clone();
let github = self.github.clone();
let path = self.path.clone();
let ssh_key = self.ssh_key.clone();
let remote = self.remote;

Box::pin(async move {
Expand Down Expand Up @@ -146,6 +180,12 @@ impl hyper::service::Service<Request<Incoming>> for Service {

// The GET request is used to fetch the assigned asset.
Method::GET => {
if let Some(response) =
Self::fetch_ssh(ssh_key.as_ref().as_ref(), req.uri().path(), &path)
{
return Ok(response);
}

match status.clone().assign(remote).await {
// No asset assigned, return poweroff EFI binary.
None => return Ok(POWEROFF_EFI.reply(None, Type::Efi, None)),
Expand Down
13 changes: 12 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ struct Args {
/// Path to offer services on
#[arg(short = 'p', long, default_value = concat!("/", std::env!("CARGO_PKG_NAME")))]
path: String,

/// Path to SSH public key file
#[arg(long)]
ssh_key_file: Option<std::path::PathBuf>,
}

/// Guard that ensures term settings are restored upon program exit
Expand All @@ -64,6 +68,13 @@ async fn main() -> Result<()> {
// Parse arguments
let args = Args::parse();

// Read SSH public key file if provided
let ssh_key = match args.ssh_key_file {
Some(path) => Some(std::fs::read_to_string(path)?),
None => None,
};
let ssh_key = Arc::new(ssh_key);

// Ensure we're authenticated with GitHub
let github = Arc::new(args.github.login().await?);

Expand All @@ -83,7 +94,7 @@ async fn main() -> Result<()> {
status.lock().await.render()?;

// Create the HTTP server
let server = Server::new(listener, status.clone(), github, path.clone())?;
let server = Server::new(listener, status.clone(), github, path.clone(), ssh_key)?;

// Create TXT records
let name = std::env!("CARGO_PKG_NAME");
Expand Down