Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the implementation a bit #57

Merged
merged 4 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ all APIs might be changed.

## Unreleased - xxxx-xx-xx

### Breaking Changes

- Subscription IDs sent to the server are now just monotonic numbers rather
than uuids.
- `SubscriptionStream` no longer takes `GraphqlClient` as a generic parameter

## v0.7.0 - 2024-01-03

### Breaking Changes
Expand Down
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ members = ["examples", "examples-wasm"]
default = ["async-tungstenite"]
client-cynic = ["async-tungstenite", "cynic"]
client-graphql-client = ["async-tungstenite", "graphql_client"]
ws_stream_wasm = ["dep:ws_stream_wasm", "uuid/js", "no-logging", "pharos", "pin-project-lite"]
ws_stream_wasm = ["dep:ws_stream_wasm", "no-logging", "pharos", "pin-project-lite"]
no-logging = []

[dependencies]
Expand All @@ -31,7 +31,6 @@ pin-project = "1"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "1.0"
uuid = { version = "1.0", features = ["v4"] }

cynic = { version = "3", optional = true }
async-tungstenite = { version = "0.24", optional = true }
Expand Down
113 changes: 37 additions & 76 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
use std::{collections::HashMap, marker::PhantomData, pin::Pin, sync::Arc};
use std::{
collections::HashMap,
marker::PhantomData,
pin::Pin,
sync::{
atomic::{self, AtomicU64},
Arc,
},
};

use futures::{
channel::{mpsc, oneshot},
channel::mpsc,
future::RemoteHandle,
lock::Mutex,
sink::{Sink, SinkExt},
stream::{Stream, StreamExt},
task::{Context, Poll, SpawnExt},
};
use serde::Serialize;
use uuid::Uuid;

use super::{
graphql::{self, GraphqlOperation},
Expand All @@ -27,6 +34,7 @@ where
{
inner: Arc<ClientInner<GraphqlClient>>,
sender_sink: mpsc::Sender<WsMessage>,
next_id: AtomicU64,
phantom: PhantomData<GraphqlClient>,
}

Expand Down Expand Up @@ -133,15 +141,14 @@ where

let (mut sender_sink, sender_stream) = mpsc::channel(1);

let (shutdown_sender, shutdown_receiver) = oneshot::channel();

let sender_handle = runtime
.spawn_with_handle(sender_loop(
sender_stream,
websocket_sink,
Arc::clone(&operations),
shutdown_receiver,
))
.spawn_with_handle(async move {
sender_stream
.map(Ok)
.forward(websocket_sink)
.await
.map_err(|error| Error::Send(error.to_string()))
})
.map_err(|err| Error::SpawnHandle(err.to_string()))?;

// wait for ack before entering receiver loop:
Expand Down Expand Up @@ -185,7 +192,6 @@ where
websocket_stream,
sender_sink.clone(),
Arc::clone(&operations),
shutdown_sender,
))
.map_err(|err| Error::SpawnHandle(err.to_string()))?;

