diff --git a/payjoin/src/uri/url_ext.rs b/payjoin/src/uri/url_ext.rs index d1f42def..21a4cb96 100644 --- a/payjoin/src/uri/url_ext.rs +++ b/payjoin/src/uri/url_ext.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use std::str::FromStr; use percent_encoding::{AsciiSet, PercentDecodeError, CONTROLS}; use url::Url; @@ -19,96 +19,79 @@ const BIP21_CONFLICTING: &AsciiSet = &CONTROLS.add(b'=').add(b'&'); impl UrlExt for Url { /// Retrieve the ohttp parameter from the URL fragment fn ohttp(&self) -> Result, PercentDecodeError> { - use std::str::FromStr; - if let Some(fragment) = self.fragment() { - let decoded_fragment = - percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy(); - for param in decoded_fragment.split('&') { - if let Some(value) = param.strip_prefix("ohttp=") { - let ohttp = Cow::from(value); - return Ok(OhttpKeys::from_str(&ohttp).ok()); - } - } - } - Ok(None) + get_param(self, "ohttp=", |value| OhttpKeys::from_str(value).ok()) } /// Set the ohttp parameter in the URL fragment fn set_ohttp(&mut self, ohttp: Option) -> Result<(), PercentDecodeError> { - let fragment = self.fragment().unwrap_or("").to_string(); - let mut fragment = - percent_encoding::percent_decode_str(&fragment)?.decode_utf8_lossy().to_string(); - if let Some(start) = fragment.find("ohttp=") { - let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); - fragment.replace_range(start..end, ""); - if fragment.ends_with('&') { - fragment.pop(); - } - } - if let Some(ohttp) = ohttp { - let new_ohttp = format!("ohttp={}", ohttp); - if !fragment.is_empty() { - fragment.push('&'); - } - fragment.push_str(&new_ohttp); - } - let encoded_fragment = - percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string(); - self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) }); - Ok(()) + set_param(self, "ohttp=", ohttp.map(|o| o.to_string())) } /// Retrieve the exp parameter from the URL fragment fn exp(&self) -> Result, PercentDecodeError> { - if let Some(fragment) = self.fragment() { - let decoded_fragment = - percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy(); - for param in decoded_fragment.split('&') { - if let Some(value) = param.strip_prefix("exp=") { - if let Ok(timestamp) = value.parse::() { - return Ok(Some( - std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp), - )); - } - } - } - } - Ok(None) + get_param(self, "exp=", |value| { + value + .parse::() + .ok() + .map(|timestamp| std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp)) + }) } /// Set the exp parameter in the URL fragment fn set_exp(&mut self, exp: Option) -> Result<(), PercentDecodeError> { - let fragment = self.fragment().unwrap_or("").to_string(); - let mut fragment = - percent_encoding::percent_decode_str(&fragment)?.decode_utf8_lossy().to_string(); - if let Some(start) = fragment.find("exp=") { - let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); - fragment.replace_range(start..end, ""); - if fragment.ends_with('&') { - fragment.pop(); + let exp_str = exp.map(|e| { + match e.duration_since(std::time::UNIX_EPOCH) { + Ok(duration) => duration.as_secs().to_string(), + Err(_) => "0".to_string(), // Handle times before Unix epoch by setting to "0" } - } - if let Some(exp) = exp { - let timestamp = exp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(); - let new_exp = format!("exp={}", timestamp); - if !fragment.is_empty() { - fragment.push('&'); + }); + set_param(self, "exp=", exp_str) + } +} + +fn get_param(url: &Url, prefix: &str, parse: F) -> Result, PercentDecodeError> +where + F: Fn(&str) -> Option, +{ + if let Some(fragment) = url.fragment() { + let decoded_fragment = percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy(); + for param in decoded_fragment.split('&') { + if let Some(value) = param.strip_prefix(prefix) { + return Ok(parse(value)); } - fragment.push_str(&new_exp); } - let encoded_fragment = - percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string(); - self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) }); - Ok(()) } + Ok(None) } -#[cfg(test)] -mod tests { - use std::str::FromStr; +fn set_param(url: &mut Url, prefix: &str, value: Option) -> Result<(), PercentDecodeError> { + let fragment = url.fragment().unwrap_or(""); + let mut fragment = percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy(); - use url::Url; + if let Some(start) = fragment.find(prefix) { + let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i); + fragment.to_mut().replace_range(start..end, ""); + if fragment.ends_with('&') { + fragment.to_mut().pop(); + } + } + if let Some(value) = value { + let new_param = format!("{}{}", prefix, value); + if !fragment.is_empty() { + fragment.to_mut().push('&'); + } + fragment.to_mut().push_str(&new_param); + } + + let encoded_fragment = + percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string(); + url.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) }); + Ok(()) +} + +#[cfg(test)] +mod tests { use super::*; use crate::{Uri, UriExt};