diff --git a/Cargo.toml b/Cargo.toml index 475f731..409d662 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,11 @@ repository = "https://github.com/hyperium/hyper-util" license = "MIT" authors = ["Sean McArthur "] keywords = ["http", "hyper", "hyperium"] -categories = ["network-programming", "web-programming::http-client", "web-programming::http-server"] +categories = [ + "network-programming", + "web-programming::http-client", + "web-programming::http-server", +] edition = "2018" publish = false # no accidents while in dev diff --git a/src/client/mod.rs b/src/client/mod.rs index 37d33fc..0237511 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -6,3 +6,4 @@ pub mod connect; pub mod legacy; #[doc(hidden)] pub mod pool; +pub mod services; diff --git a/src/client/services/http1_request_target.rs b/src/client/services/http1_request_target.rs new file mode 100644 index 0000000..5f6c11d --- /dev/null +++ b/src/client/services/http1_request_target.rs @@ -0,0 +1,84 @@ +use http::{uri::Scheme, Method, Request, Uri}; +use hyper::service::Service; +use tracing::warn; + +pub struct Http1RequestTarget { + inner: S, + is_proxied: bool, +} + +impl Http1RequestTarget { + pub fn new(inner: S, is_proxied: bool) -> Self { + Self { inner, is_proxied } + } +} + +impl Service> for Http1RequestTarget +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn call(&self, mut req: Request) -> Self::Future { + // CONNECT always sends authority-form, so check it first... + if req.method() == Method::CONNECT { + authority_form(req.uri_mut()); + } else if self.is_proxied { + absolute_form(req.uri_mut()); + } else { + origin_form(req.uri_mut()); + } + self.inner.call(req) + } +} + +fn origin_form(uri: &mut Uri) { + let path = match uri.path_and_query() { + Some(path) if path.as_str() != "/" => { + let mut parts = ::http::uri::Parts::default(); + parts.path_and_query = Some(path.clone()); + Uri::from_parts(parts).expect("path is valid uri") + } + _none_or_just_slash => { + debug_assert!(Uri::default() == "/"); + Uri::default() + } + }; + *uri = path +} + +fn absolute_form(uri: &mut Uri) { + debug_assert!(uri.scheme().is_some(), "absolute_form needs a scheme"); + debug_assert!( + uri.authority().is_some(), + "absolute_form needs an authority" + ); + // If the URI is to HTTPS, and the connector claimed to be a proxy, + // then it *should* have tunneled, and so we don't want to send + // absolute-form in that case. + if uri.scheme() == Some(&Scheme::HTTPS) { + origin_form(uri); + } +} + +fn authority_form(uri: &mut Uri) { + if let Some(path) = uri.path_and_query() { + // `https://hyper.rs` would parse with `/` path, don't + // annoy people about that... + if path != "/" { + warn!("HTTP/1.1 CONNECT request stripping path: {:?}", path); + } + } + *uri = match uri.authority() { + Some(auth) => { + let mut parts = ::http::uri::Parts::default(); + parts.authority = Some(auth.clone()); + Uri::from_parts(parts).expect("authority is valid") + } + None => { + unreachable!("authority_form with relative uri"); + } + }; +} diff --git a/src/client/services/mod.rs b/src/client/services/mod.rs new file mode 100644 index 0000000..4a0d335 --- /dev/null +++ b/src/client/services/mod.rs @@ -0,0 +1,5 @@ +mod http1_request_target; +mod set_host; + +pub use http1_request_target::Http1RequestTarget; +pub use set_host::SetHost; diff --git a/src/client/services/set_host.rs b/src/client/services/set_host.rs new file mode 100644 index 0000000..7ce67a6 --- /dev/null +++ b/src/client/services/set_host.rs @@ -0,0 +1,50 @@ +use http::{header::HOST, uri::Port, HeaderValue, Request, Uri}; +use hyper::service::Service; + +pub struct SetHost { + inner: S, +} + +impl SetHost { + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service> for SetHost +where + S: Service>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn call(&self, mut req: Request) -> Self::Future { + let uri = req.uri().clone(); + req.headers_mut().entry(HOST).or_insert_with(|| { + let hostname = uri.host().expect("authority implies host"); + if let Some(port) = get_non_default_port(&uri) { + let s = format!("{}:{}", hostname, port); + HeaderValue::from_str(&s) + } else { + HeaderValue::from_str(hostname) + } + .expect("uri host is valid header value") + }); + self.inner.call(req) + } +} + +fn get_non_default_port(uri: &Uri) -> Option> { + match (uri.port().map(|p| p.as_u16()), is_schema_secure(uri)) { + (Some(443), true) => None, + (Some(80), false) => None, + _ => uri.port(), + } +} + +fn is_schema_secure(uri: &Uri) -> bool { + uri.scheme_str() + .map(|scheme_str| matches!(scheme_str, "wss" | "https")) + .unwrap_or_default() +}