Skip to content

Commit

Permalink
feat: add {http1,http2}_only for auto conn (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
dswij authored May 24, 2024
1 parent 4b24573 commit 1635bcc
Showing 1 changed file with 116 additions and 12 deletions.
128 changes: 116 additions & 12 deletions src/server/conn/auto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub struct Builder<E> {
http1: http1::Builder,
#[cfg(feature = "http2")]
http2: http2::Builder<E>,
#[cfg(any(feature = "http1", feature = "http2"))]
version: Option<Version>,
#[cfg(not(feature = "http2"))]
_executor: E,
}
Expand All @@ -84,6 +86,8 @@ impl<E> Builder<E> {
http1: http1::Builder::new(),
#[cfg(feature = "http2")]
http2: http2::Builder::new(executor),
#[cfg(any(feature = "http1", feature = "http2"))]
version: None,
#[cfg(not(feature = "http2"))]
_executor: executor,
}
Expand All @@ -101,6 +105,26 @@ impl<E> Builder<E> {
Http2Builder { inner: self }
}

/// Only accepts HTTP/2
///
/// Does not do anything if used with [`serve_connection_with_upgrades`]
#[cfg(feature = "http2")]
pub fn http2_only(mut self) -> Self {
assert!(self.version.is_none());
self.version = Some(Version::H2);
self
}

/// Only accepts HTTP/1
///
/// Does not do anything if used with [`serve_connection_with_upgrades`]
#[cfg(feature = "http1")]
pub fn http1_only(mut self) -> Self {
assert!(self.version.is_none());
self.version = Some(Version::H1);
self
}

/// Bind a connection together with a [`Service`].
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
where
Expand All @@ -112,13 +136,28 @@ impl<E> Builder<E> {
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
Connection {
state: ConnState::ReadVersion {
let state = match self.version {
#[cfg(feature = "http1")]
Some(Version::H1) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http1.serve_connection(io, service);
ConnState::H1 { conn }
}
#[cfg(feature = "http2")]
Some(Version::H2) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http2.serve_connection(io, service);
ConnState::H2 { conn }
}
#[cfg(any(feature = "http1", feature = "http2"))]
_ => ConnState::ReadVersion {
read_version: read_version(io),
builder: self,
service: Some(service),
},
}
};

Connection { state }
}

/// Bind a connection together with a [`Service`], with the ability to
Expand Down Expand Up @@ -148,7 +187,7 @@ impl<E> Builder<E> {
}
}

#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
enum Version {
H1,
H2,
Expand Down Expand Up @@ -906,7 +945,7 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http1() {
let addr = start_server().await;
let addr = start_server(false, false).await;
let mut sender = connect_h1(addr).await;

let response = sender
Expand All @@ -922,7 +961,23 @@ mod tests {
#[cfg(not(miri))]
#[tokio::test]
async fn http2() {
let addr = start_server().await;
let addr = start_server(false, false).await;
let mut sender = connect_h2(addr).await;

let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();

let body = response.into_body().collect().await.unwrap().to_bytes();

assert_eq!(body, BODY);
}

#[cfg(not(miri))]
#[tokio::test]
async fn http2_only() {
let addr = start_server(false, true).await;
let mut sender = connect_h2(addr).await;

let response = sender
Expand All @@ -935,6 +990,46 @@ mod tests {
assert_eq!(body, BODY);
}

#[cfg(not(miri))]
#[tokio::test]
async fn http2_only_fail_if_client_is_http1() {
let addr = start_server(false, true).await;
let mut sender = connect_h1(addr).await;

let _ = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.expect_err("should fail");
}

#[cfg(not(miri))]
#[tokio::test]
async fn http1_only() {
let addr = start_server(true, false).await;
let mut sender = connect_h1(addr).await;

let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();

let body = response.into_body().collect().await.unwrap().to_bytes();

assert_eq!(body, BODY);
}

#[cfg(not(miri))]
#[tokio::test]
async fn http1_only_fail_if_client_is_http2() {
let addr = start_server(true, false).await;
let mut sender = connect_h2(addr).await;

let _ = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.expect_err("should fail");
}

#[cfg(not(miri))]
#[tokio::test]
async fn graceful_shutdown() {
Expand Down Expand Up @@ -1000,7 +1095,7 @@ mod tests {
sender
}

async fn start_server() -> SocketAddr {
async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr {
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
let listener = TcpListener::bind(addr).await.unwrap();

Expand All @@ -1011,11 +1106,20 @@ mod tests {
let (stream, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(stream);
tokio::task::spawn(async move {
let _ = auto::Builder::new(TokioExecutor::new())
.http2()
.max_header_list_size(4096)
.serve_connection_with_upgrades(stream, service_fn(hello))
.await;
let mut builder = auto::Builder::new(TokioExecutor::new());
if h1_only {
builder = builder.http1_only();
builder.serve_connection(stream, service_fn(hello)).await;
} else if h2_only {
builder = builder.http2_only();
builder.serve_connection(stream, service_fn(hello)).await;
} else {
builder
.http2()
.max_header_list_size(4096)
.serve_connection_with_upgrades(stream, service_fn(hello))
.await;
}
});
}
});
Expand Down

0 comments on commit 1635bcc

Please sign in to comment.