Skip to content

Commit

Permalink
Separate uri.rs into module folder
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Jul 2, 2024
1 parent b7cf5e8 commit 22afe85
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 46 deletions.
40 changes: 40 additions & 0 deletions payjoin/src/uri/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#[cfg(feature = "v2")]
use crate::v2::ParseOhttpKeysError;

#[derive(Debug)]
pub struct PjParseError(InternalPjParseError);

#[derive(Debug)]
pub(crate) enum InternalPjParseError {
BadPjOs,
MultipleParams(&'static str),
MissingEndpoint,
NotUtf8,
BadEndpoint,
#[cfg(feature = "v2")]
BadOhttpKeys(ParseOhttpKeysError),
UnsecureEndpoint,
}

impl From<InternalPjParseError> for PjParseError {
fn from(value: InternalPjParseError) -> Self { PjParseError(value) }
}

impl std::fmt::Display for PjParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
InternalPjParseError::BadPjOs => write!(f, "Bad pjos parameter"),
InternalPjParseError::MultipleParams(param) => {
write!(f, "Multiple instances of parameter '{}'", param)
}
InternalPjParseError::MissingEndpoint => write!(f, "Missing payjoin endpoint"),
InternalPjParseError::NotUtf8 => write!(f, "Endpoint is not valid UTF-8"),
InternalPjParseError::BadEndpoint => write!(f, "Endpoint is not valid"),
#[cfg(feature = "v2")]
InternalPjParseError::BadOhttpKeys(e) => write!(f, "OHTTP keys are not valid: {}", e),
InternalPjParseError::UnsecureEndpoint => {
write!(f, "Endpoint scheme is not secure (https or onion)")
}
}
}
}
56 changes: 10 additions & 46 deletions payjoin/src/uri.rs → payjoin/src/uri/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::borrow::Cow;
#[cfg(feature = "v2")]
use std::str::FromStr;

use bitcoin::address::NetworkChecked;
use bitcoin::{Address, Amount};
pub use error::PjParseError;
use url::Url;

#[cfg(feature = "v2")]
use crate::v2::ParseOhttpKeysError;
use crate::uri::error::InternalPjParseError;
#[cfg(feature = "v2")]
use crate::OhttpKeys;

pub mod error;

#[derive(Clone)]
pub enum MaybePayjoinExtras {
Supported(PayjoinExtras),
Expand Down Expand Up @@ -184,13 +184,6 @@ pub struct DeserializationState {
ohttp: Option<OhttpKeys>,
}

#[derive(Debug)]
pub struct PjParseError(InternalPjParseError);

impl From<InternalPjParseError> for PjParseError {
fn from(value: InternalPjParseError) -> Self { PjParseError(value) }
}

impl<'a> bip21::SerializeParams for &'a MaybePayjoinExtras {
type Key = &'static str;
type Value = String;
Expand Down Expand Up @@ -241,15 +234,17 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
match key {
#[cfg(feature = "v2")]
"ohttp" if self.ohttp.is_none() => {
use std::str::FromStr;

let base64_config =
Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let config = OhttpKeys::from_str(&base64_config)
.map_err(InternalPjParseError::ParseOhttpKeys)?;
.map_err(InternalPjParseError::BadOhttpKeys)?;
self.ohttp = Some(config);
Ok(bip21::de::ParamKind::Known)
}
#[cfg(feature = "v2")]
"ohttp" => Err(PjParseError(InternalPjParseError::MultipleParams("ohttp"))),
"ohttp" => Err(InternalPjParseError::MultipleParams("ohttp").into()),
"pj" if self.pj.is_none() => {
let endpoint = Cow::try_from(value).map_err(|_| InternalPjParseError::NotUtf8)?;
let url = Url::parse(&endpoint).map_err(|_| InternalPjParseError::BadEndpoint)?;
Expand All @@ -276,7 +271,7 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
) -> std::result::Result<Self::Value, <Self::Value as bip21::DeserializationError>::Error> {
match (self.pj, self.pjos) {
(None, None) => Ok(MaybePayjoinExtras::Unsupported),
(None, Some(_)) => Err(PjParseError(InternalPjParseError::MissingEndpoint)),
(None, Some(_)) => Err(InternalPjParseError::MissingEndpoint.into()),
(Some(endpoint), pjos) => {
if endpoint.scheme() == "https"
|| endpoint.scheme() == "http"
Expand All @@ -289,44 +284,13 @@ impl<'a> bip21::de::DeserializationState<'a> for DeserializationState {
ohttp_keys: self.ohttp,
}))
} else {
Err(PjParseError(InternalPjParseError::UnsecureEndpoint))
Err(InternalPjParseError::UnsecureEndpoint.into())
}
}
}
}
}

impl std::fmt::Display for PjParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
InternalPjParseError::BadPjOs => write!(f, "Bad pjos parameter"),
InternalPjParseError::MultipleParams(param) => {
write!(f, "Multiple instances of parameter '{}'", param)
}
InternalPjParseError::MissingEndpoint => write!(f, "Missing payjoin endpoint"),
InternalPjParseError::NotUtf8 => write!(f, "Endpoint is not valid UTF-8"),
InternalPjParseError::BadEndpoint => write!(f, "Endpoint is not valid"),
#[cfg(feature = "v2")]
InternalPjParseError::ParseOhttpKeys(e) => write!(f, "OHTTP keys are not valid: {}", e),
InternalPjParseError::UnsecureEndpoint => {
write!(f, "Endpoint scheme is not secure (https or onion)")
}
}
}
}

#[derive(Debug)]
enum InternalPjParseError {
BadPjOs,
MultipleParams(&'static str),
MissingEndpoint,
NotUtf8,
BadEndpoint,
#[cfg(feature = "v2")]
ParseOhttpKeys(ParseOhttpKeysError),
UnsecureEndpoint,
}

#[cfg(test)]
mod tests {
use std::convert::TryFrom;
Expand Down
43 changes: 43 additions & 0 deletions payjoin/src/uri/pj_url.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use url::Url;

pub struct PjUrl {
url: Url,
ohttp: Option<String>,
}

impl PjUrl {
pub fn new(url: Url) -> Self {
let (url, ohttp) = Self::extract_ohttp(url);
PjUrl { url, ohttp }
}

fn extract_ohttp(mut url: Url) -> (Url, Option<String>) {
let fragment = &mut url.fragment().and_then(|f| {
let parts: Vec<&str> = f.splitn(2, "ohttp=").collect();
if parts.len() == 2 {
Some((parts[0].trim_end_matches('&'), parts[1].to_string()))
} else {
None
}
});

if let Some((remaining_fragment, ohttp)) = fragment {
url.set_fragment(Some(remaining_fragment));
(url, Some(ohttp))
} else {
(url, None)
}
}

pub fn into_url(self) -> Url {
let mut url = self.url;
if let Some(ohttp) = self.ohttp {
let fragment = url
.fragment()
.map(|f| format!("{}&ohttp={}", f, ohttp))
.unwrap_or_else(|| format!("ohttp={}", ohttp));
url.set_fragment(Some(&fragment));
}
url
}
}

0 comments on commit 22afe85

Please sign in to comment.