diff --git a/.github/actions/pr-comment-data-export/action.yml b/.github/actions/pr-comment-data-export/action.yml index 7e2fb075f7..904c3ed589 100644 --- a/.github/actions/pr-comment-data-export/action.yml +++ b/.github/actions/pr-comment-data-export/action.yml @@ -31,7 +31,7 @@ runs: echo "${{ inputs.log-url }}" > comment-data/log-url fi - if: github.event_name == 'pull_request' - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 with: name: ${{ inputs.name }} path: comment-data diff --git a/.github/actions/quic-interop-runner/action.yml b/.github/actions/quic-interop-runner/action.yml index ba13102ecb..a8736a60bb 100644 --- a/.github/actions/quic-interop-runner/action.yml +++ b/.github/actions/quic-interop-runner/action.yml @@ -83,7 +83,7 @@ runs: done shell: bash - - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 id: upload-logs with: name: '${{ inputs.client }} vs. ${{ inputs.server }} logs' @@ -97,7 +97,7 @@ runs: mv result.json.tmp result.json shell: bash - - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 with: name: '${{ inputs.client }} vs. ${{ inputs.server }} results' path: | diff --git a/.github/actions/rust/action.yml b/.github/actions/rust/action.yml index 4b0741d748..07513b56ba 100644 --- a/.github/actions/rust/action.yml +++ b/.github/actions/rust/action.yml @@ -11,6 +11,9 @@ inputs: tools: description: 'Additional Rust tools to install' default: '' + token: + description: 'A Github PAT' + required: true runs: using: composite @@ -41,12 +44,16 @@ runs: - name: Install cargo-quickinstall shell: bash if: inputs.tools != '' + env: + GITHUB_TOKEN: ${{ inputs.token }} run: cargo +${{ inputs.version }} install cargo-quickinstall - name: Install Rust tools shell: bash if: inputs.tools != '' - run: cargo +${{ inputs.version }} quickinstall --no-binstall $(echo ${{ inputs.tools }} | tr -d ",") + env: + GITHUB_TOKEN: ${{ inputs.token }} + run: cargo +${{ inputs.version }} quickinstall $(echo ${{ inputs.tools }} | tr -d ",") # sccache slows CI down, so we leave it disabled. # Leaving the steps below commented out, so we can re-evaluate enabling it later. diff --git a/.github/workflows/bench.yml b/.github/workflows/bench.yml index 2faf13eccc..a30d067a19 100644 --- a/.github/workflows/bench.yml +++ b/.github/workflows/bench.yml @@ -45,6 +45,7 @@ jobs: with: version: $TOOLCHAIN tools: hyperfine + token: ${{ secrets.GITHUB_TOKEN }} - name: Get minimum NSS version id: nss-version @@ -237,7 +238,7 @@ jobs: - name: Export perf data id: export - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 with: name: ${{ github.event.repository.name }}-${{ github.sha }} path: | diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index f63857d868..2d3f7ba5bb 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -51,6 +51,7 @@ jobs: version: ${{ matrix.rust-toolchain }} components: rustfmt, clippy, llvm-tools-preview tools: cargo-llvm-cov, cargo-nextest, cargo-hack, cargo-fuzz, cargo-machete + token: ${{ secrets.GITHUB_TOKEN }} - name: Get minimum NSS version id: nss-version @@ -129,6 +130,8 @@ jobs: fail_ci_if_error: false token: ${{ secrets.CODECOV_TOKEN }} verbose: true + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} if: matrix.type == 'debug' && matrix.rust-toolchain == 'stable' bench: diff --git a/.github/workflows/firefox.yml b/.github/workflows/firefox.yml index eb94e49d64..03df7bbd16 100644 --- a/.github/workflows/firefox.yml +++ b/.github/workflows/firefox.yml @@ -113,7 +113,7 @@ jobs: - name: Export binary id: upload - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 with: name: ${{ runner.os }}-${{ env.FIREFOX }}-${{ matrix.type }}.tgz path: ${{ env.FIREFOX }}.tar @@ -122,7 +122,7 @@ jobs: - run: echo "${{ steps.upload.outputs.artifact-url }}" >> artifact - name: Export artifact URL - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 with: name: artifact-${{ runner.os }}-${{ env.FIREFOX }}-${{ matrix.type }} path: artifact diff --git a/.github/workflows/mutants.yml b/.github/workflows/mutants.yml index 7d4e525783..23365269d6 100644 --- a/.github/workflows/mutants.yml +++ b/.github/workflows/mutants.yml @@ -35,6 +35,7 @@ jobs: uses: ./.github/actions/rust with: tools: cargo-mutants + token: ${{ secrets.GITHUB_TOKEN }} - name: Find incremental mutants if: github.event_name == 'pull_request' @@ -63,7 +64,7 @@ jobs: } > "$GITHUB_STEP_SUMMARY" - name: Archive mutants.out - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 if: always() with: name: mutants.out diff --git a/.github/workflows/qns.yml b/.github/workflows/qns.yml index 8752680a36..40eade309a 100644 --- a/.github/workflows/qns.yml +++ b/.github/workflows/qns.yml @@ -35,7 +35,7 @@ jobs: packages: write steps: - uses: docker/setup-qemu-action@49b3bc8e6bdd4a60e6116a5414239cba5943d3cf # v3.2.0 - - uses: docker/setup-buildx-action@aa33708b10e362ff993539393ff100fa93ed6a27 # v3.5.0 + - uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: ghcr.io @@ -77,7 +77,7 @@ jobs: platforms: 'linux/amd64' outputs: type=docker,dest=/tmp/${{ env.LATEST }}.tar - - uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 + - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # v4.3.5 with: name: '${{ env.LATEST }} Docker image' path: /tmp/${{ env.LATEST }}.tar @@ -186,6 +186,7 @@ jobs: fi jq < "$RUN/result.json" ' . as $data | + .results[][].result //= "failed" | { results: [.results[] | group_by(.result)[] | {(.[0].result): [.[] | .abbr]}] | add @@ -193,6 +194,7 @@ jobs: . + {log_url: $data.log_url} ' > "$RUN/$ROLE.grouped.json" for ROLE in client server; do + [ ! -e "$RUN/$ROLE.grouped.json" ] && continue for GROUP in $(jq -r < "$RUN/$ROLE.grouped.json" '.results | keys[]'); do RESULT=$(jq < "$RUN/$ROLE.grouped.json" -r '.results.'"$GROUP"'[]' | fmt -w 1000) LOG=$(jq -r < "$RUN/$ROLE.grouped.json" -r '.log_url') diff --git a/Cargo.toml b/Cargo.toml index 3bc14976b2..765aa63aef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ description = "Neqo, the Mozilla implementation of QUIC in Rust." keywords = ["quic", "http3", "neqo", "mozilla", "ietf", "firefox"] categories = ["network-programming", "web-programming"] readme = "README.md" -version = "0.8.1" +version = "0.8.2" # Keep in sync with `.rustfmt.toml` `edition`. edition = "2021" license = "MIT OR Apache-2.0" diff --git a/README.md b/README.md index beadf22ecf..bb309ac824 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,13 @@ Compile Gecko as usual with Note: Using newer Neqo code with Gecko may also require changes (likely to `neqo_glue`) if something has changed. +### Connect with Firefox to local neqo-server + +1. Run `neqo-server` via `cargo run --bin neqo-server -- 'localhost:12345' --db ./test-fixture/db`. +2. On Firefox, set `about:config` preference `network.http.http3.alt-svc-mapping-for-testing` to `localhost;h3=":12345"`. +3. Optionally enable logging via `about:logging` or profiling via https://profiler.firefox.com/. +4. Navigate to https://localhost:12345 and accept self-signed certificate. + [NSS]: https://hg.mozilla.org/projects/nss [NSPR]: https://hg.mozilla.org/projects/nspr [GYP]: https://github.com/nodejs/gyp-next diff --git a/neqo-bin/src/client/mod.rs b/neqo-bin/src/client/mod.rs index d65f57f4cd..a68dfd104f 100644 --- a/neqo-bin/src/client/mod.rs +++ b/neqo-bin/src/client/mod.rs @@ -23,14 +23,13 @@ use futures::{ future::{select, Either}, FutureExt, TryFutureExt, }; -use neqo_common::{self as common, qdebug, qerror, qinfo, qlog::NeqoQlog, qwarn, Datagram, Role}; +use neqo_common::{qdebug, qerror, qinfo, qlog::NeqoQlog, qwarn, Datagram, Role}; use neqo_crypto::{ constants::{TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256}, init, Cipher, ResumptionToken, }; use neqo_http3::Output; use neqo_transport::{AppError, CloseReason, ConnectionId, Version}; -use qlog::{events::EventImportance, streamer::QlogStreamer}; use tokio::time::Sleep; use url::{Origin, Url}; @@ -46,7 +45,7 @@ pub enum Error { ArgumentError(&'static str), Http3Error(neqo_http3::Error), IoError(io::Error), - QlogError, + QlogError(qlog::Error), TransportError(neqo_transport::Error), ApplicationError(neqo_transport::AppError), CryptoError(neqo_crypto::Error), @@ -71,8 +70,8 @@ impl From for Error { } impl From for Error { - fn from(_err: qlog::Error) -> Self { - Self::QlogError + fn from(err: qlog::Error) -> Self { + Self::QlogError(err) } } @@ -174,7 +173,7 @@ pub struct Args { impl Args { #[must_use] - #[cfg(feature = "bench")] + #[cfg(any(test, feature = "bench"))] #[allow(clippy::missing_panics_doc)] pub fn new(requests: &[u64]) -> Self { use std::str::FromStr; @@ -277,6 +276,11 @@ impl Args { _ => exit(127), } } + + #[cfg(any(test, feature = "bench"))] + pub fn set_qlog_dir(&mut self, dir: PathBuf) { + self.shared.qlog_dir = Some(dir); + } } fn get_output_file( @@ -453,32 +457,26 @@ impl<'a, H: Handler> Runner<'a, H> { } fn qlog_new(args: &Args, hostname: &str, cid: &ConnectionId) -> Res { - if let Some(qlog_dir) = &args.shared.qlog_dir { - let mut qlog_path = qlog_dir.clone(); - let filename = format!("{hostname}-{cid}.sqlog"); - qlog_path.push(filename); - - let f = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(&qlog_path)?; - - let streamer = QlogStreamer::new( - qlog::QLOG_VERSION.to_string(), - Some("Example qlog".to_string()), - Some("Example qlog description".to_string()), - None, - std::time::Instant::now(), - common::qlog::new_trace(Role::Client), - EventImportance::Base, - Box::new(f), - ); - - Ok(NeqoQlog::enabled(streamer, qlog_path)?) - } else { - Ok(NeqoQlog::disabled()) - } + let Some(qlog_dir) = args.shared.qlog_dir.clone() else { + return Ok(NeqoQlog::disabled()); + }; + + // hostname might be an IPv6 address, e.g. `[::1]`. `:` is an invalid + // Windows file name character. + #[cfg(windows)] + let hostname: String = hostname + .chars() + .map(|c| if c == ':' { '_' } else { c }) + .collect(); + + NeqoQlog::enabled_with_file( + qlog_dir, + Role::Client, + Some("Example qlog".to_string()), + Some("Example qlog description".to_string()), + format!("{hostname}-{cid}"), + ) + .map_err(Error::QlogError) } pub async fn client(mut args: Args) -> Res<()> { diff --git a/neqo-bin/src/lib.rs b/neqo-bin/src/lib.rs index 998a3f0d8e..4dfa770ad6 100644 --- a/neqo-bin/src/lib.rs +++ b/neqo-bin/src/lib.rs @@ -65,7 +65,7 @@ pub struct SharedArgs { pub quic_parameters: QuicParameters, } -#[cfg(feature = "bench")] +#[cfg(any(test, feature = "bench"))] impl Default for SharedArgs { fn default() -> Self { Self { @@ -132,7 +132,7 @@ pub struct QuicParameters { pub preferred_address_v6: Option, } -#[cfg(feature = "bench")] +#[cfg(any(test, feature = "bench"))] impl Default for QuicParameters { fn default() -> Self { Self { @@ -252,3 +252,73 @@ impl Display for Error { } impl std::error::Error for Error {} + +#[cfg(test)] +mod tests { + use std::{fs, path::PathBuf, str::FromStr, time::SystemTime}; + + use crate::{client, server}; + + struct TempDir { + path: PathBuf, + } + + impl TempDir { + fn new() -> Self { + let mut dir = std::env::temp_dir(); + dir.push(format!( + "neqo-bin-test-{}", + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() + )); + fs::create_dir(&dir).unwrap(); + Self { path: dir } + } + + fn path(&self) -> PathBuf { + self.path.clone() + } + } + + impl Drop for TempDir { + fn drop(&mut self) { + if self.path.exists() { + fs::remove_dir_all(&self.path).unwrap(); + } + } + } + + #[tokio::test] + async fn write_qlog_file() { + neqo_crypto::init_db(PathBuf::from_str("../test-fixture/db").unwrap()).unwrap(); + + let temp_dir = TempDir::new(); + + let mut client_args = client::Args::new(&[1]); + client_args.set_qlog_dir(temp_dir.path()); + let mut server_args = server::Args::default(); + server_args.set_qlog_dir(temp_dir.path()); + + let client = client::client(client_args); + let server = Box::pin(server::server(server_args)); + tokio::select! { + _ = client => {} + res = server => panic!("expect server not to terminate: {res:?}"), + }; + + // Verify that the directory contains two non-empty files + let entries: Vec<_> = fs::read_dir(temp_dir.path()) + .unwrap() + .filter_map(Result::ok) + .collect(); + assert_eq!(entries.len(), 2, "expect 2 files in the directory"); + + for entry in entries { + let metadata = entry.metadata().unwrap(); + assert!(metadata.is_file(), "expect a file, found something else"); + assert!(metadata.len() > 0, "expect file not be empty"); + } + } +} diff --git a/neqo-bin/src/server/mod.rs b/neqo-bin/src/server/mod.rs index 06b6146728..3c3e2923ec 100644 --- a/neqo-bin/src/server/mod.rs +++ b/neqo-bin/src/server/mod.rs @@ -118,7 +118,7 @@ pub struct Args { ech: bool, } -#[cfg(feature = "bench")] +#[cfg(any(test, feature = "bench"))] impl Default for Args { fn default() -> Self { use std::str::FromStr; @@ -175,6 +175,11 @@ impl Args { Instant::now() } } + + #[cfg(any(test, feature = "bench"))] + pub fn set_qlog_dir(&mut self, dir: PathBuf) { + self.shared.qlog_dir = Some(dir); + } } fn qns_read_response(filename: &str) -> Result, io::Error> { diff --git a/neqo-bin/src/udp.rs b/neqo-bin/src/udp.rs index 0f386ec2ea..94488032d7 100644 --- a/neqo-bin/src/udp.rs +++ b/neqo-bin/src/udp.rs @@ -57,7 +57,7 @@ impl Socket { pub fn recv(&self, local_address: &SocketAddr) -> Result, io::Error> { self.inner .try_io(tokio::io::Interest::READABLE, || { - neqo_udp::recv_inner(local_address, &self.state, (&self.inner).into()) + neqo_udp::recv_inner(local_address, &self.state, &self.inner) }) .or_else(|e| { if e.kind() == io::ErrorKind::WouldBlock { diff --git a/neqo-common/src/qlog.rs b/neqo-common/src/qlog.rs index c0a4cbc960..b1f98b95c4 100644 --- a/neqo-common/src/qlog.rs +++ b/neqo-common/src/qlog.rs @@ -6,8 +6,10 @@ use std::{ cell::RefCell, - fmt, - path::{Path, PathBuf}, + fmt::{self, Display}, + fs::OpenOptions, + io::BufWriter, + path::PathBuf, rc::Rc, }; @@ -29,21 +31,53 @@ pub struct NeqoQlogShared { } impl NeqoQlog { - /// Create an enabled `NeqoQlog` configuration. + /// Create an enabled `NeqoQlog` configuration backed by a file. /// /// # Errors /// - /// Will return `qlog::Error` if cannot write to the new log. - pub fn enabled( - mut streamer: QlogStreamer, - qlog_path: impl AsRef, + /// Will return `qlog::Error` if it cannot write to the new file. + pub fn enabled_with_file( + mut qlog_path: PathBuf, + role: Role, + title: Option, + description: Option, + file_prefix: impl Display, ) -> Result { + qlog_path.push(format!("{file_prefix}.sqlog")); + + let file = OpenOptions::new() + .write(true) + // As a server, the original DCID is chosen by the client. Using + // create_new() prevents attackers from overwriting existing logs. + .create_new(true) + .open(&qlog_path) + .map_err(qlog::Error::IoError)?; + + let streamer = QlogStreamer::new( + qlog::QLOG_VERSION.to_string(), + title, + description, + None, + std::time::Instant::now(), + new_trace(role), + qlog::events::EventImportance::Base, + Box::new(BufWriter::new(file)), + ); + Self::enabled(streamer, qlog_path) + } + + /// Create an enabled `NeqoQlog` configuration. + /// + /// # Errors + /// + /// Will return `qlog::Error` if it cannot write to the new log. + pub fn enabled(mut streamer: QlogStreamer, qlog_path: PathBuf) -> Result { streamer.start_log()?; Ok(Self { inner: Rc::new(RefCell::new(Some(NeqoQlogShared { + qlog_path, streamer, - qlog_path: qlog_path.as_ref().to_owned(), }))), }) } diff --git a/neqo-crypto/bindings/bindings.toml b/neqo-crypto/bindings/bindings.toml index 5d692f78b5..01a4e178ac 100644 --- a/neqo-crypto/bindings/bindings.toml +++ b/neqo-crypto/bindings/bindings.toml @@ -45,7 +45,6 @@ functions = [ "SSL_OptionSet", "SSL_OptionGetDefault", "SSL_PeerCertificate", - "SSL_PeerCertificateChain", "SSL_PeerSignedCertTimestamps", "SSL_PeerStapledOCSPResponses", "SSL_ResetHandshake", @@ -137,8 +136,6 @@ variables = [ [nss_p11] types = [ - "CERTCertList", - "CERTCertListNode", "CK_CHACHA20_PARAMS", "CK_ATTRIBUTE_TYPE", "CK_FLAGS", @@ -151,7 +148,6 @@ types = [ ] functions = [ "CERT_DestroyCertificate", - "CERT_DestroyCertList", "CERT_GetCertificateDer", "NSS_SetAlgorithmPolicy", "PK11_CipherOp", @@ -173,6 +169,7 @@ functions = [ "PK11_ImportDataKey", "PK11_ReadRawAttribute", "PK11_ReferenceSymKey", + "SECITEM_FreeArray", "SECITEM_FreeItem", "SECKEY_CopyPrivateKey", "SECKEY_CopyPublicKey", diff --git a/neqo-crypto/min_version.txt b/neqo-crypto/min_version.txt index 422c9c7093..eaa18a6df7 100644 --- a/neqo-crypto/min_version.txt +++ b/neqo-crypto/min_version.txt @@ -1 +1 @@ -3.98 +3.103 diff --git a/neqo-crypto/src/cert.rs b/neqo-crypto/src/cert.rs index 48afbd95f9..80609f5316 100644 --- a/neqo-crypto/src/cert.rs +++ b/neqo-crypto/src/cert.rs @@ -4,23 +4,23 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use std::ptr::{addr_of, NonNull}; +use std::ptr::NonNull; use neqo_common::qerror; use crate::{ - err::secstatus_to_res, - null_safe_slice, - p11::{CERTCertListNode, CERT_GetCertificateDer, CertList, Item, SECItem, SECItemArray}, - ssl::{ - PRFileDesc, SSL_PeerCertificateChain, SSL_PeerSignedCertTimestamps, - SSL_PeerStapledOCSPResponses, - }, + experimental_api, null_safe_slice, + p11::{ItemArray, ItemArrayIterator, SECItem, SECItemArray}, + ssl::{PRFileDesc, SSL_PeerSignedCertTimestamps, SSL_PeerStapledOCSPResponses}, }; +experimental_api!(SSL_PeerCertificateChainDER( + fd: *mut PRFileDesc, + out: *mut *mut SECItemArray, +)); + pub struct CertificateInfo { - certs: CertList, - cursor: *const CERTCertListNode, + certs: ItemArray, /// `stapled_ocsp_responses` and `signed_cert_timestamp` are properties /// associated with each of the certificates. Right now, NSS only /// reports the value for the end-entity certificate (the first). @@ -28,12 +28,14 @@ pub struct CertificateInfo { signed_cert_timestamp: Option>, } -fn peer_certificate_chain(fd: *mut PRFileDesc) -> Option<(CertList, *const CERTCertListNode)> { - let chain = unsafe { SSL_PeerCertificateChain(fd) }; - CertList::from_ptr(chain.cast()).ok().map(|certs| { - let cursor = CertificateInfo::head(&certs); - (certs, cursor) - }) +fn peer_certificate_chain(fd: *mut PRFileDesc) -> Option { + let mut chain_ptr: *mut SECItemArray = std::ptr::null_mut(); + let rv = unsafe { SSL_PeerCertificateChainDER(fd, &mut chain_ptr) }; + if rv.is_ok() { + ItemArray::from_ptr(chain_ptr).ok() + } else { + None + } } // As explained in rfc6961, an OCSPResponseList can have at most @@ -72,32 +74,26 @@ fn signed_cert_timestamp(fd: *mut PRFileDesc) -> Option> { impl CertificateInfo { pub(crate) fn new(fd: *mut PRFileDesc) -> Option { - peer_certificate_chain(fd).map(|(certs, cursor)| Self { + peer_certificate_chain(fd).map(|certs| Self { certs, - cursor, stapled_ocsp_responses: stapled_ocsp_responses(fd), signed_cert_timestamp: signed_cert_timestamp(fd), }) } +} - fn head(certs: &CertList) -> *const CERTCertListNode { - // Three stars: one for the reference, one for the wrapper, one to deference the pointer. - unsafe { addr_of!((***certs).list).cast() } +impl CertificateInfo { + #[must_use] + pub fn iter(&self) -> ItemArrayIterator<'_> { + self.certs.into_iter() } } -impl<'a> Iterator for &'a mut CertificateInfo { +impl<'a> IntoIterator for &'a CertificateInfo { + type IntoIter = ItemArrayIterator<'a>; type Item = &'a [u8]; - fn next(&mut self) -> Option<&'a [u8]> { - self.cursor = unsafe { *self.cursor }.links.next.cast(); - if self.cursor == CertificateInfo::head(&self.certs) { - return None; - } - let mut item = Item::make_empty(); - let cert = unsafe { *self.cursor }.cert; - secstatus_to_res(unsafe { CERT_GetCertificateDer(cert, &mut item) }) - .expect("getting DER from certificate should work"); - Some(unsafe { null_safe_slice(item.data, item.len) }) + fn into_iter(self) -> Self::IntoIter { + self.iter() } } diff --git a/neqo-crypto/src/ech.rs b/neqo-crypto/src/ech.rs index 4ff2cda7e8..76fd362c14 100644 --- a/neqo-crypto/src/ech.rs +++ b/neqo-crypto/src/ech.rs @@ -102,8 +102,8 @@ pub fn generate_keys() -> Res<(PrivateKey, PublicKey)> { let oid = unsafe { oid_data.as_ref() }.ok_or(Error::InternalError)?; let oid_slc = unsafe { null_safe_slice(oid.oid.data, oid.oid.len) }; let mut params: Vec = Vec::with_capacity(oid_slc.len() + 2); - params.push(u8::try_from(p11::SEC_ASN1_OBJECT_ID).unwrap()); - params.push(u8::try_from(oid.oid.len).unwrap()); + params.push(u8::try_from(p11::SEC_ASN1_OBJECT_ID)?); + params.push(u8::try_from(oid.oid.len)?); params.extend_from_slice(oid_slc); let mut public_ptr: *mut SECKEYPublicKey = null_mut(); diff --git a/neqo-crypto/src/hp.rs b/neqo-crypto/src/hp.rs index f10b913039..e8412b646e 100644 --- a/neqo-crypto/src/hp.rs +++ b/neqo-crypto/src/hp.rs @@ -118,7 +118,7 @@ impl HpKey { debug_assert_eq!( res.block_size(), - usize::try_from(unsafe { PK11_GetBlockSize(mech, null_mut()) }).unwrap() + usize::try_from(unsafe { PK11_GetBlockSize(mech, null_mut()) })? ); Ok(res) } @@ -154,10 +154,10 @@ impl HpKey { &mut output_len, c_int::try_from(output.len())?, sample[..Self::SAMPLE_SIZE].as_ptr().cast(), - c_int::try_from(Self::SAMPLE_SIZE).unwrap(), + c_int::try_from(Self::SAMPLE_SIZE)?, ) })?; - debug_assert_eq!(usize::try_from(output_len).unwrap(), output.len()); + debug_assert_eq!(usize::try_from(output_len)?, output.len()); Ok(output) } @@ -182,7 +182,7 @@ impl HpKey { c_uint::try_from(Self::SAMPLE_SIZE)?, ) })?; - debug_assert_eq!(usize::try_from(output_len).unwrap(), output.len()); + debug_assert_eq!(usize::try_from(output_len)?, output.len()); Ok(output) } } diff --git a/neqo-crypto/src/lib.rs b/neqo-crypto/src/lib.rs index 52e1f42e73..e35c499dea 100644 --- a/neqo-crypto/src/lib.rs +++ b/neqo-crypto/src/lib.rs @@ -61,6 +61,7 @@ pub use self::{ mod min_version; use min_version::MINIMUM_NSS_VERSION; +use neqo_common::qerror; #[allow(non_upper_case_globals)] mod nss { @@ -94,13 +95,13 @@ fn already_initialized() -> bool { unsafe { nss::NSS_IsInitialized() != 0 } } -fn version_check() { - let min_ver = CString::new(MINIMUM_NSS_VERSION).unwrap(); - assert_ne!( - unsafe { nss::NSS_VersionCheck(min_ver.as_ptr()) }, - 0, - "Minimum NSS version of {MINIMUM_NSS_VERSION} not supported", - ); +fn version_check() -> Res<()> { + let min_ver = CString::new(MINIMUM_NSS_VERSION)?; + if unsafe { nss::NSS_VersionCheck(min_ver.as_ptr()) } == 0 { + qerror!("Minimum NSS version of {MINIMUM_NSS_VERSION} not supported"); + return Err(Error::UnsupportedVersion); + } + Ok(()) } /// Initialize NSS. This only executes the initialization routines once, so if there is any chance @@ -113,7 +114,7 @@ pub fn init() -> Res<()> { // Set time zero. time::init(); let res = INITIALIZED.get_or_init(|| { - version_check(); + version_check()?; if already_initialized() { return Ok(NssLoaded::External); } @@ -152,7 +153,7 @@ fn enable_ssl_trace() -> Res<()> { pub fn init_db>(dir: P) -> Res<()> { time::init(); let res = INITIALIZED.get_or_init(|| { - version_check(); + version_check()?; if already_initialized() { return Ok(NssLoaded::External); } @@ -213,16 +214,15 @@ pub fn assert_initialized() { /// # Safety /// The caller must adhere to the safety constraints of `std::slice::from_raw_parts`, /// except that this will accept a null value for `data`. -unsafe fn null_safe_slice<'a, T>(data: *const u8, len: T) -> &'a [u8] +unsafe fn null_safe_slice<'a, T, L>(data: *const T, len: L) -> &'a [T] where - usize: TryFrom, + usize: TryFrom, { - if data.is_null() { + let len = usize::try_from(len).unwrap_or_else(|_| panic!("null_safe_slice: size overflow")); + if data.is_null() || len == 0 { &[] - } else if let Ok(len) = usize::try_from(len) { + } else { #[allow(clippy::disallowed_methods)] std::slice::from_raw_parts(data, len) - } else { - panic!("null_safe_slice: size overflow"); } } diff --git a/neqo-crypto/src/p11.rs b/neqo-crypto/src/p11.rs index dbd1f68276..20a88ba70f 100644 --- a/neqo-crypto/src/p11.rs +++ b/neqo-crypto/src/p11.rs @@ -15,6 +15,7 @@ use std::{ ops::{Deref, DerefMut}, os::raw::c_uint, ptr::null_mut, + slice::Iter as SliceIter, }; use neqo_common::hex_with_len; @@ -76,7 +77,6 @@ macro_rules! scoped_ptr { } scoped_ptr!(Certificate, CERTCertificate, CERT_DestroyCertificate); -scoped_ptr!(CertList, CERTCertList, CERT_DestroyCertList); scoped_ptr!(PublicKey, SECKEYPublicKey, SECKEY_DestroyPublicKey); impl PublicKey { @@ -97,10 +97,10 @@ impl PublicKey { **self, buf.as_mut_ptr(), &mut len, - c_uint::try_from(buf.len()).unwrap(), + c_uint::try_from(buf.len())?, ) })?; - buf.truncate(usize::try_from(len).unwrap()); + buf.truncate(usize::try_from(len)?); Ok(buf) } } @@ -237,6 +237,12 @@ unsafe fn destroy_secitem(item: *mut SECItem) { } scoped_ptr!(Item, SECItem, destroy_secitem); +impl AsRef<[u8]> for SECItem { + fn as_ref(&self) -> &[u8] { + unsafe { null_safe_slice(self.data, self.len) } + } +} + impl Item { /// Create a wrapper for a slice of this object. /// Creating this object is technically safe, but using it is extremely dangerous. @@ -287,6 +293,38 @@ impl Item { } } +unsafe fn destroy_secitem_array(array: *mut SECItemArray) { + SECITEM_FreeArray(array, PRBool::from(true)); +} +scoped_ptr!(ItemArray, SECItemArray, destroy_secitem_array); + +impl<'a> IntoIterator for &'a ItemArray { + type Item = &'a [u8]; + type IntoIter = ItemArrayIterator<'a>; + fn into_iter(self) -> Self::IntoIter { + Self::IntoIter { + iter: AsRef::<[SECItem]>::as_ref(self).iter(), + } + } +} + +impl AsRef<[SECItem]> for ItemArray { + fn as_ref(&self) -> &[SECItem] { + unsafe { null_safe_slice((*self.ptr).items, (*self.ptr).len) } + } +} + +pub struct ItemArrayIterator<'a> { + iter: SliceIter<'a, SECItem>, +} + +impl<'a> Iterator for ItemArrayIterator<'a> { + type Item = &'a [u8]; + fn next(&mut self) -> Option<&'a [u8]> { + self.iter.next().map(AsRef::<[u8]>::as_ref) + } +} + #[cfg(feature = "disable-random")] thread_local! { static CURRENT_VALUE: std::cell::Cell = const { std::cell::Cell::new(0) }; diff --git a/neqo-crypto/tests/agent.rs b/neqo-crypto/tests/agent.rs index 80bf816930..b049f4cfb7 100644 --- a/neqo-crypto/tests/agent.rs +++ b/neqo-crypto/tests/agent.rs @@ -129,8 +129,8 @@ fn raw() { assert!(server.state().is_connected()); // The client should have one certificate for the server. - let mut certs = client.peer_certificate().unwrap(); - assert_eq!(1, certs.count()); + let certs = client.peer_certificate().unwrap(); + assert_eq!(1, certs.into_iter().count()); // The server shouldn't have a client certificate. assert!(server.peer_certificate().is_none()); diff --git a/neqo-transport/src/connection/dump.rs b/neqo-transport/src/connection/dump.rs index 22d4ede474..10e1025524 100644 --- a/neqo-transport/src/connection/dump.rs +++ b/neqo-transport/src/connection/dump.rs @@ -42,7 +42,7 @@ pub fn dump_packet( }; let x = f.dump(); if !x.is_empty() { - write!(&mut s, "\n {} {}", dir, &x).unwrap(); + _ = write!(&mut s, "\n {} {}", dir, &x); } } qdebug!( diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 1d68880f2c..276d407f17 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -328,7 +328,7 @@ impl Connection { c.conn_params.get_versions().compatible(), Role::Client, &dcid, - ); + )?; c.original_destination_cid = Some(dcid); let path = Path::temporary( local_addr, @@ -1129,21 +1129,25 @@ impl Connection { output } - fn handle_retry(&mut self, packet: &PublicPacket, now: Instant) { + fn handle_retry(&mut self, packet: &PublicPacket, now: Instant) -> Res<()> { qinfo!([self], "received Retry"); if matches!(self.address_validation, AddressValidationInfo::Retry { .. }) { self.stats.borrow_mut().pkt_dropped("Extra Retry"); - return; + return Ok(()); } if packet.token().is_empty() { self.stats.borrow_mut().pkt_dropped("Retry without a token"); - return; + return Ok(()); } - if !packet.is_valid_retry(self.original_destination_cid.as_ref().unwrap()) { + if !packet.is_valid_retry( + self.original_destination_cid + .as_ref() + .ok_or(Error::InvalidRetry)?, + ) { self.stats .borrow_mut() .pkt_dropped("Retry with bad integrity tag"); - return; + return Ok(()); } // At this point, we should only have the connection ID that we generated. // Update to the one that the server prefers. @@ -1151,7 +1155,7 @@ impl Connection { self.stats .borrow_mut() .pkt_dropped("Retry without an existing path"); - return; + return Ok(()); }; path.borrow_mut().set_remote_cid(packet.scid()); @@ -1171,11 +1175,12 @@ impl Connection { self.conn_params.get_versions().compatible(), self.role, &retry_scid, - ); + )?; self.address_validation = AddressValidationInfo::Retry { token: packet.token().to_vec(), retry_source_cid: retry_scid, }; + Ok(()) } fn discard_keys(&mut self, space: PacketNumberSpace, now: Instant) { @@ -1194,8 +1199,8 @@ impl Connection { if d.len() < 16 || !self.state.connected() { return false; } - let token = <&[u8; 16]>::try_from(&d[d.len() - 16..]).unwrap(); - path.borrow().is_stateless_reset(token) + <&[u8; 16]>::try_from(&d[d.len() - 16..]) + .map_or(false, |token| path.borrow().is_stateless_reset(token)) } fn check_stateless_reset( @@ -1265,7 +1270,7 @@ impl Connection { .clone() .versions(version, self.conn_params.get_versions().all().to_vec()); let mut c = Self::new_client( - self.crypto.server_name().unwrap(), + self.crypto.server_name().ok_or(Error::VersionNegotiation)?, self.crypto.protocols(), self.cid_manager.generator(), local_addr, @@ -1324,7 +1329,7 @@ impl Connection { match (packet.packet_type(), &self.state, &self.role) { (PacketType::Initial, State::Init, Role::Server) => { - let version = *packet.version().as_ref().unwrap(); + let version = packet.version().ok_or(Error::ProtocolViolation)?; if !packet.is_valid_initial() || !self.conn_params.get_versions().all().contains(&version) { @@ -1340,7 +1345,7 @@ impl Connection { // Record the client's selected CID so that it can be accepted until // the client starts using a real connection ID. let dcid = ConnectionId::from(packet.dcid()); - self.crypto.states.init_server(version, &dcid); + self.crypto.states.init_server(version, &dcid)?; self.original_destination_cid = Some(dcid); self.set_state(State::WaitInitial); @@ -1359,7 +1364,7 @@ impl Connection { if versions.is_empty() || versions.contains(&self.version().wire_version()) || versions.contains(&0) - || &packet.scid() != self.odcid().unwrap() + || &packet.scid() != self.odcid().ok_or(Error::InternalError)? || matches!(self.address_validation, AddressValidationInfo::Retry { .. }) { // Ignore VersionNegotiation packets that contain the current version. @@ -1375,7 +1380,7 @@ impl Connection { return Ok(PreprocessResult::End); } (PacketType::Retry, State::WaitInitial, Role::Client) => { - self.handle_retry(packet, now); + self.handle_retry(packet, now)?; return Ok(PreprocessResult::Next); } (PacketType::Handshake | PacketType::Short, State::WaitInitial, Role::Client) => { @@ -1679,7 +1684,11 @@ impl Connection { self.paths.make_permanent(path, None, cid); Ok(()) } else if let Some(primary) = self.paths.primary() { - if primary.borrow().remote_cid().is_empty() { + if primary + .borrow() + .remote_cid() + .map_or(true, |id| id.is_empty()) + { self.paths .make_permanent(path, None, ConnectionIdEntry::empty_remote()); Ok(()) @@ -1908,7 +1917,7 @@ impl Connection { // a packet on a new path, we avoid sending (and the privacy risk) rather // than reuse a connection ID. let res = if path.borrow().is_temporary() { - assert!(!cfg!(test), "attempting to close with a temporary path"); + qerror!([self], "Attempting to close with a temporary path"); Err(Error::InternalError) } else { self.output_path(&path, now, &Some(details)) @@ -1932,16 +1941,15 @@ impl Connection { ) -> (PacketType, PacketBuilder) { let pt = PacketType::from(cspace); let mut builder = if pt == PacketType::Short { - qdebug!("Building Short dcid {}", path.remote_cid()); + qdebug!("Building Short dcid {:?}", path.remote_cid()); PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid()) } else { qdebug!( - "Building {:?} dcid {} scid {}", + "Building {:?} dcid {:?} scid {:?}", pt, path.remote_cid(), path.local_cid(), ); - PacketBuilder::long(encoder, pt, version, path.remote_cid(), path.local_cid()) }; if builder.remaining() > 0 { @@ -2335,7 +2343,11 @@ impl Connection { ); self.stats.borrow_mut().packets_tx += 1; - let tx = self.crypto.states.tx_mut(self.version, cspace).unwrap(); + let tx = self + .crypto + .states + .tx_mut(self.version, cspace) + .ok_or(Error::InternalError)?; encoder = builder.build(tx)?; self.crypto.states.auto_update()?; @@ -2489,13 +2501,17 @@ impl Connection { self.validate_versions()?; { let tps = self.tps.borrow(); - let remote = tps.remote.as_ref().unwrap(); + let remote = tps.remote.as_ref().ok_or(Error::TransportParameterError)?; // If the peer provided a preferred address, then we have to be a client // and they have to be using a non-empty connection ID. if remote.get_preferred_address().is_some() && (self.role == Role::Server - || self.remote_initial_source_cid.as_ref().unwrap().is_empty()) + || self + .remote_initial_source_cid + .as_ref() + .ok_or(Error::UnknownConnectionId)? + .is_empty()) { return Err(Error::TransportParameterError); } @@ -2531,7 +2547,7 @@ impl Connection { fn validate_cids(&self) -> Res<()> { let tph = self.tps.borrow(); - let remote_tps = tph.remote.as_ref().unwrap(); + let remote_tps = tph.remote.as_ref().ok_or(Error::TransportParameterError)?; let tp = remote_tps.get_bytes(tparams::INITIAL_SOURCE_CONNECTION_ID); if self @@ -2592,7 +2608,7 @@ impl Connection { /// Validate the `version_negotiation` transport parameter from the peer. fn validate_versions(&self) -> Res<()> { let tph = self.tps.borrow(); - let remote_tps = tph.remote.as_ref().unwrap(); + let remote_tps = tph.remote.as_ref().ok_or(Error::TransportParameterError)?; // `current` and `other` are the value from the peer's transport parameters. // We're checking that these match our expectations. if let Some((current, other)) = remote_tps.get_versions() { @@ -2655,19 +2671,23 @@ impl Connection { self.version = v; } - fn compatible_upgrade(&mut self, packet_version: Version) { + fn compatible_upgrade(&mut self, packet_version: Version) -> Res<()> { if !matches!(self.state, State::WaitInitial | State::WaitVersion) { - return; + return Ok(()); } if self.role == Role::Client { self.confirm_version(packet_version); } else if self.tps.borrow().remote.is_some() { let version = self.tps.borrow().version(); - let dcid = self.original_destination_cid.as_ref().unwrap(); - self.crypto.states.init_server(version, dcid); + let dcid = self + .original_destination_cid + .as_ref() + .ok_or(Error::ProtocolViolation)?; + self.crypto.states.init_server(version, dcid)?; self.confirm_version(version); } + Ok(()) } fn handshake( @@ -2698,14 +2718,15 @@ impl Connection { } } _ => { - unreachable!("Crypto state should not be new or failed after successful handshake") + qerror!("Crypto state should not be new or failed after successful handshake"); + return Err(Error::CryptoError(neqo_crypto::Error::InternalError)); } } // There is a chance that this could be called less often, but getting the // conditions right is a little tricky, so call whenever CRYPTO data is used. if try_update { - self.compatible_upgrade(packet_version); + self.compatible_upgrade(packet_version)?; // We have transport parameters, it's go time. if self.tps.borrow().remote.is_some() { self.set_initial_limits(); @@ -2801,7 +2822,7 @@ impl Connection { if self.crypto.streams.data_ready(space) { let mut buf = Vec::new(); let read = self.crypto.streams.read_to_end(space, &mut buf); - qdebug!("Read {} bytes", read); + qdebug!("Read {:?} bytes", read); self.handshake(now, packet_version, space, Some(&buf))?; self.create_resumption_token(now); } else { @@ -3045,7 +3066,13 @@ impl Connection { // Generate a qlog event that the server connection started. qlog::server_connection_started(&self.qlog, &path); } else { - self.zero_rtt_state = if self.crypto.tls.info().unwrap().early_data_accepted() { + self.zero_rtt_state = if self + .crypto + .tls + .info() + .ok_or(Error::InternalError)? + .early_data_accepted() + { ZeroRttState::AcceptedClient } else { self.client_0rtt_rejected(now); @@ -3062,7 +3089,12 @@ impl Connection { self.create_resumption_token(now); self.saved_datagrams .make_available(CryptoSpace::ApplicationData); - self.stats.borrow_mut().resumed = self.crypto.tls.info().unwrap().resumed(); + self.stats.borrow_mut().resumed = self + .crypto + .tls + .info() + .ok_or(Error::InternalError)? + .resumed(); if self.role == Role::Server { self.state_signaling.handshake_done(); self.set_confirmed()?; @@ -3357,7 +3389,7 @@ impl Connection { ); let data_len_possible = - u64::try_from(mtu.saturating_sub(tx.expansion() + builder.len() + 1)).unwrap(); + u64::try_from(mtu.saturating_sub(tx.expansion() + builder.len() + 1))?; Ok(min(data_len_possible, max_dgram_size)) } diff --git a/neqo-transport/src/connection/params.rs b/neqo-transport/src/connection/params.rs index e305771ff4..201f245543 100644 --- a/neqo-transport/src/connection/params.rs +++ b/neqo-transport/src/connection/params.rs @@ -371,17 +371,17 @@ impl ConnectionParameters { // default parameters tps.local.set_integer( tparams::ACTIVE_CONNECTION_ID_LIMIT, - u64::try_from(LOCAL_ACTIVE_CID_LIMIT).unwrap(), + u64::try_from(LOCAL_ACTIVE_CID_LIMIT)?, ); tps.local.set_empty(tparams::DISABLE_MIGRATION); tps.local.set_empty(tparams::GREASE_QUIC_BIT); tps.local.set_integer( tparams::MAX_ACK_DELAY, - u64::try_from(DEFAULT_ACK_DELAY.as_millis()).unwrap(), + u64::try_from(DEFAULT_ACK_DELAY.as_millis())?, ); tps.local.set_integer( tparams::MIN_ACK_DELAY, - u64::try_from(GRANULARITY.as_micros()).unwrap(), + u64::try_from(GRANULARITY.as_micros())?, ); // set configurable parameters diff --git a/neqo-transport/src/connection/tests/datagram.rs b/neqo-transport/src/connection/tests/datagram.rs index 6d02419fcd..ec2795a232 100644 --- a/neqo-transport/src/connection/tests/datagram.rs +++ b/neqo-transport/src/connection/tests/datagram.rs @@ -599,7 +599,7 @@ fn datagram_fill() { let path = p.borrow(); // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number, // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD. - path.plpmtu() - path.remote_cid().len() - 19 + path.plpmtu() - path.remote_cid().unwrap().len() - 19 }; assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large. diff --git a/neqo-transport/src/connection/tests/idle.rs b/neqo-transport/src/connection/tests/idle.rs index 336648f776..55d2ac8f16 100644 --- a/neqo-transport/src/connection/tests/idle.rs +++ b/neqo-transport/src/connection/tests/idle.rs @@ -287,7 +287,7 @@ fn idle_caching() { let mut client = default_client(); let mut server = default_server(); let start = now(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // Perform the first round trip, but drop the Initial from the server. // The client then caches the Handshake packet. diff --git a/neqo-transport/src/connection/tests/migration.rs b/neqo-transport/src/connection/tests/migration.rs index 3ee88943dd..64c025f98b 100644 --- a/neqo-transport/src/connection/tests/migration.rs +++ b/neqo-transport/src/connection/tests/migration.rs @@ -9,7 +9,7 @@ use std::{ mem, net::{IpAddr, Ipv6Addr, SocketAddr}, rc::Rc, - time::{Duration, Instant}, + time::Duration, }; use neqo_common::{Datagram, Decoder}; @@ -65,14 +65,6 @@ fn change_source_port(d: &Datagram) -> Datagram { Datagram::new(new_port(d.source()), d.destination(), d.tos(), &d[..]) } -/// As these tests use a new path, that path often has a non-zero RTT. -/// Pacing can be a problem when testing that path. This skips time forward. -fn skip_pacing(c: &mut Connection, now: Instant) -> Instant { - let pacing = c.process_output(now).callback(); - assert_ne!(pacing, Duration::new(0, 0)); - now + pacing -} - #[test] fn rebinding_port() { let mut client = default_client(); @@ -100,7 +92,7 @@ fn path_forwarding_attack() { let mut client = default_client(); let mut server = default_server(); connect_force_idle(&mut client, &mut server); - let mut now = now(); + let now = now(); let dgram = send_something(&mut client, now); let dgram = change_path(&dgram, DEFAULT_ADDR_V4); @@ -160,16 +152,15 @@ fn path_forwarding_attack() { assert_v6_path(&client_data2, false); // The server keeps sending on the new path. - now = skip_pacing(&mut server, now); let server_data2 = send_something(&mut server, now); assert_v4_path(&server_data2, false); // Until new data is received from the client on the old path. server.process_input(&client_data2, now); - // The server sends a probe on the "old" path. + // The server sends a probe on the new path. let server_data3 = send_something(&mut server, now); assert_v4_path(&server_data3, true); - // But switches data transmission to the "new" path. + // But switches data transmission to the old path. let server_data4 = server.process_output(now).dgram().unwrap(); assert_v6_path(&server_data4, false); } @@ -955,7 +946,6 @@ impl crate::connection::test_internal::FrameWriter for GarbageWriter { /// Test the case that we run out of connection ID and receive an invalid frame /// from a new path. #[test] -#[should_panic(expected = "attempting to close with a temporary path")] fn error_on_new_path_with_no_connection_id() { let mut client = default_client(); let mut server = default_server(); @@ -976,5 +966,23 @@ fn error_on_new_path_with_no_connection_id() { // See issue #1697. We had a crash when the client had a temporary path and // process_output is called. + let closing_frames = client.stats().frame_tx.connection_close; mem::drop(client.process_output(now())); + assert!(matches!( + client.state(), + State::Closing { + error: CloseReason::Transport(Error::UnknownFrameType), + .. + } + )); + // Wait until the connection is closed. + let mut now = now(); + now += client.process(None, now).callback(); + _ = client.process_output(now); + // No closing frames should be sent, and the connection should be closed. + assert_eq!(client.stats().frame_tx.connection_close, closing_frames); + assert!(matches!( + client.state(), + State::Closed(CloseReason::Transport(Error::UnknownFrameType)) + )); } diff --git a/neqo-transport/src/crypto.rs b/neqo-transport/src/crypto.rs index 4733fe76a8..3a6d890cde 100644 --- a/neqo-transport/src/crypto.rs +++ b/neqo-transport/src/crypto.rs @@ -238,13 +238,13 @@ impl Crypto { }; let secret = secret.ok_or(Error::InternalError)?; self.states - .set_0rtt_keys(version, dir, &secret, cipher.unwrap()); + .set_0rtt_keys(version, dir, &secret, cipher.ok_or(Error::InternalError)?)?; Ok(true) } /// Lock in a compatible upgrade. pub fn confirm_version(&mut self, confirmed: Version) { - self.states.confirm_version(self.version, confirmed); + _ = self.states.confirm_version(self.version, confirmed); self.version = confirmed; } @@ -277,7 +277,7 @@ impl Crypto { } .ok_or(Error::InternalError)?; self.states - .set_handshake_keys(self.version, &write_secret, &read_secret, cipher); + .set_handshake_keys(self.version, &write_secret, &read_secret, cipher)?; qdebug!([self], "Handshake keys installed"); Ok(true) } @@ -313,7 +313,8 @@ impl Crypto { return Err(Error::ProtocolViolation); } qtrace!([self], "Adding CRYPTO data {:?}", r); - self.streams.send(PacketNumberSpace::from(r.epoch), &r.data); + self.streams + .send(PacketNumberSpace::from(r.epoch), &r.data)?; } Ok(()) } @@ -443,7 +444,7 @@ impl CryptoDxState { epoch: Epoch, secret: &SymKey, cipher: Cipher, - ) -> Self { + ) -> Res { qdebug!( "Making {:?} {} CryptoDxState, v={:?} cipher={}", direction, @@ -452,17 +453,17 @@ impl CryptoDxState { cipher, ); let hplabel = String::from(version.label_prefix()) + "hp"; - Self { + Ok(Self { version, direction, epoch: usize::from(epoch), - aead: Aead::new(TLS_VERSION_1_3, cipher, secret, version.label_prefix()).unwrap(), - hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, &hplabel).unwrap(), + aead: Aead::new(TLS_VERSION_1_3, cipher, secret, version.label_prefix())?, + hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, &hplabel)?, used_pn: 0..0, min_pn: 0, invocations: Self::limit(direction, cipher), largest_packet_len: INITIAL_LARGEST_PACKET_LEN, - } + }) } pub fn new_initial( @@ -470,20 +471,18 @@ impl CryptoDxState { direction: CryptoDxDirection, label: &str, dcid: &[u8], - ) -> Self { + ) -> Res { qtrace!("new_initial {:?} {}", version, ConnectionIdRef::from(dcid)); let salt = version.initial_salt(); let cipher = TLS_AES_128_GCM_SHA256; let initial_secret = hkdf::extract( TLS_VERSION_1_3, cipher, - Some(hkdf::import_key(TLS_VERSION_1_3, salt).as_ref().unwrap()), - hkdf::import_key(TLS_VERSION_1_3, dcid).as_ref().unwrap(), - ) - .unwrap(); + Some(&hkdf::import_key(TLS_VERSION_1_3, salt)?), + &hkdf::import_key(TLS_VERSION_1_3, dcid)?, + )?; - let secret = - hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap(); + let secret = hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label)?; Self::new(version, direction, TLS_EPOCH_INITIAL, &secret, cipher) } @@ -530,7 +529,7 @@ impl CryptoDxState { self.invocations <= UPDATE_WRITE_KEYS_AT } - pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self { + pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Res { let pn = self.next_pn(); // We count invocations of each write key just for that key, but all // attempts to invocations to read count toward a single limit. @@ -540,7 +539,7 @@ impl CryptoDxState { } else { Self::limit(CryptoDxDirection::Write, cipher) }; - Self { + Ok(Self { version: self.version, direction: self.direction, epoch: self.epoch + 1, @@ -549,14 +548,13 @@ impl CryptoDxState { cipher, next_secret, self.version.label_prefix(), - ) - .unwrap(), + )?, hpkey: self.hpkey.clone(), used_pn: pn..pn, min_pn: pn, invocations, largest_packet_len: INITIAL_LARGEST_PACKET_LEN, - } + }) } #[must_use] @@ -703,6 +701,7 @@ impl CryptoDxState { "server in", CLIENT_CID, ) + .unwrap() } /// Get the amount of extra padding packets protected with this profile need. @@ -763,7 +762,7 @@ impl CryptoDxAppData { cipher: Cipher, ) -> Res { Ok(Self { - dx: CryptoDxState::new(version, dir, TLS_EPOCH_APPLICATION_DATA, secret, cipher), + dx: CryptoDxState::new(version, dir, TLS_EPOCH_APPLICATION_DATA, secret, cipher)?, cipher, next_secret: Self::update_secret(cipher, secret)?, }) @@ -781,7 +780,7 @@ impl CryptoDxAppData { } let next_secret = Self::update_secret(self.cipher, &self.next_secret)?; Ok(Self { - dx: self.dx.next(&self.next_secret, self.cipher), + dx: self.dx.next(&self.next_secret, self.cipher)?, cipher: self.cipher, next_secret, }) @@ -946,7 +945,7 @@ impl CryptoStates { /// Create the initial crypto state. /// Note that the version here can change and that's OK. - pub fn init<'v, V>(&mut self, versions: V, role: Role, dcid: &[u8]) + pub fn init<'v, V>(&mut self, versions: V, role: Role, dcid: &[u8]) -> Res<()> where V: IntoIterator, { @@ -968,8 +967,8 @@ impl CryptoStates { ); let mut initial = CryptoState { - tx: CryptoDxState::new_initial(*v, CryptoDxDirection::Write, write, dcid), - rx: CryptoDxState::new_initial(*v, CryptoDxDirection::Read, read, dcid), + tx: CryptoDxState::new_initial(*v, CryptoDxDirection::Write, write, dcid)?, + rx: CryptoDxState::new_initial(*v, CryptoDxDirection::Read, read, dcid)?, }; if let Some(prev) = self.initials.get(v) { qinfo!( @@ -977,10 +976,11 @@ impl CryptoStates { "Continue packet numbers for initial after retry (write is {:?})", prev.rx.used_pn, ); - initial.tx.continuation(&prev.tx).unwrap(); + initial.tx.continuation(&prev.tx)?; } self.initials.insert(*v, initial); } + Ok(()) } /// At a server, we can be more targeted in initializing. @@ -989,24 +989,29 @@ impl CryptoStates { /// This is maybe slightly inefficient in the first case, because we might /// not need the send keys if the packet is subsequently discarded, but /// the overall effort is small enough to write off. - pub fn init_server(&mut self, version: Version, dcid: &[u8]) { + pub fn init_server(&mut self, version: Version, dcid: &[u8]) -> Res<()> { if !self.initials.contains_key(&version) { - self.init(&[version], Role::Server, dcid); + self.init(&[version], Role::Server, dcid)?; } + Ok(()) } - pub fn confirm_version(&mut self, orig: Version, confirmed: Version) { + pub fn confirm_version(&mut self, orig: Version, confirmed: Version) -> Res<()> { if orig != confirmed { // This part where the old data is removed and then re-added is to // appease the borrow checker. // Note that on the server, we might not have initials for |orig| if it // was configured for |orig| and only |confirmed| Initial packets arrived. if let Some(prev) = self.initials.remove(&orig) { - let next = self.initials.get_mut(&confirmed).unwrap(); - next.tx.continuation(&prev.tx).unwrap(); + let next = self + .initials + .get_mut(&confirmed) + .ok_or(Error::VersionNegotiation)?; + next.tx.continuation(&prev.tx)?; self.initials.insert(orig, prev); } } + Ok(()) } pub fn set_0rtt_keys( @@ -1015,7 +1020,7 @@ impl CryptoStates { dir: CryptoDxDirection, secret: &SymKey, cipher: Cipher, - ) { + ) -> Res<()> { qtrace!([self], "install 0-RTT keys"); self.zero_rtt = Some(CryptoDxState::new( version, @@ -1023,7 +1028,8 @@ impl CryptoStates { TLS_EPOCH_ZERO_RTT, secret, cipher, - )); + )?); + Ok(()) } /// Discard keys and return true if that happened. @@ -1054,7 +1060,7 @@ impl CryptoStates { write_secret: &SymKey, read_secret: &SymKey, cipher: Cipher, - ) { + ) -> Res<()> { self.cipher = cipher; self.handshake = Some(CryptoState { tx: CryptoDxState::new( @@ -1063,15 +1069,16 @@ impl CryptoStates { TLS_EPOCH_HANDSHAKE, write_secret, cipher, - ), + )?, rx: CryptoDxState::new( version, CryptoDxDirection::Read, TLS_EPOCH_HANDSHAKE, read_secret, cipher, - ), + )?, }); + Ok(()) } pub fn set_application_write_key(&mut self, version: Version, secret: &SymKey) -> Res<()> { @@ -1114,7 +1121,7 @@ impl CryptoStates { // received an acknowledgement for a packet in the current phase. // Also, skip this if we are waiting for read keys on the existing // key update to be rolled over. - let write = &self.app_write.as_ref().unwrap().dx; + let write = &self.app_write.as_ref().ok_or(Error::InternalError)?.dx; if write.can_update(largest_acknowledged) && self.read_update_time.is_none() { // This call additionally checks that we don't advance to the next // epoch while a key update is in progress. @@ -1136,8 +1143,8 @@ impl CryptoStates { // ahead of the read keys. If we initiated the key update, the write keys // will already be ahead. debug_assert!(self.read_update_time.is_none()); - let write = &self.app_write.as_ref().unwrap(); - let read = &self.app_read.as_ref().unwrap(); + let write = &self.app_write.as_ref().ok_or(Error::InternalError)?; + let read = &self.app_read.as_ref().ok_or(Error::InternalError)?; if write.epoch() == read.epoch() { qdebug!([self], "Update write keys to epoch={}", write.epoch() + 1); self.app_write = Some(write.next()?); @@ -1208,7 +1215,8 @@ impl CryptoStates { } else { qtrace!([self], "Rotating read keys"); mem::swap(&mut self.app_read, &mut self.app_read_next); - self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?); + self.app_read_next = + Some(self.app_read.as_ref().ok_or(Error::InternalError)?.next()?); } self.read_update_time = None; } @@ -1231,8 +1239,8 @@ impl CryptoStates { // We only need to do the check while we are waiting for read keys to be updated. if self.read_update_time.is_some() { qtrace!([self], "Checking for PN overlap"); - let next_dx = &mut self.app_read_next.as_mut().unwrap().dx; - next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?; + let next_dx = &mut self.app_read_next.as_mut().ok_or(Error::InternalError)?.dx; + next_dx.continuation(&self.app_read.as_ref().ok_or(Error::InternalError)?.dx)?; } Ok(()) } @@ -1383,12 +1391,16 @@ impl CryptoStreams { } } - pub fn send(&mut self, space: PacketNumberSpace, data: &[u8]) { - self.get_mut(space).unwrap().tx.send(data); + pub fn send(&mut self, space: PacketNumberSpace, data: &[u8]) -> Res<()> { + self.get_mut(space) + .ok_or(Error::ProtocolViolation)? + .tx + .send(data); + Ok(()) } pub fn inbound_frame(&mut self, space: PacketNumberSpace, offset: u64, data: &[u8]) -> Res<()> { - let rx = &mut self.get_mut(space).unwrap().rx; + let rx = &mut self.get_mut(space).ok_or(Error::InternalError)?.rx; rx.inbound_frame(offset, data); if rx.received() - rx.retired() <= Self::BUFFER_LIMIT { Ok(()) @@ -1401,8 +1413,12 @@ impl CryptoStreams { self.get(space).map_or(false, |cs| cs.rx.data_ready()) } - pub fn read_to_end(&mut self, space: PacketNumberSpace, buf: &mut Vec) -> usize { - self.get_mut(space).unwrap().rx.read_to_end(buf) + pub fn read_to_end(&mut self, space: PacketNumberSpace, buf: &mut Vec) -> Res { + Ok(self + .get_mut(space) + .ok_or(Error::ProtocolViolation)? + .rx + .read_to_end(buf)) } pub fn acked(&mut self, token: &CryptoRecoveryToken) { diff --git a/neqo-transport/src/fc.rs b/neqo-transport/src/fc.rs index 37bb3daf57..acc4d6582d 100644 --- a/neqo-transport/src/fc.rs +++ b/neqo-transport/src/fc.rs @@ -810,7 +810,7 @@ mod test { fc[StreamType::BiDi].add_retired(1); fc[StreamType::BiDi].send_flowc_update(); // consume the frame - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); fc[StreamType::BiDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); assert_eq!(tokens.len(), 1); diff --git a/neqo-transport/src/pace.rs b/neqo-transport/src/pace.rs index d34d015ab1..642a656da2 100644 --- a/neqo-transport/src/pace.rs +++ b/neqo-transport/src/pace.rs @@ -14,6 +14,8 @@ use std::{ use neqo_common::qtrace; +use crate::rtt::GRANULARITY; + /// This value determines how much faster the pacer operates than the /// congestion window. /// @@ -74,19 +76,26 @@ impl Pacer { /// the current time is). pub fn next(&self, rtt: Duration, cwnd: usize) -> Instant { if self.c >= self.p { - qtrace!([self], "next {}/{:?} no wait = {:?}", cwnd, rtt, self.t); - self.t - } else { - // This is the inverse of the function in `spend`: - // self.t + rtt * (self.p - self.c) / (PACER_SPEEDUP * cwnd) - let r = rtt.as_nanos(); - let d = r.saturating_mul(u128::try_from(self.p - self.c).unwrap()); - let add = d / u128::try_from(cwnd * PACER_SPEEDUP).unwrap(); - let w = u64::try_from(add).map(Duration::from_nanos).unwrap_or(rtt); - let nxt = self.t + w; - qtrace!([self], "next {}/{:?} wait {:?} = {:?}", cwnd, rtt, w, nxt); - nxt + qtrace!([self], "next {cwnd}/{rtt:?} no wait = {:?}", self.t); + return self.t; + } + + // This is the inverse of the function in `spend`: + // self.t + rtt * (self.p - self.c) / (PACER_SPEEDUP * cwnd) + let r = rtt.as_nanos(); + let d = r.saturating_mul(u128::try_from(self.p - self.c).unwrap()); + let add = d / u128::try_from(cwnd * PACER_SPEEDUP).unwrap(); + let w = u64::try_from(add).map(Duration::from_nanos).unwrap_or(rtt); + + // If the increment is below the timer granularity, send immediately. + if w < GRANULARITY { + qtrace!([self], "next {cwnd}/{rtt:?} below granularity ({w:?})",); + return self.t; } + + let nxt = self.t + w; + qtrace!([self], "next {cwnd}/{rtt:?} wait {w:?} = {nxt:?}"); + nxt } /// Spend credit. This cannot fail; users of this API are expected to call @@ -168,4 +177,18 @@ mod tests { p.spend(n, RTT, CWND, PACKET); assert_eq!(p.next(RTT, CWND), n); } + + #[test] + fn send_immediately_below_granularity() { + const SHORT_RTT: Duration = Duration::from_millis(10); + let n = now(); + let mut p = Pacer::new(true, n, PACKET, PACKET); + assert_eq!(p.next(SHORT_RTT, CWND), n); + p.spend(n, SHORT_RTT, CWND, PACKET); + assert_eq!( + p.next(SHORT_RTT, CWND), + n, + "Expect packet to be sent immediately, instead of being paced below timer granularity." + ); + } } diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 339800d700..09a4e19d26 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -149,15 +149,19 @@ impl PacketBuilder { /// /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get /// the encoder back. - pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self { + pub fn short(mut encoder: Encoder, key_phase: bool, dcid: Option>) -> Self { let mut limit = Self::infer_limit(&encoder); let header_start = encoder.len(); // Check that there is enough space for the header. // 5 = 1 (first byte) + 4 (packet number) - if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() { + if limit > encoder.len() + && 5 + dcid.as_ref().map_or(0, |d| d.as_ref().len()) < limit - encoder.len() + { encoder .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2)); - encoder.encode(dcid.as_ref()); + if let Some(dcid) = dcid { + encoder.encode(dcid.as_ref()); + } } else { limit = 0; } @@ -185,20 +189,23 @@ impl PacketBuilder { mut encoder: Encoder, pt: PacketType, version: Version, - dcid: impl AsRef<[u8]>, - scid: impl AsRef<[u8]>, + mut dcid: Option>, + mut scid: Option>, ) -> Self { let mut limit = Self::infer_limit(&encoder); let header_start = encoder.len(); // Check that there is enough space for the header. // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number) if limit > encoder.len() - && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len() + && 11 + + dcid.as_ref().map_or(0, |d| d.as_ref().len()) + + scid.as_ref().map_or(0, |d| d.as_ref().len()) + < limit - encoder.len() { encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4); encoder.encode_uint(4, version.wire_version()); - encoder.encode_vec(1, dcid.as_ref()); - encoder.encode_vec(1, scid.as_ref()); + encoder.encode_vec(1, dcid.take().as_ref().map_or(&[], AsRef::as_ref)); + encoder.encode_vec(1, scid.take().as_ref().map_or(&[], AsRef::as_ref)); } else { limit = 0; } @@ -994,8 +1001,8 @@ mod tests { Encoder::new(), PacketType::Initial, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(SERVER_CID), + None::<&[u8]>, + Some(ConnectionId::from(SERVER_CID)), ); builder.initial_token(&[]); builder.pn(1, 2); @@ -1058,7 +1065,7 @@ mod tests { fn build_short() { fixture_init(); let mut builder = - PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID)); + PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID))); builder.pn(0, 1); builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. let packet = builder @@ -1073,7 +1080,7 @@ mod tests { let mut firsts = Vec::new(); for _ in 0..64 { let mut builder = - PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID)); + PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID))); builder.scramble(true); builder.pn(0, 1); firsts.push(builder.as_ref()[0]); @@ -1136,8 +1143,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(SERVER_CID), - ConnectionId::from(CLIENT_CID), + Some(ConnectionId::from(SERVER_CID)), + Some(ConnectionId::from(CLIENT_CID)), ); builder.pn(0, 1); builder.encode(&[0; 3]); @@ -1145,7 +1152,8 @@ mod tests { assert_eq!(encoder.len(), 45); let first = encoder.clone(); - let mut builder = PacketBuilder::short(encoder, false, ConnectionId::from(SERVER_CID)); + let mut builder = + PacketBuilder::short(encoder, false, Some(ConnectionId::from(SERVER_CID))); builder.pn(1, 3); builder.encode(&[0]); // Minimal size (packet number is big enough). let encoder = builder.build(&mut prot).expect("build"); @@ -1170,8 +1178,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(&[][..]), + None::<&[u8]>, + None::<&[u8]>, ); builder.pn(0, 1); builder.encode(&[1, 2, 3]); @@ -1189,8 +1197,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(&[][..]), + None::<&[u8]>, + None::<&[u8]>, ); builder.pn(0, 1); builder.scramble(true); @@ -1210,8 +1218,8 @@ mod tests { Encoder::new(), PacketType::Initial, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(SERVER_CID), + None::<&[u8]>, + Some(ConnectionId::from(SERVER_CID)), ); assert_ne!(builder.remaining(), 0); builder.initial_token(&[]); @@ -1229,7 +1237,7 @@ mod tests { let mut builder = PacketBuilder::short( Encoder::with_capacity(100), true, - ConnectionId::from(SERVER_CID), + Some(ConnectionId::from(SERVER_CID)), ); builder.pn(0, 1); // Pad, but not up to the full capacity. Leave enough space for the @@ -1244,8 +1252,8 @@ mod tests { encoder, PacketType::Initial, Version::default(), - ConnectionId::from(SERVER_CID), - ConnectionId::from(SERVER_CID), + Some(ConnectionId::from(SERVER_CID)), + Some(ConnectionId::from(SERVER_CID)), ); assert_eq!(builder.remaining(), 0); assert_eq!(builder.abort(), encoder_copy); diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index 83a45ba9f4..3da334770b 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -660,8 +660,8 @@ impl Path { /// Get the first local connection ID. /// Only do this for the primary path during the handshake. - pub fn local_cid(&self) -> &ConnectionId { - self.local_cid.as_ref().unwrap() + pub const fn local_cid(&self) -> Option<&ConnectionId> { + self.local_cid.as_ref() } /// Set the remote connection ID based on the peer's choice. @@ -674,8 +674,10 @@ impl Path { } /// Access the remote connection ID. - pub fn remote_cid(&self) -> &ConnectionId { - self.remote_cid.as_ref().unwrap().connection_id() + pub fn remote_cid(&self) -> Option<&ConnectionId> { + self.remote_cid + .as_ref() + .map(super::cid::ConnectionIdEntry::connection_id) } /// Set the stateless reset token for the connection ID that is currently in use. diff --git a/neqo-transport/src/pmtud.rs b/neqo-transport/src/pmtud.rs index 5ee59e3dbf..9eec6b0eda 100644 --- a/neqo-transport/src/pmtud.rs +++ b/neqo-transport/src/pmtud.rs @@ -383,7 +383,7 @@ mod tests { let stats_before = stats.clone(); // Fake a packet number, so the builder logic works. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let pn = prot.next_pn(); builder.pn(pn, 4); builder.set_initial_limit(&SendProfile::new_limited(pmtud.plpmtu()), 16, pmtud); diff --git a/neqo-transport/src/qlog.rs b/neqo-transport/src/qlog.rs index 29f17bf6b9..fa127212f0 100644 --- a/neqo-transport/src/qlog.rs +++ b/neqo-transport/src/qlog.rs @@ -104,8 +104,8 @@ fn connection_started(qlog: &NeqoQlog, path: &PathRef) { protocol: Some("QUIC".into()), src_port: p.local_address().port().into(), dst_port: p.remote_address().port().into(), - src_cid: Some(format!("{}", p.local_cid())), - dst_cid: Some(format!("{}", p.remote_cid())), + src_cid: p.local_cid().map(ToString::to_string), + dst_cid: p.remote_cid().map(ToString::to_string), }); Some(ev_data) diff --git a/neqo-transport/src/quic_datagrams.rs b/neqo-transport/src/quic_datagrams.rs index af82d8124e..241ce30389 100644 --- a/neqo-transport/src/quic_datagrams.rs +++ b/neqo-transport/src/quic_datagrams.rs @@ -154,12 +154,15 @@ impl QuicDatagrams { tracking: DatagramTracking, stats: &mut Stats, ) -> Res<()> { - if u64::try_from(buf.len()).unwrap() > self.remote_datagram_size { + if u64::try_from(buf.len())? > self.remote_datagram_size { return Err(Error::TooMuchData); } if self.datagrams.len() == self.max_queued_outgoing_datagrams { self.conn_events.datagram_outcome( - self.datagrams.pop_front().unwrap().tracking(), + self.datagrams + .pop_front() + .ok_or(Error::InternalError)? + .tracking(), OutgoingDatagramOutcome::DroppedQueueFull, ); stats.datagram_tx.dropped_queue_full += 1; @@ -172,7 +175,7 @@ impl QuicDatagrams { } pub fn handle_datagram(&self, data: &[u8], stats: &mut Stats) -> Res<()> { - if self.local_datagram_size < u64::try_from(data.len()).unwrap() { + if self.local_datagram_size < u64::try_from(data.len())? { return Err(Error::ProtocolViolation); } self.conn_events diff --git a/neqo-transport/src/recv_stream.rs b/neqo-transport/src/recv_stream.rs index c022f5fbd0..7b46a386bc 100644 --- a/neqo-transport/src/recv_stream.rs +++ b/neqo-transport/src/recv_stream.rs @@ -643,7 +643,7 @@ impl RecvStream { // We should post a DataReadable event only once when we change from no-data-ready to // data-ready. Therefore remember the state before processing a new frame. let already_data_ready = self.data_ready(); - let new_end = offset + u64::try_from(data.len()).unwrap(); + let new_end = offset + u64::try_from(data.len())?; self.state.flow_control_consume_data(new_end, fin)?; @@ -1483,7 +1483,7 @@ mod tests { assert!(s.has_frames_to_write()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); s.write_frame(&mut builder, &mut token, &mut FrameStats::default()); @@ -1597,7 +1597,7 @@ mod tests { s.read(&mut buf).unwrap(); assert!(session_fc.borrow().frame_needed()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); session_fc .borrow_mut() @@ -1618,7 +1618,7 @@ mod tests { s.read(&mut buf).unwrap(); assert!(session_fc.borrow().frame_needed()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); session_fc .borrow_mut() @@ -1866,7 +1866,7 @@ mod tests { assert!(s.fc().unwrap().frame_needed()); // Write the fc update frame - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); let mut stats = FrameStats::default(); fc.borrow_mut() diff --git a/neqo-transport/src/send_stream.rs b/neqo-transport/src/send_stream.rs index a6e42cfdaf..3f0002da13 100644 --- a/neqo-transport/src/send_stream.rs +++ b/neqo-transport/src/send_stream.rs @@ -2596,7 +2596,7 @@ mod tests { ss.insert(StreamId::from(0), s); let mut tokens = Vec::new(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // Write a small frame: no fin. let written = builder.len(); @@ -2684,7 +2684,7 @@ mod tests { ss.insert(StreamId::from(0), s); let mut tokens = Vec::new(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); ss.write_frames( TransmissionPriority::default(), &mut builder, @@ -2762,7 +2762,7 @@ mod tests { assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..]))); // This doesn't report blocking yet. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_blocked_frame( @@ -2815,7 +2815,7 @@ mod tests { assert_eq!(s.send_atomic(b"abc").unwrap(), 0); // Assert that STREAM_DATA_BLOCKED is sent. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_blocked_frame( @@ -2902,7 +2902,7 @@ mod tests { s.mark_as_lost(len_u64, 0, true); // No frame should be sent here. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_stream_frame( @@ -2962,7 +2962,7 @@ mod tests { s.close(); } - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let header_len = builder.len(); builder.set_limit(header_len + space); @@ -3063,7 +3063,7 @@ mod tests { s.send(data).unwrap(); s.close(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let header_len = builder.len(); // Add 2 for the frame type and stream ID, then add the extra. builder.set_limit(header_len + data.len() + 2 + extra); diff --git a/neqo-transport/src/server.rs b/neqo-transport/src/server.rs index 8cafb4a4b1..e2066f2932 100644 --- a/neqo-transport/src/server.rs +++ b/neqo-transport/src/server.rs @@ -10,7 +10,6 @@ use std::{ cell::RefCell, cmp::min, collections::HashSet, - fs::OpenOptions, ops::{Deref, DerefMut}, path::PathBuf, rc::Rc, @@ -18,14 +17,12 @@ use std::{ }; use neqo_common::{ - self as common, event::Provider, hex, qdebug, qerror, qinfo, qlog::NeqoQlog, qtrace, qwarn, - Datagram, Role, + event::Provider, hex, qdebug, qerror, qinfo, qlog::NeqoQlog, qtrace, qwarn, Datagram, Role, }; use neqo_crypto::{ encode_ech_config, AntiReplay, Cipher, PrivateKey, PublicKey, ZeroRttCheckResult, ZeroRttChecker, }; -use qlog::streamer::QlogStreamer; pub use crate::addr_valid::ValidateAddress; use crate::{ @@ -258,49 +255,17 @@ impl Server { self.qlog_dir .as_ref() .map_or_else(NeqoQlog::disabled, |qlog_dir| { - let mut qlog_path = qlog_dir.clone(); - - qlog_path.push(format!("{odcid}.qlog")); - - // The original DCID is chosen by the client. Using create_new() - // prevents attackers from overwriting existing logs. - match OpenOptions::new() - .write(true) - .create_new(true) - .open(&qlog_path) - { - Ok(f) => { - qinfo!("Qlog output to {}", qlog_path.display()); - - let streamer = QlogStreamer::new( - qlog::QLOG_VERSION.to_string(), - Some("Neqo server qlog".to_string()), - Some("Neqo server qlog".to_string()), - None, - std::time::Instant::now(), - common::qlog::new_trace(Role::Server), - qlog::events::EventImportance::Base, - Box::new(f), - ); - let n_qlog = NeqoQlog::enabled(streamer, qlog_path); - match n_qlog { - Ok(nql) => nql, - Err(e) => { - // Keep going but w/o qlogging - qerror!("NeqoQlog error: {}", e); - NeqoQlog::disabled() - } - } - } - Err(e) => { - qerror!( - "Could not open file {} for qlog output: {}", - qlog_path.display(), - e - ); - NeqoQlog::disabled() - } - } + NeqoQlog::enabled_with_file( + qlog_dir.clone(), + Role::Server, + Some("Neqo server qlog".to_string()), + Some("Neqo server qlog".to_string()), + odcid, + ) + .unwrap_or_else(|e| { + qerror!("failed to create NeqoQlog: {}", e); + NeqoQlog::disabled() + }) }) } diff --git a/neqo-transport/src/tparams.rs b/neqo-transport/src/tparams.rs index 4b83533e3c..ade493cd65 100644 --- a/neqo-transport/src/tparams.rs +++ b/neqo-transport/src/tparams.rs @@ -184,9 +184,8 @@ impl TransportParameter { fn decode_preferred_address(d: &mut Decoder) -> Res { // IPv4 address (maybe) - let v4ip = - Ipv4Addr::from(<[u8; 4]>::try_from(d.decode(4).ok_or(Error::NoMoreData)?).unwrap()); - let v4port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap(); + let v4ip = Ipv4Addr::from(<[u8; 4]>::try_from(d.decode(4).ok_or(Error::NoMoreData)?)?); + let v4port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?)?; // Can't have non-zero IP and zero port, or vice versa. if v4ip.is_unspecified() ^ (v4port == 0) { return Err(Error::TransportParameterError); @@ -198,9 +197,10 @@ impl TransportParameter { }; // IPv6 address (mostly the same as v4) - let v6ip = - Ipv6Addr::from(<[u8; 16]>::try_from(d.decode(16).ok_or(Error::NoMoreData)?).unwrap()); - let v6port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap(); + let v6ip = Ipv6Addr::from(<[u8; 16]>::try_from( + d.decode(16).ok_or(Error::NoMoreData)?, + )?); + let v6port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?)?; if v6ip.is_unspecified() ^ (v6port == 0) { return Err(Error::TransportParameterError); } @@ -222,7 +222,7 @@ impl TransportParameter { // Stateless reset token let srtbuf = d.decode(16).ok_or(Error::NoMoreData)?; - let srt = <[u8; 16]>::try_from(srtbuf).unwrap(); + let srt = <[u8; 16]>::try_from(srtbuf)?; Ok(Self::PreferredAddress { v4, v6, cid, srt }) } diff --git a/neqo-transport/src/tracking.rs b/neqo-transport/src/tracking.rs index 90bbd0b54a..b7ab8bac50 100644 --- a/neqo-transport/src/tracking.rs +++ b/neqo-transport/src/tracking.rs @@ -797,7 +797,7 @@ mod tests { } fn write_frame_at(rp: &mut RecvdPackets, now: Instant) { - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut stats = FrameStats::default(); let mut tokens = Vec::new(); rp.write_frame(now, RTT, &mut builder, &mut tokens, &mut stats); @@ -952,7 +952,7 @@ mod tests { #[test] fn drop_spaces() { let mut tracker = AckTracker::default(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); tracker .get_mut(PacketNumberSpace::Initial) .unwrap() @@ -1017,7 +1017,7 @@ mod tests { .ack_time(now().checked_sub(Duration::from_millis(1)).unwrap()) .is_some()); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); builder.set_limit(10); let mut stats = FrameStats::default(); @@ -1048,7 +1048,7 @@ mod tests { .ack_time(now().checked_sub(Duration::from_millis(1)).unwrap()) .is_some()); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // The code pessimistically assumes that each range needs 16 bytes to express. // So this won't be enough for a second range. builder.set_limit(RecvdPackets::USEFUL_ACK_LEN + 8); diff --git a/neqo-udp/src/lib.rs b/neqo-udp/src/lib.rs index 688fb8ff65..5f1fb3dbe6 100644 --- a/neqo-udp/src/lib.rs +++ b/neqo-udp/src/lib.rs @@ -13,7 +13,7 @@ use std::{ slice, }; -use neqo_common::{qtrace, Datagram, IpTos}; +use neqo_common::{qdebug, qtrace, Datagram, IpTos}; use quinn_udp::{EcnCodepoint, RecvMeta, Transmit, UdpSocketState}; /// Socket receive buffer size. @@ -52,22 +52,44 @@ pub fn send_inner( Ok(()) } +#[cfg(unix)] +use std::os::fd::AsFd as SocketRef; +#[cfg(windows)] +use std::os::windows::io::AsSocket as SocketRef; + pub fn recv_inner( local_address: &SocketAddr, state: &UdpSocketState, - socket: quinn_udp::UdpSockRef<'_>, + socket: impl SocketRef, ) -> Result, io::Error> { let dgrams = RECV_BUF.with_borrow_mut(|recv_buf| -> Result, io::Error> { - let mut meta = RecvMeta::default(); + let mut meta; + + loop { + meta = RecvMeta::default(); + + state.recv( + (&socket).into(), + &mut [IoSliceMut::new(recv_buf)], + slice::from_mut(&mut meta), + )?; + + if meta.len == 0 || meta.stride == 0 { + qdebug!( + "ignoring datagram from {} to {} len {} stride {}", + meta.addr, + local_address, + meta.len, + meta.stride + ); + continue; + } - state.recv( - socket, - &mut [IoSliceMut::new(recv_buf)], - slice::from_mut(&mut meta), - )?; + break; + } Ok(recv_buf[0..meta.len] - .chunks(meta.stride.min(recv_buf.len())) + .chunks(meta.stride) .map(|d| { qtrace!( "received {} bytes from {} to {}", @@ -100,9 +122,7 @@ pub struct Socket { inner: S, } -impl<#[cfg(unix)] S: std::os::fd::AsFd, #[cfg(windows)] S: std::os::windows::io::AsSocket> - Socket -{ +impl Socket { /// Create a new [`Socket`] given a raw file descriptor managed externally. pub fn new(socket: S) -> Result { Ok(Self { @@ -119,7 +139,7 @@ impl<#[cfg(unix)] S: std::os::fd::AsFd, #[cfg(windows)] S: std::os::windows::io: /// Receive a batch of [`Datagram`]s on the given [`Socket`], each /// set with the provided local address. pub fn recv(&self, local_address: &SocketAddr) -> Result, io::Error> { - recv_inner(local_address, &self.state, (&self.inner).into()) + recv_inner(local_address, &self.state, &self.inner) } } @@ -136,6 +156,26 @@ mod tests { Ok(socket) } + #[test] + fn ignore_empty_datagram() -> Result<(), io::Error> { + let sender = socket()?; + let receiver = Socket::new(std::net::UdpSocket::bind("127.0.0.1:0")?)?; + let receiver_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + + let datagram = Datagram::new( + sender.inner.local_addr()?, + receiver.inner.local_addr()?, + IpTos::default(), + vec![], + ); + + sender.send(&datagram)?; + let res = receiver.recv(&receiver_addr); + assert_eq!(res.unwrap_err().kind(), std::io::ErrorKind::WouldBlock); + + Ok(()) + } + #[test] fn datagram_tos() -> Result<(), io::Error> { let sender = socket()?; diff --git a/test-fixture/src/lib.rs b/test-fixture/src/lib.rs index a39e20c1b4..005097cd25 100644 --- a/test-fixture/src/lib.rs +++ b/test-fixture/src/lib.rs @@ -13,6 +13,7 @@ use std::{ io::{Cursor, Result, Write}, mem, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + path::PathBuf, rc::Rc, sync::{Arc, Mutex}, time::{Duration, Instant}, @@ -410,7 +411,7 @@ pub fn new_neqo_qlog() -> (NeqoQlog, SharedVec) { EventImportance::Base, Box::new(buf), ); - let log = NeqoQlog::enabled(streamer, ""); + let log = NeqoQlog::enabled(streamer, PathBuf::from("")); (log.expect("to be able to write to new log"), contents) }