Skip to content

Commit

Permalink
feat: checking for last-data in notices from snapd (#116)
Browse files Browse the repository at this point in the history
The data we get from the snapd notices endpoint now supports an optional
`last-data` field which we can use to be a little more efficient in how
we pull prompt details.
```
    {
      "id": "1674",
      "user-id": 1000,
      "type": "interfaces-requests-prompt",
      "key": "000000000000014E",
      "first-occurred": "2024-09-13T13:05:05.189006863Z",
      "last-occurred": "2024-09-13T13:05:05.232809028Z",
      "last-repeated": "2024-09-13T13:05:05.232809028Z",
      "occurrences": 2,
      "last-data": {
        "resolved": "replied"
      },
      "expire-after": "168h0m0s"
    }
```

In the case where we get a notice with `last-data.resolved == "replied"`
we can treat this in the same way we currently treat getting a 404 when
pulling prompt details (i.e. removing our internal state for that
prompt).
  • Loading branch information
sminez authored Sep 26, 2024
2 parents c9f491c + d19e67c commit d5d2b6d
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 88 deletions.
22 changes: 16 additions & 6 deletions prompting-client/src/cli_actions/scripted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
home::{HomeConstraintsFilter, HomeInterface},
SnapInterface,
},
Action, PromptId, SnapdSocketClient, TypedPrompt, TypedPromptReply,
Action, PromptId, PromptNotice, SnapdSocketClient, TypedPrompt, TypedPromptReply,
},
Error, Result, SNAP_NAME,
};
Expand All @@ -19,10 +19,15 @@ use tracing::{debug, error, info, warn};
/// loop until at least one un-actioned prompt is encountered.
async fn grace_period_deny_and_error(snapd_client: &mut SnapdSocketClient) -> Result<()> {
loop {
let ids = snapd_client.pending_prompt_ids().await?;
let mut prompts = Vec::with_capacity(ids.len());
let notices = snapd_client.pending_prompt_notices().await?;
let mut prompts = Vec::with_capacity(notices.len());

for notice in notices {
let id = match notice {
PromptNotice::Update(id) => id,
_ => continue,
};

for id in ids {
let prompt = match snapd_client.prompt_details(&id).await {
Ok(p) => p,
Err(_) => continue,
Expand Down Expand Up @@ -76,8 +81,13 @@ impl ScriptedClient {

tokio::task::spawn(async move {
loop {
let pending = snapd_client.pending_prompt_ids().await.unwrap();
for id in pending {
let notices = snapd_client.pending_prompt_notices().await.unwrap();
for notice in notices {
let id = match notice {
PromptNotice::Update(id) => id,
_ => continue,
};

match snapd_client.prompt_details(&id).await {
Ok(TypedPrompt::Home(inner)) if filter.matches(&inner).is_success() => {
debug!("allowing read of script file");
Expand Down
31 changes: 19 additions & 12 deletions prompting-client/src/daemon/poll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! mapping into the data required for the prompt UI.
use crate::{
daemon::{EnrichedPrompt, PromptUpdate},
snapd_client::{PromptId, SnapMeta, SnapdSocketClient, TypedPrompt},
snapd_client::{PromptId, PromptNotice, SnapMeta, SnapdSocketClient, TypedPrompt},
Error,
};
use cached::proc_macro::cached;
Expand Down Expand Up @@ -66,8 +66,8 @@ impl PollLoop {

while self.running {
info!("polling for notices");
let pending = match self.client.pending_prompt_ids().await {
Ok(pending) => pending,
let notices = match self.client.pending_prompt_notices().await {
Ok(notices) => notices,

Err(Error::SnapdError {
status: StatusCode::FORBIDDEN,
Expand All @@ -94,9 +94,12 @@ impl PollLoop {
};

retries = 0;
debug!(?pending, "processing notices");
for id in pending {
self.pull_and_process_prompt(id).await;
debug!(?notices, "processing notices");
for notice in notices {
match notice {
PromptNotice::Update(id) => self.pull_and_process_prompt(id).await,
PromptNotice::Resolved(id) => self.send_update(PromptUpdate::Drop(id)),
}
}
}
}
Expand Down Expand Up @@ -161,18 +164,22 @@ impl PollLoop {
// the ones that we need to provide for the notices API, so we deliberately set up an
// overlap between pulling all pending prompts first before pulling pending prompt IDs
// and updating our internal `after` timestamp.
let pending = match self.client.pending_prompt_ids().await {
Ok(pending) => pending,
let notices = match self.client.pending_prompt_notices().await {
Ok(notices) => notices,
Err(error) => {
error!(%error, "unable to pull pending prompt ids");
return;
}
};

for id in pending {
if !seen.contains(&id) {
self.pull_and_process_prompt(id).await;
}
for notice in notices {
match notice {
PromptNotice::Update(id) if !seen.contains(&id) => {
self.pull_and_process_prompt(id).await
}

_ => (),
};
}
}
}
83 changes: 23 additions & 60 deletions prompting-client/src/daemon/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ where
rx_actioned_prompts: UnboundedReceiver<ActionedPrompt>,
active_prompt: Arc<Mutex<Option<TypedUiInput>>>,
pending_prompts: VecDeque<EnrichedPrompt>,
prompts_to_drop: Vec<PromptId>,
dead_prompts: Vec<PromptId>,
recv_timeout: Duration,
ui: S,
Expand All @@ -99,7 +98,6 @@ impl Worker<FlutterUi, SnapdSocketClient> {
rx_actioned_prompts,
active_prompt: Arc::new(Mutex::new(None)),
pending_prompts: VecDeque::new(),
prompts_to_drop: Vec::new(),
dead_prompts: Vec::new(),
recv_timeout: RECV_TIMEOUT,
ui: FlutterUi { cmd },
Expand Down Expand Up @@ -159,28 +157,18 @@ where
}
}

fn drop_prompt(&mut self, id: PromptId) {
let len = self.pending_prompts.len();
self.pending_prompts.retain(|ep| ep.prompt.id() != &id);
if self.pending_prompts.len() < len {
info!(id=%id.0, "dropping prompt as it has already been actioned");
}
}

fn process_update(&mut self, update: PromptUpdate) {
match update {
PromptUpdate::Add(ep) if self.prompts_to_drop.contains(ep.prompt.id()) => {
info!(id=%ep.prompt.id().0, "dropping prompt as it has already been actioned");
self.prompts_to_drop.retain(|id| id != ep.prompt.id());
}

PromptUpdate::Add(ep) => self.pending_prompts.push_back(ep),

PromptUpdate::Drop(id) => {
// If this prompt was already pending then remove it now, otherwise keep track of
// it as one to drop as and when it comes in
let len = self.pending_prompts.len();
self.pending_prompts.retain(|ep| ep.prompt.id() != &id);
if self.pending_prompts.len() < len {
info!(id=%id.0, "dropping prompt as it has already been actioned");
} else {
// TODO: do we need to worry about this growing unchecked if we get bogus
// prompt IDs through that are never going to be cleared out?
self.prompts_to_drop.push(id);
}
}
PromptUpdate::Drop(id) => self.drop_prompt(id),
}
}

Expand Down Expand Up @@ -252,8 +240,12 @@ where
match timeout(self.recv_timeout, self.rx_actioned_prompts.recv()).await {
Ok(Some(ActionedPrompt::Actioned { id, others })) => {
debug!(recv_id=%id.0, "reply sent for prompt");
debug!(to_drop=?others, "updating prompts to drop");
self.prompts_to_drop.extend(others);
if !others.is_empty() {
debug!(to_drop=?others, "dropping prompts actioned by last reply");
for id in others {
self.drop_prompt(id);
}
}

if self.dead_prompts.contains(&id) {
warn!(id=%id.0, "reply was for a dead prompt");
Expand Down Expand Up @@ -361,7 +353,6 @@ mod tests {
rx_actioned_prompts,
active_prompt: Arc::new(Mutex::new(None)),
pending_prompts: [ep("1")].into_iter().collect(),
prompts_to_drop: Vec::new(),
dead_prompts: Vec::new(),
recv_timeout: Duration::from_millis(100),
ui: FlutterUi {
Expand All @@ -380,32 +371,20 @@ mod tests {
);
}

#[test_case(add("1"), &[], &[], &["1"], &[]; "add new prompt")]
#[test_case(add("1"), &[], &["1"], &[], &[]; "add prompt that we have been told to drop")]
#[test_case(drop_id("1"), &["1"], &[], &[], &[]; "drop for pending prompt")]
#[test_case(drop_id("1"), &[], &[], &[], &["1"]; "drop prompt not seen yet")]
#[test_case(add("1"), &[], &["1"]; "add new prompt")]
#[test_case(drop_id("1"), &["1"], &[]; "drop for pending prompt")]
#[test_case(drop_id("1"), &[], &[]; "drop prompt not seen yet")]
#[test]
fn process_update(
update: PromptUpdate,
current_pending: &[&str],
current_to_drop: &[&str],
expected_pending: &[&str],
expected_to_drop: &[&str],
) {
fn process_update(update: PromptUpdate, current_pending: &[&str], expected_pending: &[&str]) {
let (_, rx_prompts) = unbounded_channel();
let (_, rx_actioned_prompts) = unbounded_channel();
let pending_prompts = current_pending.iter().map(|id| ep(id)).collect();
let prompts_to_drop = current_to_drop
.iter()
.map(|id| PromptId(id.to_string()))
.collect();

let mut w = Worker {
rx_prompts,
rx_actioned_prompts,
active_prompt: Arc::new(Mutex::new(None)),
pending_prompts,
prompts_to_drop,
dead_prompts: Vec::new(),
recv_timeout: Duration::from_millis(100),
ui: FlutterUi {
Expand All @@ -422,23 +401,20 @@ mod tests {
.iter()
.map(|ep| ep.prompt.id().0.as_str())
.collect();
let to_drop: Vec<&str> = w.prompts_to_drop.iter().map(|id| id.0.as_str()).collect();

assert_eq!(pending, expected_pending);
assert_eq!(to_drop, expected_to_drop);
}

#[test_case("1", "1", 10, Recv::Success, &["drop-me"], &["dead"]; "recv expected within timeout")]
#[test_case("2", "1", 10, Recv::Unexpected, &["drop-me"], &["dead"]; "recv unexpected within timeout")]
#[test_case("dead", "1", 10, Recv::DeadPrompt, &["drop-me"], &[]; "recv dead prompt")]
#[test_case("1", "1", 200, Recv::Timeout, &[], &["dead", "1"]; "recv expected after timeout")]
#[test_case("1", "1", 10, Recv::Success, &["dead"]; "recv expected within timeout")]
#[test_case("2", "1", 10, Recv::Unexpected, &["dead"]; "recv unexpected within timeout")]
#[test_case("dead", "1", 10, Recv::DeadPrompt, &[]; "recv dead prompt")]
#[test_case("1", "1", 200, Recv::Timeout, &["dead", "1"]; "recv expected after timeout")]
#[tokio::test]
async fn wait_for_expected_prompt(
sent_id: &str,
expected_id: &str,
sleep_ms: u64,
expected_recv: Recv,
expected_prompts_to_drop: &[&str],
expected_dead_prompts: &[&str],
) {
let (_, rx_prompts) = unbounded_channel();
Expand All @@ -449,7 +425,6 @@ mod tests {
rx_actioned_prompts,
active_prompt: Arc::new(Mutex::new(None)),
pending_prompts: VecDeque::new(),
prompts_to_drop: Vec::new(),
dead_prompts: vec![PromptId("dead".to_string())],
recv_timeout: Duration::from_millis(100),
ui: FlutterUi {
Expand All @@ -472,14 +447,6 @@ mod tests {
.await;

assert_eq!(recv, expected_recv);
assert_eq!(
w.prompts_to_drop,
Vec::from_iter(
expected_prompts_to_drop
.iter()
.map(|id| PromptId(id.to_string()))
)
);
assert_eq!(
w.dead_prompts,
Vec::from_iter(
Expand Down Expand Up @@ -509,7 +476,6 @@ mod tests {
rx_actioned_prompts,
active_prompt: Arc::new(Mutex::new(None)),
pending_prompts: VecDeque::new(),
prompts_to_drop: Vec::new(),
dead_prompts: vec![PromptId("dead".to_string())],
recv_timeout: Duration::from_millis(100),
ui: FlutterUi {
Expand Down Expand Up @@ -548,7 +514,6 @@ mod tests {
rx_actioned_prompts,
active_prompt: Arc::new(Mutex::new(None)),
pending_prompts: VecDeque::new(),
prompts_to_drop: Vec::new(),
dead_prompts: vec![PromptId("dead".to_string())],
recv_timeout: Duration::from_millis(100),
ui: FlutterUi {
Expand Down Expand Up @@ -659,7 +624,6 @@ mod tests {
rx_actioned_prompts,
active_prompt,
pending_prompts: VecDeque::new(),
prompts_to_drop: Vec::new(),
dead_prompts: vec![],
recv_timeout: Duration::from_millis(100),
ui,
Expand Down Expand Up @@ -719,7 +683,6 @@ mod tests {
rx_actioned_prompts,
active_prompt,
pending_prompts: [ep("1")].into_iter().collect(),
prompts_to_drop: Vec::new(),
dead_prompts: vec![],
recv_timeout: Duration::from_millis(100),
ui: StubUi,
Expand Down
38 changes: 31 additions & 7 deletions prompting-client/src/snapd_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub mod interfaces;
mod prompt;

pub use prompt::{
Action, Lifespan, Prompt, PromptId, PromptReply, TypedPrompt, TypedPromptReply, TypedUiInput,
UiInput,
Action, Lifespan, Prompt, PromptId, PromptNotice, PromptReply, TypedPrompt, TypedPromptReply,
TypedUiInput, UiInput,
};

const FEATURE_NAME: &str = "apparmor-prompting";
Expand Down Expand Up @@ -169,20 +169,35 @@ where
///
/// Calling this method will update our [Self::notices_after] field when we successfully obtain
/// new notices from snapd.
pub async fn pending_prompt_ids(&mut self) -> Result<Vec<PromptId>> {
///
/// Notices from snapd have an optional top level key of 'last-data' which can contain
/// metadata that allows us to filter what IDs we need to look at. If the 'resolved' key is
/// present and if its value is 'replied' then this is snapd telling us that a prompt has
/// been actioned and we should clear any internal state we have associated with that ID.
pub async fn pending_prompt_notices(&mut self) -> Result<Vec<PromptNotice>> {
let path = format!(
"notices?types={NOTICE_TYPES}&timeout={LONG_POLL_TIMEOUT}&after={}",
self.notices_after
);

let notices: Vec<Notice> = self.client.get_json(&path).await?;
if let Some(n) = notices.last() {
let raw_notices: Vec<Notice> = self.client.get_json(&path).await?;
if let Some(n) = raw_notices.last() {
n.last_occurred.clone_into(&mut self.notices_after);
}

debug!("received notices: {notices:?}");
debug!("received notices: {raw_notices:?}");

let notices: Vec<PromptNotice> = raw_notices
.into_iter()
.map(|n| match n.last_data {
Some(LastData { resolved: Some(s) }) if s == "replied" => {
PromptNotice::Resolved(n.key)
}
_ => PromptNotice::Update(n.key),
})
.collect();

return Ok(notices.into_iter().map(|n| n.key).collect());
return Ok(notices);

// serde structs

Expand All @@ -191,10 +206,19 @@ where
struct Notice {
key: PromptId,
last_occurred: String,
#[serde(default)]
last_data: Option<LastData>,
#[allow(dead_code)]
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}

#[derive(Debug, Deserialize)]
#[serde(rename_all = "kebab-case")]
struct LastData {
#[serde(default)]
resolved: Option<String>,
}
}

/// Pull details for all pending prompts from snapd
Expand Down
6 changes: 6 additions & 0 deletions prompting-client/src/snapd_client/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ impl TypedUiInput {
#[derive(Debug, Default, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct PromptId(pub String);

#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub enum PromptNotice {
Update(PromptId),
Resolved(PromptId),
}

#[derive(
Debug, Default, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, Display, EnumString,
)]
Expand Down
Loading

0 comments on commit d5d2b6d

Please sign in to comment.