Skip to content

Commit

Permalink
Use chashmap to remove mutexes (#42)
Browse files Browse the repository at this point in the history
Remove mutexes - using concurrent hashmaps when necessary instead - in order to allow async data send
  • Loading branch information
stuqdog committed Sep 20, 2022
1 parent a2c008c commit 86ce462
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 82 deletions.
55 changes: 18 additions & 37 deletions src/rpc/client_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@ use anyhow::Result;
use chashmap::CHashMap;
use hyper::Body;
use prost::Message;
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering},
Arc, Mutex,
},
use std::sync::{
atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering},
Arc,
};
use webrtc::{
data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel},
Expand All @@ -26,7 +23,7 @@ pub struct WebRTCClientChannel {
pub base_channel: Arc<WebRTCBaseChannel>,
stream_id_counter: AtomicU64,
message_ready: Arc<AtomicBool>,
pub streams: Mutex<HashMap<u64, ActiveWebRTCClientStream>>,
pub streams: CHashMap<u64, WebRTCClientStream>,
pub receiver_bodies: CHashMap<u64, hyper::Body>,
}

Expand All @@ -39,7 +36,7 @@ impl WebRTCClientChannel {
let channel = Self {
base_channel,
message_ready: Arc::new(AtomicBool::new(false)),
streams: Mutex::new(HashMap::new()),
streams: CHashMap::new(),
stream_id_counter: AtomicU64::new(0),
receiver_bodies: CHashMap::new(),
};
Expand All @@ -51,7 +48,7 @@ impl WebRTCClientChannel {
.on_message(Box::new(move |msg: DataChannelMessage| {
let channel = channel.clone();
Box::pin(async move {
if let Err(e) = channel.on_channel_message(msg) {
if let Err(e) = channel.on_channel_message(msg).await {
log::error!("error deserializing message: {e}");
}
})
Expand All @@ -60,37 +57,32 @@ impl WebRTCClientChannel {
ret_channel
}

pub fn new_stream(&self) -> Arc<Mutex<WebRTCClientStream>> {
pub fn new_stream(&self) -> Stream {
let id = self.stream_id_counter.fetch_add(1, Ordering::AcqRel);
let stream = Stream { id };

let (message_sender, receiver_body) = hyper::Body::channel();

let base_stream = WebRTCBaseStream {
stream,
stream: stream.clone(),
message_sender,
closed: AtomicBool::new(false),
packet_buffer: Vec::new(),
closed_reason: AtomicPtr::new(&mut None),
};

let client_stream = Arc::new(Mutex::new(WebRTCClientStream {
let client_stream = WebRTCClientStream {
base_stream,
message_sent: AtomicBool::new(false),
headers_received: AtomicBool::new(false),
trailers_received: AtomicBool::new(false),
}));

let stream = ActiveWebRTCClientStream {
client_stream: client_stream.clone(),
};
let _ = self.streams.lock().unwrap().insert(id, stream);

let _ = self.streams.insert(id, client_stream);
let _ = self.receiver_bodies.insert(id, receiver_body);
client_stream
stream
}

fn on_channel_message(&self, msg: DataChannelMessage) -> Result<()> {
let streams = self.streams.lock().unwrap();
async fn on_channel_message(&self, msg: DataChannelMessage) -> Result<()> {
let response = Response::decode(&*msg.data.to_vec())?;
let should_drop_stream = match response.r#type {
Some(RespType::Trailers(_)) => true,
Expand All @@ -106,7 +98,7 @@ impl WebRTCClientChannel {
}
Some(stream) => {
let id: u64 = stream.id;
let stream = streams.get(&stream.id).ok_or_else(|| {
let stream = self.streams.get_mut(&stream.id).ok_or_else(|| {
anyhow::anyhow!(
"No stream found for id {}: discarding response {:?}",
&stream.id,
Expand All @@ -118,19 +110,14 @@ impl WebRTCClientChannel {
};

let message_sent = match active_stream {
Ok(active_stream) => active_stream
.client_stream
.lock()
.unwrap()
.on_response(response),
Ok(mut active_stream) => active_stream.on_response(response).await,
Err(e) => {
log::error!("{e}");
return Ok(());
}
}?;
drop(streams);
if should_drop_stream {
self.streams.lock().unwrap().remove(&stream_id);
self.streams.remove(&stream_id);
}
self.message_ready.store(message_sent, Ordering::Release);
Ok(())
Expand Down Expand Up @@ -247,14 +234,8 @@ impl WebRTCClientChannel {
}

pub fn close_stream_with_recv_error(&self, stream_id: u64, error: anyhow::Error) {
let mut stream_lock = self.streams.lock().unwrap();
match stream_lock.remove(&stream_id) {
Some(stream) => stream
.client_stream
.lock()
.unwrap()
.base_stream
.close_with_recv_error(&mut Some(&error)),
match self.streams.remove(&stream_id) {
Some(stream) => stream.base_stream.close_with_recv_error(&mut Some(&error)),
None => {
log::error!("attempted to close stream with id {stream_id}, but it wasn't found!")
}
Expand Down
57 changes: 21 additions & 36 deletions src/rpc/client_stream.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
use super::base_stream::*;
use super::{base_stream::*, webrtc::trailers_from_proto};
use crate::gen::proto::rpc::webrtc::v1::{
response::Type, Response, ResponseHeaders, ResponseMessage, ResponseTrailers,
};
use anyhow::Result;
use byteorder::{BigEndian, WriteBytesExt};
use bytes::Bytes;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};

pub struct ActiveWebRTCClientStream {
pub client_stream: Arc<Mutex<WebRTCClientStream>>,
}
use std::sync::atomic::{AtomicBool, Ordering};

pub struct WebRTCClientStream {
pub base_stream: WebRTCBaseStream,
Expand All @@ -28,7 +21,7 @@ impl WebRTCClientStream {

// processes a response message, returns true if and only if message was actually
// processed
fn process_message(&mut self, response: ResponseMessage) -> Result<bool> {
async fn process_message(&mut self, response: ResponseMessage) -> Result<bool> {
let mut message_processed = false;
if let Some(message) = response.packet_message {
match self.base_stream.process_message(message) {
Expand All @@ -40,27 +33,10 @@ impl WebRTCClientStream {
message_buf.write_u32::<BigEndian>(len)?;
message_buf.append(&mut data);
let data = Bytes::from(message_buf);
// when sending data synchronously, we run the risk of sending
// too many messages too quickly, which results in messages being
// dropped. Unfortunately we cannot send asynchronously under the
// current code structure because mutexes don't play nicely with
// async. this loop is ugly, but ensures that we properly send
// all messages while still failing in cases where send is just
// not possible.
// TODO(RSDK-651) - we should refactor here so we can use the
// async `send_data(data)` fn, eliminating the need for this loop.
let now = std::time::SystemTime::now();
let timeout = std::time::Duration::from_secs(5);
loop {
if now.elapsed().unwrap() >= timeout {
return Err(anyhow::anyhow!("Error sending data: {data:?}"));
}
if let Ok(()) =
self.base_stream.message_sender.try_send_data(data.clone())
{
break;
}
}
self.base_stream
.message_sender
.send_data(data.clone())
.await?;
}
}

Expand All @@ -72,9 +48,18 @@ impl WebRTCClientStream {
Ok(message_processed)
}

fn process_trailers(&mut self, trailers: ResponseTrailers) {
self.trailers_received.store(true, Ordering::Release);
async fn process_trailers(&mut self, trailers: ResponseTrailers) {
let trailers_to_send = trailers_from_proto(trailers.clone());
if let Err(e) = self
.base_stream
.message_sender
.send_trailers(trailers_to_send)
.await
{
log::error!("Error sending trailers to http response: {e}");
}

self.trailers_received.store(true, Ordering::Release);
let err = match trailers.status {
None => None,
Some(status) => {
Expand All @@ -93,7 +78,7 @@ impl WebRTCClientStream {

// processes response. returns true if and only if a message was sent and we're done
// processing (i.e., trailers were processed)
pub fn on_response(&mut self, response: Response) -> Result<bool> {
pub async fn on_response(&mut self, response: Response) -> Result<bool> {
match &response.r#type {
Some(Type::Headers(headers)) => {
if self.headers_received.load(Ordering::Acquire) {
Expand Down Expand Up @@ -124,11 +109,11 @@ impl WebRTCClientStream {
return Err(err);
}

self.process_message(message.to_owned())
self.process_message(message.to_owned()).await
}

Some(Type::Trailers(trailers)) => {
self.process_trailers(trailers.to_owned());
self.process_trailers(trailers.to_owned()).await;
Ok(true)
}
None => Ok(false),
Expand Down
8 changes: 1 addition & 7 deletions src/rpc/dial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,7 @@ impl Service<http::Request<BoxBody>> for ViamChannel {
let fut = async move {
let (parts, body) = request.into_parts();

let stream = channel
.new_stream()
.lock()
.unwrap()
.base_stream
.stream
.clone();
let stream = channel.new_stream();
let stream_id = stream.id;
let metadata = Some(metadata_from_parts(&parts));
let headers = RequestHeaders {
Expand Down
51 changes: 49 additions & 2 deletions src/rpc/webrtc.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::gen::proto::rpc::webrtc::v1::{IceServer, WebRtcConfig};
use crate::gen::proto::rpc::webrtc::v1::{IceServer, ResponseTrailers, WebRtcConfig};
use anyhow::Result;
use bytes::Bytes;
use core::fmt;
use futures::Future;
use http::Uri;
use http::{header::HeaderName, HeaderMap, HeaderValue, Uri};
use std::{
hint,
pin::Pin,
str::FromStr,
sync::{atomic::AtomicBool, Arc},
task::{Context, Poll},
time::Duration,
Expand Down Expand Up @@ -303,3 +304,49 @@ pub async fn webrtc_action_with_timeout<T>(f: impl Future<Output = T>) -> Result
}
}
}

pub fn trailers_from_proto(proto: ResponseTrailers) -> HeaderMap {
let mut trailers = HeaderMap::new();
if let Some(metadata) = proto.metadata {
let mut vals = metadata.md.iter();
while let Some((k, v)) = vals.next() {
let k = HeaderName::from_str(k);
let v = HeaderValue::from_str(&v.values.concat());
let (k, v) = match (k, v) {
(Ok(k), Ok(v)) => (k, v),
(Err(e), _) => {
log::error!("Error converting proto trailer key: [{e}]");
continue;
}
(_, Err(e)) => {
log::error!("Error converting proto trailer value: [{e}]");
continue;
}
};
trailers.insert(k, v);
}
};

let status_name = "grpc-status";
let status_code = match proto.status {
Some(ref status) => status.code.to_string(),
None => "0".to_string(),
};

let k = match HeaderName::from_str(status_name) {
Ok(k) => k,
Err(e) => {
log::error!("Error parsing HeaderName: {e}");
return trailers;
}
};
let v = match HeaderValue::from_str(&status_code) {
Ok(v) => v,
Err(e) => {
log::error!("Error parsing HeaderValue: {e}");
return trailers;
}
};
trailers.insert(k, v);
trailers
}

0 comments on commit 86ce462

Please sign in to comment.