diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad7f17e3..0cf48553 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -188,4 +188,5 @@ jobs: REQSIGN_AWS_S3_TEST: on REQSIGN_AWS_S3_REGION: ap-northeast-1 REQSIGN_AWS_ROLE_ARN: ${{ secrets.REQSIGN_AWS_ROLE_ARN }} + REQSIGN_AWS_ASSUME_ROLE_ARN: ${{ secrets.REQSIGN_AWS_ASSUME_ROLE_ARN }} REQSIGN_AWS_PROVIDER_ARN: ${{ secrets.REQSIGN_AWS_PROVIDER_ARN }} diff --git a/src/aws/config.rs b/src/aws/config.rs index 06fa3565..b4e6a478 100644 --- a/src/aws/config.rs +++ b/src/aws/config.rs @@ -84,6 +84,22 @@ pub struct Config { /// - env value: [`AWS_WEB_IDENTITY_TOKEN_FILE`] /// - profile config: `web_identity_token_file` pub web_identity_token_file: Option, + + /// `assume_role_arn` indicates the role to assume. + pub assume_role_arn: Option, + /// `duration_seconds` indicates the duration (in seconds) of the role session. + /// available values: `900` to `43200` or configured maximum session duration + /// default to `3600` + pub duration_seconds: i64, + // /// `credential_source` indicates the source of the credentials + // /// to use for the initial AssumeRole call. + // /// `credential_source` and `source_profile` are mutually exclusive. + // /// available values: `Environment`, `Ec2InstanceMetadata` + // pub credential_source: Option, + // /// `source_profile` indicates the source profile to use for + // /// the initial AssumeRole call. + // /// `credential_source` and `source_profile` are mutually exclusive. + // pub source_profile: Option, } impl Default for Config { @@ -101,6 +117,8 @@ impl Default for Config { role_session_name: "reqsign".to_string(), external_id: None, web_identity_token_file: None, + assume_role_arn: None, + duration_seconds: DEFAULT_ROLE_DURATION_SECONDS, } } } diff --git a/src/aws/constants.rs b/src/aws/constants.rs index 2d5ec7a0..7db3799b 100644 --- a/src/aws/constants.rs +++ b/src/aws/constants.rs @@ -19,6 +19,8 @@ pub const AWS_ROLE_ARN: &str = "AWS_ROLE_ARN"; pub const AWS_ROLE_SESSION_NAME: &str = "AWS_ROLE_SESSION_NAME"; pub const AWS_STS_REGIONAL_ENDPOINTS: &str = "AWS_STS_REGIONAL_ENDPOINTS"; +pub const DEFAULT_ROLE_DURATION_SECONDS: i64 = 3600; + /// AsciiSet for [AWS UriEncode](https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html) /// /// - URI encode every byte except the unreserved characters: 'A'-'Z', 'a'-'z', '0'-'9', '-', '.', '_', and '~'. diff --git a/src/aws/credential.rs b/src/aws/credential.rs index f171c3d6..3cdac06d 100644 --- a/src/aws/credential.rs +++ b/src/aws/credential.rs @@ -14,6 +14,7 @@ use reqwest::Client; use serde::Deserialize; use super::config::Config; +use crate::aws::v4::Signer; use crate::time::now; use crate::time::parse_rfc3339; use crate::time::DateTime; @@ -109,8 +110,9 @@ impl Loader { /// 1. Environment variables /// 2. Shared config (`~/.aws/config`, `~/.aws/credentials`) /// 3. Web Identity Tokens - /// 4. ECS (IAM Roles for Tasks) & General HTTP credentials: - /// 5. EC2 IMDSv2 + /// 4. EC2 IMDSv2 + /// + /// Assume to Role if provided. pub async fn load(&self) -> Result> { // Return cached credential if it has been loaded at least once. match self.credential.lock().expect("lock poisoned").clone() { @@ -118,11 +120,15 @@ impl Loader { _ => (), } - let cred = self.load_inner().await?; + let source_cred = self.load_inner().await?; + let cred = if let Some(c) = source_cred { + self.load_via_assume_role(c).await? + } else { + None + }; let mut lock = self.credential.lock().expect("lock poisoned"); *lock = cred.clone(); - Ok(cred) } @@ -152,14 +158,6 @@ impl Loader { return Ok(Some(cred)); } - if let Ok(Some(cred)) = self - .load_via_assume_role() - .await - .map_err(|err| debug!("load credential via assume_role failed: {err:?}")) - { - return Ok(Some(cred)); - } - if let Ok(Some(cred)) = self .load_via_imds_v2() .await @@ -269,26 +267,41 @@ impl Loader { Ok(Some(cred)) } - async fn load_via_assume_role(&self) -> Result> { - let role_arn = match &self.config.role_arn { + async fn load_via_assume_role(&self, cred: Credential) -> Result> { + let role_arn = match &self.config.assume_role_arn { Some(role_arn) => role_arn, - None => return Ok(None), + None => return Ok(Some(cred)), }; + let duration_seconds = &self.config.duration_seconds; let role_session_name = &self.config.role_session_name; + let region = match &self.config.region { + Some(region) => region, + None => return Ok(Some(cred)), + }; let endpoint = self.sts_endpoint()?; + let signer = Signer::new("sts", region); + // Construct request to AWS STS Service. let mut url = format!("https://{endpoint}/?Action=AssumeRole&RoleArn={role_arn}&Version=2011-06-15&RoleSessionName={role_session_name}"); if let Some(external_id) = &self.config.external_id { write!(url, "&ExternalId={external_id}")?; } - let req = self.client.get(&url).header( - http::header::CONTENT_TYPE.as_str(), - "application/x-www-form-urlencoded", - ); - - let resp = req.send().await?; + if *duration_seconds > 0 { + write!(url, "&DurationSeconds={duration_seconds}")?; + } + let mut req = self + .client + .get(&url) + .header( + http::header::CONTENT_TYPE.as_str(), + "application/x-www-form-urlencoded", + ) + .build()?; + signer.sign(&mut req, &cred)?; + + let resp = self.client.execute(req).await?; if resp.status() != http::StatusCode::OK { let content = resp.text().await?; return Err(anyhow!("request to AWS STS Services failed: {content}")); @@ -620,7 +633,7 @@ mod tests { } // Ignore test if role_arn not set - let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ROLE_ARN") { + let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") { v } else { return Ok(()); @@ -647,7 +660,91 @@ mod tests { || { RUNTIME.block_on(async { let config = Config::default().from_env(); - let loader = Loader::new(reqwest::Client::new(), config); + let loader = + Loader::new(reqwest::Client::new(), config).with_disable_ec2_metadata(); + + let signer = Signer::new("s3", ®ion); + + let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region); + let mut req = Request::new(""); + *req.method_mut() = http::Method::GET; + *req.uri_mut() = + http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap(); + + let cred = loader + .load() + .await + .expect("credential must be valid") + .unwrap(); + + signer.sign(&mut req, &cred).expect("sign must success"); + + debug!("signed request url: {:?}", req.uri().to_string()); + debug!("signed request: {:?}", req); + + let client = Client::new(); + let resp = client.execute(req.try_into().unwrap()).await.unwrap(); + + let status = resp.status(); + debug!("got response: {:?}", resp); + debug!("got response content: {:?}", resp.text().await.unwrap()); + assert_eq!(status, StatusCode::NOT_FOUND); + }) + }, + ); + + Ok(()) + } + + #[test] + fn test_signer_with_web_loader_assume_role() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + dotenv::from_filename(".env").ok(); + + if env::var("REQSIGN_AWS_S3_TEST").is_err() + || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on" + { + return Ok(()); + } + + // Ignore test if role_arn not set + let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ROLE_ARN") { + v + } else { + return Ok(()); + }; + // Ignore test if assume_role_arn not set + let assume_role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") { + v + } else { + return Ok(()); + }; + + // let provider_arn = env::var("REQSIGN_AWS_PROVIDER_ARN").expect("REQSIGN_AWS_PROVIDER_ARN not exist"); + let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist"); + + let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist"); + let file_path = format!( + "{}/testdata/services/aws/web_identity_token_file", + env::current_dir() + .expect("current_dir must exist") + .to_string_lossy() + ); + fs::write(&file_path, github_token)?; + + temp_env::with_vars( + vec![ + (AWS_REGION, Some(®ion)), + (AWS_ROLE_ARN, Some(&role_arn)), + (AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)), + ], + || { + RUNTIME.block_on(async { + let mut config = Config::default().from_env(); + config.assume_role_arn = Some(assume_role_arn.clone()); + let loader = + Loader::new(reqwest::Client::new(), config).with_disable_ec2_metadata(); let signer = Signer::new("s3", ®ion);