diff --git a/CHANGELOG.md b/CHANGELOG.md index c20e488..0e8d0b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ In Development - Reduced the sizes of a number of streams & futures - Added doc comments to much of the code - Return 502 status when a backend returns an invalid response +- Require `--api-url` (and other URLs retrieved from APIs) to be HTTP(S) v0.4.0 (2024-07-09) ------------------- diff --git a/src/dandi/mod.rs b/src/dandi/mod.rs index 56a1ca3..2b1bbd8 100644 --- a/src/dandi/mod.rs +++ b/src/dandi/mod.rs @@ -9,7 +9,7 @@ pub(crate) use self::types::*; pub(crate) use self::version_id::*; use crate::consts::S3CLIENT_CACHE_SIZE; use crate::dav::ErrorClass; -use crate::httputil::{urljoin_slashed, BuildClientError, Client, HttpError}; +use crate::httputil::{BuildClientError, Client, HttpError, HttpUrl}; use crate::paths::{ParsePureDirPathError, PureDirPath, PurePath}; use crate::s3::{ BucketSpec, GetBucketRegionError, PrefixedS3Client, S3Client, S3Error, S3Location, @@ -20,7 +20,6 @@ use serde::de::DeserializeOwned; use smartstring::alias::CompactString; use std::sync::Arc; use thiserror::Error; -use url::Url; /// A client for fetching data about Dandisets, their versions, and their /// assets from a DANDI Archive instance @@ -30,7 +29,7 @@ pub(crate) struct DandiClient { inner: Client, /// The base API URL of the Archive instance - api_url: Url, + api_url: HttpUrl, /// A cache of [`S3Client`] instances that are used for listing Zarr /// entries on the Archive's S3 bucket. @@ -51,7 +50,7 @@ impl DandiClient { /// # Errors /// /// Returns an error if construction of the inner `reqwest::Client` fails - pub(crate) fn new(api_url: Url) -> Result { + pub(crate) fn new(api_url: HttpUrl) -> Result { let inner = Client::new()?; let s3clients = CacheBuilder::new(S3CLIENT_CACHE_SIZE) .name("s3clients") @@ -65,24 +64,26 @@ impl DandiClient { /// Return the URL formed by appending the given path segments and a /// trailing slash to the path of the API base URL - fn get_url(&self, segments: I) -> Url + fn get_url(&self, segments: I) -> HttpUrl where I: IntoIterator, I::Item: AsRef, { - urljoin_slashed(&self.api_url, segments) + let mut url = self.api_url.clone(); + url.extend(segments).ensure_dirpath(); + url } /// Perform a `GET` request to the given URL and return the deserialized /// JSON response body - async fn get(&self, url: Url) -> Result { + async fn get(&self, url: HttpUrl) -> Result { self.inner.get_json(url).await.map_err(Into::into) } /// Return a [`futures_util::Stream`] that makes paginated `GET` requests /// to the given URL and its subsequent pages and yields a `Result` value for each item deserialized from the responses - fn paginate(&self, url: Url) -> Paginate { + fn paginate(&self, url: HttpUrl) -> Paginate { Paginate::new(self, url) } @@ -158,7 +159,7 @@ impl DandiClient { /// Return the URL for the metadata for the given version of the given /// Dandiset - fn version_metadata_url(&self, dandiset_id: &DandisetId, version_id: &VersionId) -> Url { + fn version_metadata_url(&self, dandiset_id: &DandisetId, version_id: &VersionId) -> HttpUrl { self.get_url([ "dandisets", dandiset_id.as_ref(), @@ -396,7 +397,7 @@ impl<'a> VersionEndpoint<'a> { } /// Return the URL for the version's metadata - fn metadata_url(&self) -> Url { + fn metadata_url(&self) -> HttpUrl { self.client .version_metadata_url(&self.dandiset_id, &self.version_id) } @@ -421,7 +422,7 @@ impl<'a> VersionEndpoint<'a> { /// Return the URL for the metadata of the asset in this version with the /// given asset ID - fn asset_metadata_url(&self, asset_id: &str) -> Url { + fn asset_metadata_url(&self, asset_id: &str) -> HttpUrl { self.client.get_url([ "dandisets", self.dandiset_id.as_ref(), @@ -447,10 +448,9 @@ impl<'a> VersionEndpoint<'a> { self.version_id.as_ref(), "assets", ]); - url.query_pairs_mut() - .append_pair("path", path.as_ref()) - .append_pair("metadata", "1") - .append_pair("order", "path"); + url.append_query_param("path", path.as_ref()); + url.append_query_param("metadata", "1"); + url.append_query_param("order", "path"); let dirpath = path.to_dir_path(); let mut stream = self.client.paginate::(url.clone()); while let Some(asset) = stream.try_next().await? { @@ -480,8 +480,7 @@ impl<'a> VersionEndpoint<'a> { "paths", ]); if let Some(path) = path { - url.query_pairs_mut() - .append_pair("path_prefix", path.as_ref()); + url.append_query_param("path_prefix", path.as_ref()); } self.client.paginate(url) } diff --git a/src/dandi/streams.rs b/src/dandi/streams.rs index 624149f..c61caa3 100644 --- a/src/dandi/streams.rs +++ b/src/dandi/streams.rs @@ -1,12 +1,10 @@ -use super::types::Page; use super::{DandiClient, DandiError}; -use crate::httputil::{Client, HttpError}; +use crate::httputil::{Client, HttpError, HttpUrl}; use futures_util::{future::BoxFuture, FutureExt, Stream}; use pin_project::pin_project; -use serde::de::DeserializeOwned; +use serde::{de::DeserializeOwned, Deserialize}; use std::pin::Pin; use std::task::{ready, Context, Poll}; -use url::Url; // Implementing paginate() as a manually-implemented Stream instead of via // async_stream lets us save about 4700 bytes on dandidav's top-level Futures. @@ -21,13 +19,13 @@ enum PaginateState { Requesting(BoxFuture<'static, Result, HttpError>>), Yielding { results: std::vec::IntoIter, - next: Option, + next: Option, }, Done, } impl Paginate { - pub(super) fn new(client: &DandiClient, url: Url) -> Self { + pub(super) fn new(client: &DandiClient, url: HttpUrl) -> Self { Paginate { client: client.inner.clone(), state: PaginateState::Yielding { @@ -78,3 +76,9 @@ where } } } + +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +struct Page { + next: Option, + results: Vec, +} diff --git a/src/dandi/types.rs b/src/dandi/types.rs index 0b3c8d2..526519b 100644 --- a/src/dandi/types.rs +++ b/src/dandi/types.rs @@ -1,16 +1,10 @@ use super::{DandisetId, VersionId}; +use crate::httputil::HttpUrl; use crate::paths::{PureDirPath, PurePath}; use crate::s3::{PrefixedS3Client, S3Entry, S3Folder, S3Location, S3Object}; use serde::Deserialize; use thiserror::Error; use time::OffsetDateTime; -use url::Url; - -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -pub(super) struct Page { - pub(super) next: Option, - pub(super) results: Vec, -} #[derive(Clone, Debug, Deserialize, Eq, PartialEq)] pub(super) struct RawDandiset { @@ -67,7 +61,7 @@ pub(super) struct RawDandisetVersion { } impl RawDandisetVersion { - pub(super) fn with_metadata_url(self, metadata_url: Url) -> DandisetVersion { + pub(super) fn with_metadata_url(self, metadata_url: HttpUrl) -> DandisetVersion { DandisetVersion { version: self.version, size: self.size, @@ -84,7 +78,7 @@ pub(crate) struct DandisetVersion { pub(crate) size: i64, pub(crate) created: OffsetDateTime, pub(crate) modified: OffsetDateTime, - pub(crate) metadata_url: Url, + pub(crate) metadata_url: HttpUrl, } #[derive(Clone, Debug, Eq, PartialEq)] @@ -166,7 +160,7 @@ pub(crate) struct BlobAsset { pub(crate) created: OffsetDateTime, pub(crate) modified: OffsetDateTime, pub(crate) metadata: AssetMetadata, - pub(crate) metadata_url: Url, + pub(crate) metadata_url: HttpUrl, } impl BlobAsset { @@ -178,18 +172,18 @@ impl BlobAsset { self.metadata.digest.dandi_etag.as_deref() } - pub(crate) fn archive_url(&self) -> Option<&Url> { + pub(crate) fn archive_url(&self) -> Option<&HttpUrl> { self.metadata .content_url .iter() - .find(|url| S3Location::parse_url(url).is_err()) + .find(|url| S3Location::parse_url(url.as_url()).is_err()) } - pub(crate) fn s3_url(&self) -> Option<&Url> { + pub(crate) fn s3_url(&self) -> Option<&HttpUrl> { self.metadata .content_url .iter() - .find(|url| S3Location::parse_url(url).is_ok()) + .find(|url| S3Location::parse_url(url.as_url()).is_ok()) } } @@ -202,7 +196,7 @@ pub(crate) struct ZarrAsset { pub(crate) created: OffsetDateTime, pub(crate) modified: OffsetDateTime, pub(crate) metadata: AssetMetadata, - pub(crate) metadata_url: Url, + pub(crate) metadata_url: HttpUrl, } impl ZarrAsset { @@ -210,7 +204,7 @@ impl ZarrAsset { self.metadata .content_url .iter() - .find_map(|url| S3Location::parse_url(url).ok()) + .find_map(|url| S3Location::parse_url(url.as_url()).ok()) } pub(crate) fn make_resource(&self, value: S3Entry) -> DandiResource { @@ -246,7 +240,7 @@ impl ZarrAsset { #[serde(rename_all = "camelCase")] pub(crate) struct AssetMetadata { encoding_format: Option, - content_url: Vec, + content_url: Vec, digest: AssetDigests, } @@ -374,7 +368,7 @@ pub(crate) struct ZarrEntry { pub(crate) size: i64, pub(crate) modified: OffsetDateTime, pub(crate) etag: String, - pub(crate) url: Url, + pub(crate) url: HttpUrl, } #[derive(Clone, Debug)] diff --git a/src/dav/types.rs b/src/dav/types.rs index fb7d9a2..547d7cb 100644 --- a/src/dav/types.rs +++ b/src/dav/types.rs @@ -3,12 +3,12 @@ use super::xml::{PropValue, Property}; use super::VersionSpec; use crate::consts::{DEFAULT_CONTENT_TYPE, YAML_CONTENT_TYPE}; use crate::dandi::*; +use crate::httputil::HttpUrl; use crate::paths::{PureDirPath, PurePath}; use crate::zarrman::*; use enum_dispatch::enum_dispatch; use serde::{ser::Serializer, Serialize}; use time::OffsetDateTime; -use url::Url; /// Trait for querying the values of WebDAV properties from WebDAV resources /// @@ -282,7 +282,7 @@ pub(super) struct DavCollection { /// A URL for retrieving the resource's associated metadata (if any) from /// the Archive instance - pub(super) metadata_url: Option, + pub(super) metadata_url: Option, } impl DavCollection { @@ -552,7 +552,7 @@ pub(super) struct DavItem { /// A URL for retrieving the resource's associated metadata (if any) from /// the Archive instance - pub(super) metadata_url: Option, + pub(super) metadata_url: Option, } impl DavItem { @@ -727,11 +727,11 @@ pub(super) enum DavContent { #[derive(Clone, Debug, Eq, PartialEq)] pub(super) enum Redirect { /// A single URL to always redirect to - Direct(Url), + Direct(HttpUrl), /// An S3 URL and an Archive instance URL, to be selected between based on /// whether `--prefer-s3-redirects` was supplied at program invocation - Alt { s3: Url, archive: Url }, + Alt { s3: HttpUrl, archive: HttpUrl }, } impl Redirect { @@ -739,7 +739,7 @@ impl Redirect { /// /// If `prefer_s3` is `true`, `Alt` variants resolve to their `s3` field; /// otherwise, they resolve to their `archive` field. - pub(super) fn get_url(&self, prefer_s3: bool) -> &Url { + pub(super) fn get_url(&self, prefer_s3: bool) -> &HttpUrl { match self { Redirect::Direct(u) => u, Redirect::Alt { s3, archive } => { diff --git a/src/dav/util.rs b/src/dav/util.rs index 3bc11b6..73bf41d 100644 --- a/src/dav/util.rs +++ b/src/dav/util.rs @@ -1,6 +1,7 @@ use super::VersionSpec; use crate::consts::DAV_XML_CONTENT_TYPE; use crate::dandi::DandisetId; +use crate::httputil::HttpUrl; use crate::paths::PureDirPath; use axum::{ async_trait, @@ -129,14 +130,14 @@ impl AsRef for Href { } } -impl From for Href { - fn from(value: url::Url) -> Href { - Href(value.into()) +impl From for Href { + fn from(value: HttpUrl) -> Href { + Href(value.to_string()) } } -impl From<&url::Url> for Href { - fn from(value: &url::Url) -> Href { +impl From<&HttpUrl> for Href { + fn from(value: &HttpUrl) -> Href { Href(value.to_string()) } } diff --git a/src/dav/xml/multistatus.rs b/src/dav/xml/multistatus.rs index 1435102..d351ee3 100644 --- a/src/dav/xml/multistatus.rs +++ b/src/dav/xml/multistatus.rs @@ -153,6 +153,7 @@ pub(crate) enum ToXmlError { #[cfg(test)] mod tests { use super::*; + use crate::httputil::HttpUrl; use indoc::indoc; use pretty_assertions::assert_eq; @@ -222,7 +223,8 @@ mod tests { status: "HTTP/1.1 307 TEMPORARY REDIRECT".into(), }], location: Some( - url::Url::parse("https://www.example.com/data/quux.dat") + "https://www.example.com/data/quux.dat" + .parse::() .unwrap() .into(), ), diff --git a/src/httputil.rs b/src/httputil.rs index 137e7fe..65316cc 100644 --- a/src/httputil.rs +++ b/src/httputil.rs @@ -4,8 +4,13 @@ use crate::dav::ErrorClass; use reqwest::{Method, Request, Response, StatusCode}; use reqwest_middleware::{Middleware, Next}; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; -use serde::de::DeserializeOwned; +use serde::{ + de::{DeserializeOwned, Deserializer, Error as _}, + Deserialize, +}; +use std::fmt; use std::future::Future; +use std::str::FromStr; use thiserror::Error; use tracing::Instrument; use url::Url; @@ -43,10 +48,14 @@ impl Client { /// /// If sending the request fails or the response has a 4xx or 5xx status, /// an error is returned. - pub(crate) async fn request(&self, method: Method, url: Url) -> Result { + pub(crate) async fn request( + &self, + method: Method, + url: HttpUrl, + ) -> Result { let r = self .0 - .request(method, url.clone()) + .request(method, Url::from(url.clone())) .send() .await .map_err(|source| HttpError::Send { @@ -66,7 +75,7 @@ impl Client { /// /// If sending the request fails or the response has a 4xx or 5xx status, /// an error is returned. - pub(crate) async fn head(&self, url: Url) -> Result { + pub(crate) async fn head(&self, url: HttpUrl) -> Result { self.request(Method::HEAD, url).await } @@ -76,7 +85,7 @@ impl Client { /// /// If sending the request fails or the response has a 4xx or 5xx status, /// an error is returned. - pub(crate) async fn get(&self, url: Url) -> Result { + pub(crate) async fn get(&self, url: HttpUrl) -> Result { self.request(Method::GET, url).await } @@ -89,7 +98,7 @@ impl Client { /// deserialization of the response body fails, an error is returned. pub(crate) fn get_json( &self, - url: Url, + url: HttpUrl, ) -> impl Future> { // Clone the client and move it into an async block (as opposed to just // writing a "normal" async function) so that the resulting Future will @@ -147,21 +156,27 @@ pub(crate) enum HttpError { /// Sending the request failed #[error("failed to make request to {url}")] Send { - url: Url, + url: HttpUrl, source: reqwest_middleware::Error, }, /// The server returned a 404 response #[error("no such resource: {url}")] - NotFound { url: Url }, + NotFound { url: HttpUrl }, /// The server returned a 4xx or 5xx response other than 404 #[error("request to {url} returned error")] - Status { url: Url, source: reqwest::Error }, + Status { + url: HttpUrl, + source: reqwest::Error, + }, /// Deserializing the response body as JSON failed #[error("failed to deserialize response body from {url}")] - Deserialize { url: Url, source: reqwest::Error }, + Deserialize { + url: HttpUrl, + source: reqwest::Error, + }, } impl HttpError { @@ -174,128 +189,187 @@ impl HttpError { } } -/// Create a URL by extending `url`'s path with the path segments `segments`. -/// The resulting URL will not end with a slash (but see -/// [`urljoin_slashed()`]). -/// -/// If `url` does not end with a forward slash, one will be appended, and then -/// the segments will be added after that. -/// -/// # Panics -/// -/// Panics if `url` cannot be a base URL. (Note that HTTP(S) URLs can be base -/// URLs.) -pub(crate) fn urljoin(url: &Url, segments: I) -> Url -where - I: IntoIterator, - I::Item: AsRef, -{ - let mut url = url.clone(); - url.path_segments_mut() - .expect("URL should be able to be a base") - .pop_if_empty() - .extend(segments); - url -} +/// A wrapper around [`url::Url`] that enforces a scheme of "http" or "https" +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct HttpUrl(Url); -/// Create a URL by extending `url`'s path with the path segments `segments` -/// and then terminating the result with a forward slash. -/// -/// If `url` does not end with a forward slash, one will be appended, and then -/// the segments will be added after that. -/// -/// # Panics -/// -/// Panics if `url` cannot be a base URL. (Note that HTTP(S) URLs can be base -/// URLs.) -pub(crate) fn urljoin_slashed(url: &Url, segments: I) -> Url -where - I: IntoIterator, - I::Item: AsRef, -{ - let mut url = url.clone(); - url.path_segments_mut() - .expect("URL should be able to be a base") - .pop_if_empty() - .extend(segments) - // Add an empty segment so that the final URL will end with a slash: - .push(""); - url -} +impl HttpUrl { + /// Return the URL as a string + pub(crate) fn as_str(&self) -> &str { + self.0.as_str() + } -#[cfg(test)] -mod tests { - use super::*; + /// Return a reference to the underlying [`url::Url`] + pub(crate) fn as_url(&self) -> &Url { + &self.0 + } - mod urljoin { - use super::*; - use rstest::rstest; - - #[rstest] - #[case("https://api.github.com")] - #[case("https://api.github.com/")] - fn nopath(#[case] base: Url) { - let u = urljoin(&base, ["foo"]); - assert_eq!(u.as_str(), "https://api.github.com/foo"); - let u = urljoin(&base, ["foo", "bar"]); - assert_eq!(u.as_str(), "https://api.github.com/foo/bar"); + /// Append the given path segment to this URL's path component. + /// + /// If the URL does not end with a forward slash, one will be appended, and + /// then the segment will be added after that. + pub(crate) fn push>(&mut self, segment: S) -> &mut Self { + { + let Ok(mut ps) = self.0.path_segments_mut() else { + unreachable!("HTTP(S) URLs should always be able to be a base"); + }; + ps.pop_if_empty().push(segment.as_ref()); } + self + } - #[rstest] - #[case("https://api.github.com/foo/bar")] - #[case("https://api.github.com/foo/bar/")] - fn path(#[case] base: Url) { - let u = urljoin(&base, ["gnusto"]); - assert_eq!(u.as_str(), "https://api.github.com/foo/bar/gnusto"); - let u = urljoin(&base, ["gnusto", "cleesh"]); - assert_eq!(u.as_str(), "https://api.github.com/foo/bar/gnusto/cleesh"); + /// Append the given path segments to this URL's path component. + /// + /// If the URL does not end with a forward slash, one will be appended, and + /// then the segments will be added after that. + pub(crate) fn extend(&mut self, segments: I) -> &mut Self + where + I: IntoIterator, + I::Item: AsRef, + { + { + let Ok(mut ps) = self.0.path_segments_mut() else { + unreachable!("HTTP(S) URLs should always be able to be a base"); + }; + ps.pop_if_empty().extend(segments); } + self + } - #[rstest] - #[case("foo#bar", "https://api.github.com/base/foo%23bar")] - #[case("foo%bar", "https://api.github.com/base/foo%25bar")] - #[case("foo/bar", "https://api.github.com/base/foo%2Fbar")] - #[case("foo?bar", "https://api.github.com/base/foo%3Fbar")] - fn special_chars(#[case] path: &str, #[case] expected: &str) { - let base = Url::parse("https://api.github.com/base").unwrap(); - let u = urljoin(&base, [path]); - assert_eq!(u.as_str(), expected); + /// Append a trailing forward slash to the URL if it does not already end + /// with one + pub(crate) fn ensure_dirpath(&mut self) -> &mut Self { + { + let Ok(mut ps) = self.0.path_segments_mut() else { + unreachable!("HTTP(S) URLs should always be able to be a base"); + }; + ps.pop_if_empty().push(""); } + self } - mod urljoin_slashed { - use super::*; - use rstest::rstest; - - #[rstest] - #[case("https://api.github.com")] - #[case("https://api.github.com/")] - fn nopath(#[case] base: Url) { - let u = urljoin_slashed(&base, ["foo"]); - assert_eq!(u.as_str(), "https://api.github.com/foo/"); - let u = urljoin_slashed(&base, ["foo", "bar"]); - assert_eq!(u.as_str(), "https://api.github.com/foo/bar/"); - } + /// Append `"{key}={value}"` (after percent-encoding) to the URL's query + /// parameters + pub(crate) fn append_query_param(&mut self, key: &str, value: &str) -> &mut Self { + self.0.query_pairs_mut().append_pair(key, value); + self + } +} + +impl From for Url { + fn from(value: HttpUrl) -> Url { + value.0 + } +} + +impl fmt::Display for HttpUrl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl FromStr for HttpUrl { + type Err = ParseHttpUrlError; - #[rstest] - #[case("https://api.github.com/foo/bar")] - #[case("https://api.github.com/foo/bar/")] - fn path(#[case] base: Url) { - let u = urljoin_slashed(&base, ["gnusto"]); - assert_eq!(u.as_str(), "https://api.github.com/foo/bar/gnusto/"); - let u = urljoin_slashed(&base, ["gnusto", "cleesh"]); - assert_eq!(u.as_str(), "https://api.github.com/foo/bar/gnusto/cleesh/"); + fn from_str(s: &str) -> Result { + let url = s.parse::()?; + if matches!(url.scheme(), "http" | "https") { + Ok(HttpUrl(url)) + } else { + Err(ParseHttpUrlError::BadScheme) } + } +} - #[rstest] - #[case("foo#bar", "https://api.github.com/base/foo%23bar/")] - #[case("foo%bar", "https://api.github.com/base/foo%25bar/")] - #[case("foo/bar", "https://api.github.com/base/foo%2Fbar/")] - #[case("foo?bar", "https://api.github.com/base/foo%3Fbar/")] - fn special_chars(#[case] path: &str, #[case] expected: &str) { - let base = Url::parse("https://api.github.com/base").unwrap(); - let u = urljoin_slashed(&base, [path]); - assert_eq!(u.as_str(), expected); +impl<'de> Deserialize<'de> for HttpUrl { + fn deserialize>(deserializer: D) -> Result { + let url = Url::deserialize(deserializer)?; + if matches!(url.scheme(), "http" | "https") { + Ok(HttpUrl(url)) + } else { + Err(D::Error::custom("expected URL with HTTP(S) scheme")) } } } + +/// Error returned by [`HttpUrl`]'s `FromStr` implementation +#[derive(Clone, Copy, Debug, Eq, Error, PartialEq)] +pub(crate) enum ParseHttpUrlError { + /// The string was a valid URL, but the scheme was neither HTTP nor HTTPS + #[error(r#"URL scheme must be "http" or "https""#)] + BadScheme, + + /// The string was not a valid URL + #[error(transparent)] + Url(#[from] url::ParseError), +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[rstest] + #[case("foo#bar", "https://api.github.com/base/foo%23bar")] + #[case("foo%bar", "https://api.github.com/base/foo%25bar")] + #[case("foo/bar", "https://api.github.com/base/foo%2Fbar")] + #[case("foo?bar", "https://api.github.com/base/foo%3Fbar")] + fn push_special_chars(#[case] path: &str, #[case] expected: &str) { + let mut base = "https://api.github.com/base".parse::().unwrap(); + base.push(path); + assert_eq!(base.as_str(), expected); + } + + #[rstest] + #[case(&["foo"], "https://api.github.com/foo")] + #[case(&["foo", "bar"], "https://api.github.com/foo/bar")] + fn extend_nopath( + #[values("https://api.github.com", "https://api.github.com/")] mut base: HttpUrl, + #[case] segments: &[&str], + #[case] expected: &str, + ) { + base.extend(segments); + assert_eq!(base.as_str(), expected); + } + + #[rstest] + #[case(&["gnusto"], "https://api.github.com/foo/bar/gnusto")] + #[case(&["gnusto", "cleesh"], "https://api.github.com/foo/bar/gnusto/cleesh")] + fn extend_path( + #[values("https://api.github.com/foo/bar", "https://api.github.com/foo/bar/")] + mut base: HttpUrl, + #[case] segments: &[&str], + #[case] expected: &str, + ) { + base.extend(segments); + assert_eq!(base.as_str(), expected); + } + + #[rstest] + #[case("https://api.github.com", "https://api.github.com/")] + #[case("https://api.github.com/", "https://api.github.com/")] + #[case("https://api.github.com/foo", "https://api.github.com/foo/")] + #[case("https://api.github.com/foo/", "https://api.github.com/foo/")] + fn ensure_dirpath(#[case] mut before: HttpUrl, #[case] after: &str) { + before.ensure_dirpath(); + assert_eq!(before.as_str(), after); + } + + #[test] + fn append_query_param() { + let mut url = "https://api.github.com/foo".parse::().unwrap(); + assert_eq!(url.as_str(), "https://api.github.com/foo"); + url.append_query_param("bar", "baz"); + assert_eq!(url.as_str(), "https://api.github.com/foo?bar=baz"); + url.append_query_param("quux", "with space"); + assert_eq!( + url.as_str(), + "https://api.github.com/foo?bar=baz&quux=with+space" + ); + url.append_query_param("bar", "rod"); + assert_eq!( + url.as_str(), + "https://api.github.com/foo?bar=baz&quux=with+space&bar=rod" + ); + } +} diff --git a/src/main.rs b/src/main.rs index f20bf36..19b5e68 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ mod zarrman; use crate::consts::{CSS_CONTENT_TYPE, DEFAULT_API_URL, SERVER_VALUE}; use crate::dandi::DandiClient; use crate::dav::{DandiDav, Templater}; +use crate::httputil::HttpUrl; use crate::zarrman::ZarrManClient; use anyhow::Context; use axum::{ @@ -47,7 +48,7 @@ static STYLESHEET: &str = include_str!("dav/static/styles.css"); struct Arguments { /// API URL of the DANDI Archive instance to serve #[arg(long, default_value = DEFAULT_API_URL, value_name = "URL")] - api_url: url::Url, + api_url: HttpUrl, /// IP address to listen on #[arg(long, default_value = "127.0.0.1")] diff --git a/src/s3/mod.rs b/src/s3/mod.rs index a757056..554469b 100644 --- a/src/s3/mod.rs +++ b/src/s3/mod.rs @@ -2,7 +2,7 @@ mod streams; use self::streams::ListEntryPages; use crate::dav::ErrorClass; -use crate::httputil::{self, BuildClientError, HttpError}; +use crate::httputil::{self, BuildClientError, HttpError, HttpUrl, ParseHttpUrlError}; use crate::paths::{ParsePureDirPathError, ParsePurePathError, PureDirPath, PurePath}; use crate::streamutil::TryStreamUtil; use crate::validstr::TryFromStringError; @@ -320,7 +320,7 @@ pub(crate) struct S3Object { pub(crate) modified: OffsetDateTime, pub(crate) size: i64, pub(crate) etag: String, - pub(crate) download_url: Url, + pub(crate) download_url: HttpUrl, } impl S3Object { @@ -341,14 +341,12 @@ impl S3Object { return Err(TryFromAwsObjectError::NoSize { key }); }; let keypath = PurePath::try_from(key.clone()).map_err(TryFromAwsObjectError::BadKey)?; - let mut download_url = Url::parse(&format!("https://{bucket}.s3.amazonaws.com")) + let mut download_url = format!("https://{bucket}.s3.amazonaws.com") + .parse::() .expect("bucket should be a valid hostname component"); // Adding the key this way is necessary in order for URL-unsafe // characters to be percent-encoded: - download_url - .path_segments_mut() - .expect("HTTPS URL should be able to be a base") - .extend(key.split('/')); + download_url.extend(key.split('/')); let modified = modified .to_time() .map_err(|source| TryFromAwsObjectError::BadModified { @@ -443,7 +441,7 @@ pub(crate) enum TryFromAwsObjectError { pub(crate) async fn get_bucket_region(bucket: &str) -> Result { let url_str = format!("https://{bucket}.s3.amazonaws.com"); let url = url_str - .parse::() + .parse::() .map_err(|source| GetBucketRegionError::BadUrl { url: url_str, source, @@ -466,7 +464,7 @@ pub(crate) enum GetBucketRegionError { #[error("URL constructed for bucket is invalid: {url:?}")] BadUrl { url: String, - source: url::ParseError, + source: ParseHttpUrlError, }, #[error("S3 response lacked x-amz-bucket-region header")] NoHeader, diff --git a/src/zarrman/mod.rs b/src/zarrman/mod.rs index 674aae0..769cff8 100644 --- a/src/zarrman/mod.rs +++ b/src/zarrman/mod.rs @@ -18,13 +18,12 @@ mod resources; use self::path::ReqPath; pub(crate) use self::resources::*; use crate::dav::ErrorClass; -use crate::httputil::{urljoin, urljoin_slashed, BuildClientError, Client, HttpError}; +use crate::httputil::{BuildClientError, Client, HttpError, HttpUrl}; use crate::paths::{Component, PureDirPath, PurePath}; use moka::future::{Cache, CacheBuilder}; use serde::Deserialize; use std::sync::Arc; use thiserror::Error; -use url::Url; /// The manifest root URL. /// @@ -57,11 +56,11 @@ pub(crate) struct ZarrManClient { /// `MANIFEST_ROOT_URL` manifests: Cache>, - /// [`MANIFEST_ROOT_URL`], parsed into a [`url::Url`] - manifest_root_url: Url, + /// [`MANIFEST_ROOT_URL`], parsed into an [`HttpUrl`] + manifest_root_url: HttpUrl, - /// [`ENTRY_DOWNLOAD_PREFIX`], parsed into a [`url::Url`] - entry_download_prefix: Url, + /// [`ENTRY_DOWNLOAD_PREFIX`], parsed into an [`HttpUrl`] + entry_download_prefix: HttpUrl, /// The directory path `"zarrs/"`, used at various points in the code, /// pre-parsed for convenience @@ -79,13 +78,15 @@ impl ZarrManClient { let manifests = CacheBuilder::new(MANIFEST_CACHE_SIZE) .name("zarr-manifests") .build(); - let manifest_root_url = - Url::parse(MANIFEST_ROOT_URL).expect("MANIFEST_ROOT_URL should be a valid URL"); - let entry_download_prefix = - Url::parse(ENTRY_DOWNLOAD_PREFIX).expect("ENTRY_DOWNLOAD_PREFIX should be a valid URL"); + let manifest_root_url = MANIFEST_ROOT_URL + .parse::() + .expect("MANIFEST_ROOT_URL should be a valid HTTP URL"); + let entry_download_prefix = ENTRY_DOWNLOAD_PREFIX + .parse::() + .expect("ENTRY_DOWNLOAD_PREFIX should be a valid HTTP URL"); let web_path_prefix = "zarrs/" .parse::() - .expect(r#""zarrs/" should be a valid URL"#); + .expect(r#""zarrs/" should be a valid directory path"#); Ok(ZarrManClient { inner, manifests, @@ -218,10 +219,10 @@ impl ZarrManClient { &self, path: Option<&PureDirPath>, ) -> Result, ZarrManError> { - let url = match path { - Some(p) => urljoin_slashed(&self.manifest_root_url, p.component_strs()), - None => self.manifest_root_url.clone(), - }; + let mut url = self.manifest_root_url.clone(); + if let Some(p) = path { + url.extend(p.component_strs()).ensure_dirpath(); + } let index = self.inner.get_json::(url).await?; let mut entries = Vec::with_capacity(index.files.len().saturating_add(index.directories.len())); @@ -290,12 +291,10 @@ impl ZarrManClient { entry: &manifest::ManifestEntry, ) -> ManifestEntry { let web_path = manifest_path.to_web_path().join(entry_path); - let mut url = urljoin( - &self.entry_download_prefix, - std::iter::once(manifest_path.zarr_id()).chain(entry_path.component_strs()), - ); - url.query_pairs_mut() - .append_pair("versionId", &entry.version_id); + let mut url = self.entry_download_prefix.clone(); + url.push(manifest_path.zarr_id()); + url.extend(entry_path.component_strs()); + url.append_query_param("versionId", &entry.version_id); ManifestEntry { web_path, size: entry.size, diff --git a/src/zarrman/resources.rs b/src/zarrman/resources.rs index 5a18f2a..b793305 100644 --- a/src/zarrman/resources.rs +++ b/src/zarrman/resources.rs @@ -1,9 +1,7 @@ -use crate::httputil::urljoin; +use crate::httputil::HttpUrl; use crate::paths::{Component, PureDirPath, PurePath}; -use std::borrow::Cow; use std::fmt; use time::OffsetDateTime; -use url::Url; /// A resource served under `dandidav`'s `/zarrs/` hierarchy, not including /// information on child resources @@ -66,18 +64,12 @@ impl ManifestPath { } /// Returns the URL of the Zarr manifest underneath the given manifest root - pub(crate) fn under_manifest_root(&self, manifest_root_url: &Url) -> Url { - urljoin( - manifest_root_url, - self.prefix - .component_strs() - .map(Cow::from) - .chain(std::iter::once(Cow::from(&*self.zarr_id))) - .chain(std::iter::once(Cow::from(format!( - "{}.json", - self.checksum - )))), - ) + pub(crate) fn under_manifest_root(&self, manifest_root_url: &HttpUrl) -> HttpUrl { + let mut url = manifest_root_url.clone(); + url.extend(self.prefix.component_strs()); + url.push(&self.zarr_id); + url.push(format!("{}.json", self.checksum)); + url } } @@ -135,7 +127,7 @@ pub(crate) struct ManifestEntry { pub(crate) etag: String, /// The download URL for the entry - pub(crate) url: Url, + pub(crate) url: HttpUrl, } #[cfg(test)]