Skip to content

Commit

Permalink
Abstract common UrlExt helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Jul 15, 2024
1 parent e1e2e7f commit cf8714d
Showing 1 changed file with 54 additions and 71 deletions.
125 changes: 54 additions & 71 deletions payjoin/src/uri/url_ext.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use std::str::FromStr;

use percent_encoding::{AsciiSet, PercentDecodeError, CONTROLS};
use url::Url;
Expand All @@ -19,96 +19,79 @@ const BIP21_CONFLICTING: &AsciiSet = &CONTROLS.add(b'=').add(b'&');
impl UrlExt for Url {
/// Retrieve the ohttp parameter from the URL fragment
fn ohttp(&self) -> Result<Option<OhttpKeys>, PercentDecodeError> {
use std::str::FromStr;
if let Some(fragment) = self.fragment() {
let decoded_fragment =
percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy();
for param in decoded_fragment.split('&') {
if let Some(value) = param.strip_prefix("ohttp=") {
let ohttp = Cow::from(value);
return Ok(OhttpKeys::from_str(&ohttp).ok());
}
}
}
Ok(None)
get_param(self, "ohttp=", |value| OhttpKeys::from_str(value).ok())
}

/// Set the ohttp parameter in the URL fragment
fn set_ohttp(&mut self, ohttp: Option<OhttpKeys>) -> Result<(), PercentDecodeError> {
let fragment = self.fragment().unwrap_or("").to_string();
let mut fragment =
percent_encoding::percent_decode_str(&fragment)?.decode_utf8_lossy().to_string();
if let Some(start) = fragment.find("ohttp=") {
let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i);
fragment.replace_range(start..end, "");
if fragment.ends_with('&') {
fragment.pop();
}
}
if let Some(ohttp) = ohttp {
let new_ohttp = format!("ohttp={}", ohttp);
if !fragment.is_empty() {
fragment.push('&');
}
fragment.push_str(&new_ohttp);
}
let encoded_fragment =
percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string();
self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) });
Ok(())
set_param(self, "ohttp=", ohttp.map(|o| o.to_string()))
}

/// Retrieve the exp parameter from the URL fragment
fn exp(&self) -> Result<Option<std::time::SystemTime>, PercentDecodeError> {
if let Some(fragment) = self.fragment() {
let decoded_fragment =
percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy();
for param in decoded_fragment.split('&') {
if let Some(value) = param.strip_prefix("exp=") {
if let Ok(timestamp) = value.parse::<u64>() {
return Ok(Some(
std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp),
));
}
}
}
}
Ok(None)
get_param(self, "exp=", |value| {
value
.parse::<u64>()
.ok()
.map(|timestamp| std::time::UNIX_EPOCH + std::time::Duration::from_secs(timestamp))
})
}

/// Set the exp parameter in the URL fragment
fn set_exp(&mut self, exp: Option<std::time::SystemTime>) -> Result<(), PercentDecodeError> {
let fragment = self.fragment().unwrap_or("").to_string();
let mut fragment =
percent_encoding::percent_decode_str(&fragment)?.decode_utf8_lossy().to_string();
if let Some(start) = fragment.find("exp=") {
let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i);
fragment.replace_range(start..end, "");
if fragment.ends_with('&') {
fragment.pop();
let exp_str = exp.map(|e| {
match e.duration_since(std::time::UNIX_EPOCH) {
Ok(duration) => duration.as_secs().to_string(),
Err(_) => "0".to_string(), // Handle times before Unix epoch by setting to "0"
}
}
if let Some(exp) = exp {
let timestamp = exp.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
let new_exp = format!("exp={}", timestamp);
if !fragment.is_empty() {
fragment.push('&');
});
set_param(self, "exp=", exp_str)
}
}

fn get_param<F, T>(url: &Url, prefix: &str, parse: F) -> Result<Option<T>, PercentDecodeError>
where
F: Fn(&str) -> Option<T>,
{
if let Some(fragment) = url.fragment() {
let decoded_fragment = percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy();
for param in decoded_fragment.split('&') {
if let Some(value) = param.strip_prefix(prefix) {
return Ok(parse(value));
}
fragment.push_str(&new_exp);
}
let encoded_fragment =
percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string();
self.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) });
Ok(())
}
Ok(None)
}

#[cfg(test)]
mod tests {
use std::str::FromStr;
fn set_param(url: &mut Url, prefix: &str, value: Option<String>) -> Result<(), PercentDecodeError> {
let fragment = url.fragment().unwrap_or("");
let mut fragment = percent_encoding::percent_decode_str(fragment)?.decode_utf8_lossy();

use url::Url;
if let Some(start) = fragment.find(prefix) {
let end = fragment[start..].find('&').map_or(fragment.len(), |i| start + i);
fragment.to_mut().replace_range(start..end, "");
if fragment.ends_with('&') {
fragment.to_mut().pop();
}
}

if let Some(value) = value {
let new_param = format!("{}{}", prefix, value);
if !fragment.is_empty() {
fragment.to_mut().push('&');
}
fragment.to_mut().push_str(&new_param);
}

let encoded_fragment =
percent_encoding::utf8_percent_encode(&fragment, BIP21_CONFLICTING).to_string();
url.set_fragment(if encoded_fragment.is_empty() { None } else { Some(&encoded_fragment) });
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{Uri, UriExt};

Expand Down

0 comments on commit cf8714d

Please sign in to comment.