diff --git a/cfdp-core/Cargo.toml b/cfdp-core/Cargo.toml index 05f87a5..8891ca5 100644 --- a/cfdp-core/Cargo.toml +++ b/cfdp-core/Cargo.toml @@ -17,9 +17,13 @@ num-traits = "0.2" pathdiff = "~0.2" tempfile = "~3.3" thiserror = "~1.0" +tokio = {version = "1.28.2", features = ["rt-multi-thread", "net", "sync", "time"]} +async-trait = "0.1" [dev-dependencies] rstest = "0.15.0" rstest_reuse = "0.5.0" signal-hook = "~0.3" dirs = "~4.0" +tokio = {version = "1.28.2", features = ["parking_lot", "macros"]} +env_logger = "0.10" \ No newline at end of file diff --git a/cfdp-core/src/daemon.rs b/cfdp-core/src/daemon.rs index 4cfff6f..68fcf51 100644 --- a/cfdp-core/src/daemon.rs +++ b/cfdp-core/src/daemon.rs @@ -6,15 +6,21 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, - thread::{self, JoinHandle}, time::{Duration, Instant}, }; use camino::Utf8PathBuf; -use crossbeam_channel::{bounded, unbounded, Receiver, Select, Sender, TryRecvError}; use log::{error, info, warn}; use num_traits::FromPrimitive; use thiserror::Error; +use tokio::{ + select, + sync::{ + mpsc::{channel, Receiver, Sender}, + oneshot, + }, + task::JoinHandle, +}; use crate::{ filestore::{ChecksumType, FileStore}, @@ -89,11 +95,11 @@ fn construct_metadata(req: PutRequest, config: EntityConfig, file_size: u64) -> } } -#[derive(Debug, Clone)] +#[derive(Debug)] /// Possible User Primitives sent from a end user application via the user primitive channel pub enum UserPrimitive { /// Initiate a Put transaction with the specified [PutRequest] configuration. - Put(PutRequest, Sender), + Put(PutRequest, oneshot::Sender), /// Cancel the give transaction. Cancel(TransactionID), /// Suspend operations of the given transaction. @@ -101,7 +107,7 @@ pub enum UserPrimitive { /// Resume operations of the given transaction. Resume(TransactionID), /// Report progress of the given transaction. - Report(TransactionID, Sender), + Report(TransactionID, oneshot::Sender), /// Send the designated PromptPDU from the given transaction. /// This primitive is only valid for [Send](crate::transaction::SendTransaction) transactions Prompt(TransactionID, NakOrKeepAlive), @@ -114,7 +120,7 @@ enum Command { Cancel, Suspend, Resume, - Report(Sender), + Report(oneshot::Sender), Prompt(NakOrKeepAlive), // may find a use for abandon in the future. #[allow(unused)] @@ -343,10 +349,10 @@ impl Daemon { indication_tx: Sender, ) -> Self { let mut transport_tx_map: HashMap> = HashMap::new(); - let (pdu_send, pdu_receive) = unbounded(); + let (pdu_send, pdu_receive) = channel(100); let terminate = Arc::new(AtomicBool::new(false)); for (vec, mut transport) in transport_map.into_iter() { - let (remote_send, remote_receive) = bounded(1); + let (remote_send, remote_receive) = channel(1); vec.iter().for_each(|id| { transport_tx_map.insert(*id, remote_send.clone()); @@ -354,7 +360,9 @@ impl Daemon { let signal = terminate.clone(); let sender = pdu_send.clone(); - thread::spawn(move || transport.pdu_handler(signal, sender, remote_receive)); + tokio::task::spawn(async move { + transport.pdu_handler(signal, sender, remote_receive).await + }); } Self { transaction_handles: vec![], @@ -370,6 +378,7 @@ impl Daemon { primitive_rx, } } + fn spawn_receive_transaction( header: &PDUHeader, transport_tx: Sender<(VariableID, PDU)>, @@ -377,7 +386,7 @@ impl Daemon { filestore: Arc, indication_tx: Sender, ) -> Result> { - let (transaction_tx, transaction_rx) = unbounded(); + let (transaction_tx, mut transaction_rx) = channel(100); let config = TransactionConfig { source_entity_id: header.source_entity_id, @@ -394,10 +403,10 @@ impl Daemon { ack_timeout: entity_config.ack_timeout, nak_timeout: entity_config.nak_timeout, }; - let name = format!( + /* let name = format!( "({}, {})", &config.source_entity_id, &config.sequence_number - ); + );*/ let mut transaction = RecvTransaction::new( config, entity_config.nak_procedure, @@ -406,75 +415,64 @@ impl Daemon { ); let id = transaction.id(); - let handle = thread::Builder::new().name(name).spawn(move || { + // tokio tasks can have names but that seems an unsable feature + let handle = tokio::task::spawn(async move { transaction.send_report(None)?; - let mut sel = Select::new(); - let rx_select_id = sel.recv(&transaction_rx); - - let mut tx_select_id = Option::::None; - while transaction.get_state() != TransactionState::Terminated { - if transaction.has_pdu_to_send() { - tx_select_id.get_or_insert_with(||sel.send(&transport_tx)); - } else if let Some(idx) = tx_select_id.take() { - sel.remove(idx); - } - let timeout = transaction.until_timeout(); - let oper = sel.ready_timeout(timeout); - match oper { - Err(_) => { - transaction.handle_timeout()?; - } - Ok(id) => { - if tx_select_id == Some(id) { - transaction.send_pdu(&transport_tx)?; - } else if id == rx_select_id { - match transaction_rx.try_recv() { - Ok(command) => { - match command { - Command::Pdu(pdu) => { - match transaction.process_pdu(pdu) { - Ok(()) => {} - Err(err @ TransactionError::UnexpectedPDU(..)) => { - info!("Transaction {} Received Unexpected PDU: {err}", transaction.id()); - // log some info on the unexpected PDU? - } - Err(err) => return Err(err), - } - } - Command::Resume => transaction.resume()?, - Command::Cancel => transaction.cancel()?, - Command::Suspend => transaction.suspend()?, - Command::Abandon => transaction.shutdown(), - Command::Report(sender) => { - transaction.send_report(Some(sender))? - } - Command::Prompt(_) =>{ - // prompt is a no-op for a receive transaction. + select! { + permit = transport_tx.reserve(), if transaction.has_pdu_to_send() => { + if let Ok(permit) = permit { + transaction.send_pdu(permit)? + } else { + log::error!("Channel to transport severed for transaction {}", transaction.id()); + break; + } + }, + command = transaction_rx.recv() => { + if let Some(command) = command { + match command { + Command::Pdu(pdu) => { + match transaction.process_pdu(pdu) { + Ok(()) => {} + Err(err @ TransactionError::UnexpectedPDU(..)) => { + info!("Transaction {} Received Unexpected PDU: {err}", transaction.id()); + // log some info on the unexpected PDU? } + Err(err) => return Err(err), } } - Err(TryRecvError::Empty) => { - // this normally should not happen + Command::Resume => transaction.resume()?, + Command::Cancel => transaction.cancel()?, + Command::Suspend => transaction.suspend()?, + Command::Abandon => transaction.shutdown(), + Command::Report(sender) => { + transaction.send_report(Some(sender))? } - Err(TryRecvError::Disconnected) => { - // Really do not expect to be in this situation - // probably the thread should exit - panic!( - "Connection to Daemon Severed for Transaction {}", - transaction.id() - ); + Command::Prompt(_) =>{ + // prompt is a no-op for a receive transaction. } - }; + } + } else { + log::warn!( + "Connection to Daemon Severed for Transaction {}", + transaction.id() + ); + break; } } + _= tokio::time::sleep(timeout) => { + transaction.handle_timeout()?; + }, + + } } + transaction.send_report(None)?; Ok(transaction.id()) - })?; + }); Ok((id, transaction_tx, handle)) } @@ -489,7 +487,7 @@ impl Daemon { filestore: Arc, indication_tx: Sender, ) -> Result> { - let (transaction_tx, transaction_rx) = unbounded(); + let (transaction_tx, mut transaction_rx) = channel(10); let id = TransactionID(source_entity_id, sequence_number); let destination_entity_id = request.destination_entity_id; @@ -511,127 +509,97 @@ impl Daemon { }; let mut metadata = construct_metadata(request, entity_config, 0_u64); - let handle = thread::Builder::new() - .name(format!( - "({}, {})", - config.source_entity_id, config.sequence_number - )) - .spawn(move || { - let file_size = match &metadata.source_filename.file_name().is_none() { - true => 0_u64, - false => filestore.get_size(&metadata.source_filename)?, - }; - - metadata.file_size = file_size; - config.file_size_flag = match metadata.file_size <= u32::MAX.into() { - true => FileSizeFlag::Small, - false => FileSizeFlag::Large, - }; - - let mut transaction = - SendTransaction::new(config, metadata, filestore, indication_tx)?; - transaction.send_report(None)?; - let mut sel = Select::new(); - let rx_select_id = sel.recv(&transaction_rx); + let handle = tokio::task::spawn(async move { + let file_size = match &metadata.source_filename.file_name().is_none() { + true => 0_u64, + false => filestore.get_size(&metadata.source_filename)?, + }; - let mut tx_select_id = Option::::None; + metadata.file_size = file_size; + config.file_size_flag = match metadata.file_size <= u32::MAX.into() { + true => FileSizeFlag::Small, + false => FileSizeFlag::Large, + }; - while transaction.get_state() != TransactionState::Terminated { - if transaction.has_pdu_to_send() { - if tx_select_id.is_none() { - tx_select_id = Some(sel.send(&transport_tx)); - } - } else if let Some(idx) = tx_select_id { - sel.remove(idx); - tx_select_id = None; - } + let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx)?; + transaction.send_report(None)?; - let timeout = transaction.until_timeout(); - let oper = sel.ready_timeout(timeout); + while transaction.get_state() != TransactionState::Terminated { + let timeout = transaction.until_timeout(); - match oper { - Err(_) => { - transaction.handle_timeout()?; + select! { + permit = transport_tx.reserve(), if transaction.has_pdu_to_send() => { + if let Ok(permit) = permit { + transaction.send_pdu(permit)?; + } else { + log::error!("Connection to transport severed for transaction {}", transaction.id()); + break; } - Ok(id) => { - if tx_select_id == Some(id) { - // println!("transport_tx capacity :{}", transport_tx.len()); - transaction.send_pdu(&transport_tx)?; - // println!("dupa transport_tx capacity :{}", transport_tx.len()); - } else if id == rx_select_id { - match transaction_rx.try_recv() { - Ok(command) => { - match command { - Command::Pdu(pdu) => { - match transaction.process_pdu(pdu) { - Ok(()) => {} - Err( - err @ TransactionError::UnexpectedPDU(..), - ) => { - info!("Recieved Unexpected PDU: {err}"); - // log some info on the unexpected PDU? - } - Err(err) => { - return Err(err); - } - } - } - Command::Resume => transaction.resume()?, - Command::Cancel => transaction.cancel()?, - Command::Suspend => transaction.suspend()?, - Command::Abandon => transaction.shutdown(), - Command::Report(sender) => { - transaction.send_report(Some(sender))? - } - Command::Prompt(option) => { - transaction.prepare_prompt(option) - } + }, + + command = transaction_rx.recv() => { + if let Some(command) = command { + match command { + Command::Pdu(pdu) => { + match transaction.process_pdu(pdu) { + Ok(()) => {} + Err( + err @ TransactionError::UnexpectedPDU(..), + ) => { + info!("Recieved Unexpected PDU: {err}"); + // log some info on the unexpected PDU? + } + Err(err) => { + return Err(err); } } - Err(TryRecvError::Empty) => { - // nothing for us at this time just sleep - } - Err(TryRecvError::Disconnected) => { - // Really do not expect to be in this situation - // probably the thread should exit - panic!( - "Connection to Daemon Severed for Transaction {}", - transaction.id() - ) - } + } + Command::Resume => transaction.resume()?, + Command::Cancel => transaction.cancel()?, + Command::Suspend => transaction.suspend()?, + Command::Abandon => transaction.shutdown(), + Command::Report(sender) => { + transaction.send_report(Some(sender))? + } + Command::Prompt(option) => { + transaction.prepare_prompt(option) } } + } else { + panic!( + "Connection to Daemon Severed for Transaction {}", + transaction.id() + ) } - }; - } - transaction.send_report(None)?; - Ok(id) - })?; + }, + _ = tokio::time::sleep(timeout) => { + transaction.handle_timeout()?; + } + }; + } + transaction.send_report(None)?; + Ok(id) + }); Ok((id, transaction_tx, handle)) } /// This function will consist of the main logic loop in any daemon process. - pub fn manage_transactions(&mut self) -> Result<(), Box> { + pub async fn manage_transactions(&mut self) -> Result<(), Box> { let mut sequence_num = self.sequence_num; // Create the selection object to check if any messages are available. // the returned index will be used to determine which action to take. - let mut selector = Select::new(); - let transport = selector.recv(&self.transport_rx); - let user_primitive = selector.recv(&self.primitive_rx); - // mapping of unique transaction ids to channels used to talk to each transaction let mut transaction_channels: HashMap> = HashMap::new(); let mut cleanup = Instant::now(); loop { - match selector.select() { - oper if oper.index() == transport => { - match oper.recv(&self.transport_rx) { - Ok(pdu) => { - // find the entity this entity will be sending too. + select! { + pdu = self.transport_rx.recv() => { + if let Some(pdu) = pdu { + // find the entity this entity will be sending too. // If this PDU is to the sender, we send to the destination // if this PDU is to the receiver, we send to the source let transport_entity = match &pdu.header.direction { @@ -680,7 +648,7 @@ impl Daemon { } }; - match channel.send(Command::Pdu(pdu.clone())) { + match channel.send(Command::Pdu(pdu.clone())).await { Ok(()) => {} Err(_) => { // the transaction is completed. @@ -705,7 +673,7 @@ impl Daemon { )?; self.transaction_handles.push(handle); - new_channel.send(Command::Pdu(pdu.clone()))?; + new_channel.send(Command::Pdu(pdu.clone())).await?; // update the dict to have the new channel transaction_channels.insert(key, new_channel); } else { @@ -716,9 +684,8 @@ impl Daemon { } } }; - } - Err(_err) => { - // the channel is empty and disconnected + } else { + // the channel is empty and disconnected // this should only happen when we are cleaning up // but may happen when the transport crashes or quits if !self.terminate.load(Ordering::Relaxed) { @@ -726,93 +693,84 @@ impl Daemon { } break; } - }; - } - // received a UserPrimitive from the user implementation - oper if oper.index() == user_primitive => { - match oper.recv(&self.primitive_rx) { - Ok(primitive) => { - match primitive { - UserPrimitive::Put(request, put_sender) => { - let sequence_number = sequence_num.get_and_increment(); - - let entity_config = self - .entity_configs - .get(&request.destination_entity_id) - .unwrap_or(&self.default_config) - .clone(); - - if let Some(transport_tx) = self - .transport_tx_map - .get(&request.destination_entity_id) - .cloned() - { - let (id, sender, handle) = Self::spawn_send_transaction( - request, - sequence_number, - self.entity_id, - transport_tx, - entity_config, - self.filestore.clone(), - self.indication_tx.clone(), - )?; - - self.transaction_handles.push(handle); - transaction_channels.insert(id, sender); - - // ignore the possible error if the user disconnected; - let _ = put_sender.send(id); - } else { - warn!( - "No Transport available for EntityID: {}. Skipping transaction creation.", - request.destination_entity_id - ) - } + }, + primitive = self.primitive_rx.recv() => { + if let Some(primitive) = primitive { + match primitive { + UserPrimitive::Put(request, put_sender) => { + let sequence_number = sequence_num.get_and_increment(); + + let entity_config = self + .entity_configs + .get(&request.destination_entity_id) + .unwrap_or(&self.default_config) + .clone(); + + if let Some(transport_tx) = self + .transport_tx_map + .get(&request.destination_entity_id) + .cloned() + { + let (id, sender, handle) = Self::spawn_send_transaction( + request, + sequence_number, + self.entity_id, + transport_tx, + entity_config, + self.filestore.clone(), + self.indication_tx.clone(), + )?; + self.transaction_handles.push(handle); + transaction_channels.insert(id, sender); + + // ignore the possible error if the user disconnected; + let _ = put_sender.send(id); + } else { + warn!( + "No Transport available for EntityID: {}. Skipping transaction creation.", + request.destination_entity_id + ) } - UserPrimitive::Cancel(id) => { - if let Some(channel) = transaction_channels.get(&id) { - channel.send(Command::Cancel)?; - } + } + UserPrimitive::Cancel(id) => { + if let Some(channel) = transaction_channels.get(&id) { + channel.send(Command::Cancel).await?; } - UserPrimitive::Suspend(id) => { - if let Some(channel) = transaction_channels.get(&id) { - channel.send(Command::Suspend)?; - } + } + UserPrimitive::Suspend(id) => { + if let Some(channel) = transaction_channels.get(&id) { + channel.send(Command::Suspend).await?; } - UserPrimitive::Resume(id) => { - if let Some(channel) = transaction_channels.get(&id) { - channel.send(Command::Resume)?; - } + } + UserPrimitive::Resume(id) => { + if let Some(channel) = transaction_channels.get(&id) { + channel.send(Command::Resume).await?; } - UserPrimitive::Report(id, report_sender) => { - if let Some(channel) = transaction_channels.get(&id) { - // It's possible for the user to ask for a report after the Transaction is finished - // but before the channel is cleaned up. - // for now ignore errors until a better solution is found. - // maybe possible to trigger a cleanup immediately after a transaction finishes? - let _ = channel.send(Command::Report(report_sender)); - } + } + UserPrimitive::Report(id, report_sender) => { + if let Some(channel) = transaction_channels.get(&id) { + // It's possible for the user to ask for a report after the Transaction is finished + // but before the channel is cleaned up. + // for now ignore errors until a better solution is found. + // maybe possible to trigger a cleanup immediately after a transaction finishes? + let _ = channel.send(Command::Report(report_sender)).await; } - UserPrimitive::Prompt(id, option) => { - if let Some(channel) = transaction_channels.get(&id) { - channel.send(Command::Prompt(option))?; - } + } + UserPrimitive::Prompt(id, option) => { + if let Some(channel) = transaction_channels.get(&id) { + channel.send(Command::Prompt(option)).await?; } - }; - } - Err(_err) => { - // The channel is disconnected - // this is only an issue if the channel was the user interface - error!("User interface disconnected from daemon."); - self.terminate.store(true, Ordering::Relaxed); - break; - } + } + }; + } else { + // The channel is disconnected + // this is only an issue if the channel was the user interface + error!("User interface disconnected from daemon."); + self.terminate.store(true, Ordering::Relaxed); + break; } } - _ => { - unreachable!() - } - }; + } // join any handles that have completed // maybe should only run every so often? @@ -821,7 +779,7 @@ impl Daemon { while ind < self.transaction_handles.len() { if self.transaction_handles[ind].is_finished() { let handle = self.transaction_handles.remove(ind); - match handle.join() { + match handle.await { Ok(Ok(id)) => { // remove the channel for this transaction if it is complete let _ = transaction_channels.remove(&id); @@ -842,13 +800,13 @@ impl Daemon { // a final cleanup while let Some(handle) = self.transaction_handles.pop() { - match handle.join() { + match handle.await { Ok(Ok(id)) => { // remove the channel for this transaction if it is complete let _ = transaction_channels.remove(&id); } Ok(Err(err)) => { - info!("Error occured during transaction: {}", err) + info!("Error occurred during transaction: {}", err) } Err(_) => error!("Unable to join handle!"), }; diff --git a/cfdp-core/src/transaction/error.rs b/cfdp-core/src/transaction/error.rs index c257f40..f82bc1f 100644 --- a/cfdp-core/src/transaction/error.rs +++ b/cfdp-core/src/transaction/error.rs @@ -1,7 +1,7 @@ use std::num::TryFromIntError; -use crossbeam_channel::SendError; use thiserror::Error; +use tokio::sync::mpsc::error::SendError; use crate::{ daemon::{Indication, Report}, @@ -20,7 +20,7 @@ pub enum TransactionError { #[error("Error Communicating with transport: {0}")] Transport(#[from] Box>), - #[error("Error transfering Indication {0}")] + #[error("Error transferring Indication {0}")] UserMessage(#[from] Box>), #[error("No open file in transaction: {0:?}")] diff --git a/cfdp-core/src/transaction/recv.rs b/cfdp-core/src/transaction/recv.rs index 665d41a..1e071ae 100644 --- a/cfdp-core/src/transaction/recv.rs +++ b/cfdp-core/src/transaction/recv.rs @@ -7,8 +7,11 @@ use std::{ }; use camino::Utf8PathBuf; -use crossbeam_channel::Sender; use log::{debug, warn}; +use tokio::sync::{ + mpsc::{Permit, Sender}, + oneshot, +}; use super::{ config::{Metadata, TransactionConfig, TransactionState}, @@ -169,26 +172,23 @@ impl RecvTransaction { self.timer.until_timeout() } - pub(crate) fn send_pdu( - &mut self, - transport_tx: &Sender<(VariableID, PDU)>, - ) -> TransactionResult<()> { + pub(crate) fn send_pdu(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { if self.prompt.is_some() { - self.answer_prompt(transport_tx)?; + self.answer_prompt(permit)?; } else { match self.recv_state { RecvState::ReceiveData => { if self.ack.is_some() { - self.send_ack_eof(transport_tx)?; + self.send_ack_eof(permit)?; } else if !self.naks.is_empty() { - self.send_naks(transport_tx)?; + self.send_naks(permit)?; } } RecvState::Finished | RecvState::Cancelled => { if self.ack.is_some() { - self.send_ack_eof(transport_tx)?; + self.send_ack_eof(permit)?; } else if self.finished.as_ref().map_or(false, |x| x.1) { - self.send_finished(transport_tx)?; + self.send_finished(permit)?; } } } @@ -240,12 +240,12 @@ impl RecvTransaction { Ok(()) } - fn answer_prompt(&mut self, transport_tx: &Sender<(VariableID, PDU)>) -> TransactionResult<()> { + fn answer_prompt(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { if let Some(prompt) = self.prompt.take() { match prompt.nak_or_keep_alive { NakOrKeepAlive::Nak => { self.naks = self.get_all_naks(); - self.send_naks(transport_tx)?; + self.send_naks(permit)?; } NakOrKeepAlive::KeepAlive => { let progress = self.get_progress(); @@ -264,7 +264,7 @@ impl RecvTransaction { let destination = header.source_entity_id; let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); } } } @@ -309,12 +309,12 @@ impl RecvTransaction { } } - pub fn send_report(&self, sender: Option>) -> TransactionResult<()> { + pub fn send_report(&self, sender: Option>) -> TransactionResult<()> { let report = self.generate_report(); if let Some(channel) = sender { - channel.send(report.clone())?; + let _ = channel.send(report.clone()); } - self.indication_tx.send(Indication::Report(report))?; + self.send_indication(Indication::Report(report)); Ok(()) } @@ -483,11 +483,10 @@ impl RecvTransaction { self.timer.inactivity.pause(); self.state = TransactionState::Suspended; - self.indication_tx - .send(Indication::Suspended(SuspendIndication { - id: self.id(), - condition: self.condition, - }))?; + self.send_indication(Indication::Suspended(SuspendIndication { + id: self.id(), + condition: self.condition, + })); Ok(()) } @@ -504,11 +503,10 @@ impl RecvTransaction { RecvState::Finished | RecvState::Cancelled => self.timer.reset_ack(), } self.state = TransactionState::Active; - self.indication_tx - .send(Indication::Resumed(ResumeIndication { - id: self.id(), - progress: self.get_progress(), - }))?; + self.send_indication(Indication::Resumed(ResumeIndication { + id: self.id(), + progress: self.get_progress(), + })); Ok(()) } @@ -551,7 +549,7 @@ impl RecvTransaction { transaction_status: self.status, }); } - fn send_ack_eof(&mut self, transport_tx: &Sender<(VariableID, PDU)>) -> TransactionResult<()> { + fn send_ack_eof(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { if let Some(ack) = self.ack.take() { let payload = PDUPayload::Directive(Operations::Ack(ack)); let payload_len = payload.encoded_len(self.config.file_size_flag); @@ -566,7 +564,7 @@ impl RecvTransaction { let destination = header.source_entity_id; let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); } Ok(()) } @@ -600,7 +598,7 @@ impl RecvTransaction { )); } - fn send_finished(&mut self, transport_tx: &Sender<(VariableID, PDU)>) -> TransactionResult<()> { + fn send_finished(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { self.timer.restart_ack(); if let Some((finished, true)) = &self.finished { let payload = PDUPayload::Directive(Operations::Finished(finished.clone())); @@ -617,7 +615,7 @@ impl RecvTransaction { let destination = header.source_entity_id; let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); debug!("Transaction {0} sent Finished", self.id()); self.set_finished_flag(false); } @@ -659,7 +657,7 @@ impl RecvTransaction { naks } - fn send_naks(&mut self, transport_tx: &Sender<(VariableID, PDU)>) -> TransactionResult<()> { + fn send_naks(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { if self.nak_received_file_size == self.received_file_size { if self.timer.nak.limit_reached() { self.handle_fault(Condition::NakLimitReached)?; @@ -709,7 +707,7 @@ impl RecvTransaction { let destination = header.source_entity_id; let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); Ok(()) } @@ -770,14 +768,13 @@ impl RecvTransaction { out }; // send indication this transaction is finished. - self.indication_tx - .send(Indication::Finished(FinishedIndication { - id: self.id(), - report: self.generate_report(), - file_status: self.file_status, - delivery_code: self.delivery_code, - filestore_responses: self.filestore_response.clone(), - }))?; + self.send_indication(Indication::Finished(FinishedIndication { + id: self.id(), + report: self.generate_report(), + file_status: self.file_status, + delivery_code: self.delivery_code, + filestore_responses: self.filestore_response.clone(), + })); Ok(()) } @@ -796,13 +793,11 @@ impl RecvTransaction { let (offset, length) = self.store_file_data(filedata)?; - self.indication_tx.send(Indication::FileSegmentRecv( - FileSegmentIndication { - id: self.id(), - offset, - length: length as u64, - }, - ))?; + self.send_indication(Indication::FileSegmentRecv(FileSegmentIndication { + id: self.id(), + offset, + length: length as u64, + })); if self.nak_procedure == NakProcedure::Immediate && !self.eof_received() { if self.timer.nak.timeout_occured() { @@ -829,7 +824,7 @@ impl RecvTransaction { self.prepare_ack_eof(); self.checksum = Some(eof.checksum); - self.indication_tx.send(Indication::EoFRecv(self.id()))?; + self.send_indication(Indication::EoFRecv(self.id())); if self.condition == Condition::NoError { self.check_file_size(eof.file_size)?; @@ -900,7 +895,7 @@ impl RecvTransaction { .map_err(FileStoreError::UTF8)? .into(); - self.indication_tx.send(Indication::MetadataRecv( + self.send_indication(Indication::MetadataRecv( MetadataRecvIndication { id: self.id(), source_filename: source_filename.clone(), @@ -909,7 +904,7 @@ impl RecvTransaction { transmission_mode: self.config.transmission_mode, user_messages: message_to_user.clone().collect(), }, - ))?; + )); self.metadata = Some(Metadata { source_filename, @@ -958,13 +953,11 @@ impl RecvTransaction { match payload { PDUPayload::FileData(filedata) => { let (offset, length) = self.store_file_data(filedata)?; - self.indication_tx.send(Indication::FileSegmentRecv( - FileSegmentIndication { - id: self.id(), - offset, - length: length as u64, - }, - ))?; + self.send_indication(Indication::FileSegmentRecv(FileSegmentIndication { + id: self.id(), + offset, + length: length as u64, + })); Ok(()) } PDUPayload::Directive(operation) => { @@ -996,7 +989,7 @@ impl RecvTransaction { self.condition = eof.condition; self.checksum = Some(eof.checksum); - self.indication_tx.send(Indication::EoFRecv(self.id()))?; + self.send_indication(Indication::EoFRecv(self.id())); if self.condition == Condition::NoError { self.check_file_size(eof.file_size)?; @@ -1044,7 +1037,7 @@ impl RecvTransaction { .map_err(FileStoreError::UTF8)? .into(); - self.indication_tx.send(Indication::MetadataRecv( + self.send_indication(Indication::MetadataRecv( MetadataRecvIndication { id: self.id(), source_filename: source_filename.clone(), @@ -1053,7 +1046,7 @@ impl RecvTransaction { transmission_mode: self.config.transmission_mode, user_messages: message_to_user.clone().collect(), }, - ))?; + )); self.metadata = Some(Metadata { source_filename, @@ -1121,6 +1114,12 @@ impl RecvTransaction { } Ok(()) } + + //send the indication in another task, no need to wait for it + fn send_indication(&self, indication: Indication) { + let tx = self.indication_tx.clone(); + tokio::task::spawn(async move { tx.send(indication).await }); + } } #[cfg(test)] @@ -1141,10 +1140,9 @@ mod test { use super::*; use camino::{Utf8Path, Utf8PathBuf}; - use crossbeam_channel::unbounded; use rstest::{fixture, rstest}; - use std::thread; use tempfile::TempDir; + use tokio::sync::mpsc::channel; #[fixture] #[once] @@ -1154,14 +1152,14 @@ mod test { #[rstest] fn header(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (message_tx, _) = unbounded(); + let (indication_tx, _) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let payload_len = 12; let expected = PDUHeader { version: U3::One, @@ -1190,14 +1188,14 @@ mod test { #[rstest] fn test_if_file_transfer(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (message_tx, _) = unbounded(); + let (indication_tx, _) = channel(10); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); assert_eq!( TransactionStatus::Undefined, transaction.get_status().clone() @@ -1220,14 +1218,14 @@ mod test { #[rstest] fn store_filedata(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (message_tx, _) = unbounded(); + let (indication_tx, _) = channel(10); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let input = vec![0, 5, 255, 99]; let data = FileDataPDU::Unsegmented(UnsegmentedFileData { @@ -1251,7 +1249,7 @@ mod test { #[rstest] fn finalize_file(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (message_tx, _) = unbounded(); + let (indication_tx, _) = channel(10); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), @@ -1268,7 +1266,7 @@ mod test { config, NakProcedure::Deferred, filestore.clone(), - message_tx, + indication_tx, ); transaction.metadata = Some(Metadata { closure_requested: false, @@ -1309,13 +1307,14 @@ mod test { } #[rstest] - fn test_naks( + #[tokio::test] + async fn test_naks( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(FileSizeFlag::Small, FileSizeFlag::Large)] file_size_flag: FileSizeFlag, ) { - let (transport_tx, transport_rx) = unbounded(); - let (message_tx, _) = unbounded(); + let (transport_tx, mut transport_rx) = channel(10); + let (indication_tx, _) = channel(10); let mut config = default_config.clone(); config.file_size_flag = file_size_flag; let file_size = match &file_size_flag { @@ -1327,7 +1326,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let input = vec![0, 5, 255, 99]; let data = FileDataPDU::Unsegmented(UnsegmentedFileData { @@ -1370,9 +1369,11 @@ mod test { let pdu = PDU { header, payload }; transaction.naks = transaction.get_all_naks(); - transaction.send_naks(&transport_tx).unwrap(); + transaction + .send_naks(transport_tx.reserve().await.unwrap()) + .unwrap(); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.source_entity_id; assert_eq!(expected_id, destination_id); @@ -1380,14 +1381,15 @@ mod test { } #[rstest] - fn cancel_receive( + #[tokio::test] + async fn cancel_receive( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(TransmissionMode::Unacknowledged, TransmissionMode::Acknowledged)] transmission_mode: TransmissionMode, ) { - let (transport_tx, transport_rx) = unbounded(); - let (message_tx, _) = unbounded(); + let (transport_tx, mut transport_rx) = channel(1); + let (indication_tx, _) = channel(1); let mut config = default_config.clone(); config.transmission_mode = transmission_mode; @@ -1398,7 +1400,7 @@ mod test { config.clone(), NakProcedure::Deferred, filestore, - message_tx, + indication_tx, ); let path = { @@ -1441,10 +1443,12 @@ mod test { assert_eq!(TransactionState::Terminated, transaction.state); } - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); if config.transmission_mode == TransmissionMode::Acknowledged { - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.source_entity_id; assert_eq!(expected_id, destination_id); @@ -1453,15 +1457,16 @@ mod test { } #[rstest] - fn suspend(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (message_tx, _message_rx) = unbounded(); + #[tokio::test] + async fn suspend(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (indication_tx, _indication_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); transaction.timer.restart_ack(); transaction.timer.restart_inactivity(); @@ -1490,15 +1495,16 @@ mod test { } #[rstest] - fn resume(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (message_tx, _message_rx) = unbounded(); + #[tokio::test] + async fn resume(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (indication_tx, _indication_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); transaction.timer.restart_ack(); transaction.timer.restart_inactivity(); @@ -1542,9 +1548,10 @@ mod test { } #[rstest] - fn send_ack_eof(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (transport_tx, transport_rx) = unbounded(); - let (message_tx, _) = unbounded(); + #[tokio::test] + async fn send_ack_eof(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (transport_tx, mut transport_rx) = channel(10); + let (indication_tx, _) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( @@ -1554,7 +1561,7 @@ mod test { config.clone(), NakProcedure::Deferred, filestore, - message_tx, + indication_tx, ); let path = { @@ -1592,12 +1599,14 @@ mod test { ); let pdu = PDU { header, payload }; - thread::spawn(move || { + tokio::task::spawn(async move { transaction.prepare_ack_eof(); - transaction.send_ack_eof(&transport_tx).unwrap(); + transaction + .send_ack_eof(transport_tx.reserve().await.unwrap()) + .unwrap(); }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.source_entity_id; assert_eq!(expected_id, destination_id); @@ -1605,8 +1614,9 @@ mod test { } #[rstest] - fn finalize_receive(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (message_tx, _message_rx) = unbounded(); + #[tokio::test] + async fn finalize_receive(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (indication_tx, _indication_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( @@ -1616,7 +1626,7 @@ mod test { config, NakProcedure::Deferred, filestore.clone(), - message_tx, + indication_tx, ); let path = { @@ -1709,7 +1719,7 @@ mod test { )] operation: Operations, ) { - let (message_tx, _) = unbounded(); + let (indication_tx, _) = channel(1); let mut config = default_config.clone(); config.transmission_mode = TransmissionMode::Unacknowledged; @@ -1717,7 +1727,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let path = { let mut path = Utf8PathBuf::new(); @@ -1795,7 +1805,7 @@ mod test { )] operation: Operations, ) { - let (message_tx, _) = unbounded(); + let (indication_tx, _) = channel(1); let mut config = default_config.clone(); config.transmission_mode = TransmissionMode::Acknowledged; @@ -1803,7 +1813,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let path = { let mut path = Utf8PathBuf::new(); @@ -1845,13 +1855,14 @@ mod test { } #[rstest] - fn recv_store_data( + #[tokio::test] + async fn recv_store_data( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(TransmissionMode::Unacknowledged, TransmissionMode::Acknowledged)] transmission_mode: TransmissionMode, ) { - let (message_tx, _message_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(1); let mut config = default_config.clone(); config.transmission_mode = transmission_mode; @@ -1859,7 +1870,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let path = { let mut path = Utf8PathBuf::new(); @@ -1908,13 +1919,14 @@ mod test { } #[rstest] - fn recv_store_metadata( + #[tokio::test] + async fn recv_store_metadata( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(TransmissionMode::Unacknowledged, TransmissionMode::Acknowledged)] transmission_mode: TransmissionMode, ) { - let (message_tx, message_rx) = unbounded(); + let (indication_tx, mut indication_rx) = channel(1); let mut config = default_config.clone(); config.transmission_mode = transmission_mode; @@ -1922,7 +1934,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let path = { let mut path = Utf8PathBuf::new(); @@ -1973,7 +1985,7 @@ mod test { transaction.process_pdu(pdu).unwrap(); assert_eq!(expected, transaction.metadata); - let indication = message_rx.recv().unwrap(); + let indication = indication_rx.recv().await.unwrap(); if let Indication::MetadataRecv(MetadataRecvIndication { user_messages: user_msg, .. @@ -1987,14 +1999,15 @@ mod test { } #[rstest] - fn recv_eof_all_data( + #[tokio::test] + async fn recv_eof_all_data( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(TransmissionMode::Unacknowledged, TransmissionMode::Acknowledged)] transmission_mode: TransmissionMode, ) { - let (transport_tx, transport_rx) = unbounded(); - let (message_tx, _message_rx) = unbounded(); + let (transport_tx, mut transport_rx) = channel(10); + let (indication_tx, _indication_rx) = channel(1); let mut config = default_config.clone(); config.transmission_mode = transmission_mode; @@ -2004,7 +2017,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let path = { let mut path = Utf8PathBuf::new(); @@ -2096,10 +2109,14 @@ mod test { transaction.process_pdu(file_pdu).unwrap(); transaction.process_pdu(input_pdu).unwrap(); - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); if transaction.config.transmission_mode == TransmissionMode::Acknowledged { - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); assert!(transaction.timer.ack.is_ticking()); let ack_fin = { @@ -2147,26 +2164,27 @@ mod test { PDU { header, payload } }; - let (destination_id, end_of_file) = transport_rx.recv().unwrap(); + let (destination_id, end_of_file) = transport_rx.recv().await.unwrap(); assert_eq!(expected_id, destination_id); assert_eq!(eof_pdu, end_of_file) } - let (destination_id, finished_pdu) = transport_rx.recv().unwrap(); + let (destination_id, finished_pdu) = transport_rx.recv().await.unwrap(); assert_eq!(expected_id, destination_id); assert_eq!(expected_pdu, finished_pdu) } #[rstest] - fn recv_prompt( + #[tokio::test] + async fn recv_prompt( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(NakOrKeepAlive::Nak, NakOrKeepAlive::KeepAlive)] nak_or_keep_alive: NakOrKeepAlive, ) { - let (transport_tx, transport_rx) = unbounded(); - let (message_tx, _message_rx) = unbounded(); + let (transport_tx, mut transport_rx) = channel(10); + let (indication_tx, _indication_rx) = channel(1); let mut config = default_config.clone(); config.transmission_mode = TransmissionMode::Acknowledged; @@ -2176,7 +2194,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); let path = { let mut buf = Utf8PathBuf::new(); buf.push("test_file"); @@ -2251,17 +2269,20 @@ mod test { transaction.process_pdu(prompt_pdu).unwrap(); assert!(transaction.has_pdu_to_send()); - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); assert_eq!(expected_id, destination_id); assert_eq!(expected_pdu, received_pdu) } #[rstest] - fn nak_split(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (transport_tx, transport_rx) = unbounded(); - let (message_tx, _message_rx) = unbounded(); + #[tokio::test] + async fn nak_split(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (transport_tx, mut transport_rx) = channel(10); + let (indication_tx, _indication_rx) = channel(10); let mut config = default_config.clone(); config.file_size_segment = 16; config.transmission_mode = TransmissionMode::Acknowledged; @@ -2272,7 +2293,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); let mut transaction = - RecvTransaction::new(config, NakProcedure::Deferred, filestore, message_tx); + RecvTransaction::new(config, NakProcedure::Deferred, filestore, indication_tx); transaction.nak_procedure = NakProcedure::Immediate; let file_pdu1 = { @@ -2353,14 +2374,18 @@ mod test { transaction.process_pdu(file_pdu2).unwrap(); assert!(transaction.has_pdu_to_send()); - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); assert!(transaction.has_pdu_to_send()); - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); assert_eq!(expected_id, destination_id); assert_eq!(expected_nak1, received_pdu); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); assert_eq!(expected_id, destination_id); assert_eq!(expected_nak2, received_pdu) } diff --git a/cfdp-core/src/transaction/send.rs b/cfdp-core/src/transaction/send.rs index e30bcd2..3e71a17 100644 --- a/cfdp-core/src/transaction/send.rs +++ b/cfdp-core/src/transaction/send.rs @@ -6,8 +6,11 @@ use std::{ time::Duration, }; -use crossbeam_channel::Sender; use log::{debug, info}; +use tokio::sync::{ + mpsc::{Permit, Sender}, + oneshot, +}; use super::{ config::{Metadata, TransactionConfig, TransactionState}, @@ -135,7 +138,7 @@ impl SendTransaction { indication_tx, send_eof_indication: true, }; - me.indication_tx.send(Indication::Transaction(me.id()))?; + me.send_indication(Indication::Transaction(me.id())); Ok(me) } @@ -161,14 +164,14 @@ impl SendTransaction { pub(crate) fn send_pdu( &mut self, - transport_tx: &Sender<(VariableID, PDU)>, + permit: Permit<'_, (VariableID, PDU)>, ) -> TransactionResult<()> { if self.prompt.is_some() { - self.send_prompt(transport_tx)?; + self.send_prompt(permit)?; } else { match self.send_state { SendState::SendMetadata => { - self.send_metadata(transport_tx)?; + self.send_metadata(permit)?; if self.is_file_transfer() { self.send_state = SendState::SendData; } else { @@ -177,34 +180,31 @@ impl SendTransaction { } } SendState::SendData => { - while !transport_tx.is_full() { - if !self.naks.is_empty() { - // if we have received a NAK send the missing data - self.send_missing_data(transport_tx)?; - } else { - self.send_file_segment(None, None, transport_tx)? - } + if !self.naks.is_empty() { + // if we have received a NAK send the missing data + self.send_missing_data(permit)?; + } else { + self.send_file_segment(None, None, permit)? + } - let handle = self.get_handle()?; - if handle.stream_position().map_err(FileStoreError::IO)? - == handle.metadata().map_err(FileStoreError::IO)?.len() - { - self.prepare_eof(None)?; - self.send_state = SendState::SendEof; - break; - } + let handle = self.get_handle()?; + if handle.stream_position().map_err(FileStoreError::IO)? + == handle.metadata().map_err(FileStoreError::IO)?.len() + { + self.prepare_eof(None)?; + self.send_state = SendState::SendEof; } } SendState::SendEof => { if !self.naks.is_empty() { // if we have received a NAK send the missing data - self.send_missing_data(transport_tx)?; + self.send_missing_data(permit)?; } else { - self.send_eof(transport_tx)?; + self.send_eof(permit)?; // EoFSent indication only needs to be sent for the initial EoF transmission if self.send_eof_indication { - self.indication_tx.send(Indication::EoFSent(self.id()))?; + self.send_indication(Indication::EoFSent(self.id())); self.send_eof_indication = false; } @@ -212,25 +212,23 @@ impl SendTransaction { // if closure is not requested, this transaction is finished. // indicate as much to the User if !self.metadata.closure_requested { - self.indication_tx.send(Indication::Finished( - FinishedIndication { - id: self.id(), - report: self.generate_report(), - filestore_responses: vec![], - file_status: self.file_status, - delivery_code: self.delivery_code, - }, - ))?; + self.send_indication(Indication::Finished(FinishedIndication { + id: self.id(), + report: self.generate_report(), + filestore_responses: vec![], + file_status: self.file_status, + delivery_code: self.delivery_code, + })); } self.shutdown(); } } } SendState::Cancelled => { - self.send_eof(transport_tx)?; + self.send_eof(permit)?; } SendState::Finished => { - self.send_ack(transport_tx)?; + self.send_ack(permit)?; } } } @@ -289,12 +287,13 @@ impl SendTransaction { } } - pub fn send_report(&self, sender: Option>) -> TransactionResult<()> { + pub fn send_report(&self, sender: Option>) -> TransactionResult<()> { let report = self.generate_report(); + if let Some(channel) = sender { - channel.send(report.clone())?; + let _ = channel.send(report.clone()); } - self.indication_tx.send(Indication::Report(report))?; + self.send_indication(Indication::Report(report)); Ok(()) } @@ -421,7 +420,7 @@ impl SendTransaction { &mut self, offset: Option, length: Option, - transport_tx: &Sender<(VariableID, PDU)>, + permit: Permit<(VariableID, PDU)>, ) -> TransactionResult<()> { let (offset, file_data) = self.get_file_segment(offset, length)?; @@ -447,14 +446,14 @@ impl SendTransaction { ); let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); Ok(()) } pub fn send_missing_data( &mut self, - transport_tx: &Sender<(VariableID, PDU)>, + permit: Permit<(VariableID, PDU)>, ) -> TransactionResult<()> { match self.naks.pop_front() { Some(request) => { @@ -466,14 +465,14 @@ impl SendTransaction { ); match offset == 0 && length == 0 { - true => self.send_metadata(transport_tx), + true => self.send_metadata(permit), false => { let current_pos = { let handle = self.get_handle()?; handle.stream_position().map_err(FileStoreError::IO)? }; - self.send_file_segment(Some(offset), Some(length), transport_tx)?; + self.send_file_segment(Some(offset), Some(length), permit)?; // restore to original location in the file let handle = self.get_handle()?; @@ -502,7 +501,7 @@ impl SendTransaction { Ok(()) } - pub fn send_eof(&mut self, transport_tx: &Sender<(VariableID, PDU)>) -> TransactionResult<()> { + pub fn send_eof(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { if let Some((eof, true)) = &self.eof { self.timer.restart_ack(); @@ -518,7 +517,7 @@ impl SendTransaction { let destination = header.destination_entity_id; let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); debug!("Transaction {0} sent EndOfFile.", self.id()); self.set_eof_flag(false); } @@ -538,10 +537,7 @@ impl SendTransaction { }) } - pub fn send_prompt( - &mut self, - transport_tx: &Sender<(VariableID, PDU)>, - ) -> TransactionResult<()> { + pub fn send_prompt(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { if let Some(prompt) = self.prompt.take() { self.timer.restart_ack(); @@ -557,7 +553,7 @@ impl SendTransaction { let destination = header.destination_entity_id; let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); debug!("Transaction {0} sent PromptPdu.", self.id()); } Ok(()) @@ -592,11 +588,11 @@ impl SendTransaction { self.timer.inactivity.pause(); self.state = TransactionState::Suspended; - self.indication_tx - .send(Indication::Suspended(SuspendIndication { - id: self.id(), - condition: self.condition, - }))?; + self.send_indication(Indication::Suspended(SuspendIndication { + id: self.id(), + condition: self.condition, + })); + Ok(()) } @@ -620,11 +616,10 @@ impl SendTransaction { false => 0, }; - self.indication_tx - .send(Indication::Resumed(ResumeIndication { - id: self.id(), - progress, - }))?; + self.send_indication(Indication::Resumed(ResumeIndication { + id: self.id(), + progress, + })); Ok(()) } @@ -669,7 +664,7 @@ impl SendTransaction { }); } - fn send_ack(&mut self, transport_tx: &Sender<(VariableID, PDU)>) -> TransactionResult<()> { + fn send_ack(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { if let Some(ack) = self.ack.take() { let payload = PDUPayload::Directive(Operations::Ack(ack)); let payload_len = payload.encoded_len(self.config.file_size_flag); @@ -684,7 +679,7 @@ impl SendTransaction { let destination = header.destination_entity_id; let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); debug!("Transaction {0} sent Ack(Finished).", self.id()); self.shutdown(); } @@ -735,14 +730,13 @@ impl SendTransaction { self.condition ); - self.indication_tx - .send(Indication::Finished(FinishedIndication { - id: self.id(), - report: self.generate_report(), - filestore_responses: finished.filestore_response, - file_status: self.file_status, - delivery_code: self.delivery_code, - }))?; + self.send_indication(Indication::Finished(FinishedIndication { + id: self.id(), + report: self.generate_report(), + filestore_responses: finished.filestore_response, + file_status: self.file_status, + delivery_code: self.delivery_code, + })); match self.condition != Condition::NoError { true => { @@ -865,14 +859,13 @@ impl SendTransaction { self.condition = finished.condition; self.delivery_code = finished.delivery_code; - self.indication_tx - .send(Indication::Finished(FinishedIndication { - id: self.id(), - report: self.generate_report(), - filestore_responses: finished.filestore_response, - file_status: self.file_status, - delivery_code: self.delivery_code, - }))?; + self.send_indication(Indication::Finished(FinishedIndication { + id: self.id(), + report: self.generate_report(), + filestore_responses: finished.filestore_response, + file_status: self.file_status, + delivery_code: self.delivery_code, + })); self.shutdown(); Ok(()) @@ -893,7 +886,7 @@ impl SendTransaction { } } - fn send_metadata(&mut self, transport_tx: &Sender<(VariableID, PDU)>) -> TransactionResult<()> { + fn send_metadata(&mut self, permit: Permit<(VariableID, PDU)>) -> TransactionResult<()> { let destination = self.config.destination_entity_id; let metadata = MetadataPDU { closure_requested: self.metadata.closure_requested, @@ -933,10 +926,16 @@ impl SendTransaction { let pdu = PDU { header, payload }; - transport_tx.send((destination, pdu))?; + permit.send((destination, pdu)); debug!("Transaction {0} sent Metadata.", self.id()); Ok(()) } + + //send the indication in another task, no need to wait for it + fn send_indication(&self, indication: Indication) { + let tx = self.indication_tx.clone(); + tokio::task::spawn(async move { tx.send(indication).await }); + } } #[cfg(test)] @@ -957,10 +956,9 @@ mod test { use super::*; use camino::{Utf8Path, Utf8PathBuf}; - use crossbeam_channel::unbounded; use rstest::{fixture, rstest}; - use std::thread; use tempfile::TempDir; + use tokio::sync::mpsc::channel; #[fixture] #[once] fn tempdir_fixture() -> TempDir { @@ -968,13 +966,14 @@ mod test { } #[rstest] - fn header(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + #[tokio::test] + async fn header(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(10, Utf8PathBuf::from("")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx) .expect("unable to start transaction."); @@ -1006,13 +1005,14 @@ mod test { } #[rstest] - fn test_if_file_transfer(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + #[tokio::test] + async fn test_if_file_transfer(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(600_u64, Utf8PathBuf::from("a")); let transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1025,8 +1025,9 @@ mod test { } #[rstest] - fn send_filedata(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (transport_tx, transport_rx) = unbounded(); + #[tokio::test] + async fn send_filedata(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (transport_tx, mut transport_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), @@ -1034,7 +1035,7 @@ mod test { let path = Utf8PathBuf::from("testfile"); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(10, path.clone()); let mut transaction = SendTransaction::new(config, metadata, filestore.clone(), indication_tx).unwrap(); @@ -1055,7 +1056,7 @@ mod test { ); let pdu = PDU { header, payload }; - thread::spawn(move || { + tokio::task::spawn(async move { let fname = transaction.metadata.source_filename.clone(); { @@ -1076,10 +1077,14 @@ mod test { let length = 4; transaction - .send_file_segment(Some(offset), Some(length as u16), &transport_tx) + .send_file_segment( + Some(offset), + Some(length as u16), + transport_tx.reserve().await.unwrap(), + ) .unwrap(); }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.destination_entity_id; assert_eq!(expected_id, destination_id); assert_eq!(pdu, received_pdu); @@ -1089,12 +1094,13 @@ mod test { #[rstest] #[case(SegmentRequestForm { start_offset: 6, end_offset: 10 })] #[case(SegmentRequestForm { start_offset: 0, end_offset: 0 })] - fn send_missing( + #[tokio::test] + async fn send_missing( #[case] nak: SegmentRequestForm, default_config: &TransactionConfig, tempdir_fixture: &TempDir, ) { - let (transport_tx, transport_rx) = unbounded(); + let (transport_tx, mut transport_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), @@ -1103,7 +1109,7 @@ mod test { let path = Utf8PathBuf::from("testfile_missing"); let metadata = test_metadata(10, path.clone()); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1148,7 +1154,7 @@ mod test { } }; - thread::spawn(move || { + tokio::task::spawn(async move { let fname = transaction.metadata.source_filename.clone(); { @@ -1169,27 +1175,30 @@ mod test { } transaction.naks.push_back(nak); - transaction.send_missing_data(&transport_tx).unwrap(); + transaction + .send_missing_data(transport_tx.reserve().await.unwrap()) + .unwrap(); transaction .filestore .delete_file(path.clone()) .expect("cannot remove file"); }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.destination_entity_id; assert_eq!(expected_id, destination_id); assert_eq!(pdu, received_pdu); } #[rstest] - fn checksum_cache(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + #[tokio::test] + async fn checksum_cache(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(1); let metadata = test_metadata(10, Utf8PathBuf::from("")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1206,8 +1215,9 @@ mod test { } #[rstest] - fn send_eof(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (transport_tx, transport_rx) = unbounded(); + #[tokio::test] + async fn send_eof(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (transport_tx, mut transport_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), @@ -1215,7 +1225,7 @@ mod test { let input = "Here is some test data to write!$*#*.\n"; - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let path = Utf8PathBuf::from("test_eof.dat"); let metadata = test_metadata(input.as_bytes().len() as u64, path.clone()); let mut transaction = @@ -1249,7 +1259,7 @@ mod test { ); let pdu = PDU { header, payload }; - thread::spawn(move || { + tokio::task::spawn(async move { let fname = transaction.metadata.source_filename.clone(); { @@ -1263,10 +1273,12 @@ mod test { handle.sync_all().expect("Bad file sync."); } transaction.prepare_eof(None).unwrap(); - transaction.send_eof(&transport_tx).unwrap() + transaction + .send_eof(transport_tx.reserve().await.unwrap()) + .unwrap() }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.destination_entity_id; assert_eq!(expected_id, destination_id); @@ -1276,13 +1288,14 @@ mod test { } #[rstest] - fn cancel_send( + #[tokio::test] + async fn cancel_send( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(TransmissionMode::Unacknowledged, TransmissionMode::Acknowledged)] transmission_mode: TransmissionMode, ) { - let (transport_tx, transport_rx) = unbounded(); + let (transport_tx, mut transport_rx) = channel(1); let mut config = default_config.clone(); config.transmission_mode = transmission_mode; @@ -1292,7 +1305,7 @@ mod test { let input = "Here is some test data to write!$*#*.\n"; - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let path = Utf8PathBuf::from(format!("test_eof_{:}.dat", config.transmission_mode as u8)); let metadata = test_metadata(input.as_bytes().len() as u64, path.clone()); let mut transaction = @@ -1327,7 +1340,7 @@ mod test { ); let pdu = PDU { header, payload }; - thread::spawn(move || { + tokio::task::spawn(async move { let fname = transaction.metadata.source_filename.clone(); { @@ -1345,14 +1358,16 @@ mod test { } transaction.cancel().unwrap(); assert!(transaction.has_pdu_to_send()); - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); if transaction.config.transmission_mode == TransmissionMode::Unacknowledged { assert_eq!(TransactionStatus::Terminated, transaction.status); } }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.destination_entity_id; assert_eq!(expected_id, destination_id); @@ -1362,13 +1377,14 @@ mod test { } #[rstest] - fn suspend(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + #[tokio::test] + async fn suspend(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(0, Utf8PathBuf::from("")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1392,13 +1408,14 @@ mod test { } #[rstest] - fn resume(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + #[tokio::test] + async fn resume(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(0, Utf8PathBuf::from("")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1429,8 +1446,9 @@ mod test { } #[rstest] - fn send_ack_fin(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (transport_tx, transport_rx) = unbounded(); + #[tokio::test] + async fn send_ack_fin(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (transport_tx, mut transport_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( @@ -1440,7 +1458,7 @@ mod test { let path = Utf8PathBuf::from(format!("test_eof_{:}.dat", config.transmission_mode as u8)); let input = "Here is some test data to write!$*#*.\n"; - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(input.as_bytes().len() as u64, path); let mut transaction = SendTransaction::new(config.clone(), metadata, filestore, indication_tx).unwrap(); @@ -1462,12 +1480,14 @@ mod test { ); let pdu = PDU { header, payload }; - thread::spawn(move || { + tokio::task::spawn(async move { transaction.prepare_ack(); - transaction.send_ack(&transport_tx).unwrap(); + transaction + .send_ack(transport_tx.reserve().await.unwrap()) + .unwrap(); }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.destination_entity_id; assert_eq!(expected_id, destination_id); @@ -1475,7 +1495,8 @@ mod test { } #[rstest] - fn pdu_error_unack_send( + #[tokio::test] + async fn pdu_error_unack_send( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values( @@ -1535,7 +1556,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(600, Utf8PathBuf::from("Test_file.txt")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1563,7 +1584,8 @@ mod test { } #[rstest] - fn pdu_error_ack_send( + #[tokio::test] + async fn pdu_error_ack_send( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values( @@ -1600,7 +1622,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(600, Utf8PathBuf::from("Test_file.txt")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1628,7 +1650,8 @@ mod test { } #[rstest] - fn pdu_error_send_data( + #[tokio::test] + async fn pdu_error_send_data( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(TransmissionMode::Unacknowledged, TransmissionMode::Acknowledged)] @@ -1641,7 +1664,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(600, Utf8PathBuf::from("Test_file.txt")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1664,8 +1687,9 @@ mod test { assert_err!(result, Err(TransactionError::UnexpectedPDU(_, _, _))) } #[rstest] - fn recv_finished(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (transport_tx, transport_rx) = unbounded(); + #[tokio::test] + async fn recv_finished(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (transport_tx, mut transport_rx) = channel(1); let mut config = default_config.clone(); config.transmission_mode = TransmissionMode::Acknowledged; @@ -1675,7 +1699,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let path = Utf8PathBuf::from("test_file"); let metadata = test_metadata(600, path); let mut transaction = @@ -1719,19 +1743,22 @@ mod test { PDU { header, payload } }; - thread::spawn(move || { + tokio::task::spawn(async move { transaction.process_pdu(finished_pdu).unwrap(); assert!(transaction.has_pdu_to_send()); - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); assert_eq!(expected_id, destination_id); assert_eq!(expected_pdu, received_pdu) } #[rstest] - fn recv_nak( + #[tokio::test] + async fn recv_nak( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(2_u64, u32::MAX as u64 + 100_u64)] total_size: u64, @@ -1749,7 +1776,7 @@ mod test { Utf8Path::from_path(tempdir_fixture.path()).expect("Unable to make utf8 tempdir"), )); - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(total_size, Utf8PathBuf::from("test_file")); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1797,8 +1824,9 @@ mod test { } #[rstest] - fn recv_ack_eof(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { - let (transport_tx, transport_rx) = unbounded(); + #[tokio::test] + async fn recv_ack_eof(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + let (transport_tx, mut transport_rx) = channel(1); let mut config = default_config.clone(); config.transmission_mode = TransmissionMode::Acknowledged; @@ -1820,7 +1848,7 @@ mod test { checksum_type: ChecksumType::Null, }; - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); @@ -1863,21 +1891,24 @@ mod test { PDU { header, payload } }; - thread::spawn(move || { + tokio::task::spawn(async move { transaction.prepare_eof(None).unwrap(); - transaction.send_eof(&transport_tx).unwrap(); + transaction + .send_eof(transport_tx.reserve().await.unwrap()) + .unwrap(); assert!(transaction.timer.ack.is_ticking()); transaction.process_pdu(ack_pdu).unwrap(); assert!(!transaction.timer.ack.is_ticking()); }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); assert_eq!(expected_id, destination_id); assert_eq!(expected_pdu, received_pdu) } #[rstest] - fn recv_keepalive(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { + #[tokio::test] + async fn recv_keepalive(default_config: &TransactionConfig, tempdir_fixture: &TempDir) { let mut config = default_config.clone(); config.transmission_mode = TransmissionMode::Acknowledged; @@ -1897,7 +1928,7 @@ mod test { checksum_type: ChecksumType::Null, }; - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let mut transaction = SendTransaction::new(config, metadata, filestore, indication_tx).unwrap(); transaction.checksum = Some(0); @@ -1935,12 +1966,13 @@ mod test { } #[rstest] - fn send_prompt( + #[tokio::test] + async fn send_prompt( default_config: &TransactionConfig, tempdir_fixture: &TempDir, #[values(NakOrKeepAlive::Nak, NakOrKeepAlive::KeepAlive)] option: NakOrKeepAlive, ) { - let (transport_tx, transport_rx) = unbounded(); + let (transport_tx, mut transport_rx) = channel(1); let config = default_config.clone(); let filestore = Arc::new(NativeFileStore::new( @@ -1950,7 +1982,7 @@ mod test { let path = Utf8PathBuf::from(format!("test_eof_{:}.dat", config.transmission_mode as u8)); let input = "Here is some test data to write!$*#*.\n"; - let (indication_tx, _indication_rx) = unbounded(); + let (indication_tx, _indication_rx) = channel(10); let metadata = test_metadata(input.as_bytes().len() as u64, path); let mut transaction = SendTransaction::new(config.clone(), metadata, filestore, indication_tx).unwrap(); @@ -1969,12 +2001,14 @@ mod test { ); let pdu = PDU { header, payload }; - thread::spawn(move || { + tokio::task::spawn(async move { transaction.prepare_prompt(option); - transaction.send_pdu(&transport_tx).unwrap(); + transaction + .send_pdu(transport_tx.reserve().await.unwrap()) + .unwrap(); }); - let (destination_id, received_pdu) = transport_rx.recv().unwrap(); + let (destination_id, received_pdu) = transport_rx.recv().await.unwrap(); let expected_id = default_config.destination_entity_id; assert_eq!(expected_id, destination_id); diff --git a/cfdp-core/src/transport.rs b/cfdp-core/src/transport.rs index e3fac32..90126d9 100644 --- a/cfdp-core/src/transport.rs +++ b/cfdp-core/src/transport.rs @@ -1,26 +1,31 @@ use std::{ collections::HashMap, + fmt::Debug, io::{Error as IoError, ErrorKind}, - net::{SocketAddr, ToSocketAddrs, UdpSocket}, + net::SocketAddr, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, - thread, time::Duration, }; -use crossbeam_channel::{Receiver, Sender}; +use async_trait::async_trait; use log::error; +use tokio::{ + net::{ToSocketAddrs, UdpSocket}, + sync::mpsc::{Receiver, Sender}, +}; use crate::pdu::{PDUEncode, VariableID, PDU}; /// Transports are designed to run in a thread in the background /// inside a [Daemon](crate::daemon::Daemon) process +#[async_trait] pub trait PDUTransport { /// Send input PDU to the remote /// The implementation must have a method to lookup an Entity's address from the ID - fn request(&mut self, destination: VariableID, pdu: PDU) -> Result<(), IoError>; + async fn request(&mut self, destination: VariableID, pdu: PDU) -> Result<(), IoError>; /// Provides logic for listening for incoming PDUs and sending any outbound PDUs @@ -30,11 +35,11 @@ pub trait PDUTransport { /// The [Daemon](crate::daemon::Daemon) is responsible for receiving messages and distribute them to each /// transaction [Send](crate::transaction::SendTransaction) or [Recv](crate::transaction::RecvTransaction) /// The signal is used to indicate a shutdown operation was requested. - fn pdu_handler( + async fn pdu_handler( &mut self, signal: Arc, sender: Sender, - recv: Receiver<(VariableID, PDU)>, + mut recv: Receiver<(VariableID, PDU)>, ) -> Result<(), IoError>; } @@ -45,14 +50,11 @@ pub struct UdpTransport { entity_map: HashMap, } impl UdpTransport { - pub fn new( + pub async fn new( addr: T, entity_map: HashMap, ) -> Result { - let socket = UdpSocket::bind(addr)?; - socket.set_read_timeout(Some(Duration::from_secs(1)))?; - socket.set_write_timeout(Some(Duration::from_secs(1)))?; - socket.set_nonblocking(true)?; + let socket = UdpSocket::bind(addr).await?; Ok(Self { socket, entity_map }) } } @@ -63,71 +65,62 @@ impl TryFrom<(UdpSocket, HashMap)> for UdpTransport { socket: inputs.0, entity_map: inputs.1, }; - me.socket.set_read_timeout(Some(Duration::from_secs(1)))?; - me.socket.set_write_timeout(Some(Duration::from_secs(1)))?; - me.socket.set_nonblocking(true)?; Ok(me) } } + +#[async_trait] impl PDUTransport for UdpTransport { - fn request(&mut self, destination: VariableID, pdu: PDU) -> Result<(), IoError> { - self.entity_map + async fn request(&mut self, destination: VariableID, pdu: PDU) -> Result<(), IoError> { + let addr = self + .entity_map .get(&destination) - .ok_or_else(|| IoError::from(ErrorKind::AddrNotAvailable)) - .and_then(|addr| { - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - }) + .ok_or_else(|| IoError::from(ErrorKind::AddrNotAvailable))?; + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ())?; + Ok(()) } - fn pdu_handler( + async fn pdu_handler( &mut self, signal: Arc, sender: Sender, - recv: Receiver<(VariableID, PDU)>, + mut recv: Receiver<(VariableID, PDU)>, ) -> Result<(), IoError> { // this buffer will be 511 KiB, should be sufficiently small; let mut buffer = vec![0_u8; u16::MAX as usize]; while !signal.load(Ordering::Relaxed) { - match self.socket.recv_from(&mut buffer) { - Ok(_n) => match PDU::decode(&mut buffer.as_slice()) { - Ok(pdu) => { - match sender.send(pdu) { - Ok(()) => {} - Err(error) => { - error!("Transport found disconnect sending channel: {}", error); - return Err(IoError::from(ErrorKind::ConnectionAborted)); - } - }; - } - Err(error) => { - error!("Error decoding PDU: {}", error); - // might need to stop depending on the error. - // some are recoverable though + tokio::select! { + Ok((_n, _addr)) = self.socket.recv_from(&mut buffer) => { + match PDU::decode(&mut buffer.as_slice()) { + Ok(pdu) => { + match sender.send(pdu).await { + Ok(()) => {} + Err(error) => { + error!("Channel to daemon severed: {}", error); + return Err(IoError::from(ErrorKind::ConnectionAborted)); + } + }; + } + Err(error) => { + error!("Error decoding PDU: {}", error); + // might need to stop depending on the error. + // some are recoverable though + } } }, - Err(ref e) - if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => - { - // continue to trying to send - } - Err(e) => { - error!("encountered IO error: {e}"); - return Err(e); - } - }; - match recv.try_recv() { - Ok((entity, pdu)) => self.request(entity, pdu)?, - Err(crossbeam_channel::TryRecvError::Empty) => { - // nothing to do here - } - Err(err @ crossbeam_channel::TryRecvError::Disconnected) => { - error!("Transport found disconnected channel: {}", err); - return Err(IoError::from(ErrorKind::ConnectionAborted)); + Some((entity, pdu)) = recv.recv() => { + self.request(entity, pdu).await?; + }, + else => { + log::info!("UdpSocket or Channel disconnected"); + break } - }; - thread::sleep(Duration::from_micros(500)) + } + // this should be at minimum made configurable + tokio::time::sleep(Duration::from_micros(100)).await; } Ok(()) } diff --git a/cfdp-core/tests/common/mod.rs b/cfdp-core/tests/common/mod.rs index 72497e1..d09fb1a 100644 --- a/cfdp-core/tests/common/mod.rs +++ b/cfdp-core/tests/common/mod.rs @@ -3,16 +3,15 @@ use std::{ fs::{self, OpenOptions}, io::{Error as IoError, ErrorKind, Write}, marker::PhantomData, - net::{SocketAddr, ToSocketAddrs, UdpSocket}, + net::SocketAddr, path::Path, sync::{ atomic::{AtomicBool, Ordering}, Arc, RwLock, }, - thread::{self, JoinHandle}, - time::Duration, }; +use async_trait::async_trait; use camino::Utf8PathBuf; use cfdp_core::{ daemon::{ @@ -30,12 +29,21 @@ use cfdp_core::{ transaction::TransactionID, transport::{PDUTransport, UdpTransport}, }; -use crossbeam_channel::{bounded, unbounded, Receiver, Sender}; + use itertools::{Either, Itertools}; use log::{error, info}; use tempfile::TempDir; use rstest::fixture; +use tokio::{ + net::{ToSocketAddrs, UdpSocket}, + runtime::{self}, + sync::{ + mpsc::{self, Receiver, Sender}, + oneshot, + }, + task::JoinHandle, +}; #[derive(Debug)] pub(crate) struct JoD<'a, T> { handle: Vec>, @@ -197,19 +205,20 @@ pub(crate) struct TestUser { // Indication listener thread indication_handle: JoinHandle<()>, history: Arc>>, + tokio_handle: tokio::runtime::Handle, } impl TestUser { pub(crate) fn new(filestore: Arc) -> Self { - let (internal_tx, internal_rx) = bounded::(1); - let (indication_tx, indication_rx) = unbounded::(); + let (internal_tx, internal_rx) = mpsc::channel::(1); + let (indication_tx, mut indication_rx) = mpsc::channel::(1000); let history = Arc::new(RwLock::new(HashMap::::new())); let auto_history = history.clone(); let auto_sender = internal_tx.clone(); - let indication_handle = thread::spawn(move || { + let indication_handle = tokio::task::spawn(async move { let mut proxy_map = HashMap::new(); - while let Ok(indication) = indication_rx.recv() { + while let Some(indication) = indication_rx.recv().await { // (origin_id, tx_mode, messages) match indication { Indication::MetadataRecv(MetadataRecvIndication { @@ -222,13 +231,14 @@ impl TestUser { categorize_user_msg(&origin_id, messages); for request in put_requests { - let (put_sender, put_recv) = bounded(1); + let (put_sender, put_recv) = oneshot::channel(); auto_sender .send(UserPrimitive::Put(request, put_sender)) + .await .expect("Unable to send auto request"); - let id = put_recv.recv().expect("Recv channel disconnected: "); + let id = put_recv.await.expect("Recv channel disconnected: "); proxy_map.insert(id, origin_id); } @@ -236,6 +246,7 @@ impl TestUser { let primitive = UserPrimitive::Cancel(id); auto_sender .send(primitive) + .await .map_err(|_| { IoError::new( ErrorKind::ConnectionReset, @@ -332,7 +343,7 @@ impl TestUser { }, }, UserRequest::RemoteStatusReport(report_request) => { - let (report_tx, report_rx) = bounded(0); + let (report_tx, report_rx) = oneshot::channel(); let id = TransactionID( report_request.source_entity_id, report_request.transaction_sequence_number, @@ -341,6 +352,7 @@ impl TestUser { auto_sender .send(primitive) + .await .map_err(|_| { IoError::new( ErrorKind::ConnectionReset, @@ -349,7 +361,7 @@ impl TestUser { }) .expect("error asking for report."); - let report = match report_rx.recv().map_err(|_| { + let report = match report_rx.await.map_err(|_| { IoError::new( ErrorKind::ConnectionReset, "Daemon Half of User disconnected.", @@ -405,6 +417,7 @@ impl TestUser { let suspend_indication = auto_sender .send(primitive) + .await .map_err(|_| { IoError::new( ErrorKind::ConnectionReset, @@ -452,6 +465,7 @@ impl TestUser { let suspend_indication = auto_sender .send(primitive) + .await .map_err(|_| { IoError::new( ErrorKind::ConnectionReset, @@ -492,9 +506,10 @@ impl TestUser { } } }; - let (sender, _recv) = bounded(0); + let (sender, _recv) = oneshot::channel(); auto_sender .send(UserPrimitive::Put(request, sender)) + .await .expect("Unable to send auto request"); } for response in responses { @@ -548,16 +563,16 @@ impl TestUser { // we should be able to connect to the socket we are running // just fine. but we can ignore errors per // CCSDS 727.0-B-5 ยง 6.2.5.1.2 - let (sender, _) = bounded(0); - let _ = - auto_sender - .send(UserPrimitive::Put(req, sender)) - .map_err(|_| { - IoError::new( - ErrorKind::ConnectionReset, - "Daemon Half of User disconnected.", - ) - }); + let (sender, _) = oneshot::channel(); + let _ = auto_sender + .send(UserPrimitive::Put(req, sender)) + .await + .map_err(|_| { + IoError::new( + ErrorKind::ConnectionReset, + "Daemon Half of User disconnected.", + ) + }); } } Indication::Report(report) => { @@ -575,6 +590,7 @@ impl TestUser { indication_tx, indication_handle, history, + tokio_handle: runtime::Handle::current(), } } @@ -585,12 +601,14 @@ impl TestUser { indication_tx, indication_handle, history, + tokio_handle, } = self; ( TestUserHalf { internal_tx, _indication_handle: indication_handle, history, + tokio_handle, }, internal_rx, indication_tx, @@ -634,24 +652,27 @@ pub struct TestUserHalf { internal_tx: Sender, _indication_handle: JoinHandle<()>, history: Arc>>, + tokio_handle: tokio::runtime::Handle, } impl TestUserHalf { #[allow(unused)] pub fn put(&self, request: PutRequest) -> Result { - let (put_send, put_recv) = bounded(1); - let primitive = UserPrimitive::Put(request, put_send); - - self.internal_tx.send(primitive).map_err(|_| { - IoError::new( - ErrorKind::ConnectionReset, - "Daemon Half of User disconnected.", - ) - })?; - put_recv.recv().map_err(|_| { - IoError::new( - ErrorKind::ConnectionReset, - "Daemon Half of User disconnected.", - ) + self.tokio_handle.block_on(async { + let (put_send, put_recv) = oneshot::channel(); + let primitive = UserPrimitive::Put(request, put_send); + + self.internal_tx.send(primitive).await.map_err(|_| { + IoError::new( + ErrorKind::ConnectionReset, + " 1 Daemon Half of User disconnected.", + ) + })?; + put_recv.await.map_err(|_| { + IoError::new( + ErrorKind::ConnectionReset, + "Daemon Half of User disconnected.", + ) + }) }) } @@ -659,40 +680,44 @@ impl TestUserHalf { // apparently related https://github.com/rust-lang/rust/issues/46379 #[allow(unused)] pub fn cancel(&self, transaction: TransactionID) -> Result<(), IoError> { - let primitive = UserPrimitive::Cancel(transaction); - self.internal_tx.send(primitive).map_err(|_| { - IoError::new( - ErrorKind::ConnectionReset, - "Daemon Half of User disconnected.", - ) + self.tokio_handle.block_on(async { + let primitive = UserPrimitive::Cancel(transaction); + self.internal_tx.send(primitive).await.map_err(|_| { + IoError::new( + ErrorKind::ConnectionReset, + "Daemon Half of User disconnected.", + ) + }) }) } #[allow(unused)] pub fn report(&self, transaction: TransactionID) -> Result, IoError> { - let (report_tx, report_rx) = bounded(1); - let primitive = UserPrimitive::Report(transaction, report_tx); - - self.internal_tx.send(primitive).map_err(|err| { - IoError::new( - ErrorKind::ConnectionReset, - format!("Daemon Half of User disconnected on send: {err}"), - ) - })?; - let response = match report_rx.recv() { - Ok(report) => Some(report), - // if the channel disconnects because the transaction is finished then just get from history. - Err(_) => self.history.read().unwrap().get(&transaction).cloned(), - }; - Ok(response) + self.tokio_handle.block_on(async { + let (report_tx, report_rx) = oneshot::channel(); + let primitive = UserPrimitive::Report(transaction, report_tx); + + self.internal_tx.send(primitive).await.map_err(|err| { + IoError::new( + ErrorKind::ConnectionReset, + format!("Daemon Half of User disconnected on send: {err}"), + ) + })?; + let response = match report_rx.await { + Ok(report) => Some(report), + // if the channel disconnects because the transaction is finished then just get from history. + Err(_) => self.history.read().unwrap().get(&transaction).cloned(), + }; + Ok(response) + }) } } impl<'a, T> Drop for JoD<'a, T> { fn drop(&mut self) { - let handle = self.handle.remove(0); - - handle.join().expect("Unable to join handle."); + for handle in self.handle.drain(..) { + handle.abort(); + } } } @@ -708,7 +733,7 @@ type DaemonType = ( type Timeouts = [Option; 3]; #[allow(clippy::too_many_arguments)] -pub(crate) fn create_daemons( +pub(crate) async fn create_daemons( filestore: Arc, local_transport_map: HashMap, Box>, remote_transport_map: HashMap, Box>, @@ -751,15 +776,14 @@ pub(crate) fn create_daemons( indication_tx, ); - let local_handle = thread::Builder::new() - .name("Local Daemon".to_string()) - .spawn(move || { - local_daemon - .manage_transactions() - .map_err(|e| e.to_string())?; - Ok(()) - }) - .expect("Unable to spwan local."); + let local_handle = tokio::task::spawn(async move { + local_daemon + .manage_transactions() + .await + .map_err(|e| e.to_string())?; + + Ok(()) + }); let remote_filestore = filestore; let remote_user = TestUser::new(remote_filestore.clone()); @@ -776,15 +800,13 @@ pub(crate) fn create_daemons( remote_indication_tx, ); - let remote_handle = thread::Builder::new() - .name("Remote Daemon".to_string()) - .spawn(move || { - remote_daemon - .manage_transactions() - .map_err(|e| e.to_string())?; - Ok(()) - }) - .expect("Unable to spawn remote."); + let remote_handle = tokio::task::spawn(async move { + remote_daemon + .manage_transactions() + .await + .map_err(|e| e.to_string())?; + Ok(()) + }); let _local_h = JoD::from(local_handle); let _remote_h: JoD<_> = JoD::from(remote_handle); @@ -792,17 +814,19 @@ pub(crate) fn create_daemons( (local_userhalf, remote_userhalf, _local_h, _remote_h) } -#[fixture] -#[once] -fn tempdir_fixture() -> TempDir { - TempDir::new().unwrap() +pub struct StaticAssets { + //we need to keep the object here because the directory is removed as soon as the object is dropped + _tempdir: TempDir, + pub filestore: Arc, + tokio_runtime: tokio::runtime::Runtime, } #[fixture] #[once] -pub(crate) fn filestore_fixture(tempdir_fixture: &TempDir) -> Arc { +pub fn static_assets() -> StaticAssets { + let tempdir = TempDir::new().unwrap(); let utf8_path = Utf8PathBuf::from( - tempdir_fixture + tempdir .path() .as_os_str() .to_str() @@ -829,7 +853,18 @@ pub(crate) fn filestore_fixture(tempdir_fixture: &TempDir) -> Arc, + static_assets: &StaticAssets, local_transport_issue: Option, remote_transport_issue: Option, timeouts: Timeouts, ) -> EntityConstructorReturn { - let remote_udp = UdpSocket::bind("127.0.0.1:0").expect("Unable to bind remote UDP."); - let remote_addr = remote_udp.local_addr().expect("Cannot find local address."); - - let local_udp = UdpSocket::bind("127.0.0.1:0").expect("Unable to bind local UDP."); - let local_addr = local_udp.local_addr().expect("Cannot find local address."); - - let entity_map = HashMap::from([ - (EntityID::from(0_u16), local_addr), - (EntityID::from(1_u16), remote_addr), - ]); + let (local_user, remote_user, local_handle, remote_handle) = + static_assets.tokio_runtime.block_on(async move { + let remote_udp = UdpSocket::bind("127.0.0.1:0") + .await + .expect("Unable to bind remote UDP."); + let remote_addr = remote_udp.local_addr().expect("Cannot find local address."); + + let local_udp = UdpSocket::bind("127.0.0.1:0") + .await + .expect("Unable to bind local UDP."); + let local_addr = local_udp.local_addr().expect("Cannot find local address."); + + let entity_map = HashMap::from([ + (EntityID::from(0_u16), local_addr), + (EntityID::from(1_u16), remote_addr), + ]); + + let local_transport = if let Some(issue) = local_transport_issue { + Box::new( + LossyTransport::try_from((local_udp, entity_map.clone(), issue)) + .expect("Unable to make Lossy Transport."), + ) as Box + } else { + Box::new( + UdpTransport::try_from((local_udp, entity_map.clone())) + .expect("Unable to make UDP Transport."), + ) as Box + }; - let local_transport = if let Some(issue) = local_transport_issue { - Box::new( - LossyTransport::try_from((local_udp, entity_map.clone(), issue)) - .expect("Unable to make Lossy Transport."), - ) as Box - } else { - Box::new( - UdpTransport::try_from((local_udp, entity_map.clone())) - .expect("Unable to make UDP Transport."), - ) as Box - }; + let remote_transport = if let Some(issue) = remote_transport_issue { + Box::new( + LossyTransport::try_from((remote_udp, entity_map.clone(), issue)) + .expect("Unable to make Lossy Transport."), + ) as Box + } else { + Box::new( + UdpTransport::try_from((remote_udp, entity_map.clone())) + .expect("Unable to make UDP Transport."), + ) as Box + }; - let remote_transport = if let Some(issue) = remote_transport_issue { - Box::new( - LossyTransport::try_from((remote_udp, entity_map.clone(), issue)) - .expect("Unable to make Lossy Transport."), - ) as Box - } else { - Box::new( - UdpTransport::try_from((remote_udp, entity_map.clone())) - .expect("Unable to make UDP Transport."), - ) as Box - }; + let remote_transport_map: HashMap, Box> = + HashMap::from([(vec![EntityID::from(0_u16)], remote_transport)]); - let remote_transport_map: HashMap, Box> = - HashMap::from([(vec![EntityID::from(0_u16)], remote_transport)]); + let local_transport_map: HashMap, Box> = + HashMap::from([(vec![EntityID::from(1_u16)], local_transport)]); - let local_transport_map: HashMap, Box> = - HashMap::from([(vec![EntityID::from(1_u16)], local_transport)]); + create_daemons( + static_assets.filestore.clone(), + local_transport_map, + remote_transport_map, + timeouts, + ) + .await + }); - let (local_user, remote_user, local_handle, remote_handle) = create_daemons( - filestore_fixture.clone(), - local_transport_map, - remote_transport_map, - timeouts, - ); ( local_user, remote_user, - filestore_fixture.clone(), + static_assets.filestore.clone(), local_handle, remote_handle, ) @@ -905,8 +949,8 @@ pub(crate) fn new_entities( #[fixture] #[once] -fn make_entities(filestore_fixture: &Arc) -> EntityConstructorReturn { - new_entities(filestore_fixture, None, None, [None; 3]) +fn make_entities(static_assets: &StaticAssets) -> EntityConstructorReturn { + new_entities(static_assets, None, None, [None; 3]) } pub(crate) type UsersAndFilestore = ( @@ -947,15 +991,12 @@ pub(crate) struct LossyTransport { } impl LossyTransport { #[allow(dead_code)] - pub fn new( + pub async fn new( addr: T, entity_map: HashMap, issue: TransportIssue, ) -> Result { - let socket = UdpSocket::bind(addr)?; - socket.set_read_timeout(Some(Duration::from_secs(1)))?; - socket.set_write_timeout(Some(Duration::from_secs(1)))?; - socket.set_nonblocking(true)?; + let socket = UdpSocket::bind(addr).await?; Ok(Self { socket, entity_map, @@ -978,196 +1019,195 @@ impl TryFrom<(UdpSocket, HashMap, TransportIssue)> for L issue: inputs.2, buffer: vec![], }; - me.socket.set_read_timeout(Some(Duration::from_secs(1)))?; - me.socket.set_write_timeout(Some(Duration::from_secs(1)))?; - me.socket.set_nonblocking(true)?; Ok(me) } } + +#[async_trait] impl PDUTransport for LossyTransport { - fn request(&mut self, destination: VariableID, pdu: PDU) -> Result<(), IoError> { - self.entity_map + async fn request(&mut self, destination: VariableID, pdu: PDU) -> Result<(), IoError> { + let addr = self + .entity_map .get(&destination) - .ok_or_else(|| IoError::from(ErrorKind::AddrNotAvailable)) - .and_then(|addr| { - // send a delayed packet if there are any - if !self.buffer.is_empty() { - let pdu = self.buffer.remove(0); + .ok_or_else(|| IoError::from(ErrorKind::AddrNotAvailable))?; + + // send a delayed packet if there are any + if !self.buffer.is_empty() { + let pdu = self.buffer.remove(0); + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ())?; + } + + match &self.issue { + TransportIssue::Rate(rate) => { + if self.counter % rate == 0 { + self.counter += 1; + Ok(()) + } else { + self.counter += 1; self.socket .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()) + } + } + TransportIssue::Duplicate(rate) => { + if self.counter % rate == 0 { + self.counter += 1; + self.socket + .send_to(pdu.clone().encode().as_slice(), addr) + .await .map(|_n| ())?; + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()) + } else { + self.counter += 1; + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()) } - match &self.issue { - TransportIssue::Rate(rate) => { - if self.counter % rate == 0 { - self.counter += 1; - Ok(()) - } else { - self.counter += 1; - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } + } + TransportIssue::Reorder(rate) => { + if self.counter % rate == 0 { + self.counter += 1; + self.buffer.push(pdu); + Ok(()) + } else { + self.counter += 1; + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()) + } + } + TransportIssue::Once(skip_directive) => match &pdu.payload { + PDUPayload::Directive(operation) => { + if self.counter == 1 && operation.get_directive() == *skip_directive { + self.counter += 1; + Ok(()) + } else { + if operation.get_directive() == *skip_directive {} + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()) } - TransportIssue::Duplicate(rate) => { - if self.counter % rate == 0 { - self.counter += 1; - self.socket - .send_to(pdu.clone().encode().as_slice(), addr) - .map(|_n| ())?; - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } else { + } + PDUPayload::FileData(_data) => self + .socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()), + }, + TransportIssue::All(skip_directive) => match &pdu.payload { + PDUPayload::Directive(operation) => { + if skip_directive.contains(&operation.get_directive()) { + Ok(()) + } else { + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()) + } + } + PDUPayload::FileData(_data) => self + .socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()), + }, + // only drop the PDUs if we have not yet send EoF. + // Flip the counter on EoF to signify we can send again. + TransportIssue::Every => match &pdu.payload { + PDUPayload::Directive(operation) => { + match (self.counter, operation.get_directive()) { + (1, PDUDirective::EoF) => { self.counter += 1; self.socket .send_to(pdu.encode().as_slice(), addr) + .await .map(|_n| ()) } - } - TransportIssue::Reorder(rate) => { - if self.counter % rate == 0 { + (1, PDUDirective::Ack) => { self.counter += 1; - self.buffer.push(pdu); + // increment counter but still don't send it Ok(()) - } else { - self.counter += 1; - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } - } - TransportIssue::Once(skip_directive) => match &pdu.payload { - PDUPayload::Directive(operation) => { - if self.counter == 1 && operation.get_directive() == *skip_directive { - self.counter += 1; - Ok(()) - } else { - if operation.get_directive() == *skip_directive {} - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } } - PDUPayload::FileData(_data) => self + (1, _) => Ok(()), + (_, _) => self .socket .send_to(pdu.encode().as_slice(), addr) + .await .map(|_n| ()), - }, - TransportIssue::All(skip_directive) => match &pdu.payload { - PDUPayload::Directive(operation) => { - if skip_directive.contains(&operation.get_directive()) { - Ok(()) - } else { - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } - } - PDUPayload::FileData(_data) => self - .socket + } + } + PDUPayload::FileData(_data) => { + if self.counter == 1 { + Ok(()) + } else { + self.socket .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()), - }, - // only drop the PDUs if we have not yet send EoF. - // Flip the counter on EoF to signify we can send again. - TransportIssue::Every => match &pdu.payload { - PDUPayload::Directive(operation) => { - match (self.counter, operation.get_directive()) { - (1, PDUDirective::EoF) => { - self.counter += 1; - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } - (1, PDUDirective::Ack) => { - self.counter += 1; - // increment counter but still don't send it - Ok(()) - } - (1, _) => Ok(()), - (_, _) => self - .socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()), - } - } - PDUPayload::FileData(_data) => { - if self.counter == 1 { - Ok(()) - } else { - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } - } - }, - TransportIssue::Inactivity => { - // Send the Metadata PDU only, and nothing else. - if self.counter == 1 { - self.counter += 1; - self.socket - .send_to(pdu.encode().as_slice(), addr) - .map(|_n| ()) - } else { - Ok(()) - } + .await + .map(|_n| ()) } } - }) + }, + TransportIssue::Inactivity => { + // Send the Metadata PDU only, and nothing else. + if self.counter == 1 { + self.counter += 1; + self.socket + .send_to(pdu.encode().as_slice(), addr) + .await + .map(|_n| ()) + } else { + Ok(()) + } + } + } } - fn pdu_handler( + async fn pdu_handler( &mut self, signal: Arc, sender: Sender, - recv: Receiver<(VariableID, PDU)>, + mut recv: Receiver<(VariableID, PDU)>, ) -> Result<(), IoError> { // this buffer will be 511 KiB, should be sufficiently small; let mut buffer = vec![0_u8; u16::MAX as usize]; while !signal.load(Ordering::Relaxed) { - match self.socket.recv_from(&mut buffer) { - Ok(_n) => match PDU::decode(&mut buffer.as_slice()) { - Ok(pdu) => { - match sender.send(pdu) { - Ok(()) => {} - Err(error) => { - error!("Transport found disconnect sending channel: {error}"); - return Err(IoError::from(ErrorKind::ConnectionAborted)); - } - }; - continue; - } - Err(error) => { - error!("Error decoding PDU: {error}"); - // might need to stop depending on the error. - // some are recoverable though + tokio::select! { + Ok((_n, _addr)) = self.socket.recv_from(&mut buffer) => { + match PDU::decode(&mut buffer.as_slice()) { + Ok(pdu) => { + match sender.send(pdu).await { + Ok(()) => {} + Err(error) => { + error!("Channel to daemon severed: {}", error); + return Err(IoError::from(ErrorKind::ConnectionAborted)); + } + }; + } + Err(error) => { + error!("Error decoding PDU: {}", error); + // might need to stop depending on the error. + // some are recoverable though + } } }, - Err(ref e) - if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => - { - // continue to trying to send - } - Err(e) => { - error!("encountered IO error: {e}"); - return Err(e); - } - }; - match recv.try_recv() { - Ok((entity, pdu)) => { - self.request(entity, pdu)?; - continue; - } - Err(crossbeam_channel::TryRecvError::Empty) => { - // nothing to do here - } - Err(err @ crossbeam_channel::TryRecvError::Disconnected) => { - error!("Transport found disconnected channel: {err}"); - return Err(IoError::from(ErrorKind::ConnectionAborted)); + Some((entity, pdu)) = recv.recv() => { + self.request(entity, pdu).await?; + }, + else => { + log::info!("UdpSocket or Channel disconnected"); + break } - }; - thread::sleep(Duration::from_millis(10)) + } } Ok(()) } diff --git a/cfdp-core/tests/series_f1.rs b/cfdp-core/tests/series_f1.rs index e5df2d3..4828a65 100644 --- a/cfdp-core/tests/series_f1.rs +++ b/cfdp-core/tests/series_f1.rs @@ -14,7 +14,8 @@ use rstest::{fixture, rstest}; mod common; use common::{ - get_filestore, new_entities, EntityConstructorReturn, TransportIssue, UsersAndFilestore, + get_filestore, new_entities, static_assets, EntityConstructorReturn, StaticAssets, + TransportIssue, UsersAndFilestore, }; #[rstest] @@ -116,9 +117,9 @@ fn f1s03(get_filestore: &UsersAndFilestore) { #[fixture] #[once] -fn fixture_f1s04(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f1s04(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Rate(13)), None, [None; 3], @@ -160,9 +161,9 @@ fn f1s04(fixture_f1s04: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f1s05(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f1s05(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Duplicate(13)), None, [None; 3], @@ -205,9 +206,9 @@ fn f1s05(fixture_f1s05: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f1s06(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f1s06(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Reorder(13)), None, [None; 3], diff --git a/cfdp-core/tests/series_f2.rs b/cfdp-core/tests/series_f2.rs index 4a7fb68..1db7d53 100644 --- a/cfdp-core/tests/series_f2.rs +++ b/cfdp-core/tests/series_f2.rs @@ -9,15 +9,13 @@ use cfdp_core::{ use rstest::{fixture, rstest}; mod common; -use common::{ - get_filestore, new_entities, EntityConstructorReturn, TransportIssue, UsersAndFilestore, -}; +use common::{new_entities, static_assets, EntityConstructorReturn, StaticAssets, TransportIssue}; #[fixture] #[once] -fn fixture_f2s01(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s01(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Once(PDUDirective::Metadata)), None, [None; 3], @@ -61,9 +59,9 @@ fn f2s01(fixture_f2s01: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s02(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s02(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Once(PDUDirective::EoF)), None, [None; 3], @@ -106,9 +104,9 @@ fn f2s02(fixture_f2s02: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s03(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s03(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Once(PDUDirective::Finished)), None, [None; 3], @@ -151,9 +149,9 @@ fn f2s03(fixture_f2s03: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s04(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s04(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, None, Some(TransportIssue::Once(PDUDirective::Ack)), [None; 3], @@ -196,9 +194,9 @@ fn f2s04(fixture_f2s04: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s05(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s05(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Once(PDUDirective::Ack)), None, [None; 3], @@ -241,9 +239,9 @@ fn f2s05(fixture_f2s05: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s06(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s06(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Every), Some(TransportIssue::Every), [None; 3], @@ -286,9 +284,9 @@ fn f2s06(fixture_f2s06: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s07(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s07(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, None, Some(TransportIssue::All(vec![ PDUDirective::Finished, @@ -349,9 +347,9 @@ fn f2s07(fixture_f2s07: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s08(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s08(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::All(vec![PDUDirective::Metadata])), Some(TransportIssue::All(vec![PDUDirective::Nak])), [Some(10), Some(1), Some(1)], @@ -406,9 +404,9 @@ fn f2s08(fixture_f2s08: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s09(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s09(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, None, Some(TransportIssue::All(vec![PDUDirective::Finished])), [Some(1), Some(10), Some(1)], @@ -464,9 +462,9 @@ fn f2s09(fixture_f2s09: &'static EntityConstructorReturn) { #[fixture] #[once] -fn fixture_f2s10(get_filestore: &UsersAndFilestore) -> EntityConstructorReturn { +fn fixture_f2s10(static_assets: &StaticAssets) -> EntityConstructorReturn { new_entities( - &get_filestore.2, + static_assets, Some(TransportIssue::Inactivity), None, [Some(1), Some(10), Some(10)],