Skip to content

Commit

Permalink
Feature: Add MPSC channel to AsyncRuntime
Browse files Browse the repository at this point in the history
`AsyncRuntime` trait defines the async-runtime such as tokio to run
Openraft. This commit add MPSC abstraction to `AsyncRuntime` and
MPSC implementations to tokio based runtime and monoio based runtime.
  • Loading branch information
drmingdrmer committed Aug 17, 2024
1 parent e54a1fd commit f8c2f8b
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 3 deletions.
83 changes: 83 additions & 0 deletions openraft/src/testing/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ use std::task::Poll;

use crate::async_runtime::watch::WatchReceiver;
use crate::async_runtime::watch::WatchSender;
use crate::async_runtime::Mpsc;
use crate::async_runtime::MpscReceiver;
use crate::async_runtime::MpscSender;
use crate::async_runtime::MpscUnboundedWeakSender;
use crate::async_runtime::MpscWeakSender;
use crate::instant::Instant;
use crate::type_config::async_runtime::mpsc_unbounded::MpscUnbounded;
use crate::type_config::async_runtime::mpsc_unbounded::MpscUnboundedReceiver;
Expand Down Expand Up @@ -43,11 +47,19 @@ impl<Rt: AsyncRuntime> Suite<Rt> {
Self::test_sleep_until().await;
Self::test_timeout().await;
Self::test_timeout_at().await;

Self::test_mpsc_recv_empty().await;
Self::test_mpsc_recv_channel_closed().await;
Self::test_mpsc_weak_sender_wont_prevent_channel_close().await;
Self::test_mpsc_weak_sender_upgrade().await;
Self::test_mpsc_send().await;

Self::test_unbounded_mpsc_recv_empty().await;
Self::test_unbounded_mpsc_recv_channel_closed().await;
Self::test_unbounded_mpsc_weak_sender_wont_prevent_channel_close().await;
Self::test_unbounded_mpsc_weak_sender_upgrade().await;
Self::test_unbounded_mpsc_send().await;

Self::test_watch_init_value().await;
Self::test_watch_overwrite_init_value().await;
Self::test_watch_send_error_no_receiver().await;
Expand Down Expand Up @@ -131,6 +143,77 @@ impl<Rt: AsyncRuntime> Suite<Rt> {
assert!(timeout_result.is_err());
}

pub async fn test_mpsc_recv_empty() {
let (_tx, mut rx) = Rt::Mpsc::channel::<()>(5);
let recv_err = rx.try_recv().unwrap_err();
assert!(matches!(recv_err, TryRecvError::Empty));
}

pub async fn test_mpsc_recv_channel_closed() {
let (_, mut rx) = Rt::Mpsc::channel::<()>(5);
let recv_err = rx.try_recv().unwrap_err();
assert!(matches!(recv_err, TryRecvError::Disconnected));

let recv_result = rx.recv().await;
assert!(recv_result.is_none());
}

pub async fn test_mpsc_weak_sender_wont_prevent_channel_close() {
let (tx, mut rx) = Rt::Mpsc::channel::<()>(5);

let _weak_tx = tx.downgrade();
drop(tx);
let recv_err = rx.try_recv().unwrap_err();
assert!(matches!(recv_err, TryRecvError::Disconnected));

let recv_result = rx.recv().await;
assert!(recv_result.is_none());
}

pub async fn test_mpsc_weak_sender_upgrade() {
let (tx, _rx) = Rt::Mpsc::channel::<()>(5);

let weak_tx = tx.downgrade();
let opt_tx = weak_tx.upgrade();
assert!(opt_tx.is_some());

drop(tx);
drop(opt_tx);
// now there is no Sender instances alive

let opt_tx = weak_tx.upgrade();
assert!(opt_tx.is_none());
}

pub async fn test_mpsc_send() {
let (tx, mut rx) = Rt::Mpsc::channel::<usize>(5);
let tx = Arc::new(tx);

let n_senders = 10_usize;
let recv_expected = (0..n_senders).collect::<Vec<_>>();

for idx in 0..n_senders {
let tx = tx.clone();
// no need to wait for senders here, we wait by recv()ing
let _handle = Rt::spawn(async move {
tx.send(idx).await.unwrap();
});
}

let mut recv = Vec::with_capacity(n_senders);
while let Some(recv_number) = rx.recv().await {
recv.push(recv_number);

if recv.len() == n_senders {
break;
}
}

recv.sort();

assert_eq!(recv_expected, recv);
}

pub async fn test_unbounded_mpsc_recv_empty() {
let (_tx, mut rx) = Rt::MpscUnbounded::channel::<()>();
let recv_err = rx.try_recv().unwrap_err();
Expand Down
7 changes: 7 additions & 0 deletions openraft/src/type_config/async_runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub(crate) mod tokio_impls {
mod tokio_runtime;
pub use tokio_runtime::TokioRuntime;
}
pub mod mpsc;
pub mod mpsc_unbounded;
pub mod mutex;
pub mod oneshot;
Expand All @@ -19,6 +20,10 @@ use std::fmt::Display;
use std::future::Future;
use std::time::Duration;

pub use mpsc::Mpsc;
pub use mpsc::MpscReceiver;
pub use mpsc::MpscSender;
pub use mpsc::MpscWeakSender;
pub use mpsc_unbounded::MpscUnbounded;
pub use mpsc_unbounded::MpscUnboundedReceiver;
pub use mpsc_unbounded::MpscUnboundedSender;
Expand Down Expand Up @@ -99,6 +104,8 @@ pub trait AsyncRuntime: Debug + Default + PartialEq + Eq + OptionalSend + Option
/// sent to another thread.
fn thread_rng() -> Self::ThreadLocalRng;

type Mpsc: Mpsc;

type MpscUnbounded: MpscUnbounded;

type Watch: Watch;
Expand Down
73 changes: 73 additions & 0 deletions openraft/src/type_config/async_runtime/mpsc/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::future::Future;

use base::OptionalSend;
use base::OptionalSync;

/// mpsc shares the same error types as mpsc_unbounded
pub use super::mpsc_unbounded::SendError;
pub use super::mpsc_unbounded::TryRecvError;
use crate::base;

/// Multi-producer, single-consumer channel.
pub trait Mpsc: Sized + OptionalSend {
type Sender<T: OptionalSend>: MpscSender<Self, T>;
type Receiver<T: OptionalSend>: MpscReceiver<T>;
type WeakSender<T: OptionalSend>: MpscWeakSender<Self, T>;

/// Creates an unbounded mpsc channel for communicating between asynchronous
/// tasks without backpressure.
fn channel<T: OptionalSend>(buffer: usize) -> (Self::Sender<T>, Self::Receiver<T>);
}

/// Send values to the associated [`MpscReceiver`].
pub trait MpscSender<MU, T>: OptionalSend + OptionalSync + Clone
where
MU: Mpsc,
T: OptionalSend,
{
/// Attempts to send a message without blocking.
///
/// If the receiving half of the channel is closed, this
/// function returns an error. The error includes the value passed to `send`.
fn send(&self, msg: T) -> impl Future<Output = Result<(), SendError<T>>> + OptionalSend;

/// Converts the [`MpscSender`] to a [`MpscWeakSender`] that does not count
/// towards RAII semantics, i.e. if all `Sender` instances of the
/// channel were dropped and only `WeakSender` instances remain,
/// the channel is closed.
fn downgrade(&self) -> MU::WeakSender<T>;
}

/// Receive values from the associated [`MpscSender`].
pub trait MpscReceiver<T>: OptionalSend + OptionalSync {
/// Receives the next value for this receiver.
///
/// This method returns `None` if the channel has been closed and there are
/// no remaining messages in the channel's buffer.
fn recv(&mut self) -> impl Future<Output = Option<T>> + OptionalSend;

/// Tries to receive the next value for this receiver.
///
/// This method returns the [`TryRecvError::Empty`] error if the channel is currently
/// empty, but there are still outstanding senders.
///
/// This method returns the [`TryRecvError::Disconnected`] error if the channel is
/// currently empty, and there are no outstanding senders.
fn try_recv(&mut self) -> Result<T, TryRecvError>;
}

/// A sender that does not prevent the channel from being closed.
///
/// If all [`MpscSender`] instances of a channel were dropped and only
/// `WeakSender` instances remain, the channel is closed.
pub trait MpscWeakSender<MU, T>: OptionalSend + OptionalSync + Clone
where
MU: Mpsc,
T: OptionalSend,
{
/// Tries to convert a [`MpscWeakSender`] into an [`MpscSender`].
///
/// This will return `Some` if there are other `Sender` instances alive and
/// the channel wasn't previously dropped, otherwise `None` is returned.
fn upgrade(&self) -> Option<MU::Sender<T>>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ impl AsyncRuntime for TokioRuntime {
rand::thread_rng()
}

type Mpsc = mpsc_impl::TokioMpsc;
type MpscUnbounded = TokioMpscUnbounded;
type Watch = TokioWatch;
type Oneshot = TokioOneshot;
Expand Down Expand Up @@ -134,6 +135,75 @@ where T: OptionalSend
}
}

mod mpsc_impl {
use std::future::Future;

use futures::TryFutureExt;
use tokio::sync::mpsc;

use crate::async_runtime::Mpsc;
use crate::async_runtime::MpscReceiver;
use crate::async_runtime::MpscSender;
use crate::async_runtime::MpscWeakSender;
use crate::async_runtime::SendError;
use crate::async_runtime::TryRecvError;
use crate::OptionalSend;

pub struct TokioMpsc;

impl Mpsc for TokioMpsc {
type Sender<T: OptionalSend> = mpsc::Sender<T>;
type Receiver<T: OptionalSend> = mpsc::Receiver<T>;
type WeakSender<T: OptionalSend> = mpsc::WeakSender<T>;

/// Creates an unbounded mpsc channel for communicating between asynchronous
/// tasks without backpressure.
fn channel<T: OptionalSend>(buffer: usize) -> (Self::Sender<T>, Self::Receiver<T>) {
mpsc::channel(buffer)
}
}

impl<T> MpscSender<TokioMpsc, T> for mpsc::Sender<T>
where T: OptionalSend
{
#[inline]
fn send(&self, msg: T) -> impl Future<Output = Result<(), SendError<T>>> + OptionalSend {
self.send(msg).map_err(|e| SendError(e.0))
}

#[inline]
fn downgrade(&self) -> <TokioMpsc as Mpsc>::WeakSender<T> {
self.downgrade()
}
}

impl<T> MpscReceiver<T> for mpsc::Receiver<T>
where T: OptionalSend
{
#[inline]
fn recv(&mut self) -> impl Future<Output = Option<T>> + OptionalSend {
self.recv()
}

#[inline]
fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.try_recv().map_err(|e| match e {
mpsc::error::TryRecvError::Empty => TryRecvError::Empty,
mpsc::error::TryRecvError::Disconnected => TryRecvError::Disconnected,
})
}
}

impl<T> MpscWeakSender<TokioMpsc, T> for mpsc::WeakSender<T>
where T: OptionalSend
{
#[inline]
fn upgrade(&self) -> Option<<TokioMpsc as Mpsc>::Sender<T>> {
self.upgrade()
}
}
}

pub struct TokioWatch;

impl watch::Watch for TokioWatch {
Expand Down
7 changes: 5 additions & 2 deletions rt-monoio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ repository = "https://github.com/datafuselabs/openraft"
openraft = { path = "../openraft", version = "0.10.0", default-features = false, features = ["singlethreaded"] }

rand = "0.8"
tokio = { version = "1.22", features = ["sync"] }
monoio = "0.2.3"

futures = { version = "0.3" }
local-sync = "0.1.1"

monoio = "0.2.3"
tokio = { version = "1.22", features = ["sync"] }
Loading

0 comments on commit f8c2f8b

Please sign in to comment.