Expand All @@ -195,6 +201,7 @@ where
operations,
sender_handle,
}),
next_id: 0.into(),
sender_sink,
phantom: PhantomData,
})
Expand All @@ -218,12 +225,12 @@ where
pub async fn streaming_operation<'a, Operation>(
&mut self,
op: Operation,
) -> Result<SubscriptionStream<GraphqlClient, Operation>, Error>
) -> Result<SubscriptionStream<Operation>, Error>
where
Operation:
GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin + Send + 'static,
{
let id = Uuid::new_v4();
let id = self.next_id.fetch_add(1, atomic::Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(SUBSCRIPTION_BUFFER_SIZE);

self.inner.operations.lock().await.insert(id, sender);
Expand All @@ -242,7 +249,7 @@ where
let mut sender_clone = self.sender_sink.clone();
let id_clone = id.to_string();

Ok(SubscriptionStream::<GraphqlClient, Operation> {
Ok(SubscriptionStream::<Operation> {
id: id.to_string(),
stream: Box::pin(receiver.map(move |response| {
op.decode(response)
Expand All @@ -260,7 +267,6 @@ where
Ok(())
})
}),
phantom: PhantomData,
})
}
}
Expand All @@ -269,32 +275,28 @@ where
///
/// Emits an item for each message received by the subscription.
#[pin_project::pin_project]
pub struct SubscriptionStream<GraphqlClient, Operation>
pub struct SubscriptionStream<Operation>
where
GraphqlClient: graphql::GraphqlClient,
Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response>,
Operation: GraphqlOperation,
{
id: String,
stream: Pin<Box<dyn Stream<Item = Result<Operation::Response, Error>> + Send>>,
cancel_func: Box<dyn FnOnce() -> futures::future::BoxFuture<'static, Result<(), Error>> + Send>,
phantom: PhantomData<GraphqlClient>,
}

impl<GraphqlClient, Operation> SubscriptionStream<GraphqlClient, Operation>
impl<Operation> SubscriptionStream<Operation>
where
GraphqlClient: graphql::GraphqlClient + Send,
Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Send,
Operation: GraphqlOperation + Send,
{
/// Stops the operation by sending a Complete message to the server.
pub async fn stop_operation(self) -> Result<(), Error> {
(self.cancel_func)().await
}
}

impl<GraphqlClient, Operation> Stream for SubscriptionStream<GraphqlClient, Operation>
impl<Operation> Stream for SubscriptionStream<Operation>
where
GraphqlClient: graphql::GraphqlClient,
Operation: GraphqlOperation<GenericResponse = GraphqlClient::Response> + Unpin,
Operation: GraphqlOperation + Unpin,
{
type Item = Result<Operation::Response, Error>;

Expand All @@ -305,13 +307,12 @@ where

type OperationSender<GenericResponse> = mpsc::Sender<GenericResponse>;

type OperationMap<GenericResponse> = Arc<Mutex<HashMap<Uuid, OperationSender<GenericResponse>>>>;
type OperationMap<GenericResponse> = Arc<Mutex<HashMap<u64, OperationSender<GenericResponse>>>>;

async fn receiver_loop<S, WsMessage, GraphqlClient>(
mut receiver: S,
mut sender: mpsc::Sender<WsMessage>,
operations: OperationMap<GraphqlClient::Response>,
shutdown: oneshot::Sender<()>,
) -> Result<(), Error>
where
S: Stream<Item = Result<WsMessage, WsMessage::Error>> + Unpin,
Expand All @@ -330,9 +331,10 @@ where
}
}

shutdown
.send(())
.map_err(|_| Error::SenderShutdown("Couldn't shutdown sender".to_owned()))
// Clear out any operations
operations.lock().await.clear();

Ok(())
}

async fn handle_message<WsMessage, GraphqlClient>(
Expand All @@ -355,7 +357,10 @@ where
};

let id = match event.id() {
Some(id) => Some(Uuid::parse_str(id).map_err(|err| Error::Decode(err.to_string()))?),
Some(id) => Some(
id.parse::<u64>()
.map_err(|err| Error::Decode(err.to_string()))?,
),
None => None,
};

Expand Down Expand Up @@ -414,50 +419,6 @@ where
Ok(())
}

async fn sender_loop<M, S, E, GenericResponse>(
message_stream: mpsc::Receiver<M>,
mut ws_sender: S,
operations: OperationMap<GenericResponse>,
shutdown: oneshot::Receiver<()>,
) -> Result<(), Error>
where
M: WebsocketMessage,
S: Sink<M, Error = E> + Unpin,
E: std::error::Error,
{
use futures::{future::FutureExt, select};

let mut message_stream = message_stream.fuse();
let mut shutdown = shutdown.fuse();

loop {
select! {
msg = message_stream.next() => {
if let Some(msg) = msg {
trace!("Sending message: {:?}", msg);
ws_sender
.send(msg)
.await
.map_err(|err| Error::Send(err.to_string()))?;
} else {
return Ok(());
}
}
_ = shutdown => {
// Shutdown the incoming message stream
let mut message_stream = message_stream.into_inner();
message_stream.close();
while message_stream.next().await.is_some() {}

// Clear out any operations
operations.lock().await.clear();

return Ok(());
}
}
}
}

struct ClientInner<GraphqlClient>
where
GraphqlClient: crate::graphql::GraphqlClient,
Expand Down
Loading