diff --git a/crates/reqsign-aws-v4/Cargo.toml b/crates/reqsign-aws-v4/Cargo.toml index 7d6dff2..c672835 100644 --- a/crates/reqsign-aws-v4/Cargo.toml +++ b/crates/reqsign-aws-v4/Cargo.toml @@ -39,6 +39,8 @@ hex.workspace = true macro_rules_attribute.workspace = true once_cell.workspace = true pretty_assertions.workspace = true +reqsign-file-read-tokio = { path = "../reqsign-file-read-tokio" } +reqsign-http-send-reqwest = { path = "../reqsign-http-send-reqwest" } reqwest = { workspace = true, features = ["rustls-tls"] } sha2.workspace = true temp-env.workspace = true diff --git a/crates/reqsign-aws-v4/src/config.rs b/crates/reqsign-aws-v4/src/config.rs index db1caa1..20097bd 100644 --- a/crates/reqsign-aws-v4/src/config.rs +++ b/crates/reqsign-aws-v4/src/config.rs @@ -1,8 +1,4 @@ -use std::collections::HashMap; -use std::env; -#[cfg(not(target_arch = "wasm32"))] -use std::fs; - +use super::constants::*; #[cfg(not(target_arch = "wasm32"))] use anyhow::anyhow; #[cfg(not(target_arch = "wasm32"))] @@ -11,10 +7,7 @@ use anyhow::Result; use ini::Ini; #[cfg(not(target_arch = "wasm32"))] use log::debug; - -use super::constants::*; -#[cfg(not(target_arch = "wasm32"))] -use reqsign::dirs::expand_homedir; +use reqsign::Context; /// Config for aws services. #[derive(Clone)] @@ -131,8 +124,8 @@ impl Default for Config { impl Config { /// Load config from env. - pub fn from_env(mut self) -> Self { - let envs = env::vars().collect::>(); + pub fn from_env(mut self, ctx: &Context) -> Self { + let envs = ctx.env_vars(); if let Some(v) = envs.get(AWS_CONFIG_FILE) { self.config_file = v.to_string(); @@ -178,30 +171,31 @@ impl Config { /// If the env var AWS_PROFILE is set, this profile will be used, /// otherwise the contents of `self.profile` will be used. #[cfg(not(target_arch = "wasm32"))] - pub fn from_profile(mut self) -> Self { + pub async fn from_profile(mut self, ctx: &Context) -> Self { // self.profile is checked by the two load methods. - if let Ok(profile) = env::var(AWS_PROFILE) { + if let Some(profile) = ctx.env_var(AWS_PROFILE) { self.profile = profile; } // make sure we're getting profile info from the correct place. // Respecting these env vars also makes it possible to unit test // this method. - if let Ok(config_file) = env::var(AWS_CONFIG_FILE) { + if let Some(config_file) = ctx.env_var(AWS_CONFIG_FILE) { self.config_file = config_file; } - if let Ok(shared_credentials_file) = env::var(AWS_SHARED_CREDENTIALS_FILE) { + if let Some(shared_credentials_file) = ctx.env_var(AWS_SHARED_CREDENTIALS_FILE) { self.shared_credentials_file = shared_credentials_file; } // Ignore all errors happened internally. - let _ = self.load_via_profile_config_file().map_err(|err| { + let _ = self.load_via_profile_config_file(ctx).await.map_err(|err| { debug!("load_via_profile_config_file failed: {err:?}"); }); let _ = self - .load_via_profile_shared_credentials_file() + .load_via_profile_shared_credentials_file(ctx) + .await .map_err(|err| debug!("load_via_profile_shared_credentials_file failed: {err:?}")); self @@ -213,13 +207,13 @@ impl Config { /// - `aws_secret_access_key` /// - `aws_session_token` #[cfg(not(target_arch = "wasm32"))] - fn load_via_profile_shared_credentials_file(&mut self) -> Result<()> { - let path = expand_homedir(&self.shared_credentials_file) + async fn load_via_profile_shared_credentials_file(&mut self, ctx: &Context) -> Result<()> { + let path = ctx + .expand_home_dir(&self.shared_credentials_file) .ok_or_else(|| anyhow!("expand homedir failed"))?; - let _ = fs::metadata(&path)?; - - let conf = Ini::load_from_file(path)?; + let content = ctx.file_read(&path).await?; + let conf = Ini::load_from_str(&String::from_utf8_lossy(&content))?; let props = conf .section(Some(&self.profile)) @@ -239,13 +233,13 @@ impl Config { } #[cfg(not(target_arch = "wasm32"))] - fn load_via_profile_config_file(&mut self) -> Result<()> { - let path = - expand_homedir(&self.config_file).ok_or_else(|| anyhow!("expand homedir failed"))?; - - let _ = fs::metadata(&path)?; + async fn load_via_profile_config_file(&mut self, ctx: &Context) -> Result<()> { + let path = ctx + .expand_home_dir(&self.config_file) + .ok_or_else(|| anyhow!("expand homedir failed"))?; - let conf = Ini::load_from_file(path)?; + let content = ctx.file_read(&path).await?; + let conf = Ini::load_from_str(&String::from_utf8_lossy(&content))?; let section = match self.profile.as_str() { "default" => "default".to_string(), @@ -291,13 +285,17 @@ impl Config { mod tests { use super::*; use pretty_assertions::assert_eq; + use reqsign::StaticEnv; + use reqsign_file_read_tokio::TokioFileRead; + use reqsign_http_send_reqwest::ReqwestHttpSend; + use std::collections::HashMap; use std::fs::File; use std::io::Write; use tempfile::tempdir; - #[test] + #[tokio::test] #[cfg(not(target_arch = "wasm32"))] - fn test_config_from_profile_shared_credentials() -> Result<()> { + async fn test_config_from_profile_shared_credentials() -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); // Create a dummy credentials file to test against @@ -314,37 +312,37 @@ mod tests { writeln!(tmp_file, "aws_secret_access_key = PROFILE1SECRETACCESSKEY")?; writeln!(tmp_file, "aws_session_token = PROFILE1SESSIONTOKEN")?; - temp_env::with_vars( - [ - (AWS_PROFILE, Some("profile1".to_owned())), - (AWS_CONFIG_FILE, None::), + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + (AWS_PROFILE.to_string(), "profile1".to_string()), ( - AWS_SHARED_CREDENTIALS_FILE, - Some(file_path.to_str().unwrap().to_owned()), + AWS_SHARED_CREDENTIALS_FILE.to_string(), + file_path.to_str().unwrap().to_owned(), ), - ], - || { - let config = Config::default().from_profile(); - - assert_eq!(config.profile, "profile1".to_owned()); - assert_eq!(config.access_key_id, Some("PROFILE1ACCESSKEYID".to_owned())); - assert_eq!( - config.secret_access_key, - Some("PROFILE1SECRETACCESSKEY".to_owned()) - ); - assert_eq!( - config.session_token, - Some("PROFILE1SESSIONTOKEN".to_owned()) - ); - }, + ]), + }); + + let config = Config::default().from_profile(&context).await; + + assert_eq!(config.profile, "profile1".to_owned()); + assert_eq!(config.access_key_id, Some("PROFILE1ACCESSKEYID".to_owned())); + assert_eq!( + config.secret_access_key, + Some("PROFILE1SECRETACCESSKEY".to_owned()) + ); + assert_eq!( + config.session_token, + Some("PROFILE1SESSIONTOKEN".to_owned()) ); Ok(()) } - #[test] + #[tokio::test] #[cfg(not(target_arch = "wasm32"))] - fn test_config_from_profile_config() -> Result<()> { + async fn test_config_from_profile_config() -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); // Create a dummy credentials file to test against @@ -361,29 +359,29 @@ mod tests { writeln!(tmp_file, "aws_secret_access_key = PROFILE1SECRETACCESSKEY")?; writeln!(tmp_file, "aws_session_token = PROFILE1SESSIONTOKEN")?; - temp_env::with_vars( - [ - (AWS_PROFILE, Some("profile1".to_owned())), + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + (AWS_PROFILE.to_string(), "profile1".to_string()), ( - AWS_CONFIG_FILE, - Some(file_path.to_str().unwrap().to_owned()), + AWS_CONFIG_FILE.to_string(), + file_path.to_str().unwrap().to_owned(), ), - (AWS_SHARED_CREDENTIALS_FILE, None::), - ], - || { - let config = Config::default().from_profile(); - - assert_eq!(config.profile, "profile1".to_owned()); - assert_eq!(config.access_key_id, Some("PROFILE1ACCESSKEYID".to_owned())); - assert_eq!( - config.secret_access_key, - Some("PROFILE1SECRETACCESSKEY".to_owned()) - ); - assert_eq!( - config.session_token, - Some("PROFILE1SESSIONTOKEN".to_owned()) - ); - }, + ]), + }); + + let config = Config::default().from_profile(&context).await; + + assert_eq!(config.profile, "profile1".to_owned()); + assert_eq!(config.access_key_id, Some("PROFILE1ACCESSKEYID".to_owned())); + assert_eq!( + config.secret_access_key, + Some("PROFILE1SECRETACCESSKEY".to_owned()) + ); + assert_eq!( + config.session_token, + Some("PROFILE1SESSIONTOKEN".to_owned()) ); Ok(()) diff --git a/crates/reqsign-aws-v4/src/credential.rs b/crates/reqsign-aws-v4/src/credential.rs index 3d631cc..f3a5a83 100644 --- a/crates/reqsign-aws-v4/src/credential.rs +++ b/crates/reqsign-aws-v4/src/credential.rs @@ -574,22 +574,25 @@ struct Ec2MetadataIamSecurityCredentials { #[cfg(test)] mod tests { + use std::collections::HashMap; use std::env; use std::str::FromStr; use std::vec; + use super::*; + use crate::constants::*; + use crate::signer::Signer; use anyhow::Result; use http::Request; use http::StatusCode; use once_cell::sync::Lazy; use quick_xml::de; + use reqsign::{Context, StaticEnv}; + use reqsign_file_read_tokio::TokioFileRead; + use reqsign_http_send_reqwest::ReqwestHttpSend; use reqwest::Client; use tokio::runtime::Runtime; - use super::*; - use crate::constants::*; - use crate::signer::Signer; - static RUNTIME: Lazy = Lazy::new(|| { tokio::runtime::Builder::new_multi_thread() .enable_all() @@ -611,154 +614,156 @@ mod tests { }); } - #[test] - fn test_credential_env_loader_with_env() { + #[tokio::test] + async fn test_credential_env_loader_with_env() { let _ = env_logger::builder().is_test(true).try_init(); - temp_env::with_vars( - vec![ - (AWS_ACCESS_KEY_ID, Some("access_key_id")), - (AWS_SECRET_ACCESS_KEY, Some("secret_access_key")), - ], - || { - RUNTIME.block_on(async { - let l = DefaultLoader::new(Client::new(), Config::default().from_env()); - let x = l.load().await.expect("load must succeed"); - - let x = x.expect("must load succeed"); - assert_eq!("access_key_id", x.access_key_id); - assert_eq!("secret_access_key", x.secret_access_key); - }) - }, - ); + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + (AWS_ACCESS_KEY_ID.to_string(), "access_key_id".to_string()), + ( + AWS_SECRET_ACCESS_KEY.to_string(), + "secret_access_key".to_string(), + ), + ]), + }); + + let l = DefaultLoader::new(Client::new(), Config::default().from_env(&context)); + let x = l.load().await.expect("load must succeed"); + + let x = x.expect("must load succeed"); + assert_eq!("access_key_id", x.access_key_id); + assert_eq!("secret_access_key", x.secret_access_key); } - #[test] - fn test_credential_profile_loader_from_config() { + #[tokio::test] + async fn test_credential_profile_loader_from_config() { let _ = env_logger::builder().is_test(true).try_init(); - temp_env::with_vars( - vec![ - (AWS_ACCESS_KEY_ID, None), - (AWS_SECRET_ACCESS_KEY, None), + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ ( - AWS_CONFIG_FILE, - Some(format!( + AWS_CONFIG_FILE.to_string(), + format!( "{}/testdata/default_config", env::current_dir() .expect("current_dir must exist") .to_string_lossy() - )), + ), ), ( - AWS_SHARED_CREDENTIALS_FILE, - Some(format!( + AWS_SHARED_CREDENTIALS_FILE.to_string(), + format!( "{}/testdata/not_exist", env::current_dir() .expect("current_dir must exist") .to_string_lossy() - )), + ), ), - ], - || { - RUNTIME.block_on(async { - let l = DefaultLoader::new( - Client::new(), - Config::default().from_env().from_profile(), - ); - let x = l.load().await.unwrap().unwrap(); - assert_eq!("config_access_key_id", x.access_key_id); - assert_eq!("config_secret_access_key", x.secret_access_key); - }) - }, + ]), + }); + + let l = DefaultLoader::new( + Client::new(), + Config::default() + .from_env(&context) + .from_profile(&context) + .await, ); + let x = l.load().await.unwrap().unwrap(); + assert_eq!("config_access_key_id", x.access_key_id); + assert_eq!("config_secret_access_key", x.secret_access_key); } - #[test] - fn test_credential_profile_loader_from_shared() { + #[tokio::test] + async fn test_credential_profile_loader_from_shared() { let _ = env_logger::builder().is_test(true).try_init(); - temp_env::with_vars( - vec![ - (AWS_ACCESS_KEY_ID, None), - (AWS_SECRET_ACCESS_KEY, None), + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ ( - AWS_CONFIG_FILE, - Some(format!( + AWS_CONFIG_FILE.to_string(), + format!( "{}/testdata/not_exist", env::current_dir() - .expect("load must exist") + .expect("current_dir must exist") .to_string_lossy() - )), + ), ), ( - AWS_SHARED_CREDENTIALS_FILE, - Some(format!( + AWS_SHARED_CREDENTIALS_FILE.to_string(), + format!( "{}/testdata/default_credential", env::current_dir() - .expect("load must exist") + .expect("current_dir must exist") .to_string_lossy() - )), + ), ), - ], - || { - RUNTIME.block_on(async { - let l = DefaultLoader::new( - Client::new(), - Config::default().from_env().from_profile(), - ); - let x = l.load().await.unwrap().unwrap(); - assert_eq!("shared_access_key_id", x.access_key_id); - assert_eq!("shared_secret_access_key", x.secret_access_key); - }) - }, + ]), + }); + + let l = DefaultLoader::new( + Client::new(), + Config::default() + .from_env(&context) + .from_profile(&context) + .await, ); + let x = l.load().await.unwrap().unwrap(); + assert_eq!("shared_access_key_id", x.access_key_id); + assert_eq!("shared_secret_access_key", x.secret_access_key); } /// AWS_SHARED_CREDENTIALS_FILE should be taken first. - #[test] - fn test_credential_profile_loader_from_both() { + #[tokio::test] + async fn test_credential_profile_loader_from_both() { let _ = env_logger::builder().is_test(true).try_init(); - temp_env::with_vars( - vec![ - (AWS_ACCESS_KEY_ID, None), - (AWS_SECRET_ACCESS_KEY, None), + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ ( - AWS_CONFIG_FILE, - Some(format!( + AWS_CONFIG_FILE.to_string(), + format!( "{}/testdata/default_config", env::current_dir() .expect("current_dir must exist") .to_string_lossy() - )), + ), ), ( - AWS_SHARED_CREDENTIALS_FILE, - Some(format!( + AWS_SHARED_CREDENTIALS_FILE.to_string(), + format!( "{}/testdata/default_credential", env::current_dir() .expect("current_dir must exist") .to_string_lossy() - )), + ), ), - ], - || { - RUNTIME.block_on(async { - let l = DefaultLoader::new( - Client::new(), - Config::default().from_env().from_profile(), - ); - let x = l.load().await.expect("load must success").unwrap(); - assert_eq!("shared_access_key_id", x.access_key_id); - assert_eq!("shared_secret_access_key", x.secret_access_key); - }) - }, + ]), + }); + + let l = DefaultLoader::new( + Client::new(), + Config::default() + .from_env(&context) + .from_profile(&context) + .await, ); + let x = l.load().await.expect("load must success").unwrap(); + assert_eq!("shared_access_key_id", x.access_key_id); + assert_eq!("shared_secret_access_key", x.secret_access_key); } - #[test] - fn test_signer_with_web_loader() -> Result<()> { + #[tokio::test] + async fn test_signer_with_web_loader() -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); dotenv::from_filename(".env").ok(); @@ -788,54 +793,55 @@ mod tests { ); fs::write(&file_path, github_token)?; - temp_env::with_vars( - vec![ - (AWS_REGION, Some(®ion)), - (AWS_ROLE_ARN, Some(&role_arn)), - (AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)), - ], - || { - RUNTIME.block_on(async { - let config = Config::default().from_env(); - let loader = DefaultLoader::new(reqwest::Client::new(), config); - - let signer = Signer::new("s3", ®ion); - - let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); - let mut req = Request::new(""); - *req.method_mut() = http::Method::GET; - *req.uri_mut() = - http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); - - let cred = loader - .load() - .await - .expect("credential must be valid") - .unwrap(); - - let (mut req, body) = req.into_parts(); - signer.sign(&mut req, &cred).expect("sign must success"); - let req = Request::from_parts(req, body); - - debug!("signed request url: {:?}", req.uri().to_string()); - debug!("signed request: {:?}", req); - - let client = Client::new(); - let resp = client.execute(req.try_into().unwrap()).await.unwrap(); - - let status = resp.status(); - debug!("got response: {:?}", resp); - debug!("got response content: {:?}", resp.text().await.unwrap()); - assert_eq!(status, StatusCode::NOT_FOUND); - }) - }, - ); + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + (AWS_REGION.to_string(), region.to_string()), + (AWS_ROLE_ARN.to_string(), role_arn.to_string()), + ( + AWS_WEB_IDENTITY_TOKEN_FILE.to_string(), + file_path.to_string(), + ), + ]), + }); + + let config = Config::default().from_env(&context); + let loader = DefaultLoader::new(reqwest::Client::new(), config); + + let signer = Signer::new("s3", ®ion); + + let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); + let mut req = Request::new(""); + *req.method_mut() = http::Method::GET; + *req.uri_mut() = + http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); + + let cred = loader + .load() + .await + .expect("credential must be valid") + .unwrap(); + + let (mut req, body) = req.into_parts(); + signer.sign(&mut req, &cred).expect("sign must success"); + let req = Request::from_parts(req, body); + + debug!("signed request url: {:?}", req.uri().to_string()); + debug!("signed request: {:?}", req); + let client = Client::new(); + let resp = client.execute(req.try_into().unwrap()).await.unwrap(); + + let status = resp.status(); + debug!("got response: {:?}", resp); + debug!("got response content: {:?}", resp.text().await.unwrap()); + assert_eq!(status, StatusCode::NOT_FOUND); Ok(()) } - #[test] - fn test_signer_with_web_loader_assume_role() -> Result<()> { + #[tokio::test] + async fn test_signer_with_web_loader_assume_role() -> Result<()> { let _ = env_logger::builder().is_test(true).try_init(); dotenv::from_filename(".env").ok(); @@ -870,56 +876,57 @@ mod tests { ); fs::write(&file_path, github_token)?; - temp_env::with_vars( - vec![ - (AWS_REGION, Some(®ion)), - (AWS_ROLE_ARN, Some(&role_arn)), - (AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)), - ], - || { - RUNTIME.block_on(async { - let client = reqwest::Client::new(); - let default_loader = - DefaultLoader::new(client.clone(), Config::default().from_env()) - .with_disable_ec2_metadata(); - - let cfg = Config { - role_arn: Some(assume_role_arn.clone()), - region: Some(region.clone()), - sts_regional_endpoints: "regional".to_string(), - ..Default::default() - }; - let loader = - AssumeRoleLoader::new(client.clone(), cfg, Box::new(default_loader)) - .expect("AssumeRoleLoader must be valid"); - - let signer = Signer::new("s3", ®ion); - let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); - let mut req = Request::new(""); - *req.method_mut() = http::Method::GET; - *req.uri_mut() = - http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); - let cred = loader - .load() - .await - .expect("credential must be valid") - .unwrap(); - - let (mut parts, body) = req.into_parts(); - signer.sign(&mut parts, &cred).expect("sign must success"); - let req = Request::from_parts(parts, body); - - debug!("signed request url: {:?}", req.uri().to_string()); - debug!("signed request: {:?}", req); - let client = Client::new(); - let resp = client.execute(req.try_into().unwrap()).await.unwrap(); - let status = resp.status(); - debug!("got response: {:?}", resp); - debug!("got response content: {:?}", resp.text().await.unwrap()); - assert_eq!(status, StatusCode::NOT_FOUND); - }) - }, - ); + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let context = context.with_env(StaticEnv { + home_dir: None, + envs: HashMap::from_iter([ + (AWS_REGION.to_string(), region.to_string()), + (AWS_ROLE_ARN.to_string(), role_arn.to_string()), + ( + AWS_WEB_IDENTITY_TOKEN_FILE.to_string(), + file_path.to_string(), + ), + ]), + }); + + let client = reqwest::Client::new(); + let default_loader = + DefaultLoader::new(client.clone(), Config::default().from_env(&context)) + .with_disable_ec2_metadata(); + + let cfg = Config { + role_arn: Some(assume_role_arn.clone()), + region: Some(region.clone()), + sts_regional_endpoints: "regional".to_string(), + ..Default::default() + }; + let loader = AssumeRoleLoader::new(client.clone(), cfg, Box::new(default_loader)) + .expect("AssumeRoleLoader must be valid"); + + let signer = Signer::new("s3", ®ion); + let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); + let mut req = Request::new(""); + *req.method_mut() = http::Method::GET; + *req.uri_mut() = + http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); + let cred = loader + .load() + .await + .expect("credential must be valid") + .unwrap(); + + let (mut parts, body) = req.into_parts(); + signer.sign(&mut parts, &cred).expect("sign must success"); + let req = Request::from_parts(parts, body); + + debug!("signed request url: {:?}", req.uri().to_string()); + debug!("signed request: {:?}", req); + let client = Client::new(); + let resp = client.execute(req.try_into().unwrap()).await.unwrap(); + let status = resp.status(); + debug!("got response: {:?}", resp); + debug!("got response content: {:?}", resp.text().await.unwrap()); + assert_eq!(status, StatusCode::NOT_FOUND); Ok(()) } diff --git a/crates/reqsign-aws-v4/tests/main.rs b/crates/reqsign-aws-v4/tests/main.rs index 6e82e12..c04493e 100644 --- a/crates/reqsign-aws-v4/tests/main.rs +++ b/crates/reqsign-aws-v4/tests/main.rs @@ -9,14 +9,17 @@ use log::debug; use log::warn; use percent_encoding::utf8_percent_encode; use percent_encoding::NON_ALPHANUMERIC; +use reqsign::Context; use reqsign_aws_v4::Config; use reqsign_aws_v4::DefaultLoader; use reqsign_aws_v4::Signer; +use reqsign_file_read_tokio::TokioFileRead; +use reqsign_http_send_reqwest::ReqwestHttpSend; use reqwest::Client; use sha2::Digest; use sha2::Sha256; -fn init_signer() -> Option<(DefaultLoader, Signer)> { +async fn init_signer() -> Option<(DefaultLoader, Signer)> { let _ = env_logger::builder().is_test(true).try_init(); dotenv::from_filename(".env").ok(); @@ -26,6 +29,8 @@ fn init_signer() -> Option<(DefaultLoader, Signer)> { return None; } + let context = Context::new(TokioFileRead, ReqwestHttpSend::default()); + let config = Config { region: Some( env::var("REQSIGN_AWS_V4_REGION").expect("env REQSIGN_AWS_V4_REGION must set"), @@ -38,8 +43,9 @@ fn init_signer() -> Option<(DefaultLoader, Signer)> { ), ..Default::default() } - .from_env() - .from_profile(); + .from_env(&context) + .from_profile(&context) + .await; let region = config.region.as_deref().unwrap().to_string(); @@ -55,7 +61,7 @@ fn init_signer() -> Option<(DefaultLoader, Signer)> { #[tokio::test] async fn test_head_object() -> Result<()> { - let signer = init_signer(); + let signer = init_signer().await; if signer.is_none() { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); @@ -97,7 +103,7 @@ async fn test_head_object() -> Result<()> { #[tokio::test] async fn test_put_object_with_query() -> Result<()> { - let signer = init_signer(); + let signer = init_signer().await; if signer.is_none() { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); @@ -149,7 +155,7 @@ async fn test_put_object_with_query() -> Result<()> { #[tokio::test] async fn test_get_object_with_query() -> Result<()> { - let signer = init_signer(); + let signer = init_signer().await; if signer.is_none() { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); @@ -191,7 +197,7 @@ async fn test_get_object_with_query() -> Result<()> { #[tokio::test] async fn test_head_object_with_special_characters() -> Result<()> { - let signer = init_signer(); + let signer = init_signer().await; if signer.is_none() { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); @@ -237,7 +243,7 @@ async fn test_head_object_with_special_characters() -> Result<()> { #[tokio::test] async fn test_head_object_with_encoded_characters() -> Result<()> { - let signer = init_signer(); + let signer = init_signer().await; if signer.is_none() { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); @@ -283,7 +289,7 @@ async fn test_head_object_with_encoded_characters() -> Result<()> { #[tokio::test] async fn test_list_bucket() -> Result<()> { - let signer = init_signer(); + let signer = init_signer().await; if signer.is_none() { warn!("REQSIGN_AWS_V4_TEST is not set, skipped"); return Ok(()); diff --git a/crates/reqsign-file-read-tokio/Cargo.toml b/crates/reqsign-file-read-tokio/Cargo.toml new file mode 100644 index 0000000..1dcb3fa --- /dev/null +++ b/crates/reqsign-file-read-tokio/Cargo.toml @@ -0,0 +1,15 @@ +[package] +categories.workspace = true +description.workspace = true +documentation.workspace = true +edition.workspace = true +license.workspace = true +name = "reqsign-file-read-tokio" +repository.workspace = true +version = "0.1.0" + +[dependencies] +anyhow = "1" +async-trait = "0.1" +reqsign = { version = "0.16", path = "../reqsign" } +tokio = { version = "1", features = ["fs"] } diff --git a/crates/reqsign-file-read-tokio/src/lib.rs b/crates/reqsign-file-read-tokio/src/lib.rs new file mode 100644 index 0000000..705439d --- /dev/null +++ b/crates/reqsign-file-read-tokio/src/lib.rs @@ -0,0 +1,13 @@ +use anyhow::Result; +use async_trait::async_trait; +use reqsign::FileRead; + +#[derive(Debug, Clone, Copy, Default)] +pub struct TokioFileRead; + +#[async_trait] +impl FileRead for TokioFileRead { + async fn file_read(&self, path: &str) -> Result> { + tokio::fs::read(path).await.map_err(Into::into) + } +} diff --git a/crates/reqsign-http-send-reqwest/Cargo.toml b/crates/reqsign-http-send-reqwest/Cargo.toml new file mode 100644 index 0000000..a676f17 --- /dev/null +++ b/crates/reqsign-http-send-reqwest/Cargo.toml @@ -0,0 +1,18 @@ +[package] +categories.workspace = true +description.workspace = true +documentation.workspace = true +edition.workspace = true +license.workspace = true +name = "reqsign-http-send-reqwest" +repository.workspace = true +version = "0.1.0" + +[dependencies] +anyhow = "1" +async-trait = "0.1" +bytes.workspace = true +http-body-util = "0.1.2" +http.workspace = true +reqsign = { version = "0.16.0", path = "../reqsign" } +reqwest = { version = "0.12", default-features = false } diff --git a/crates/reqsign-http-send-reqwest/src/lib.rs b/crates/reqsign-http-send-reqwest/src/lib.rs new file mode 100644 index 0000000..a9e244e --- /dev/null +++ b/crates/reqsign-http-send-reqwest/src/lib.rs @@ -0,0 +1,29 @@ +use async_trait::async_trait; +use bytes::Bytes; +use http_body_util::BodyExt; +use reqsign::HttpSend; +use reqwest::{Client, Request}; + +#[derive(Debug, Default)] +pub struct ReqwestHttpSend { + client: Client, +} + +impl ReqwestHttpSend { + /// Create a new ReqwestHttpSend with a reqwest::Client. + pub fn new(client: Client) -> Self { + Self { client } + } +} + +#[async_trait] +impl HttpSend for ReqwestHttpSend { + async fn http_send(&self, req: http::Request) -> anyhow::Result> { + let req = Request::try_from(req)?; + let resp: http::Response<_> = self.client.execute(req).await?.into(); + + let (parts, body) = resp.into_parts(); + let bs = BodyExt::collect(body).await.map(|buf| buf.to_bytes())?; + Ok(http::Response::from_parts(parts, bs)) + } +} diff --git a/crates/reqsign/src/context.rs b/crates/reqsign/src/context.rs index e327614..2f49a62 100644 --- a/crates/reqsign/src/context.rs +++ b/crates/reqsign/src/context.rs @@ -2,6 +2,8 @@ use crate::env::{Env, OsEnv}; use crate::{FileRead, HttpSend}; use anyhow::Result; use bytes::Bytes; +use std::collections::HashMap; +use std::path::PathBuf; use std::sync::Arc; /// Context provides the context for the request signing. @@ -41,4 +43,40 @@ impl Context { pub async fn http_send(&self, req: http::Request) -> Result> { self.http.http_send(req).await } + + /// Get the home directory of the current user. + #[inline] + pub fn home_dir(&self) -> Option { + self.env.home_dir() + } + + /// Expand `~` in input path. + /// + /// - If path not starts with `~/` or `~\\`, returns `Some(path)` directly. + /// - Otherwise, replace `~` with home dir instead. + /// - If home_dir is not found, returns `None`. + pub fn expand_home_dir(&self, path: &str) -> Option { + if !path.starts_with("~/") && !path.starts_with("~\\") { + Some(path.to_string()) + } else { + self.home_dir() + .map(|home| path.replace('~', &home.to_string_lossy())) + } + } + + /// Get the environment variable. + /// + /// - Returns `Some(v)` if the environment variable is found and is valid utf-8. + /// - Returns `None` if the environment variable is not found or value is invalid. + #[inline] + pub fn env_var(&self, key: &str) -> Option { + self.env.var(key) + } + + /// Returns an hashmap of (variable, value) pairs of strings, for all the + /// environment variables of the current process. + #[inline] + pub fn env_vars(&self) -> HashMap { + self.env.vars() + } } diff --git a/crates/reqsign/src/env.rs b/crates/reqsign/src/env.rs index 24fc37d..fa0cc61 100644 --- a/crates/reqsign/src/env.rs +++ b/crates/reqsign/src/env.rs @@ -1,10 +1,18 @@ +use std::collections::HashMap; use std::fmt::Debug; -use std::{ffi::OsString, path::PathBuf}; +use std::path::PathBuf; /// Permits parameterizing the home functions via the _from variants pub trait Env: Debug + 'static { - /// Get an environment variable, as per std::env::var_os. - fn var_os(&self, key: &str) -> Option; + /// Get an environment variable. + /// + /// - Returns `Some(v)` if the environment variable is found and is valid utf-8. + /// - Returns `None` if the environment variable is not found or value is invalid. + fn var(&self, key: &str) -> Option; + + /// Returns an hashmap of (variable, value) pairs of strings, for all the + /// environment variables of the current process. + fn vars(&self) -> HashMap; /// Return the path to the users home dir, returns `None` if any error occurs. fn home_dir(&self) -> Option; @@ -15,8 +23,12 @@ pub trait Env: Debug + 'static { pub struct OsEnv; impl Env for OsEnv { - fn var_os(&self, key: &str) -> Option { - std::env::var_os(key) + fn var(&self, key: &str) -> Option { + std::env::var_os(key)?.into_string().ok() + } + + fn vars(&self) -> HashMap { + std::env::vars().collect() } #[cfg(any(unix, target_os = "redox"))] @@ -36,20 +48,26 @@ impl Env for OsEnv { } } -/// Implements Env for the mock context. -#[cfg(test)] +/// StaticEnv provides a static env environment. +/// +/// This is useful for testing or for providing a fixed environment. #[derive(Debug, Clone)] -pub struct MockEnv { +pub struct StaticEnv { + /// The home directory to use. pub home_dir: Option, - pub envs: std::collections::HashMap, + /// The environment variables to use. + pub envs: HashMap, } -#[cfg(test)] -impl Env for MockEnv { - fn var_os(&self, key: &str) -> Option { +impl Env for StaticEnv { + fn var(&self, key: &str) -> Option { self.envs.get(key).cloned() } + fn vars(&self) -> HashMap { + self.envs.clone() + } + fn home_dir(&self) -> Option { self.home_dir.clone() } diff --git a/crates/reqsign/src/lib.rs b/crates/reqsign/src/lib.rs index fee1543..f05698a 100644 --- a/crates/reqsign/src/lib.rs +++ b/crates/reqsign/src/lib.rs @@ -14,7 +14,6 @@ mod sign; pub use sign::*; -pub mod dirs; pub mod hash; pub mod time; @@ -24,5 +23,6 @@ mod http; pub use http::HttpSend; mod env; pub use env::Env; +pub use env::StaticEnv; mod context; pub use context::Context;