From 1f87efd0527fbb69dbe706bcaa55da2f5e384960 Mon Sep 17 00:00:00 2001 From: Innes Anderson-Morrison Date: Thu, 26 Sep 2024 09:06:11 +0100 Subject: [PATCH 1/3] feat: checking for last-data in notices from snapd --- prompting-client/src/cli_actions/scripted.rs | 22 ++++++++---- prompting-client/src/daemon/poll.rs | 31 +++++++++------- prompting-client/src/snapd_client/mod.rs | 38 ++++++++++++++++---- prompting-client/src/snapd_client/prompt.rs | 16 +++++++++ prompting-client/tests/integration.rs | 7 ++-- 5 files changed, 87 insertions(+), 27 deletions(-) diff --git a/prompting-client/src/cli_actions/scripted.rs b/prompting-client/src/cli_actions/scripted.rs index 6ef1985..03572d8 100644 --- a/prompting-client/src/cli_actions/scripted.rs +++ b/prompting-client/src/cli_actions/scripted.rs @@ -6,7 +6,7 @@ use crate::{ home::{HomeConstraintsFilter, HomeInterface}, SnapInterface, }, - Action, PromptId, SnapdSocketClient, TypedPrompt, TypedPromptReply, + Action, PromptId, PromptNotice, SnapdSocketClient, TypedPrompt, TypedPromptReply, }, Error, Result, SNAP_NAME, }; @@ -19,10 +19,15 @@ use tracing::{debug, error, info, warn}; /// loop until at least one un-actioned prompt is encountered. async fn grace_period_deny_and_error(snapd_client: &mut SnapdSocketClient) -> Result<()> { loop { - let ids = snapd_client.pending_prompt_ids().await?; - let mut prompts = Vec::with_capacity(ids.len()); + let notices = snapd_client.pending_prompt_notices().await?; + let mut prompts = Vec::with_capacity(notices.len()); + + for notice in notices { + let id = match notice { + PromptNotice::Update(id) => id, + _ => continue, + }; - for id in ids { let prompt = match snapd_client.prompt_details(&id).await { Ok(p) => p, Err(_) => continue, @@ -76,8 +81,13 @@ impl ScriptedClient { tokio::task::spawn(async move { loop { - let pending = snapd_client.pending_prompt_ids().await.unwrap(); - for id in pending { + let notices = snapd_client.pending_prompt_notices().await.unwrap(); + for notice in notices { + let id = match notice { + PromptNotice::Update(id) => id, + _ => continue, + }; + match snapd_client.prompt_details(&id).await { Ok(TypedPrompt::Home(inner)) if filter.matches(&inner).is_success() => { debug!("allowing read of script file"); diff --git a/prompting-client/src/daemon/poll.rs b/prompting-client/src/daemon/poll.rs index a56cbd0..945b063 100644 --- a/prompting-client/src/daemon/poll.rs +++ b/prompting-client/src/daemon/poll.rs @@ -6,7 +6,7 @@ //! mapping into the data required for the prompt UI. use crate::{ daemon::{EnrichedPrompt, PromptUpdate}, - snapd_client::{PromptId, SnapMeta, SnapdSocketClient, TypedPrompt}, + snapd_client::{PromptId, PromptNotice, SnapMeta, SnapdSocketClient, TypedPrompt}, Error, }; use cached::proc_macro::cached; @@ -66,8 +66,8 @@ impl PollLoop { while self.running { info!("polling for notices"); - let pending = match self.client.pending_prompt_ids().await { - Ok(pending) => pending, + let notices = match self.client.pending_prompt_notices().await { + Ok(notices) => notices, Err(Error::SnapdError { status: StatusCode::FORBIDDEN, @@ -94,9 +94,12 @@ impl PollLoop { }; retries = 0; - debug!(?pending, "processing notices"); - for id in pending { - self.pull_and_process_prompt(id).await; + debug!(?notices, "processing notices"); + for notice in notices { + match notice { + PromptNotice::Update(id) => self.pull_and_process_prompt(id).await, + PromptNotice::Resolved(id) => self.send_update(PromptUpdate::Drop(id)), + } } } } @@ -161,18 +164,22 @@ impl PollLoop { // the ones that we need to provide for the notices API, so we deliberately set up an // overlap between pulling all pending prompts first before pulling pending prompt IDs // and updating our internal `after` timestamp. - let pending = match self.client.pending_prompt_ids().await { - Ok(pending) => pending, + let notices = match self.client.pending_prompt_notices().await { + Ok(notices) => notices, Err(error) => { error!(%error, "unable to pull pending prompt ids"); return; } }; - for id in pending { - if !seen.contains(&id) { - self.pull_and_process_prompt(id).await; - } + for notice in notices { + match notice { + PromptNotice::Update(id) if !seen.contains(&id) => { + self.pull_and_process_prompt(id).await + } + + _ => (), + }; } } } diff --git a/prompting-client/src/snapd_client/mod.rs b/prompting-client/src/snapd_client/mod.rs index 1273952..0adc2d4 100644 --- a/prompting-client/src/snapd_client/mod.rs +++ b/prompting-client/src/snapd_client/mod.rs @@ -13,8 +13,8 @@ pub mod interfaces; mod prompt; pub use prompt::{ - Action, Lifespan, Prompt, PromptId, PromptReply, TypedPrompt, TypedPromptReply, TypedUiInput, - UiInput, + Action, Lifespan, Prompt, PromptId, PromptNotice, PromptReply, TypedPrompt, TypedPromptReply, + TypedUiInput, UiInput, }; const FEATURE_NAME: &str = "apparmor-prompting"; @@ -169,20 +169,35 @@ where /// /// Calling this method will update our [Self::notices_after] field when we successfully obtain /// new notices from snapd. - pub async fn pending_prompt_ids(&mut self) -> Result> { + /// + /// Notices from snapd have an optional top level key of 'last-data' which can contain + /// metadata that allows us to filter what IDs we need to look at. If the 'resolved' key is + /// present and if its value is 'replied' then this is snapd telling us that a prompt has + /// been actioned and we should clear any internal state we have associated with that ID. + pub async fn pending_prompt_notices(&mut self) -> Result> { let path = format!( "notices?types={NOTICE_TYPES}&timeout={LONG_POLL_TIMEOUT}&after={}", self.notices_after ); - let notices: Vec = self.client.get_json(&path).await?; - if let Some(n) = notices.last() { + let raw_notices: Vec = self.client.get_json(&path).await?; + if let Some(n) = raw_notices.last() { n.last_occurred.clone_into(&mut self.notices_after); } - debug!("received notices: {notices:?}"); + debug!("received notices: {raw_notices:?}"); + + let notices: Vec = raw_notices + .into_iter() + .map(|n| match n.last_data { + Some(LastData { resolved: Some(s) }) if s == "replied" => { + PromptNotice::Resolved(n.key) + } + _ => PromptNotice::Update(n.key), + }) + .collect(); - return Ok(notices.into_iter().map(|n| n.key).collect()); + return Ok(notices); // serde structs @@ -191,10 +206,19 @@ where struct Notice { key: PromptId, last_occurred: String, + #[serde(default)] + last_data: Option, #[allow(dead_code)] #[serde(flatten)] extra: HashMap, } + + #[derive(Debug, Deserialize)] + #[serde(rename_all = "kebab-case")] + struct LastData { + #[serde(default)] + resolved: Option, + } } /// Pull details for all pending prompts from snapd diff --git a/prompting-client/src/snapd_client/prompt.rs b/prompting-client/src/snapd_client/prompt.rs index 6c9de33..677f91c 100644 --- a/prompting-client/src/snapd_client/prompt.rs +++ b/prompting-client/src/snapd_client/prompt.rs @@ -205,6 +205,22 @@ impl TypedUiInput { #[derive(Debug, Default, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct PromptId(pub String); +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +pub enum PromptNotice { + Update(PromptId), + Resolved(PromptId), +} + +impl PromptNotice { + /// Flatten this notice into the enclosed ID if it was an update + pub fn into_option_id(self) -> Option { + match self { + Self::Update(id) => Some(id), + _ => None, + } + } +} + #[derive( Debug, Default, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Display, EnumString, )] diff --git a/prompting-client/tests/integration.rs b/prompting-client/tests/integration.rs index 0c45157..a7862bc 100644 --- a/prompting-client/tests/integration.rs +++ b/prompting-client/tests/integration.rs @@ -77,8 +77,11 @@ fn setup_test_dir(subdir: Option<&str>, files: &[(&str, &str)]) -> io::Result<(S macro_rules! expect_single_prompt { ($c:expr, $expected_path:expr, $expected_permissions:expr) => { async { - let mut pending = match $c.pending_prompt_ids().await { - Ok(pending) => pending, + let mut pending: Vec<_> = match $c.pending_prompt_notices().await { + Ok(pending) => pending + .into_iter() + .flat_map(|n| n.into_option_id()) + .collect(), Err(e) => panic!("error pulling pending prompts: {e}"), }; assert_eq!(pending.len(), 1, "expected a single prompt"); From 5e1076f6eaee181ac8bd522e0fd9725e9f8cc531 Mon Sep 17 00:00:00 2001 From: Innes Anderson-Morrison Date: Thu, 26 Sep 2024 11:01:19 +0100 Subject: [PATCH 2/3] fix: no longer storing prompts_to_drop so we dont buffer dropped prompts indefinitely --- prompting-client/src/daemon/worker.rs | 83 ++++++++------------------- 1 file changed, 23 insertions(+), 60 deletions(-) diff --git a/prompting-client/src/daemon/worker.rs b/prompting-client/src/daemon/worker.rs index 4529333..975ca4b 100644 --- a/prompting-client/src/daemon/worker.rs +++ b/prompting-client/src/daemon/worker.rs @@ -77,7 +77,6 @@ where rx_actioned_prompts: UnboundedReceiver, active_prompt: Arc>>, pending_prompts: VecDeque, - prompts_to_drop: Vec, dead_prompts: Vec, recv_timeout: Duration, ui: S, @@ -99,7 +98,6 @@ impl Worker { rx_actioned_prompts, active_prompt: Arc::new(Mutex::new(None)), pending_prompts: VecDeque::new(), - prompts_to_drop: Vec::new(), dead_prompts: Vec::new(), recv_timeout: RECV_TIMEOUT, ui: FlutterUi { cmd }, @@ -159,28 +157,18 @@ where } } + fn drop_prompt(&mut self, id: PromptId) { + let len = self.pending_prompts.len(); + self.pending_prompts.retain(|ep| ep.prompt.id() != &id); + if self.pending_prompts.len() < len { + info!(id=%id.0, "dropping prompt as it has already been actioned"); + } + } + fn process_update(&mut self, update: PromptUpdate) { match update { - PromptUpdate::Add(ep) if self.prompts_to_drop.contains(ep.prompt.id()) => { - info!(id=%ep.prompt.id().0, "dropping prompt as it has already been actioned"); - self.prompts_to_drop.retain(|id| id != ep.prompt.id()); - } - PromptUpdate::Add(ep) => self.pending_prompts.push_back(ep), - - PromptUpdate::Drop(id) => { - // If this prompt was already pending then remove it now, otherwise keep track of - // it as one to drop as and when it comes in - let len = self.pending_prompts.len(); - self.pending_prompts.retain(|ep| ep.prompt.id() != &id); - if self.pending_prompts.len() < len { - info!(id=%id.0, "dropping prompt as it has already been actioned"); - } else { - // TODO: do we need to worry about this growing unchecked if we get bogus - // prompt IDs through that are never going to be cleared out? - self.prompts_to_drop.push(id); - } - } + PromptUpdate::Drop(id) => self.drop_prompt(id), } } @@ -252,8 +240,12 @@ where match timeout(self.recv_timeout, self.rx_actioned_prompts.recv()).await { Ok(Some(ActionedPrompt::Actioned { id, others })) => { debug!(recv_id=%id.0, "reply sent for prompt"); - debug!(to_drop=?others, "updating prompts to drop"); - self.prompts_to_drop.extend(others); + if !others.is_empty() { + debug!(to_drop=?others, "dropping prompts actioned by last reply"); + for id in others { + self.drop_prompt(id); + } + } if self.dead_prompts.contains(&id) { warn!(id=%id.0, "reply was for a dead prompt"); @@ -361,7 +353,6 @@ mod tests { rx_actioned_prompts, active_prompt: Arc::new(Mutex::new(None)), pending_prompts: [ep("1")].into_iter().collect(), - prompts_to_drop: Vec::new(), dead_prompts: Vec::new(), recv_timeout: Duration::from_millis(100), ui: FlutterUi { @@ -380,32 +371,20 @@ mod tests { ); } - #[test_case(add("1"), &[], &[], &["1"], &[]; "add new prompt")] - #[test_case(add("1"), &[], &["1"], &[], &[]; "add prompt that we have been told to drop")] - #[test_case(drop_id("1"), &["1"], &[], &[], &[]; "drop for pending prompt")] - #[test_case(drop_id("1"), &[], &[], &[], &["1"]; "drop prompt not seen yet")] + #[test_case(add("1"), &[], &["1"]; "add new prompt")] + #[test_case(drop_id("1"), &["1"], &[]; "drop for pending prompt")] + #[test_case(drop_id("1"), &[], &[]; "drop prompt not seen yet")] #[test] - fn process_update( - update: PromptUpdate, - current_pending: &[&str], - current_to_drop: &[&str], - expected_pending: &[&str], - expected_to_drop: &[&str], - ) { + fn process_update(update: PromptUpdate, current_pending: &[&str], expected_pending: &[&str]) { let (_, rx_prompts) = unbounded_channel(); let (_, rx_actioned_prompts) = unbounded_channel(); let pending_prompts = current_pending.iter().map(|id| ep(id)).collect(); - let prompts_to_drop = current_to_drop - .iter() - .map(|id| PromptId(id.to_string())) - .collect(); let mut w = Worker { rx_prompts, rx_actioned_prompts, active_prompt: Arc::new(Mutex::new(None)), pending_prompts, - prompts_to_drop, dead_prompts: Vec::new(), recv_timeout: Duration::from_millis(100), ui: FlutterUi { @@ -422,23 +401,20 @@ mod tests { .iter() .map(|ep| ep.prompt.id().0.as_str()) .collect(); - let to_drop: Vec<&str> = w.prompts_to_drop.iter().map(|id| id.0.as_str()).collect(); assert_eq!(pending, expected_pending); - assert_eq!(to_drop, expected_to_drop); } - #[test_case("1", "1", 10, Recv::Success, &["drop-me"], &["dead"]; "recv expected within timeout")] - #[test_case("2", "1", 10, Recv::Unexpected, &["drop-me"], &["dead"]; "recv unexpected within timeout")] - #[test_case("dead", "1", 10, Recv::DeadPrompt, &["drop-me"], &[]; "recv dead prompt")] - #[test_case("1", "1", 200, Recv::Timeout, &[], &["dead", "1"]; "recv expected after timeout")] + #[test_case("1", "1", 10, Recv::Success, &["dead"]; "recv expected within timeout")] + #[test_case("2", "1", 10, Recv::Unexpected, &["dead"]; "recv unexpected within timeout")] + #[test_case("dead", "1", 10, Recv::DeadPrompt, &[]; "recv dead prompt")] + #[test_case("1", "1", 200, Recv::Timeout, &["dead", "1"]; "recv expected after timeout")] #[tokio::test] async fn wait_for_expected_prompt( sent_id: &str, expected_id: &str, sleep_ms: u64, expected_recv: Recv, - expected_prompts_to_drop: &[&str], expected_dead_prompts: &[&str], ) { let (_, rx_prompts) = unbounded_channel(); @@ -449,7 +425,6 @@ mod tests { rx_actioned_prompts, active_prompt: Arc::new(Mutex::new(None)), pending_prompts: VecDeque::new(), - prompts_to_drop: Vec::new(), dead_prompts: vec![PromptId("dead".to_string())], recv_timeout: Duration::from_millis(100), ui: FlutterUi { @@ -472,14 +447,6 @@ mod tests { .await; assert_eq!(recv, expected_recv); - assert_eq!( - w.prompts_to_drop, - Vec::from_iter( - expected_prompts_to_drop - .iter() - .map(|id| PromptId(id.to_string())) - ) - ); assert_eq!( w.dead_prompts, Vec::from_iter( @@ -509,7 +476,6 @@ mod tests { rx_actioned_prompts, active_prompt: Arc::new(Mutex::new(None)), pending_prompts: VecDeque::new(), - prompts_to_drop: Vec::new(), dead_prompts: vec![PromptId("dead".to_string())], recv_timeout: Duration::from_millis(100), ui: FlutterUi { @@ -548,7 +514,6 @@ mod tests { rx_actioned_prompts, active_prompt: Arc::new(Mutex::new(None)), pending_prompts: VecDeque::new(), - prompts_to_drop: Vec::new(), dead_prompts: vec![PromptId("dead".to_string())], recv_timeout: Duration::from_millis(100), ui: FlutterUi { @@ -659,7 +624,6 @@ mod tests { rx_actioned_prompts, active_prompt, pending_prompts: VecDeque::new(), - prompts_to_drop: Vec::new(), dead_prompts: vec![], recv_timeout: Duration::from_millis(100), ui, @@ -719,7 +683,6 @@ mod tests { rx_actioned_prompts, active_prompt, pending_prompts: [ep("1")].into_iter().collect(), - prompts_to_drop: Vec::new(), dead_prompts: vec![], recv_timeout: Duration::from_millis(100), ui: StubUi, From d19e67c81dcbca37eee9c3b0021239bacdfb0c2e Mon Sep 17 00:00:00 2001 From: Innes Anderson-Morrison Date: Thu, 26 Sep 2024 14:13:04 +0100 Subject: [PATCH 3/3] refactor: match on enum varients directly instead of using a method --- prompting-client/src/snapd_client/prompt.rs | 10 ---------- prompting-client/tests/integration.rs | 7 +++++-- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/prompting-client/src/snapd_client/prompt.rs b/prompting-client/src/snapd_client/prompt.rs index 677f91c..14e766d 100644 --- a/prompting-client/src/snapd_client/prompt.rs +++ b/prompting-client/src/snapd_client/prompt.rs @@ -211,16 +211,6 @@ pub enum PromptNotice { Resolved(PromptId), } -impl PromptNotice { - /// Flatten this notice into the enclosed ID if it was an update - pub fn into_option_id(self) -> Option { - match self { - Self::Update(id) => Some(id), - _ => None, - } - } -} - #[derive( Debug, Default, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Display, EnumString, )] diff --git a/prompting-client/tests/integration.rs b/prompting-client/tests/integration.rs index a7862bc..a9e2ea2 100644 --- a/prompting-client/tests/integration.rs +++ b/prompting-client/tests/integration.rs @@ -11,7 +11,7 @@ use prompting_client::{ prompt_sequence::MatchError, snapd_client::{ interfaces::{home::HomeInterface, SnapInterface}, - Action, Lifespan, PromptId, SnapdSocketClient, TypedPrompt, + Action, Lifespan, PromptId, PromptNotice, SnapdSocketClient, TypedPrompt, }, Error, Result, }; @@ -80,7 +80,10 @@ macro_rules! expect_single_prompt { let mut pending: Vec<_> = match $c.pending_prompt_notices().await { Ok(pending) => pending .into_iter() - .flat_map(|n| n.into_option_id()) + .flat_map(|n| match n { + PromptNotice::Update(id) => Some(id), + _ => None, + }) .collect(), Err(e) => panic!("error pulling pending prompts: {e}"), };