diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..a234263 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,141 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +`fb-cli` is a Rust-based CLI tool for interacting with the Firebolt analytical database. It supports both one-off queries and an interactive REPL mode with history, search, and multi-line query support. + +Repo upstream: https://github.com/firebolt-db/fb-cli + +## Common Commands + +### Building and Running +```bash +# Build the project +cargo build + +# Build and install locally +cargo install --path . --locked + +# Run the CLI (after installation) +fb + +# Run directly with cargo +cargo run + +# Run with arguments +cargo run -- select 42 +``` + +### Testing +```bash +# Run all tests +cargo test + +# Run a specific test +cargo test test_name + +# Run tests with output +cargo test -- --nocapture + +# Run tests in a specific file/module +cargo test query::tests +``` + +### Code Quality +```bash +# Format code (uses rustfmt.toml config) +cargo fmt + +# Check code without building +cargo check + +# Lint with clippy +cargo clippy +``` + +## Architecture + +### Core Modules + +- **main.rs**: Entry point; handles REPL loop with rustyline for readline functionality. Binds Ctrl+O to insert newlines in the REPL. + +- **args.rs**: Command-line argument parsing using gumdrop. Manages configuration defaults stored in `~/.firebolt/fb_config` (YAML format). The `get_args()` function merges CLI args with saved defaults. URL construction happens in `get_url()`. + +- **query.rs**: Query execution logic. Uses Pest parser (defined in sql.pest) to split multi-statement queries while handling SQL comments and string literals. Sends HTTP/2 requests to Firebolt with keep-alive. Handles `set`/`unset` commands for dynamic parameter configuration. Supports async cancellation via Ctrl+C. + +- **auth.rs**: OAuth authentication for Service Accounts. Caches tokens in `~/.firebolt/fb_sa_token/` (valid for 30 minutes). Supports both JWT and bearer token authentication. + +- **context.rs**: Application state container holding Args, URL, and optional ServiceAccountToken. The URL is rebuilt when parameters change. + +- **utils.rs**: Helper functions for file paths (`~/.firebolt/` directory management), terminal spinner, and time formatting. + +### SQL Query Parsing + +The tool uses Pest parser generator (src/sql.pest) to split multiple SQL queries while correctly handling: +- Single and double-quoted strings with escaped quotes +- E-strings (PostgreSQL-style escaped strings) +- Raw strings ($$-delimited) +- Line comments (--) and nested block comments (/* */) +- Semicolons inside strings/comments (not treated as query terminators) + +The parser requires all queries to end with semicolons. Incomplete queries in the REPL are accumulated across lines until a semicolon is encountered. + +### Configuration and State + +- **Defaults**: Stored in `~/.firebolt/fb_config` (YAML). Updated with `--update-defaults` flag. +- **JWT tokens**: Can be loaded from `~/.firebolt/jwt` with `--jwt-from-file`. +- **Service Account tokens**: Cached in `~/.firebolt/fb_sa_token/` (YAML). +- **REPL history**: Saved to `~/.firebolt/history` (max 10,000 entries). + +### Authentication Modes + +1. **Local development**: Direct connection (default: localhost:8123 or localhost:9123 with JWT) +2. **Firebolt Core**: Use `-C` flag (connects to localhost:3473, uses PSQL format) +3. **Service Account**: Provide `--sa-id`, `--sa-secret`, and `--oauth-env` (staging/app) +4. **Bearer token**: Use `--bearer` flag for browser-extracted tokens +5. **JWT**: Use `--jwt` or `--jwt-from-file` + +### HTTP Protocol + +- Uses HTTP/2 with keep-alive (3600s timeout, 60s interval) +- Custom headers: `user-agent` (fdb-cli/version), `Firebolt-Protocol-Version` (2.3) +- Supports dynamic parameter updates via response headers: + - `firebolt-update-parameters`: Sets query parameters + - `firebolt-remove-parameters`: Removes query parameters + - `firebolt-update-endpoint`: Changes host/endpoint with parameters + +## Firebolt-Specific Features + +### SET/UNSET Commands +In REPL mode, dynamically modify query parameters: +```sql +set format = Vertical; +set engine = my_engine; +unset enable_result_cache; +``` + +These modify the URL query string. `format` is a special case that maps to `output_format` parameter. + +### Output Formats +Controlled by `--format` flag or `set format = ...`: +- PSQL (default for most modes) +- TabSeparatedWithNames +- TabSeparatedWithNamesAndTypes +- JSONLines_Compact +- CSVWithNames +- Vertical + +### Query Labels +Add `--label` to tag queries for tracking in Firebolt logs. + +## Testing + +The codebase has comprehensive unit tests, especially for: +- SQL query splitting with complex edge cases (see query.rs tests) +- Parameter encoding and URL generation (see args.rs tests) +- SET/UNSET command parsing +- String literals and comment handling + +When modifying the SQL parser (sql.pest), run `cargo test` to validate against the extensive test suite in query.rs. diff --git a/Cargo.lock b/Cargo.lock index 3432a9e..897f5ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,6 +118,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width 0.2.2", + "windows-sys 0.59.0", +] + [[package]] name = "core-foundation" version = "0.9.3" @@ -181,9 +194,15 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys", + "windows-sys 0.48.0", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.33" @@ -212,7 +231,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -235,6 +254,7 @@ checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" name = "fb" version = "0.2.3" dependencies = [ + "console", "dirs", "gumdrop", "once_cell", @@ -261,7 +281,7 @@ checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -417,7 +437,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -620,7 +640,7 @@ checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -766,7 +786,7 @@ dependencies = [ "libc", "redox_syscall 0.4.1", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -1002,7 +1022,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1033,7 +1053,7 @@ dependencies = [ "regex", "scopeguard", "unicode-segmentation", - "unicode-width", + "unicode-width 0.1.11", "utf8parse", "winapi", ] @@ -1050,7 +1070,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1189,7 +1209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1257,7 +1277,7 @@ dependencies = [ "fastrand", "redox_syscall 0.4.1", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1331,7 +1351,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1502,6 +1522,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unsafe-libyaml" version = "0.2.11" @@ -1662,7 +1688,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] @@ -1671,13 +1706,29 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -1686,42 +1737,90 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "winnow" version = "0.5.19" @@ -1738,5 +1837,5 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ "cfg-if", - "windows-sys", + "windows-sys 0.48.0", ] diff --git a/Cargo.toml b/Cargo.toml index ac2a217..1b90a4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ license = "Apache-2.0" rustyline = { version = "12.0.0", features = ["case_insensitive_history_search"] } gumdrop = { version = "0.8.1", features = ["default_expr"] } reqwest = { version = "0.12", features = ["json", "http2"] } -openssl = { version = "*", features = ["vendored"] } +openssl = { version = "*", features = ["vendored"] } tokio = { version = "1", features = ["full"] } tokio-util = "0.7.10" dirs = "5.0" @@ -23,3 +23,4 @@ toml = "0.8" urlencoding = "2.1" pest = "2.7" pest_derive = "2.7" +console = "0.15" diff --git a/src/args.rs b/src/args.rs index 4e4d4c7..9da6c6b 100644 --- a/src/args.rs +++ b/src/args.rs @@ -62,6 +62,10 @@ pub struct Args { #[serde(skip_serializing, skip_deserializing)] pub sa_secret: String, + #[options(no_short, help = "Account name for Service Account authentication")] + #[serde(skip_serializing, skip_deserializing)] + pub account_name: String, + #[options(no_short, help = "Load JWT from file (~/.firebolt/jwt)")] #[serde(default)] pub jwt_from_file: bool, @@ -149,10 +153,20 @@ pub fn get_args() -> Result> { serde_yaml::from_str("")? }; - let mut args = Args::parse_args_default_or_exit(); + let args_vec: Vec = std::env::args().skip(1).collect(); + let mut args = match Args::parse_args_default(&args_vec) { + Ok(args) => args, + Err(e) => { + eprintln!("{}", e); + std::process::exit(2); + } + }; args.extra = normalize_extras(args.extra, true)?; + // Auto-load saved credentials + crate::auth::load_saved_credentials(&mut args)?; + args.jwt_from_file = args.jwt_from_file || defaults.jwt_from_file; if args.jwt_from_file { let jwt_path = init_root_path()?.join("jwt"); @@ -234,10 +248,24 @@ pub fn get_url(args: &Args) -> String { }; let advanced_mode = if is_localhost { "" } else { "&advanced_mode=1" }; - format!( - "{protocol}://{host}/?{database}{query_label}{extra}{output_format}{advanced_mode}", - host = args.host - ) + // Build all query parameters + let all_params = format!("{database}{query_label}{extra}{output_format}{advanced_mode}"); + + // Check if host already contains query parameters (e.g., ?engine=name) + // If yes, append with &, otherwise start with /? + let params = if args.host.contains('?') { + // Host has query params, strip leading & from our params + all_params.strip_prefix('&').unwrap_or(&all_params) + } else { + // Host has no query params, use /? + all_params.as_str() + }; + + if args.host.contains('?') { + format!("{protocol}://{host}&{params}", host = args.host) + } else { + format!("{protocol}://{host}/?{params}", host = args.host) + } } #[cfg(test)] diff --git a/src/auth.rs b/src/auth.rs index 43d634c..0bcbb8d 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -4,13 +4,62 @@ use std::time::SystemTime; use tokio::task; use tokio_util::sync::CancellationToken; -use crate::context::{Context, ServiceAccountToken}; -use crate::utils::{format_remaining_time, sa_token_path, spin}; +use crate::context::{Context, SavedCredentials, ServiceAccountToken}; +use crate::utils::{credentials_path, format_remaining_time, sa_token_path, spin}; +use std::io::{self, Write}; + +/// Helper to create an authenticated context from saved credentials +async fn create_context_from_credentials( + host: String, + database: String, + format: String, + no_spinner: bool, +) -> Result> { + let creds_path = credentials_path()?; + if !creds_path.exists() { + return Err("No saved credentials found. Run 'fb auth' first.".into()); + } + + let saved_creds: SavedCredentials = serde_yaml::from_str(&fs::read_to_string(&creds_path)?)?; + + let temp_args = crate::args::Args { + command: String::new(), + core: false, + host, + database, + format, + extra: vec![], + label: String::new(), + jwt: String::new(), + sa_id: saved_creds.sa_id, + sa_secret: saved_creds.sa_secret, + account_name: saved_creds.account_name, + jwt_from_file: false, + oauth_env: saved_creds.oauth_env, + verbose: false, + concise: true, + hide_pii: false, + no_spinner, + update_defaults: false, + version: false, + help: false, + query: vec![], + }; + + let mut context = Context::new(temp_args); + context.update_url(); + authenticate_service_account(&mut context).await?; + + Ok(context) +} pub async fn authenticate_service_account(context: &mut Context) -> Result<(), Box> { if let Some(sa_token) = &context.sa_token { let valid_until = SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(sa_token.until); - if sa_token.sa_id == context.args.sa_id && sa_token.sa_secret == context.args.sa_secret && valid_until > SystemTime::now() { + if sa_token.sa_id == context.args.sa_id && + sa_token.sa_secret == context.args.sa_secret && + sa_token.oauth_env == context.args.oauth_env && + valid_until > SystemTime::now() { return Ok(()); } } @@ -37,7 +86,10 @@ pub async fn authenticate_service_account(context: &mut Context) -> Result<(), B if sa_token_path.exists() { if let Some(sa_token) = serde_yaml::from_str::>(&fs::read_to_string(&sa_token_path)?)? { let valid_until = SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(sa_token.until); - if sa_token.sa_id == args.sa_id && sa_token.sa_secret == args.sa_secret && valid_until > SystemTime::now() { + if sa_token.sa_id == args.sa_id && + sa_token.sa_secret == args.sa_secret && + sa_token.oauth_env == args.oauth_env && + valid_until > SystemTime::now() { if args.verbose { eprintln!( "Using cached SA token from {:?}, valid for {:}", @@ -108,6 +160,7 @@ pub async fn authenticate_service_account(context: &mut Context) -> Result<(), B sa_secret: args.sa_secret.clone(), token: response.access_token.unwrap().to_string(), until: valid_until.duration_since(SystemTime::UNIX_EPOCH)?.as_secs(), + oauth_env: args.oauth_env.clone(), }; args.jwt.clear(); @@ -143,6 +196,532 @@ pub async fn maybe_authenticate(context: &mut Context) -> Result<(), Box Result> { + let gateway_url = format!( + "https://{}/web/v3/account/{}/engineUrl", + api_endpoint, account_name + ); + + let client = reqwest::Client::new(); + let response = client + .get(&gateway_url) + .header("Authorization", format!("Bearer {}", access_token)) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await?; + return Err(format!( + "Failed to discover system engine URL (status {}): {}", + status, text + ) + .into()); + } + + #[derive(Deserialize)] + struct EngineUrlResponse { + #[serde(rename = "engineUrl")] + engine_url: String, + } + + let response_data: EngineUrlResponse = response.json().await?; + Ok(response_data.engine_url) +} + +/// Interactive prompt for authentication setup +pub async fn interactive_auth_setup() -> Result<(), Box> { + println!("Welcome to Firebolt CLI authentication setup!\n"); + + print!("Enter Service Account ID: "); + io::stdout().flush()?; + let mut sa_id = String::new(); + io::stdin().read_line(&mut sa_id)?; + let sa_id = sa_id.trim().to_string(); + + // Use dialoguer for better password input with visual feedback + use console::Term; + let term = Term::stderr(); + + eprint!("Enter Service Account Secret: "); + io::stderr().flush()?; + + let mut sa_secret = String::new(); + loop { + if let Ok(key) = term.read_key() { + match key { + console::Key::Enter => { + eprintln!(); + break; + } + console::Key::Backspace => { + if !sa_secret.is_empty() { + sa_secret.pop(); + eprint!("\x08 \x08"); + io::stderr().flush()?; + } + } + console::Key::Char(c) if !c.is_control() => { + sa_secret.push(c); + eprint!("*"); + io::stderr().flush()?; + } + _ => {} + } + } + } + let sa_secret = sa_secret.trim().to_string(); + + print!("Enter account name (e.g., my_account): "); + io::stdout().flush()?; + let mut account_name = String::new(); + io::stdin().read_line(&mut account_name)?; + let account_name = account_name.trim().to_string(); + + // Always use "app" environment (production) + let oauth_env = "app"; + let api_endpoint = "api.app.firebolt.io"; + + // Step 1: Authenticate to get access token + println!("\nAuthenticating..."); + + // Create minimal args for authentication + let temp_args = crate::args::Args { + command: String::new(), + core: false, + host: api_endpoint.to_string(), // Temporarily use API endpoint + database: String::new(), + format: String::new(), + extra: vec![], + label: String::new(), + jwt: String::new(), + sa_id: sa_id.clone(), + sa_secret: sa_secret.clone(), + account_name: account_name.clone(), + jwt_from_file: false, + oauth_env: oauth_env.to_string(), + verbose: false, + concise: false, + hide_pii: false, + no_spinner: true, + update_defaults: false, + version: false, + help: false, + query: vec![], + }; + + let mut temp_context = crate::context::Context::new(temp_args); + authenticate_service_account(&mut temp_context).await?; + + let access_token = if let Some(token) = &temp_context.sa_token { + token.token.clone() + } else { + return Err("Failed to obtain access token".into()); + }; + + println!("✓ Authentication successful!"); + + // Step 2: Discover system engine URL + println!("Discovering system engine endpoint..."); + let system_engine_url = discover_system_engine_url(&account_name, &access_token, api_endpoint).await?; + println!("✓ System engine URL: {}", system_engine_url); + + // Update context with system engine URL + temp_context.args.host = system_engine_url.clone(); + temp_context.update_url(); + + // Step 3: Optional database and engine configuration + println!("\n(Optional) Configure defaults:"); + + print!("Default database name [press Enter to skip]: "); + io::stdout().flush()?; + let mut database_input = String::new(); + io::stdin().read_line(&mut database_input)?; + let database_name = database_input.trim(); + + let final_database = if !database_name.is_empty() { + // Validate database exists + println!("Validating database '{}'...", database_name); + let check_query = format!( + "SELECT catalog_name FROM information_schema.catalogs WHERE catalog_name = '{}'", + database_name.replace("'", "''") + ); + + match execute_query_internal(&mut temp_context, check_query).await { + Ok(response) => { + if response.contains(database_name) { + println!("✓ Database '{}' validated", database_name); + temp_context.args.database = database_name.to_string(); + temp_context.update_url(); + Some(database_name.to_string()) + } else { + eprintln!("⚠ Warning: Database '{}' does not exist.", database_name); + eprintln!(" Skipping database configuration."); + None + } + } + Err(e) => { + eprintln!("⚠ Warning: Failed to validate database: {}", e); + eprintln!(" Skipping database configuration."); + None + } + } + } else { + None + }; + + print!("Default engine name [press Enter to skip]: "); + io::stdout().flush()?; + let mut engine_input = String::new(); + io::stdin().read_line(&mut engine_input)?; + let engine_name = engine_input.trim(); + + let final_host = if !engine_name.is_empty() { + // Validate engine exists + println!("Validating engine '{}'...", engine_name); + let check_query = format!( + "SELECT engine_name FROM information_schema.engines WHERE engine_name = '{}'", + engine_name.replace("'", "''") + ); + + match execute_query_internal(&mut temp_context, check_query).await { + Ok(response) => { + if response.contains(engine_name) { + // Engine exists, now resolve its endpoint + println!("Resolving engine '{}' endpoint...", engine_name); + let use_engine_query = format!("USE ENGINE {}", engine_name); + match crate::query::query(&mut temp_context, use_engine_query).await { + Ok(_) => { + println!("✓ Engine '{}' configured: {}", engine_name, temp_context.args.host); + Some(temp_context.args.host.clone()) + } + Err(e) => { + eprintln!("⚠ Warning: Failed to resolve engine '{}' endpoint: {}", engine_name, e); + eprintln!(" Continuing with system engine endpoint."); + Some(system_engine_url) + } + } + } else { + eprintln!("⚠ Warning: Engine '{}' does not exist.", engine_name); + eprintln!(" Continuing with system engine endpoint."); + Some(system_engine_url) + } + } + Err(e) => { + eprintln!("⚠ Warning: Failed to validate engine: {}", e); + eprintln!(" Continuing with system engine endpoint."); + Some(system_engine_url) + } + } + } else { + Some(system_engine_url) + }; + + let saved_creds = SavedCredentials { + sa_id, + sa_secret, + oauth_env: oauth_env.to_string(), + account_name, + host: final_host, + database: final_database, + }; + + let creds_path = credentials_path()?; + fs::write(&creds_path, serde_yaml::to_string(&saved_creds)?)?; + + println!("\nCredentials saved to {:?}", creds_path); + println!("\n✓ Setup complete! You can now run queries:"); + if saved_creds.database.is_some() { + println!(" fb \"select 42\""); + println!(" fb \"select * from my_table\""); + } else { + println!(" fb -d \"select 42\""); + println!(" fb -d \"select * from my_table\""); + } + + Ok(()) +} + +/// Helper function to execute a query and return the response text (doesn't print to stdout) +async fn execute_query_internal( + context: &mut Context, + query_text: String, +) -> Result> { + use crate::{FIREBOLT_PROTOCOL_VERSION, USER_AGENT}; + + let mut request = reqwest::Client::builder() + .http2_keep_alive_timeout(std::time::Duration::from_secs(3600)) + .http2_keep_alive_interval(Some(std::time::Duration::from_secs(60))) + .http2_keep_alive_while_idle(false) + .tcp_keepalive(Some(std::time::Duration::from_secs(60))) + .build()? + .post(context.url.clone()) + .header("user-agent", USER_AGENT) + .header("Firebolt-Protocol-Version", FIREBOLT_PROTOCOL_VERSION) + .body(query_text); + + if let Some(sa_token) = &context.sa_token { + request = request.header("authorization", format!("Bearer {}", sa_token.token)); + } + + let response = request.send().await?; + + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await?; + return Err(format!("Query failed (status {}): {}", status, text).into()); + } + + Ok(response.text().await?) +} + +/// Set default database in saved credentials (with validation) +pub async fn set_default_database(database_name: String) -> Result<(), Box> { + let creds_path = credentials_path()?; + let saved_creds: SavedCredentials = serde_yaml::from_str(&fs::read_to_string(&creds_path)?)?; + + // Get system engine host for querying information_schema + let system_engine_host = if let Some(host) = &saved_creds.host { + if let Some(pos) = host.find("?engine=") { + host[..pos].to_string() + } else { + host.clone() + } + } else { + return Err("No host configured. Run 'fb auth' to set up credentials.".into()); + }; + + let mut temp_context = create_context_from_credentials( + system_engine_host, + String::new(), + String::from("TabSeparatedWithNames"), + true, + ) + .await?; + + // Query information_schema.catalogs to check if database exists + println!("Validating database '{}'...", database_name); + + let check_query = format!( + "SELECT catalog_name FROM information_schema.catalogs WHERE catalog_name = '{}'", + database_name.replace("'", "''") // Escape single quotes + ); + + match execute_query_internal(&mut temp_context, check_query).await { + Ok(response) => { + // Check if response contains the database name (simple check - response will have header + data row if exists) + if response.contains(&database_name) { + // Database exists! Save it to credentials + let mut updated_creds = saved_creds; + updated_creds.database = Some(database_name.clone()); + + fs::write(&creds_path, serde_yaml::to_string(&updated_creds)?)?; + println!("✓ Default database set to: {}", database_name); + Ok(()) + } else { + eprintln!("Database '{}' does not exist.", database_name); + eprintln!("Run 'fb show databases' to see available databases."); + Err("Database validation failed".into()) + } + } + Err(e) => { + Err(format!("Failed to validate database: {}", e).into()) + } + } +} + +/// Set default engine in saved credentials by running USE ENGINE and capturing endpoint +pub async fn set_default_engine(engine_name: String) -> Result<(), Box> { + let creds_path = credentials_path()?; + if !creds_path.exists() { + return Err("No saved credentials found. Run 'fb auth' first.".into()); + } + + let mut saved_creds: SavedCredentials = serde_yaml::from_str(&fs::read_to_string(&creds_path)?)?; + + // Need to get the system engine URL first + let system_engine_host = if let Some(host) = &saved_creds.host { + // If the host contains ?engine=, strip it to get system engine + if let Some(pos) = host.find("?engine=") { + host[..pos].to_string() + } else { + host.clone() + } + } else { + // No host saved, need to discover it + println!("Discovering system engine endpoint..."); + + let api_endpoint = "api.app.firebolt.io"; + let temp_context = create_context_from_credentials( + api_endpoint.to_string(), + String::new(), + String::new(), + true, + ) + .await?; + + let access_token = if let Some(token) = &temp_context.sa_token { + token.token.clone() + } else { + return Err("Failed to obtain access token".into()); + }; + + let system_url = + discover_system_engine_url(&saved_creds.account_name, &access_token, api_endpoint) + .await?; + println!("✓ System engine: {}", system_url); + system_url + }; + + // Create context with system engine to run USE ENGINE + let mut temp_context = create_context_from_credentials( + system_engine_host.clone(), + saved_creds.database.clone().unwrap_or_default(), + String::new(), + true, + ) + .await?; + + // First, validate that the engine exists by querying information_schema.engines + println!("Validating engine '{}'...", engine_name); + let check_query = format!( + "SELECT engine_name FROM information_schema.engines WHERE engine_name = '{}'", + engine_name.replace("'", "''") // Escape single quotes + ); + + match execute_query_internal(&mut temp_context, check_query).await { + Ok(response) => { + // Check if response contains the engine name + if !response.contains(&engine_name) { + eprintln!("Engine '{}' does not exist.", engine_name); + eprintln!("Run 'fb show engines' to see available engines."); + return Err("Engine validation failed".into()); + } + } + Err(e) => { + return Err(format!("Failed to validate engine: {}", e).into()); + } + } + + // Engine exists! Now run USE ENGINE to get the endpoint + println!("Resolving engine '{}' endpoint...", engine_name); + let use_engine_query = format!("USE ENGINE {}", engine_name); + crate::query::query(&mut temp_context, use_engine_query).await?; + + // The query function will have updated temp_context.args.host with the engine endpoint + saved_creds.host = Some(temp_context.args.host.clone()); + + fs::write(&creds_path, serde_yaml::to_string(&saved_creds)?)?; + println!("✓ Default engine set to: {} ({})", engine_name, temp_context.args.host); + + Ok(()) +} + +/// Display current authentication status +pub fn show_auth_status() -> Result<(), Box> { + let creds_path = credentials_path()?; + if !creds_path.exists() { + println!("No saved credentials found."); + println!("Run 'fb auth' to set up authentication."); + return Ok(()); + } + + let saved_creds: SavedCredentials = serde_yaml::from_str(&fs::read_to_string(&creds_path)?)?; + + println!("Authenticated as: Service Account"); + println!(" ID: {}", saved_creds.sa_id); + println!(" Account: {}", saved_creds.account_name); + println!(" Environment: {}", saved_creds.oauth_env); + + // Check if token is cached and valid + if let Ok(sa_token_path) = sa_token_path() { + if sa_token_path.exists() { + if let Ok(content) = fs::read_to_string(&sa_token_path) { + if let Ok(Some(token)) = + serde_yaml::from_str::>(&content) + { + let valid_until = SystemTime::UNIX_EPOCH + + std::time::Duration::from_secs(token.until); + if valid_until > SystemTime::now() { + println!( + " Token valid for: {}", + format_remaining_time(valid_until, "".into())? + ); + } else { + println!(" Token: expired"); + } + } + } + } + } + + if let Some(host) = &saved_creds.host { + println!(" System engine: {}", host); + } + if let Some(database) = &saved_creds.database { + println!(" Default database: {}", database); + } + + Ok(()) +} + +/// Clear saved credentials +pub fn clear_auth() -> Result<(), Box> { + let creds_path = credentials_path()?; + if creds_path.exists() { + fs::remove_file(&creds_path)?; + println!("Credentials cleared from {:?}", creds_path); + } else { + println!("No saved credentials to clear."); + } + Ok(()) +} + +/// Load saved credentials into Args +pub fn load_saved_credentials( + args: &mut crate::args::Args, +) -> Result<(), Box> { + let creds_path = credentials_path()?; + if !creds_path.exists() { + return Ok(()); // No saved credentials + } + + let saved_creds: SavedCredentials = serde_yaml::from_str(&fs::read_to_string(&creds_path)?)?; + + // Only apply if not explicitly provided on command line + if args.sa_id.is_empty() { + args.sa_id = saved_creds.sa_id; + } + if args.sa_secret.is_empty() { + args.sa_secret = saved_creds.sa_secret; + } + // Always use the saved oauth_env + args.oauth_env = saved_creds.oauth_env; + + // Apply default host/database if saved + if args.host.is_empty() { + if let Some(host) = saved_creds.host { + args.host = host; + } + } + if args.database.is_empty() { + if let Some(database) = saved_creds.database { + args.database = database; + } + } + + if args.verbose { + eprintln!("Loaded credentials from {:?}", creds_path); + } + + Ok(()) +} + #[cfg(test)] mod tests { // Add tests for authentication functionality when possible diff --git a/src/context.rs b/src/context.rs index b9b868c..c2746da 100644 --- a/src/context.rs +++ b/src/context.rs @@ -7,6 +7,22 @@ pub struct ServiceAccountToken { pub sa_secret: String, pub token: String, pub until: u64, + #[serde(default = "default_oauth_env")] + pub oauth_env: String, +} + +fn default_oauth_env() -> String { + "app".to_string() +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct SavedCredentials { + pub sa_id: String, + pub sa_secret: String, + pub oauth_env: String, + pub account_name: String, + pub host: Option, + pub database: Option, } pub struct Context { diff --git a/src/main.rs b/src/main.rs index 45df4b4..e297f55 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ mod auth; mod context; mod meta_commands; mod query; +mod show; mod utils; use args::get_args; @@ -19,6 +20,65 @@ pub const CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); pub const USER_AGENT: &str = concat!("fdb-cli/", env!("CARGO_PKG_VERSION")); pub const FIREBOLT_PROTOCOL_VERSION: &str = "2.3"; +fn print_help() { + println!("fb - Firebolt CLI v{}\n", CLI_VERSION); + println!("USAGE:"); + println!(" fb [OPTIONS] [QUERY]"); + println!(" fb auth [SUBCOMMAND]"); + println!(" fb use "); + println!(" fb show "); + println!(); + println!("QUERY EXECUTION:"); + println!(" fb \"SELECT 42\" Run a single query"); + println!(" fb Start interactive REPL"); + println!(); + println!("AUTHENTICATION:"); + println!(" fb auth Interactive authentication setup"); + println!(" fb auth check Show authentication status"); + println!(" fb auth clear Clear saved credentials"); + println!(); + println!("CONFIGURATION:"); + println!(" fb use database Set default database"); + println!(" fb use engine Set default engine (resolves endpoint)"); + println!(); + println!("DISCOVERY:"); + println!(" fb show databases List all available databases"); + println!(" fb show engines List all available engines"); + println!(); + println!("OPTIONS:"); + println!(" --database Database name (transient override)"); + println!(" -d Alias for --database"); + println!(" --host Hostname (transient override)"); + println!(" --format Output format (PSQL, TabSeparatedWithNames, etc.)"); + println!(" --label