Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streaming arrow data support #373

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion crates/duckdb/src/arrow_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::{
Statement,
};

/// An handle for the resulting RecordBatch of a query.
/// A handle for the resulting RecordBatch of a query.
#[must_use = "Arrow is lazy and will do nothing unless consumed"]
pub struct Arrow<'stmt> {
pub(crate) stmt: Option<&'stmt Statement<'stmt>>,
Expand All @@ -29,3 +29,34 @@ impl<'stmt> Iterator for Arrow<'stmt> {
Some(RecordBatch::from(&self.stmt?.step()?))
}
}

/// A handle for the resulting RecordBatch of a query in streaming
#[must_use = "Arrow stream is lazy and will not fetch data unless consumed"]
pub struct ArrowStream<'stmt> {
pub(crate) stmt: Option<&'stmt Statement<'stmt>>,
pub(crate) schema: SchemaRef,
}

impl<'stmt> ArrowStream<'stmt> {
#[inline]
pub(crate) fn new(stmt: &'stmt Statement<'stmt>, schema: SchemaRef) -> ArrowStream<'stmt> {
ArrowStream {
stmt: Some(stmt),
schema,
}
}

/// return arrow schema
#[inline]
pub fn get_schema(&self) -> SchemaRef {
self.schema.clone()
}
}

impl<'stmt> Iterator for ArrowStream<'stmt> {
type Item = RecordBatch;

fn next(&mut self) -> Option<Self::Item> {
Some(RecordBatch::from(&self.stmt?.stream_step(self.get_schema())?))
}
}
2 changes: 1 addition & 1 deletion crates/duckdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub use crate::r2d2::DuckdbConnectionManager;
pub use crate::{
appender::Appender,
appender_params::{appender_params_from_iter, AppenderParams, AppenderParamsFromIter},
arrow_batch::Arrow,
arrow_batch::{Arrow, ArrowStream},
cache::CachedStatement,
column::Column,
config::{AccessMode, Config, DefaultNullOrder, DefaultOrder},
Expand Down
61 changes: 59 additions & 2 deletions crates/duckdb/src/raw_statement.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{ffi::CStr, ptr, rc::Rc, sync::Arc};
use std::{ffi::CStr, ops::Deref, ptr, rc::Rc, sync::Arc};

use arrow::{
array::StructArray,
Expand All @@ -9,14 +9,15 @@ use arrow::{
use super::{ffi, Result};
#[cfg(feature = "polars")]
use crate::arrow2;
use crate::error::result_from_duckdb_arrow;
use crate::{error::result_from_duckdb_arrow, Error};

// Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
// TODO: destroy statement and result
#[derive(Debug)]
pub struct RawStatement {
ptr: ffi::duckdb_prepared_statement,
result: Option<ffi::duckdb_arrow>,
duckdb_result: Option<ffi::duckdb_result>,
schema: Option<SchemaRef>,
// Cached SQL (trimmed) that we use as the key when we're in the statement
// cache. This is None for statements which didn't come from the statement
Expand All @@ -38,6 +39,7 @@ impl RawStatement {
ptr: stmt,
result: None,
schema: None,
duckdb_result: None,
statement_cache_key: None,
}
}
Expand Down Expand Up @@ -110,6 +112,39 @@ impl RawStatement {
}
}

#[inline]
pub fn streaming_step(&self, schema: SchemaRef) -> Option<StructArray> {
if let Some(result) = self.duckdb_result {
unsafe {
let mut out = ffi::duckdb_stream_fetch_chunk(result);

if out.is_null() {
return None;
}

let mut arrays = FFI_ArrowArray::empty();
ffi::duckdb_result_arrow_array(
result,
out,
&mut std::ptr::addr_of_mut!(arrays) as *mut _ as *mut ffi::duckdb_arrow_array,
);

ffi::duckdb_destroy_data_chunk(&mut out);

if arrays.is_empty() {
return None;
}

let schema = FFI_ArrowSchema::try_from(schema.deref()).ok()?;
let array_data = from_ffi(arrays, &schema).expect("ok");
let struct_array = StructArray::from(array_data);
return Some(struct_array);
}
}

None
}

#[cfg(feature = "polars")]
#[inline]
pub fn step2(&self) -> Option<arrow2::array::StructArray> {
Expand Down Expand Up @@ -242,6 +277,22 @@ impl RawStatement {
}
}

pub fn execute_streaming(&mut self) -> Result<()> {
self.reset_result();
unsafe {
let mut out: ffi::duckdb_result = std::mem::zeroed();

let rc = ffi::duckdb_execute_prepared_streaming(self.ptr, &mut out);
if rc != ffi::DuckDBSuccess {
return Err(Error::DuckDBFailure(ffi::Error::new(rc), None));
}

self.duckdb_result = Some(out);

Ok(())
}
}

#[inline]
pub fn reset_result(&mut self) {
self.schema = None;
Expand All @@ -251,6 +302,12 @@ impl RawStatement {
}
self.result = None;
}
if let Some(mut result) = self.duckdb_result {
unsafe {
ffi::duckdb_destroy_result(&mut result);
}
self.duckdb_result = None;
}
}

#[inline]
Expand Down
32 changes: 31 additions & 1 deletion crates/duckdb/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{ffi, AndThenRows, Connection, Error, MappedRows, Params, RawStatemen
#[cfg(feature = "polars")]
use crate::{arrow2, polars_dataframe::Polars};
use crate::{
arrow_batch::Arrow,
arrow_batch::{Arrow, ArrowStream},
error::result_from_duckdb_prepare,
types::{TimeUnit, ToSql, ToSqlOutput},
};
Expand Down Expand Up @@ -109,6 +109,30 @@ impl Statement<'_> {
Ok(Arrow::new(self))
}

/// Execute the prepared statement, returning a handle to the resulting
/// vector of arrow RecordBatch in streaming way
///
/// ## Example
///
/// ```rust,no_run
/// # use duckdb::{Result, Connection};
/// # use arrow::record_batch::RecordBatch;
/// # use arrow::datatypes::SchemaRef;
/// fn get_arrow_data(conn: &Connection, schema: SchemaRef) -> Result<Vec<RecordBatch>> {
/// Ok(conn.prepare("SELECT * FROM test")?.stream_arrow([], schema)?.collect())
/// }
/// ```
///
/// # Failure
///
/// Will return `Err` if binding parameters fails.
#[inline]
pub fn stream_arrow<P: Params>(&mut self, params: P, schema: SchemaRef) -> Result<ArrowStream<'_>> {
params.__bind_in(self)?;
self.stmt.execute_streaming()?;
Ok(ArrowStream::new(self, schema))
}

/// Execute the prepared statement, returning a handle to the resulting
/// vector of polars DataFrame.
///
Expand Down Expand Up @@ -337,6 +361,12 @@ impl Statement<'_> {
self.stmt.step()
}

/// Get next batch records in arrow-rs in a streaming way
#[inline]
pub fn stream_step(&self, schema: SchemaRef) -> Option<StructArray> {
self.stmt.streaming_step(schema)
}

#[cfg(feature = "polars")]
/// Get next batch records in arrow2
#[inline]
Expand Down
Loading