Skip to content

Commit

Permalink
Add only_authenticated option to the client (#7545)
Browse files Browse the repository at this point in the history
  • Loading branch information
konstin committed Sep 21, 2024
1 parent 0d81bfb commit d9a5f5c
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 39 deletions.
82 changes: 55 additions & 27 deletions crates/uv-auth/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
realm::Realm,
CredentialsCache, KeyringProvider, CREDENTIALS_CACHE,
};
use anyhow::anyhow;
use anyhow::{anyhow, format_err};
use netrc::Netrc;
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next};
Expand All @@ -22,6 +22,9 @@ pub struct AuthMiddleware {
netrc: Option<Netrc>,
keyring: Option<KeyringProvider>,
cache: Option<CredentialsCache>,
/// We know that the endpoint needs authentication, so we don't try to send an unauthenticated
/// request, avoiding cloning an uncloneable request.
only_authenticated: bool,
}

impl AuthMiddleware {
Expand All @@ -30,6 +33,7 @@ impl AuthMiddleware {
netrc: Netrc::new().ok(),
keyring: None,
cache: None,
only_authenticated: false,
}
}

Expand All @@ -56,6 +60,14 @@ impl AuthMiddleware {
self
}

/// We know that the endpoint needs authentication, so we don't try to send an unauthenticated
/// request, avoiding cloning an uncloneable request.
#[must_use]
pub fn with_only_authenticated(mut self, only_authenticated: bool) -> Self {
self.only_authenticated = only_authenticated;
self
}

/// Get the configured authentication store.
///
/// If not set, the global store is used.
Expand Down Expand Up @@ -198,32 +210,42 @@ impl Middleware for AuthMiddleware {
.as_ref()
.is_some_and(|credentials| credentials.username().is_some());

// Otherwise, attempt an anonymous request
trace!("Attempting unauthenticated request for {url}");

// <https://github.com/TrueLayer/reqwest-middleware/blob/abdf1844c37092d323683c2396b7eefda1418d3c/reqwest-retry/src/middleware.rs#L141-L149>
// Clone the request so we can retry it on authentication failure
let mut retry_request = request.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not cloneable. Are you passing a streaming body?".to_string()
))
})?;

let response = next.clone().run(request, extensions).await?;

// If we don't fail with authorization related codes, return the response
if !matches!(
response.status(),
StatusCode::FORBIDDEN | StatusCode::NOT_FOUND | StatusCode::UNAUTHORIZED
) {
return Ok(response);
}
let (mut retry_request, response) = if self.only_authenticated {
// For endpoints where we require the user to provide credentials, we don't try the
// unauthenticated request first.
trace!("Checking for credentials for {url}");
(request, None)
} else {
// Otherwise, attempt an anonymous request
trace!("Attempting unauthenticated request for {url}");

// <https://github.com/TrueLayer/reqwest-middleware/blob/abdf1844c37092d323683c2396b7eefda1418d3c/reqwest-retry/src/middleware.rs#L141-L149>
// Clone the request so we can retry it on authentication failure
let retry_request = request.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not cloneable. Are you passing a streaming body?"
.to_string()
))
})?;

let response = next.clone().run(request, extensions).await?;

// If we don't fail with authorization related codes, return the response
if !matches!(
response.status(),
StatusCode::FORBIDDEN | StatusCode::NOT_FOUND | StatusCode::UNAUTHORIZED
) {
return Ok(response);
}

// Otherwise, search for credentials
trace!(
"Request for {url} failed with {}, checking for credentials",
response.status()
);
// Otherwise, search for credentials
trace!(
"Request for {url} failed with {}, checking for credentials",
response.status()
);

(retry_request, Some(response))
};

// Check in the cache first
let credentials = self.cache().get_realm(
Expand Down Expand Up @@ -265,7 +287,13 @@ impl Middleware for AuthMiddleware {
}
}

Ok(response)
if let Some(response) = response {
Ok(response)
} else {
Err(Error::Middleware(format_err!(
"Missing credentials for {url}"
)))
}
}
}

Expand Down
38 changes: 26 additions & 12 deletions crates/uv-client/src/base_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct BaseClientBuilder<'a> {
client: Option<Client>,
markers: Option<&'a MarkerEnvironment>,
platform: Option<&'a Platform>,
only_authenticated: bool,
}

impl Default for BaseClientBuilder<'_> {
Expand All @@ -55,6 +56,7 @@ impl BaseClientBuilder<'_> {
client: None,
markers: None,
platform: None,
only_authenticated: false,
}
}
}
Expand Down Expand Up @@ -108,6 +110,12 @@ impl<'a> BaseClientBuilder<'a> {
self
}

#[must_use]
pub fn only_authenticated(mut self, only_authenticated: bool) -> Self {
self.only_authenticated = only_authenticated;
self
}

pub fn is_offline(&self) -> bool {
matches!(self.connectivity, Connectivity::Offline)
}
Expand Down Expand Up @@ -230,20 +238,26 @@ impl<'a> BaseClientBuilder<'a> {
fn apply_middleware(&self, client: Client) -> ClientWithMiddleware {
match self.connectivity {
Connectivity::Online => {
let client = reqwest_middleware::ClientBuilder::new(client);

// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy(
retry_policy,
UvRetryableStrategy,
);
let client = client.with(retry_strategy);
let mut client = reqwest_middleware::ClientBuilder::new(client);

// Avoid uncloneable errors with a streaming body during publish.
if self.retries > 0 {
// Initialize the retry strategy.
let retry_policy =
ExponentialBackoff::builder().build_with_max_retries(self.retries);
let retry_strategy = RetryTransientMiddleware::new_with_policy_and_strategy(
retry_policy,
UvRetryableStrategy,
);
client = client.with(retry_strategy);
}

// Initialize the authentication middleware to set headers.
let client =
client.with(AuthMiddleware::new().with_keyring(self.keyring.to_provider()));
client = client.with(
AuthMiddleware::new()
.with_keyring(self.keyring.to_provider())
.with_only_authenticated(self.only_authenticated),
);

client.build()
}
Expand Down

0 comments on commit d9a5f5c

Please sign in to comment.