Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let mut buffer: String = String::new();
loop {
let prompt = if !buffer.trim_start().is_empty() { "~> " } else { "=> " };
let prompt = if !buffer.trim_start().is_empty() {
"~> "
} else if context.args.extra.iter().any(|arg| arg.starts_with("transaction_id=")) {
"*> "
} else {
"=> "
};
let readline = rl.readline(prompt);

match readline {
Ok(line) => {
buffer += line.as_str();

if buffer.trim() == "quit" || buffer.trim() == "exit" {
break;
}

Comment on lines +71 to +75
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just bundling into; useful for testing with script

buffer += "\n";
if !line.is_empty() {
let queries = try_split_queries(&buffer).unwrap_or_default();
Expand Down
82 changes: 45 additions & 37 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use std::time::Instant;
use tokio::{select, signal, task};
use tokio_util::sync::CancellationToken;

use crate::USER_AGENT;
use crate::FIREBOLT_PROTOCOL_VERSION;
use crate::args::normalize_extras;
use crate::auth::authenticate_service_account;
use crate::context::Context;
use crate::utils::spin;
use crate::FIREBOLT_PROTOCOL_VERSION;
use crate::USER_AGENT;

// Set parameters via query
pub fn set_args(context: &mut Context, query: &str) -> Result<bool, Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -44,10 +44,6 @@ pub fn set_args(context: &mut Context, query: &str) -> Result<bool, Box<dyn std:

context.update_url();

if !context.args.concise && !context.args.hide_pii {
eprintln!("URL: {}", context.url);
}

return Ok(true);
}

Expand All @@ -67,10 +63,6 @@ pub fn unset_args(context: &mut Context, query: &str) -> Result<bool, Box<dyn st

context.update_url();

if !context.args.concise && !context.args.hide_pii {
eprintln!("URL: {}", context.url);
}

return Ok(true);
}

Expand All @@ -81,10 +73,18 @@ pub fn unset_args(context: &mut Context, query: &str) -> Result<bool, Box<dyn st
pub async fn query(context: &mut Context, query_text: String) -> Result<(), Box<dyn std::error::Error>> {
// Handle set/unset commands
if set_args(context, &query_text)? {
if !context.args.concise && !context.args.hide_pii {
eprintln!("URL: {}", context.url);
}

return Ok(());
}

if unset_args(context, &query_text)? {
if !context.args.concise && !context.args.hide_pii {
eprintln!("URL: {}", context.url);
}

return Ok(());
}

Expand Down Expand Up @@ -150,37 +150,45 @@ pub async fn query(context: &mut Context, query_text: String) -> Result<(), Box<
let mut maybe_request_id: Option<String> = None;
match response {
Ok(resp) => {
if let Some(header) = resp.headers().get("X-REQUEST-ID") {
maybe_request_id = header.to_str().map_or(None, |l| Some(String::from(l)));
}
if let Some(header) = resp.headers().get("firebolt-update-parameters") {
set_args(context, format!("set {}", header.to_str().unwrap()).as_str())?;
}
if let Some(header) = resp.headers().get("firebolt-remove-parameters") {
unset_args(context, format!("unset {}", header.to_str().unwrap()).as_str())?;
}
if let Some(header) = resp.headers().get("firebolt-update-endpoint") {
let header_str = header.to_str().unwrap();
// Split the header at the '?' character
if let Some(pos) = header_str.find('?') {
// Extract base URL and query part
let base_url = &header_str[..pos];
let query_part = &header_str[pos+1..];

// Update the context URL with just the base part
context.args.host = base_url.to_string();

// Process each query parameter
for param in query_part.split('&') {
if !param.is_empty() {
set_args(context, format!("set {};", param).as_str())?;
let mut updated_url = false;
for (header, value) in resp.headers() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also need to handle Firebolt-Reset-Session header, which removes all parameters that were added with Firebolt-Update-Parameters

if header == "firebolt-remove-parameters" {
unset_args(context, format!("unset {}", value.to_str()?).as_str())?;
updated_url = true;
} else if header == "firebolt-update-parameters" {
set_args(context, format!("set {}", value.to_str()?).as_str())?;
updated_url = true;
} else if header == "X-REQUEST-ID" {
maybe_request_id = value.to_str().map_or(None, |l| Some(String::from(l)));
updated_url = true;
} else if header == "firebolt-update-endpoint" {
let header_str = value.to_str()?;
// Split the header at the '?' character
if let Some(pos) = header_str.find('?') {
// Extract base URL and query part
let base_url = &header_str[..pos];
let query_part = &header_str[pos+1..];

// Update the context URL with just the base part
context.args.host = base_url.to_string();

// Process each query parameter
for param in query_part.split('&') {
if !param.is_empty() {
set_args(context, format!("set {};", param).as_str())?;
}
}
} else {
// No query parameters, just set the URL
context.args.host = header_str.to_string();
}
} else {
// No query parameters, just set the URL
context.args.host = header_str.to_string();
updated_url = true;
}
}
if updated_url && !context.args.concise && !context.args.hide_pii {
eprintln!("URL: {}", context.url);
}

// on stdout, on purpose
println!("{}", resp.text().await?);
}
Expand Down
30 changes: 30 additions & 0 deletions tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,33 @@ fn test_command_parsing() {

assert!(stdout.contains("1339"));
}

#[test]
fn test_exiting() {
let mut child = Command::new(env!("CARGO_BIN_EXE_fb"))
.args(&[
"--core",
"--concise",
"-f",
"TabSeparatedWithNamesAndTypes",
])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.spawn()
.unwrap();

let mut stdin = child.stdin.take().unwrap();
writeln!(stdin, "SELECT 42;").unwrap();
writeln!(stdin, "quit").unwrap();
drop(stdin); // Close stdin to end interactive mode

let output = child.wait_with_output().unwrap();
let stdout = String::from_utf8(output.stdout).unwrap();

assert!(output.status.success());
let mut lines = stdout.lines();
assert_eq!(lines.next().unwrap(), "?column?");
lines.next();
assert_eq!(lines.next().unwrap(), "42");
lines.next();
}