From 1b77a4c03cfdd1880425f4bf8564e036dfc88f06 Mon Sep 17 00:00:00 2001 From: Graeme Coupar Date: Tue, 30 Jan 2024 16:12:22 +0000 Subject: [PATCH] Implement the new API (#61) As outlined in #59 - I'm not happy with the current API. This PR introduces most of the new API I want to support - this is basically ended up as a rewrite rather than the refactor I'd intended. But that has some advantages - I'm probably going to release v0.8.0 with the old API still intact but marked as deprecated. I think this might provide a nicer update experience than being faced with an immediate wall of compiler errors. Future PRs will continue with the rest of the work. --- CHANGELOG.md | 6 + Cargo.toml | 2 +- src/client.rs | 9 +- src/lib.rs | 4 + src/next/actor.rs | 233 +++++++++++++++++++++++++++++++ src/next/builder.rs | 115 +++++++++++++++ src/next/connection.rs | 53 +++++++ src/next/mod.rs | 108 ++++++++++++++ src/next/stream.rs | 47 +++++++ src/protocol.rs | 1 - tests/subscription_server/mod.rs | 2 +- 11 files changed, 574 insertions(+), 6 deletions(-) create mode 100644 src/next/actor.rs create mode 100644 src/next/builder.rs create mode 100644 src/next/connection.rs create mode 100644 src/next/mod.rs create mode 100644 src/next/stream.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 828c19e..d93a3ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ This project intends to inhere to [Semantic Versioning](http://semver.org/spec/v2.0.0.html), but has not yet reached 1.0 so all APIs might be changed. +## Unreleased - xxxx-xx-xx + +### Breaking Changes + +- `Error::Close` now has a code as well as a reason. + ## v0.8.0-alpha.1 - 2024-01-19 ### Breaking Changes diff --git a/Cargo.toml b/Cargo.toml index 4232370..1b8b691 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ ws_stream_wasm = ["dep:ws_stream_wasm", "no-logging", "pharos", "pin-project-lit no-logging = [] [dependencies] +async-trait = "0.1" futures = "0.3" log = "0.4" pin-project = "1" @@ -47,7 +48,6 @@ async-graphql-axum = "5" async-tungstenite = { version = "0.24", features = ["tokio-runtime"] } axum = "0.6" cynic = { version = "3" } -futures-util = "0.3" insta = "1.11" tokio = { version = "1", features = ["macros"] } tokio-stream = { version = "0.1", features = ["sync"] } diff --git a/src/client.rs b/src/client.rs index 420f13c..4e3be8a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -44,11 +44,14 @@ pub enum Error { #[error("{0}: {1}")] Custom(String, String), /// Unexpected close frame - #[error("got close frame, reason: {0}")] - Close(String), + #[error("got close frame. code: {0}, reason: {1}")] + Close(u16, String), /// Decoding / parsing error #[error("message decode error, reason: {0}")] Decode(String), + /// Serializing error + #[error("couldn't serialize message, reason: {0}")] + Serializing(String), /// Sending error #[error("message sending error, reason: {0}")] Send(String), @@ -408,7 +411,7 @@ fn decode_message( if message.is_ping() || message.is_pong() { Ok(None) } else if message.is_close() { - Err(Error::Close(message.error_message().unwrap_or_default())) + Err(Error::Close(0, message.error_message().unwrap_or_default())) } else if let Some(s) = message.text() { trace!("Decoding message: {}", s); Ok(Some( diff --git a/src/lib.rs b/src/lib.rs index 6e8e1bf..cfd02e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,10 @@ mod protocol; pub mod graphql; pub mod websockets; +// TODO: next shouldn't be public really, and shouldn't allow missing_docs +#[allow(missing_docs)] +pub mod next; + #[cfg(feature = "ws_stream_wasm")] mod wasm; #[cfg(feature = "ws_stream_wasm")] diff --git a/src/next/actor.rs b/src/next/actor.rs new file mode 100644 index 0000000..b1f4d1c --- /dev/null +++ b/src/next/actor.rs @@ -0,0 +1,233 @@ +use std::{ + collections::{hash_map::Entry, HashMap}, + future::{Future, IntoFuture}, +}; + +use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt}; +use serde_json::{json, Value}; + +use crate::{logging::trace, protocol::Event, Error}; + +use super::{ + connection::{Connection, Message}, + ConnectionCommand, +}; + +#[must_use] +pub struct ConnectionActor { + client: Option>, + connection: Box, + operations: HashMap>, +} + +impl ConnectionActor { + pub(super) fn new( + connection: Box, + client: mpsc::Receiver, + ) -> Self { + ConnectionActor { + client: Some(client), + connection, + operations: HashMap::new(), + } + } + + async fn run(mut self) { + while let Some(next) = self.next().await { + let response = match next { + Next::Command(cmd) => self.handle_command(cmd).await, + Next::Message(message) => self.handle_message(message).await, + }; + + let Some(response) = response else { continue }; + + if matches!(response, Message::Close { .. }) { + self.connection.send(response).await.ok(); + return; + } + + if self.connection.send(response).await.is_err() { + return; + } + } + + self.connection + .send(Message::Close { + code: Some(100), + reason: None, + }) + .await + .ok(); + } + + async fn handle_command(&mut self, cmd: ConnectionCommand) -> Option { + match cmd { + ConnectionCommand::Subscribe { + request, + sender, + id, + } => { + assert!(self.operations.insert(id, sender).is_none()); + + Some(Message::Text(request)) + } + ConnectionCommand::Cancel(id) => { + if self.operations.remove(&id).is_some() { + return Some(Message::complete(id)); + } + None + } + ConnectionCommand::Close(code, reason) => Some(Message::Close { + code: Some(code), + reason: Some(reason), + }), + } + } + + async fn handle_message(&mut self, message: Message) -> Option { + let event = match extract_event(message) { + Ok(event) => event?, + Err(Error::Close(code, reason)) => { + return Some(Message::Close { + code: Some(code), + reason: Some(reason), + }) + } + Err(other) => { + return Some(Message::Close { + code: Some(4857), + reason: Some(format!("Error while decoding event: {other}")), + }) + } + }; + + match event { + event @ (Event::Next { .. } | Event::Error { .. }) => { + let id = match event.id().unwrap().parse::().ok() { + Some(id) => id, + None => return Some(Message::close(Reason::UnknownSubscription)), + }; + + let sender = self.operations.entry(id); + + let Entry::Occupied(mut sender) = sender else { + return None; + }; + + let payload = event.forwarding_payload().unwrap(); + + if sender.get_mut().send(payload).await.is_err() { + sender.remove(); + return Some(Message::complete(id)); + } + + None + } + Event::Complete { id } => { + let id = match id.parse::().ok() { + Some(id) => id, + None => return Some(Message::close(Reason::UnknownSubscription)), + }; + + trace!("Stream complete"); + + self.operations.remove(&id); + None + } + Event::ConnectionAck { .. } => Some(Message::close(Reason::UnexpectedAck)), + Event::Ping { .. } => Some(Message::Pong), + Event::Pong { .. } => None, + } + } + + async fn next(&mut self) -> Option { + loop { + if let Some(client) = &mut self.client { + let mut next_command = client.next().fuse(); + let mut next_message = self.connection.receive().fuse(); + futures::select! { + command = next_command => { + let Some(command) = command else { + self.client.take(); + continue; + }; + + return Some(Next::Command(command)); + }, + message = next_message => { + return Some(Next::Message(message?)); + }, + } + } + + if self.operations.is_empty() { + // If client has disconnected and we have no running operations + // then we should shut down + return None; + } + + return Some(Next::Message(self.connection.receive().await?)); + } + } +} + +enum Next { + Command(ConnectionCommand), + Message(Message), +} + +impl IntoFuture for ConnectionActor { + type Output = (); + + type IntoFuture = BoxFuture<'static, ()>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.run()) + } +} + +fn extract_event(message: Message) -> Result, Error> { + match message { + Message::Text(s) => { + trace!("Decoding message: {}", s); + Ok(Some( + serde_json::from_str(&s).map_err(|err| Error::Decode(err.to_string()))?, + )) + } + Message::Close { code, reason } => Err(Error::Close( + code.unwrap_or_default(), + reason.unwrap_or_default(), + )), + Message::Ping | Message::Pong => Ok(None), + } +} + +enum Reason { + UnexpectedAck, + UnknownSubscription, +} + +impl Message { + fn close(reason: Reason) -> Self { + match reason { + Reason::UnexpectedAck => Message::Close { + code: Some(4855), + reason: Some("too many acknowledges".into()), + }, + Reason::UnknownSubscription => Message::Close { + code: Some(4856), + reason: Some("unknown subscription".into()), + }, + } + } +} + +impl Event { + fn forwarding_payload(self) -> Option { + match self { + Event::Next { id, payload } => Some(payload), + Event::Error { id, payload } => Some(json!({"errors": payload})), + _ => None, + } + } +} diff --git a/src/next/builder.rs b/src/next/builder.rs new file mode 100644 index 0000000..1034199 --- /dev/null +++ b/src/next/builder.rs @@ -0,0 +1,115 @@ +use std::collections::HashMap; + +use futures::channel::mpsc; +use serde::Serialize; + +use crate::{logging::trace, protocol::Event, Error}; + +use super::{ + actor::ConnectionActor, + connection::{Connection, Message}, + Client, +}; + +/// A websocket client builder +#[derive(Default)] +pub struct ClientBuilder { + payload: Option, + subscription_buffer_size: Option, +} + +impl ClientBuilder { + /// Constructs an AsyncWebsocketClientBuilder + pub fn new() -> ClientBuilder { + ClientBuilder::default() + } + + /// Add payload to `connection_init` + pub fn payload(self, payload: NewPayload) -> Result + where + NewPayload: Serialize, + { + Ok(ClientBuilder { + payload: Some( + serde_json::to_value(payload) + .map_err(|error| Error::Serializing(error.to_string()))?, + ), + ..self + }) + } + + pub fn subscription_buffer_size(self, new: usize) -> Self { + ClientBuilder { + subscription_buffer_size: Some(new), + ..self + } + } +} + +impl ClientBuilder { + /// Constructs a Client + /// + /// Accepts an already built websocket connection, and returns the connection + /// and a future that must be awaited somewhere - if the future is dropped the + /// connection will also drop. + pub async fn build(self, connection: Conn) -> Result<(Client, ConnectionActor), Error> + where + Conn: Connection + Send + 'static, + { + self.build_impl(Box::new(connection)).await + } + + async fn build_impl( + self, + mut connection: Box, + ) -> Result<(Client, ConnectionActor), Error> { + connection.send(Message::init(self.payload)).await?; + + // wait for ack before entering receiver loop: + loop { + match connection.receive().await { + None => return Err(Error::Unknown("connection dropped".into())), + Some(Message::Close { code, reason }) => { + return Err(Error::Close( + code.unwrap_or_default(), + reason.unwrap_or_default(), + )) + } + Some(Message::Ping) | Some(Message::Pong) => {} + Some(message @ Message::Text(_)) => { + let event = message.deserialize::()?; + match event { + // pings can be sent at any time + Event::Ping { .. } => { + connection.send(Message::graphql_pong()).await?; + } + Event::Pong { .. } => {} + Event::ConnectionAck { .. } => { + // handshake completed, ready to enter main receiver loop + trace!("connection_ack received, handshake completed"); + break; + } + event => { + connection.send(Message::Close { + code: Some(4950), + reason: Some("Unexpected message while waiting for ack".into()), + }); + return Err(Error::Decode(format!( + "expected a connection_ack or ping, got {}", + event.r#type() + ))); + } + } + } + } + } + + let (command_sender, command_receiver) = mpsc::channel(5); + + let actor = ConnectionActor::new(connection, command_receiver); + + let client = Client::new(command_sender, self.subscription_buffer_size.unwrap_or(5)); + + Ok((client, actor)) + } +} diff --git a/src/next/connection.rs b/src/next/connection.rs new file mode 100644 index 0000000..2145424 --- /dev/null +++ b/src/next/connection.rs @@ -0,0 +1,53 @@ +use std::future::Future; + +use serde::Serialize; +use serde_json::json; + +use crate::{protocol, Error}; + +#[async_trait::async_trait] +pub trait Connection { + async fn receive(&mut self) -> Option; + async fn send(&mut self, message: Message) -> Result<(), Error>; +} + +pub enum Message { + Text(String), + Close { + code: Option, + reason: Option, + }, + Ping, + Pong, +} + +impl Message { + pub(crate) fn deserialize(self) -> Result + where + T: serde::de::DeserializeOwned, + { + let Message::Text(text) = self else { + panic!("Don't call deserialize on non-text messages"); + }; + + serde_json::from_str(&text).map_err(|error| Error::Decode(error.to_string())) + } + + pub(crate) fn init(payload: Option) -> Self { + Self::Text( + serde_json::to_string(&crate::protocol::ConnectionInit::new(payload)) + .expect("payload is already serialized so this shouldn't fail"), + ) + } + + pub(crate) fn graphql_pong() -> Self { + Self::Text(serde_json::to_string(&crate::protocol::Message::Pong::<()>).unwrap()) + } + + pub(crate) fn complete(id: usize) -> Self { + Self::Text( + serde_json::to_string(&crate::protocol::Message::Complete::<()> { id: id.to_string() }) + .unwrap(), + ) + } +} diff --git a/src/next/mod.rs b/src/next/mod.rs new file mode 100644 index 0000000..7a38e03 --- /dev/null +++ b/src/next/mod.rs @@ -0,0 +1,108 @@ +#![allow(unused)] // TEMPORARY + +use std::sync::{ + atomic::{AtomicU64, AtomicUsize, Ordering}, + Arc, +}; + +use futures::{ + channel::{mpsc, oneshot}, + SinkExt, StreamExt, +}; +use serde_json::Value; + +use crate::{ + graphql::GraphqlOperation, + protocol::{self}, + Error, +}; + +mod actor; +mod builder; +mod connection; +mod stream; + +pub use self::{ + actor::ConnectionActor, + builder::ClientBuilder, + connection::{Connection, Message}, + stream::SubscriptionStream, +}; + +pub struct Client { + actor: mpsc::Sender, + subscription_buffer_size: usize, + next_id: Arc, +} + +impl Client { + pub(super) fn new( + actor: mpsc::Sender, + subscription_buffer_size: usize, + ) -> Self { + Client { + actor, + subscription_buffer_size, + next_id: Arc::new(AtomicUsize::new(0)), + } + } + + // Starts a streaming operation on this client. + /// + /// Returns a `Stream` of responses. + pub async fn streaming_operation<'a, Operation>( + &mut self, + op: Operation, + ) -> Result, Error> + where + Operation: GraphqlOperation + Unpin + Send + 'static, + { + let (sender, receiver) = mpsc::channel(self.subscription_buffer_size); + + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + + let message = protocol::Message::Subscribe { + id: id.to_string(), + payload: &op, + }; + + let request = serde_json::to_string(&message) + .map_err(|error| Error::Serializing(error.to_string()))?; + + self.actor + .send(ConnectionCommand::Subscribe { + request, + sender, + id, + }) + .await + .map_err(|error| Error::Send(error.to_string()))?; + + Ok(SubscriptionStream:: { + id, + stream: Box::pin(receiver.map(move |response| { + op.decode(response) + .map_err(|err| Error::Decode(err.to_string())) + })), + actor: self.actor.clone(), + }) + } + + pub async fn close(mut self, code: u16, description: impl Into) { + self.actor + .send(ConnectionCommand::Close(code, description.into())) + .await + .ok(); + } +} + +pub(super) enum ConnectionCommand { + Subscribe { + /// The full subscribe request as a JSON encoded string. + request: String, + sender: mpsc::Sender, + id: usize, + }, + Cancel(usize), + Close(u16, String), +} diff --git a/src/next/stream.rs b/src/next/stream.rs new file mode 100644 index 0000000..fdbf3d1 --- /dev/null +++ b/src/next/stream.rs @@ -0,0 +1,47 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{channel::mpsc, SinkExt, Stream}; + +use crate::{graphql::GraphqlOperation, Error}; + +use super::{actor::ConnectionActor, ConnectionCommand}; + +/// A `futures::Stream` for a subscription. +/// +/// Emits an item for each message received by the subscription. +#[pin_project::pin_project] +pub struct SubscriptionStream +where + Operation: GraphqlOperation, +{ + pub(super) id: usize, + pub(super) stream: Pin> + Send>>, + pub(super) actor: mpsc::Sender, +} + +impl SubscriptionStream +where + Operation: GraphqlOperation + Send, +{ + /// Stops the operation by sending a Complete message to the server. + pub async fn stop_operation(mut self) -> Result<(), Error> { + self.actor + .send(ConnectionCommand::Cancel(self.id)) + .await + .map_err(|error| Error::Send(error.to_string())) + } +} + +impl Stream for SubscriptionStream +where + Operation: GraphqlOperation + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.as_mut().poll_next(cx) + } +} diff --git a/src/protocol.rs b/src/protocol.rs index 1e64aec..4b47f1d 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -38,7 +38,6 @@ pub enum Message<'a, Operation> { #[serde(rename = "subscribe")] Subscribe { id: String, payload: &'a Operation }, #[serde(rename = "complete")] - #[allow(dead_code)] Complete { id: String }, #[serde(rename = "pong")] Pong, diff --git a/tests/subscription_server/mod.rs b/tests/subscription_server/mod.rs index 065336c..1047948 100644 --- a/tests/subscription_server/mod.rs +++ b/tests/subscription_server/mod.rs @@ -4,7 +4,7 @@ use async_graphql::{EmptyMutation, Object, Schema, SimpleObject, Subscription, ID}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse, GraphQLSubscription}; use axum::{extract::Extension, routing::post, Router, Server}; -use futures_util::{Stream, StreamExt}; +use futures::{Stream, StreamExt}; use tokio::sync::broadcast::Sender; use tokio_stream::wrappers::BroadcastStream;