diff --git a/Cargo.lock b/Cargo.lock index 36c12fa1b55..b8dc31c0af3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12958,8 +12958,12 @@ dependencies = [ name = "ic-xnet-endpoint" version = "0.9.0" dependencies = [ + "axum 0.7.5", "bytes", - "hyper 0.14.29", + "crossbeam-channel", + "hyper 1.4.1", + "hyper-util", + "ic-async-utils", "ic-crypto-tls-interfaces", "ic-crypto-tls-interfaces-mocks", "ic-interfaces-certified-stream-store", @@ -12977,7 +12981,6 @@ dependencies = [ "ic-test-utilities-metrics", "ic-test-utilities-types", "ic-types", - "ic-xnet-hyper", "maplit", "prometheus", "prost", @@ -12986,6 +12989,8 @@ dependencies = [ "serde_json", "slog", "tokio", + "tokio-rustls 0.26.0", + "tower", "url", ] diff --git a/rs/async_utils/src/lib.rs b/rs/async_utils/src/lib.rs index 963c771ae85..af016630c64 100644 --- a/rs/async_utils/src/lib.rs +++ b/rs/async_utils/src/lib.rs @@ -48,7 +48,11 @@ pub async fn shutdown_signal(log: Logger) { /// Recommended way of starting a TCP listener given a socket addr. The function /// will panic if it cannot start the listener, because the OS error can't be /// handled by the caller. -pub fn start_tcp_listener(local_addr: std::net::SocketAddr) -> tokio::net::TcpListener { +pub fn start_tcp_listener( + local_addr: std::net::SocketAddr, + runtime_handle: &tokio::runtime::Handle, +) -> tokio::net::TcpListener { + let _enter = runtime_handle.enter(); let err_msg = format!("Could not start TCP listener at addr = {}", local_addr); let socket = if local_addr.is_ipv6() { tokio::net::TcpSocket::new_v6().expect(&err_msg) diff --git a/rs/http_endpoints/metrics/src/lib.rs b/rs/http_endpoints/metrics/src/lib.rs index aa04787fc77..82580898467 100644 --- a/rs/http_endpoints/metrics/src/lib.rs +++ b/rs/http_endpoints/metrics/src/lib.rs @@ -149,11 +149,11 @@ impl MetricsHttpEndpoint { fn start_http(&self, address: SocketAddr) { // we need to enter the tokio context in order to create the timeout layer and the tcp // socket - let _enter = self.rt_handle.enter(); let mut addr = "[::]:9090".parse::().unwrap(); addr.set_port(address.port()); - let tcp_listener = start_tcp_listener(addr); + let tcp_listener = start_tcp_listener(addr, &self.rt_handle); + let _enter: tokio::runtime::EnterGuard = self.rt_handle.enter(); let metrics_service = get(metrics_endpoint) .layer( ServiceBuilder::new() diff --git a/rs/http_endpoints/public/src/lib.rs b/rs/http_endpoints/public/src/lib.rs index a2c453db44d..2922c07b1a5 100644 --- a/rs/http_endpoints/public/src/lib.rs +++ b/rs/http_endpoints/public/src/lib.rs @@ -304,14 +304,13 @@ pub fn start_server( let listen_addr = config.listen_addr; info!(log, "Starting HTTP server..."); - let _enter = rt_handle.enter(); // TODO(OR4-60): temporarily listen on [::] so that we accept both IPv4 and // IPv6 connections. This requires net.ipv6.bindv6only = 0. Revert this once // we have rolled out IPv6 in prometheus and ic_p8s_service_discovery. let mut addr = "[::]:8080".parse::().unwrap(); addr.set_port(listen_addr.port()); - let tcp_listener = start_tcp_listener(addr); - + let tcp_listener = start_tcp_listener(addr, &rt_handle); + let _enter = rt_handle.enter(); if !AtomicCell::::is_lock_free() { error!(log, "Replica health status uses locks instead of atomics."); } diff --git a/rs/replica/src/setup_ic_stack.rs b/rs/replica/src/setup_ic_stack.rs index 435c6044bf6..a7d0a92c6ee 100755 --- a/rs/replica/src/setup_ic_stack.rs +++ b/rs/replica/src/setup_ic_stack.rs @@ -254,7 +254,7 @@ pub fn construct_ic_stack( let message_router = Arc::new(message_router); let xnet_config = XNetEndpointConfig::from(Arc::clone(®istry) as Arc<_>, node_id, log); let xnet_endpoint = XNetEndpoint::new( - rt_handle_xnet.clone(), + rt_handle_http.clone(), Arc::clone(&certified_stream_store), Arc::clone(&crypto) as Arc<_>, registry.clone(), diff --git a/rs/xnet/endpoint/BUILD.bazel b/rs/xnet/endpoint/BUILD.bazel index 7312962db0e..fee7c0054f5 100644 --- a/rs/xnet/endpoint/BUILD.bazel +++ b/rs/xnet/endpoint/BUILD.bazel @@ -4,6 +4,7 @@ package(default_visibility = ["//visibility:public"]) DEPENDENCIES = [ # Keep sorted. + "//rs/async_utils", "//rs/crypto/tls_interfaces", "//rs/interfaces/certified_stream_store", "//rs/interfaces/registry", @@ -12,13 +13,17 @@ DEPENDENCIES = [ "//rs/protobuf", "//rs/registry/helpers", "//rs/types/types", - "//rs/xnet/hyper", - "@crate_index//:hyper_0_14_27", + "@crate_index//:axum", + "@crate_index//:crossbeam-channel", + "@crate_index//:hyper", + "@crate_index//:hyper-util", "@crate_index//:prometheus", "@crate_index//:serde", "@crate_index//:serde_json", "@crate_index//:slog", "@crate_index//:tokio", + "@crate_index//:tokio-rustls", + "@crate_index//:tower", "@crate_index//:url", ] diff --git a/rs/xnet/endpoint/Cargo.toml b/rs/xnet/endpoint/Cargo.toml index 3fae649c1fb..81c8949664d 100644 --- a/rs/xnet/endpoint/Cargo.toml +++ b/rs/xnet/endpoint/Cargo.toml @@ -7,7 +7,11 @@ description.workspace = true documentation.workspace = true [dependencies] -hyper = { version = "0.14.18", features = ["full", "tcp"] } +axum = { workspace = true } +hyper = { workspace = true } +hyper-util = { workspace = true } +crossbeam-channel = { workspace = true } +ic-async-utils = { path = "../../async_utils" } ic-crypto-tls-interfaces = { path = "../../crypto/tls_interfaces" } ic-interfaces-certified-stream-store = { path = "../../interfaces/certified_stream_store" } ic-interfaces-registry = { path = "../../interfaces/registry" } @@ -16,12 +20,13 @@ ic-metrics = { path = "../../monitoring/metrics" } ic-protobuf = { path = "../../protobuf" } ic-registry-client-helpers = { path = "../../registry/helpers" } ic-types = { path = "../../types/types" } -ic-xnet-hyper = { path = "../hyper" } prometheus = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } slog = { workspace = true } tokio = { workspace = true } +tokio-rustls = { workspace = true } +tower = { workspace = true } url = { workspace = true } [dev-dependencies] diff --git a/rs/xnet/endpoint/src/lib.rs b/rs/xnet/endpoint/src/lib.rs index 3ec0984d65b..6991ff6307d 100644 --- a/rs/xnet/endpoint/src/lib.rs +++ b/rs/xnet/endpoint/src/lib.rs @@ -3,11 +3,14 @@ mod config_tests; #[cfg(test)] mod tests; -use hyper::{Body, Request, Response, StatusCode}; +use axum::{body::Body, extract::State, response::IntoResponse, routing::any}; +use hyper::{body::Incoming, Request, Response, StatusCode}; +use hyper_util::{rt::TokioIo, server::graceful::GracefulShutdown}; +use ic_async_utils::start_tcp_listener; use ic_crypto_tls_interfaces::TlsConfig; use ic_interfaces_certified_stream_store::{CertifiedStreamStore, EncodeStreamError}; use ic_interfaces_registry::RegistryClient; -use ic_logger::{debug, info, warn, ReplicaLogger}; +use ic_logger::{info, warn, ReplicaLogger}; use ic_metrics::{buckets::decimal_buckets, MetricsRegistry}; use ic_protobuf::messaging::xnet::v1 as pb; use ic_protobuf::proxy::ProtoProxy; @@ -21,9 +24,10 @@ use std::str::FromStr; use std::sync::Arc; use std::time::Instant; use tokio::{ - runtime, + runtime, select, sync::{Notify, Semaphore}, }; +use tower::Service; use url::Url; pub struct XNetEndpointMetrics { @@ -106,6 +110,171 @@ impl Drop for XNetEndpoint { const API_URL_STREAMS: &str = "/api/v1/streams"; const API_URL_STREAM_PREFIX: &str = "/api/v1/stream/"; +/// Struct passed to each request handled by `enqueue_task`. +#[derive(Clone)] +struct Context { + log: ReplicaLogger, + semaphore: Arc, + metrics: Arc, + certified_stream_store: Arc, + base_url: Url, +} + +fn ok(t: T) -> Result { + Ok(t) +} + +/// Handles an incoming HTTP request by taking a permit from the semaphore, parsing the URL, +/// handing over to `route_request()` and replying with the produced response. +async fn handle_xnet_request( + State(ctx): State, + request: Request, +) -> impl IntoResponse { + let owned_permit = match ctx.semaphore.try_acquire_owned() { + Ok(permit) => permit, + Err(_) => { + ctx.metrics + .request_duration + .with_label_values(&[RESOURCE_UNKNOWN, StatusCode::SERVICE_UNAVAILABLE.as_str()]) + .observe(0.0); + + return ok(Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .body(Body::from("Queue full")) + .unwrap()); + } + }; + + ok(tokio::task::spawn_blocking(move || { + let _permit = owned_permit; + + match ctx.base_url.join( + request + .uri() + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or(""), + ) { + Ok(url) => route_request(url, ctx.certified_stream_store.as_ref(), &ctx.metrics), + Err(e) => { + let msg = format!("Invalid URL {}: {}", request.uri(), e); + warn!(ctx.log, "{}", msg); + bad_request(msg) + } + } + }) + .await + .expect("Processing http request panicked!")) +} + +fn start_server( + address: SocketAddr, + metrics: Arc, + certified_stream_store: Arc, + runtime_handle: runtime::Handle, + tls: Arc, + registry_client: Arc, + log: ReplicaLogger, + shutdown_notify: Arc, +) -> SocketAddr { + let listener = start_tcp_listener(address, &runtime_handle); + let address = listener.local_addr().expect("Failed to get local addr."); + + let _guard: runtime::EnterGuard<'_> = runtime_handle.enter(); + let ctx = Context { + log: log.clone(), + metrics: Arc::clone(&metrics), + semaphore: Arc::new(Semaphore::new(XNET_ENDPOINT_MAX_CONCURRENT_REQUESTS)), + certified_stream_store, + base_url: Url::parse(&format!("http://{}/", address)).unwrap(), + }; + + // Create a router that handles all requests by calling `enqueue_task` + // and attaches the `Context` as state. + let router = any(handle_xnet_request).with_state(ctx); + + let hyper_service = + hyper::service::service_fn(move |request: Request| router.clone().call(request)); + + let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()); + let graceful_shutdown = GracefulShutdown::new(); + + tokio::spawn(async move { + loop { + select! { + Ok((stream, _peer_addr)) = listener.accept() => { + let log = log.clone(); + let hyper_service = hyper_service.clone(); + + #[cfg(test)] + { + // TLS is not used in tests. + let _ = tls; + let _ = registry_client; + + let io = TokioIo::new(stream); + let conn = server.serve_connection_with_upgrades(io, hyper_service); + let conn = graceful_shutdown.watch(conn.into_owned()); + tokio::spawn(async move { + if let Err(err) = conn.await { + warn!(log, "failed to serve connection: {err}"); + } + }); + } + + #[cfg(not(test))] + { + // Creates a new TLS server config and uses it to accept the request. + let registry_version = registry_client.get_latest_version(); + let mut server_config = match tls.server_config( + ic_crypto_tls_interfaces::SomeOrAllNodes::All, + registry_version, + ) { + Ok(config) => config, + Err(err) => { + warn!(log, "Failed to get server config from crypto {err}"); + return; + } + }; + /// [TLS Application-Layer Protocol Negotiation (ALPN) Protocol `HTTP/2 over TLS` ID][spec] + /// [spec]: https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids) + const ALPN_HTTP2: &[u8; 2] = b"h2"; + + /// [TLS Application-Layer Protocol Negotiation (ALPN) Protocol `HTTP/1.1` ID][spec] + /// [spec]: https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids) + const ALPN_HTTP1_1: &[u8; 8] = b"http/1.1"; + + server_config.alpn_protocols = vec![ALPN_HTTP2.to_vec(), ALPN_HTTP1_1.to_vec()]; + let tls_acceptor = + tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); + match tls_acceptor.accept(stream).await { + Ok(tls_stream) => { + let io = TokioIo::new(tls_stream); + let conn = server.serve_connection_with_upgrades(io, hyper_service); + let conn = graceful_shutdown.watch(conn.into_owned()); + tokio::spawn(async move { + if let Err(err) = conn.await { + warn!(log, "failed to serve connection: {err}"); + } + }); + } + Err(err) => { + warn!(log, "Error setting up TLS stream: {err}"); + } + }; + } + } + _ = shutdown_notify.notified() => { + graceful_shutdown.shutdown().await; + break; + } + }; + } + }); + + address +} + impl XNetEndpoint { /// Creates and starts an `XNetEndpoint` to publish XNet `Streams`. pub fn new( @@ -117,135 +286,22 @@ impl XNetEndpoint { metrics: &MetricsRegistry, log: ReplicaLogger, ) -> Self { - use hyper::service::{make_service_fn, service_fn}; - use ic_xnet_hyper::{ExecuteOnRuntime, TlsConnection}; - let metrics = Arc::new(XNetEndpointMetrics::new(metrics)); - // Spawn a request handler. We pass the certified stream store, which is - // currently realized by the state manager. - - let make_service_closure = |address: SocketAddr| { - make_service_fn({ - let base_url = Url::parse(&format!("http://{}/", address)).unwrap(); - #[derive(Clone)] - struct Context { - log: ReplicaLogger, - semaphore: Arc, - metrics: Arc, - } - - let ctx = Context { - log: log.clone(), - metrics: Arc::clone(&metrics), - semaphore: Arc::new(Semaphore::new(XNET_ENDPOINT_MAX_CONCURRENT_REQUESTS)), - }; - - fn ok(t: T) -> Result { - Ok(t) - } - - move |tls_conn: &TlsConnection| { - let ctx = ctx.clone(); - let certified_stream_store = certified_stream_store.clone(); - let base_url = base_url.clone(); - - debug!( - ctx.log, - "Serving XNet streams to peer {:?}", - tls_conn.peer() - ); - - async move { - ok(service_fn({ - move |request: Request| { - let ctx = ctx.clone(); - let certified_stream_store = certified_stream_store.clone(); - let base_url = base_url.clone(); - - async move { - let owned_permit = match ctx.semaphore.try_acquire_owned() { - Ok(permit) => permit, - Err(_) => { - ctx.metrics - .request_duration - .with_label_values(&[ - RESOURCE_UNKNOWN, - StatusCode::SERVICE_UNAVAILABLE.as_str(), - ]) - .observe(0.0); - - return ok(Response::builder() - .status(StatusCode::SERVICE_UNAVAILABLE) - .body(Body::from("Queue full")) - .unwrap()); - } - }; - let metrics = ctx.metrics.clone(); - let log = ctx.log.clone(); - - Ok(tokio::task::spawn_blocking(move || { - let _permit = owned_permit; - - handle_http_request( - request, - certified_stream_store.as_ref(), - &base_url, - &metrics, - &log, - ) - }) - .await - .expect("Processing http request panicked!")) - } - } - })) - } - } - }) - }; - let (address, server) = { - let _guard = runtime_handle.enter(); - - #[cfg(test)] - use ic_xnet_hyper::tls_bind_for_test as tls_bind; - - #[cfg(not(test))] - use ic_xnet_hyper::tls_bind; - - let (addr, builder) = - tls_bind(&config.address, tls, registry_client).unwrap_or_else(|e| { - panic!( - "failed to bind XNet socket, address {:?}: {}", - config.address, e - ) - }); - - ( - addr, - builder - .executor(ExecuteOnRuntime(runtime_handle.clone())) - .serve(make_service_closure(addr)), - ) - }; - - info!(log, "XNet Endpoint listening on {}", address); - let shutdown_notify = Arc::new(Notify::new()); - let shutdown = server.with_graceful_shutdown({ - let shutdown_notify = Arc::clone(&shutdown_notify); - async move { shutdown_notify.notified().await } - }); + let address = start_server( + config.address, + metrics, + certified_stream_store, + runtime_handle.clone(), + tls, + registry_client, + log.clone(), + shutdown_notify.clone(), + ); - runtime_handle.spawn({ - let log = log.clone(); - async move { - if let Err(e) = shutdown.await { - warn!(log, "XNet http server failed: {}", e); - } - } - }); + info!(log, "XNet Endpoint listening on {}", address); Self { server_address: address, @@ -265,31 +321,6 @@ impl XNetEndpoint { } } -/// Handles an incoming HTTP request by parsing the URL, handing over to -/// `route_request()` and replying with the produced response. -fn handle_http_request( - request: Request, - certified_stream_store: &dyn CertifiedStreamStore, - base_url: &Url, - metrics: &XNetEndpointMetrics, - log: &ReplicaLogger, -) -> Response { - match base_url.join( - request - .uri() - .path_and_query() - .map(|pq| pq.as_str()) - .unwrap_or(""), - ) { - Ok(url) => route_request(url, certified_stream_store, metrics), - Err(e) => { - let msg = format!("Invalid URL {}: {}", request.uri(), e); - warn!(log, "{}", msg); - bad_request(msg) - } - } -} - /// Routes an `XNetEndpoint` request to the appropriate handler; or produces an /// HTTP 404 Not Found response if the URL doesn't match any handler. fn route_request( diff --git a/rs/xnet/endpoint/src/tests.rs b/rs/xnet/endpoint/src/tests.rs index a196410d8f0..76187d281c7 100644 --- a/rs/xnet/endpoint/src/tests.rs +++ b/rs/xnet/endpoint/src/tests.rs @@ -605,7 +605,7 @@ async fn http_get(url: &str) -> Bytes { /// Parses a `Response` into status code and body. async fn parse_response(response: Response) -> (u16, Vec) { let status = response.status().as_u16(); - let body = hyper::body::to_bytes(response.into_body()) + let body = axum::body::to_bytes(response.into_body(), usize::MAX) .await .unwrap() .to_vec();