From 9dbdd97fa74e481eb366d0e06b73be3b86bf45f1 Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 25 Jun 2024 12:43:34 -0400 Subject: [PATCH] Handle fragments with uri::UrlExt trait This extension trait defines functions to parse and set the ohttp parameter in the fragment of a `pj=` URL. Close #298 --- payjoin/src/send/error.rs | 54 ++++++++++++++++++++++----- payjoin/src/send/mod.rs | 65 +++++++++++---------------------- payjoin/src/uri/error.rs | 7 ---- payjoin/src/uri/mod.rs | 70 ++++++++++------------------------- payjoin/src/uri/pj_url.rs | 43 ---------------------- payjoin/src/uri/url_ext.rs | 75 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 161 insertions(+), 153 deletions(-) delete mode 100644 payjoin/src/uri/pj_url.rs create mode 100644 payjoin/src/uri/url_ext.rs diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index fad708d9..ddaab8cb 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -192,9 +192,7 @@ pub(crate) enum InternalCreateRequestError { #[cfg(feature = "v2")] OhttpEncapsulation(crate::v2::OhttpEncapsulationError), #[cfg(feature = "v2")] - SubdirectoryNotBase64(bitcoin::base64::DecodeError), - #[cfg(feature = "v2")] - SubdirectoryInvalidPubkey(bitcoin::secp256k1::Error), + ParseSubdirectory(ParseSubdirectoryError), #[cfg(feature = "v2")] MissingOhttpConfig, } @@ -223,9 +221,7 @@ impl fmt::Display for CreateRequestError { #[cfg(feature = "v2")] OhttpEncapsulation(e) => write!(f, "v2 error: {}", e), #[cfg(feature = "v2")] - SubdirectoryNotBase64(e) => write!(f, "subdirectory is not valid base64 error: {}", e), - #[cfg(feature = "v2")] - SubdirectoryInvalidPubkey(e) => write!(f, "subdirectory does not represent a valid pubkey: {}", e), + ParseSubdirectory(e) => write!(f, "cannot parse subdirectory: {}", e), #[cfg(feature = "v2")] MissingOhttpConfig => write!(f, "no ohttp configuration with which to make a v2 request available"), } @@ -256,9 +252,7 @@ impl std::error::Error for CreateRequestError { #[cfg(feature = "v2")] OhttpEncapsulation(error) => Some(error), #[cfg(feature = "v2")] - SubdirectoryNotBase64(error) => Some(error), - #[cfg(feature = "v2")] - SubdirectoryInvalidPubkey(error) => Some(error), + ParseSubdirectory(error) => Some(error), #[cfg(feature = "v2")] MissingOhttpConfig => None, } @@ -269,6 +263,48 @@ impl From for CreateRequestError { fn from(value: InternalCreateRequestError) -> Self { CreateRequestError(value) } } +#[cfg(feature = "v2")] +impl From for CreateRequestError { + fn from(value: ParseSubdirectoryError) -> Self { + CreateRequestError(InternalCreateRequestError::ParseSubdirectory(value)) + } +} + +#[cfg(feature = "v2")] +#[derive(Debug)] +pub(crate) enum ParseSubdirectoryError { + MissingSubdirectory, + SubdirectoryNotBase64(bitcoin::base64::DecodeError), + SubdirectoryInvalidPubkey(bitcoin::secp256k1::Error), +} + +#[cfg(feature = "v2")] +impl std::fmt::Display for ParseSubdirectoryError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + use ParseSubdirectoryError::*; + + match &self { + MissingSubdirectory => write!(f, "subdirectory is missing"), + SubdirectoryNotBase64(e) => write!(f, "subdirectory is not valid base64: {}", e), + SubdirectoryInvalidPubkey(e) => + write!(f, "subdirectory does not represent a valid pubkey: {}", e), + } + } +} + +#[cfg(feature = "v2")] +impl std::error::Error for ParseSubdirectoryError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + use ParseSubdirectoryError::*; + + match &self { + MissingSubdirectory => None, + SubdirectoryNotBase64(error) => Some(error), + SubdirectoryInvalidPubkey(error) => Some(error), + } + } +} + /// Represent an error returned by Payjoin receiver. pub enum ResponseError { /// `WellKnown` Errors are defined in the [`BIP78::ReceiverWellKnownError`] spec. diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 1f247bfc..3bff4714 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -202,8 +202,6 @@ impl<'a> RequestBuilder<'a> { psbt.validate_input_utxos(true) .map_err(InternalCreateRequestError::InvalidOriginalInput)?; let endpoint = self.uri.extras.endpoint.clone(); - #[cfg(feature = "v2")] - let ohttp_keys = self.uri.extras.ohttp_keys; let disable_output_substitution = self.uri.extras.disable_output_substitution || self.disable_output_substitution; let payee = self.uri.address.script_pubkey(); @@ -234,8 +232,6 @@ impl<'a> RequestBuilder<'a> { Ok(RequestContext { psbt, endpoint, - #[cfg(feature = "v2")] - ohttp_keys, disable_output_substitution, fee_contribution, payee, @@ -252,8 +248,6 @@ impl<'a> RequestBuilder<'a> { pub struct RequestContext { psbt: Psbt, endpoint: Url, - #[cfg(feature = "v2")] - ohttp_keys: Option, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, @@ -271,7 +265,7 @@ impl RequestContext { /// Extract serialized V1 Request and Context froma Payjoin Proposal pub fn extract_v1(self) -> Result<(Request, ContextV1), CreateRequestError> { let url = serialize_url( - self.endpoint.into(), + self.endpoint, self.disable_output_substitution, self.fee_contribution, self.min_fee_rate, @@ -303,6 +297,7 @@ impl RequestContext { &mut self, ohttp_relay: Url, ) -> Result<(Request, ContextV2), CreateRequestError> { + use crate::uri::UrlExt; let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -314,7 +309,7 @@ impl RequestContext { let body = crate::v2::encrypt_message_a(body, self.e, rs) .map_err(InternalCreateRequestError::Hpke)?; let (body, ohttp_res) = crate::v2::ohttp_encapsulate( - self.ohttp_keys.as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?, + self.endpoint.ohttp().as_mut().ok_or(InternalCreateRequestError::MissingOhttpConfig)?, "POST", url.as_str(), Some(&body), @@ -342,33 +337,22 @@ impl RequestContext { #[cfg(feature = "v2")] fn rs_pubkey_from_dir_endpoint(endpoint: &Url) -> Result { - let path_and_query: String; - - if let Some(pos) = endpoint.as_str().rfind('/') { - path_and_query = endpoint.as_str()[pos + 1..].to_string(); - } else { - path_and_query = endpoint.to_string(); - } - - let subdirectory: String; - - if let Some(pos) = path_and_query.find('?') { - subdirectory = path_and_query[..pos].to_string(); - } else { - subdirectory = path_and_query; - } - - let pubkey_bytes = - bitcoin::base64::decode_config(subdirectory, bitcoin::base64::URL_SAFE_NO_PAD) - .map_err(InternalCreateRequestError::SubdirectoryNotBase64)?; - Ok(bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes) - .map_err(InternalCreateRequestError::SubdirectoryInvalidPubkey)?) - } - - #[cfg(feature = "v2")] - pub fn public_key(&self) -> PublicKey { - let secp = bitcoin::secp256k1::Secp256k1::new(); - self.e.public_key(&secp) + use bitcoin::base64; + + use crate::send::error::ParseSubdirectoryError; + + let subdirectory = endpoint + .path_segments() + .ok_or(ParseSubdirectoryError::MissingSubdirectory)? + .next() + .ok_or(ParseSubdirectoryError::MissingSubdirectory)? + .to_string(); + + let pubkey_bytes = base64::decode_config(subdirectory, base64::URL_SAFE_NO_PAD) + .map_err(ParseSubdirectoryError::SubdirectoryNotBase64)?; + bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes) + .map_err(ParseSubdirectoryError::SubdirectoryInvalidPubkey) + .map_err(CreateRequestError::from) } pub fn endpoint(&self) -> &Url { &self.endpoint } @@ -383,7 +367,6 @@ impl Serialize for RequestContext { let mut state = serializer.serialize_struct("RequestContext", 8)?; state.serialize_field("psbt", &self.psbt.to_string())?; state.serialize_field("endpoint", &self.endpoint.as_str())?; - state.serialize_field("ohttp_keys", &self.ohttp_keys)?; state.serialize_field("disable_output_substitution", &self.disable_output_substitution)?; state.serialize_field( "fee_contribution", @@ -432,7 +415,6 @@ impl<'de> Deserialize<'de> for RequestContext { { let mut psbt = None; let mut endpoint = None; - let mut ohttp_keys = None; let mut disable_output_substitution = None; let mut fee_contribution = None; let mut min_fee_rate = None; @@ -452,7 +434,6 @@ impl<'de> Deserialize<'de> for RequestContext { url::Url::from_str(&map.next_value::()?) .map_err(de::Error::custom)?, ), - "ohttp_keys" => ohttp_keys = Some(map.next_value()?), "disable_output_substitution" => disable_output_substitution = Some(map.next_value()?), "fee_contribution" => { @@ -478,7 +459,6 @@ impl<'de> Deserialize<'de> for RequestContext { Ok(RequestContext { psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?, endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?, - ohttp_keys: ohttp_keys.ok_or_else(|| de::Error::missing_field("ohttp_keys"))?, disable_output_substitution: disable_output_substitution .ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?, fee_contribution, @@ -974,7 +954,7 @@ fn serialize_v2_body( ) -> Result, CreateRequestError> { // Grug say localhost base be discarded anyway. no big brain needed. let placeholder_url = serialize_url( - "http:/localhost".to_string(), + Url::parse("http://localhost").unwrap(), disable_output_substitution, fee_contribution, min_feerate, @@ -986,12 +966,12 @@ fn serialize_v2_body( } fn serialize_url( - endpoint: String, + endpoint: Url, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, ) -> Result { - let mut url = Url::parse(&endpoint)?; + let mut url = endpoint; url.query_pairs_mut().append_pair("v", "1"); if disable_output_substitution { url.query_pairs_mut().append_pair("disableoutputsubstitution", "1"); @@ -1065,7 +1045,6 @@ mod test { let req_ctx = RequestContext { psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(), endpoint: Url::parse("http://localhost:1234").unwrap(), - ohttp_keys: None, disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, diff --git a/payjoin/src/uri/error.rs b/payjoin/src/uri/error.rs index eb1636e5..03732db3 100644 --- a/payjoin/src/uri/error.rs +++ b/payjoin/src/uri/error.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "v2")] -use crate::v2::ParseOhttpKeysError; - #[derive(Debug)] pub struct PjParseError(InternalPjParseError); @@ -11,8 +8,6 @@ pub(crate) enum InternalPjParseError { MissingEndpoint, NotUtf8, BadEndpoint, - #[cfg(feature = "v2")] - BadOhttpKeys(ParseOhttpKeysError), UnsecureEndpoint, } @@ -30,8 +25,6 @@ impl std::fmt::Display for PjParseError { InternalPjParseError::MissingEndpoint => write!(f, "Missing payjoin endpoint"), InternalPjParseError::NotUtf8 => write!(f, "Endpoint is not valid UTF-8"), InternalPjParseError::BadEndpoint => write!(f, "Endpoint is not valid"), - #[cfg(feature = "v2")] - InternalPjParseError::BadOhttpKeys(e) => write!(f, "OHTTP keys are not valid: {}", e), InternalPjParseError::UnsecureEndpoint => { write!(f, "Endpoint scheme is not secure (https or onion)") } diff --git a/payjoin/src/uri/mod.rs b/payjoin/src/uri/mod.rs index 89d96b5c..79ee13d9 100644 --- a/payjoin/src/uri/mod.rs +++ b/payjoin/src/uri/mod.rs @@ -7,9 +7,13 @@ use url::Url; use crate::uri::error::InternalPjParseError; #[cfg(feature = "v2")] +pub(crate) use crate::uri::url_ext::UrlExt; +#[cfg(feature = "v2")] use crate::OhttpKeys; pub mod error; +#[cfg(feature = "v2")] +pub(crate) mod url_ext; #[derive(Clone)] pub enum MaybePayjoinExtras { @@ -30,8 +34,6 @@ impl MaybePayjoinExtras { pub struct PayjoinExtras { pub(crate) endpoint: Url, pub(crate) disable_output_substitution: bool, - #[cfg(feature = "v2")] - pub(crate) ohttp_keys: Option, } impl PayjoinExtras { @@ -96,30 +98,25 @@ pub struct PjUriBuilder { pj: Url, /// Whether or not payjoin output substitution is allowed pjos: bool, - #[cfg(feature = "v2")] - /// Config for ohttp. - /// - /// Required only for v2 payjoin. - ohttp: Option, } impl PjUriBuilder { /// Create a new `PjUriBuilder` with required parameters. + /// + /// ## Parameters + /// - `address`: Represents a bitcoin address. + /// - `origin`: Represents either the payjoin endpoint in v1 or the directory in v2. + /// - `ohttp_keys`: Optional OHTTP keys for v2 (only available if the "v2" feature is enabled). pub fn new( address: Address, - pj: Url, + origin: Url, #[cfg(feature = "v2")] ohttp_keys: Option, ) -> Self { - Self { - address, - amount: None, - message: None, - label: None, - pj, - pjos: false, - #[cfg(feature = "v2")] - ohttp: ohttp_keys, - } + #[allow(unused_mut)] + let mut pj = origin; + #[cfg(feature = "v2")] + pj.set_ohttp(ohttp_keys); + Self { address, amount: None, message: None, label: None, pj, pjos: false } } /// Set the amount you want to receive. pub fn amount(mut self, amount: Amount) -> Self { @@ -150,12 +147,7 @@ impl PjUriBuilder { /// Constructs a `bip21::Uri` with PayjoinParams from the /// parameters set in the builder. pub fn build<'a>(self) -> PjUri<'a> { - let extras = PayjoinExtras { - endpoint: self.pj, - disable_output_substitution: self.pjos, - #[cfg(feature = "v2")] - ohttp_keys: self.ohttp, - }; + let extras = PayjoinExtras { endpoint: self.pj, disable_output_substitution: self.pjos }; let mut pj_uri = bip21::Uri::with_extras(self.address, extras); pj_uri.amount = self.amount; pj_uri.label = self.label.map(Into::into); @@ -180,8 +172,6 @@ impl<'a> bip21::de::DeserializeParams<'a> for MaybePayjoinExtras { pub struct DeserializationState { pj: Option, pjos: Option, - #[cfg(feature = "v2")] - ohttp: Option, } impl<'a> bip21::SerializeParams for &'a MaybePayjoinExtras { @@ -203,18 +193,11 @@ impl<'a> bip21::SerializeParams for &'a PayjoinExtras { type Iterator = std::vec::IntoIter<(Self::Key, Self::Value)>; fn serialize_params(self) -> Self::Iterator { - #[allow(unused_mut)] - let mut params = vec![ + vec![ ("pj", self.endpoint.as_str().to_string()), ("pjos", if self.disable_output_substitution { "1" } else { "0" }.to_string()), - ]; - #[cfg(feature = "v2")] - if let Some(ohttp_keys) = &self.ohttp_keys { - params.push(("ohttp", ohttp_keys.to_string())); - } else { - log::warn!("Failed to encode ohttp config, ignoring"); - } - params.into_iter() + ] + .into_iter() } } @@ -232,19 +215,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { ::Error, > { match key { - #[cfg(feature = "v2")] - "ohttp" if self.ohttp.is_none() => { - use std::str::FromStr; - - let base64_config = - Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; - let config = OhttpKeys::from_str(&base64_config) - .map_err(InternalPjParseError::BadOhttpKeys)?; - self.ohttp = Some(config); - Ok(bip21::de::ParamKind::Known) - } - #[cfg(feature = "v2")] - "ohttp" => Err(InternalPjParseError::DuplicateParams("ohttp").into()), "pj" if self.pj.is_none() => { let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?; let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?; @@ -280,8 +250,6 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState { Ok(MaybePayjoinExtras::Supported(PayjoinExtras { endpoint, disable_output_substitution: pjos.unwrap_or(false), - #[cfg(feature = "v2")] - ohttp_keys: self.ohttp, })) } else { Err(InternalPjParseError::UnsecureEndpoint.into()) diff --git a/payjoin/src/uri/pj_url.rs b/payjoin/src/uri/pj_url.rs deleted file mode 100644 index 199c922a..00000000 --- a/payjoin/src/uri/pj_url.rs +++ /dev/null @@ -1,43 +0,0 @@ -use url::Url; - -pub struct PjUrl { - url: Url, - ohttp: Option, -} - -impl PjUrl { - pub fn new(url: Url) -> Self { - let (url, ohttp) = Self::extract_ohttp(url); - PjUrl { url, ohttp } - } - - fn extract_ohttp(mut url: Url) -> (Url, Option) { - let fragment = &mut url.fragment().and_then(|f| { - let parts: Vec<&str> = f.splitn(2, "ohttp=").collect(); - if parts.len() == 2 { - Some((parts[0].trim_end_matches('&'), parts[1].to_string())) - } else { - None - } - }); - - if let Some((remaining_fragment, ohttp)) = fragment { - url.set_fragment(Some(remaining_fragment)); - (url, Some(ohttp)) - } else { - (url, None) - } - } - - pub fn into_url(self) -> Url { - let mut url = self.url; - if let Some(ohttp) = self.ohttp { - let fragment = url - .fragment() - .map(|f| format!("{}&ohttp={}", f, ohttp)) - .unwrap_or_else(|| format!("ohttp={}", ohttp)); - url.set_fragment(Some(&fragment)); - } - url - } -} diff --git a/payjoin/src/uri/url_ext.rs b/payjoin/src/uri/url_ext.rs new file mode 100644 index 00000000..50223074 --- /dev/null +++ b/payjoin/src/uri/url_ext.rs @@ -0,0 +1,75 @@ +use std::borrow::Cow; + +use url::Url; + +use crate::OhttpKeys; + +/// Parse and set fragment parameters from `&pj=` URI parameter URLs +pub(crate) trait UrlExt { + fn ohttp(&self) -> Option; + fn set_ohttp(&mut self, ohttp: Option); +} + +impl UrlExt for Url { + /// Retrieve the ohttp parameter from the URL fragment + fn ohttp(&self) -> Option { + use std::str::FromStr; + self.fragment().and_then(|f| { + for param in f.split('&') { + if let Some(value) = param.strip_prefix("ohttp=") { + let ohttp = Cow::from(value); + return OhttpKeys::from_str(&ohttp).ok(); + } + } + None + }) + } + + /// Set the ohttp parameter in the URL fragment + fn set_ohttp(&mut self, ohttp: Option) { + let mut fragment = self.fragment().unwrap_or("").to_string(); + if let Some(start) = fragment.find("ohttp=") { + let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); + fragment.replace_range(start..end, ""); + if fragment.ends_with('&') { + fragment.pop(); + } + } + if let Some(ohttp) = ohttp { + let new_ohttp = format!("ohttp={}", ohttp); + if !fragment.is_empty() { + fragment.push('&'); + } + fragment.push_str(&new_ohttp); + } + self.set_fragment(if fragment.is_empty() { None } else { Some(&fragment) }); + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use url::Url; + + use super::*; + + #[test] + fn test_ohttp_get_set() { + let mut url = Url::parse("https://example.com").unwrap(); + + let ohttp_keys = + OhttpKeys::from_str("AQAg3WpRjS0aqAxQUoLvpas2VYjT2oIg6-3XSiB-QiYI1BAABAABAAM").unwrap(); + url.set_ohttp(Some(ohttp_keys.clone())); + assert_eq!( + url.fragment(), + Some("ohttp=AQAg3WpRjS0aqAxQUoLvpas2VYjT2oIg6-3XSiB-QiYI1BAABAABAAM") + ); + + let retrieved_ohttp = url.ohttp(); + assert_eq!(retrieved_ohttp, Some(ohttp_keys)); + + url.set_ohttp(None); + assert_eq!(url.fragment(), None); + } +}