Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fallible stream for arrow-flight do_exchange call (#3462) #5698

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// .build(futures::stream::iter(vec![Ok(batch)]))
/// // data encoder return Results, but do_exchange requires FlightData
/// .map(|batch|batch.unwrap());
/// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `RecordBatches`
/// let response: Vec<RecordBatch> = client
Expand All @@ -431,20 +429,40 @@ impl FlightClient {
/// .expect("error calling do_exchange");
/// # }
/// ```
pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let request = self.make_request(request);
let (sender, mut receiver) = futures::channel::oneshot::channel();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is some way to avoid the repetition here with the code in do_get above 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did wonder about duplication of the code as in do_put method, making all errors (both in request stream and response stream) go back to client. Since it's my first MR here I didn't want to take it further. But happy to explore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could do it as a follow on PR 🤔

I think this PR is ready to merge as long as we can accept API changes in this release (we are working to slow down the rate of breaking API changes - see #5368 )


let response = self
.inner
.do_exchange(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_exchange(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});

Ok(FlightRecordBatchStream::new_from_flight_data(response))
// combine the response from the server and any error from the client
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
}

/// Make a `ListFlights` call to the server with the provided
Expand Down
97 changes: 94 additions & 3 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ async fn test_do_exchange() {
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());

let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");

Expand Down Expand Up @@ -528,7 +528,7 @@ async fn test_do_exchange_error() {
let input_flight_data = test_flight_data().await;

let response = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Expand Down Expand Up @@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() {
test_server.set_do_exchange_response(response);

let response_stream = client
.do_exchange(futures::stream::iter(input_flight_data.clone()))
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
.await
.expect("error making request");

Expand All @@ -593,6 +593,97 @@ async fn test_do_exchange_error_stream() {
.await;
}

#[tokio::test]
async fn test_do_exchange_error_stream_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e = Status::invalid_argument("bad arg: client");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));

let output_flight_data = FlightData::new()
.with_descriptor(FlightDescriptor::new_cmd("Sample command"))
.with_data_body("body".as_bytes())
.with_data_header("header".as_bytes())
.with_app_metadata("metadata".as_bytes());

// server responds with one good message
let response = vec![Ok(output_flight_data)];
test_server.set_do_exchange_response(response);

let response_stream = client
.do_exchange(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client
expect_status(response, e);
// server still got the request messages until the client sent the error
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_exchange_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));

// server responds with an error (e.g. because it got truncated data)
let response = vec![Err(e_server)];
test_server.set_do_exchange_response(response);

let response_stream = client
.do_exchange(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client (not the server)
expect_status(response, e_client);
// server still got the request messages until the client sent the error
assert_eq!(
test_server.take_do_exchange_request(),
Some(input_flight_data)
);
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_get_schema() {
do_test(|test_server, mut client| async move {
Expand Down
Loading