diff --git a/crates/duckdb/src/vtab/function.rs b/crates/duckdb/src/vtab/function.rs index a1ce75aa..722c4513 100644 --- a/crates/duckdb/src/vtab/function.rs +++ b/crates/duckdb/src/vtab/function.rs @@ -7,7 +7,7 @@ use super::{ duckdb_table_function_add_named_parameter, duckdb_table_function_add_parameter, duckdb_table_function_init_t, duckdb_table_function_set_bind, duckdb_table_function_set_extra_info, duckdb_table_function_set_function, duckdb_table_function_set_init, duckdb_table_function_set_local_init, duckdb_table_function_set_name, - duckdb_table_function_supports_projection_pushdown, idx_t, + duckdb_table_function_supports_projection_pushdown, duckdb_vector, idx_t, }, LogicalType, Value, }; @@ -385,3 +385,97 @@ impl From for FunctionInfo { Self(ptr) } } + +use super::ffi::{ + duckdb_create_scalar_function, duckdb_destroy_scalar_function, duckdb_scalar_function, + duckdb_scalar_function_add_parameter, duckdb_scalar_function_set_extra_info, duckdb_scalar_function_set_function, + duckdb_scalar_function_set_name, duckdb_scalar_function_set_return_type, +}; + +/// A function that returns a queryable scalar function +#[derive(Debug)] +pub struct ScalarFunction { + pub(crate) ptr: duckdb_scalar_function, +} + +impl Drop for ScalarFunction { + fn drop(&mut self) { + unsafe { + duckdb_destroy_scalar_function(&mut self.ptr); + } + } +} + +impl ScalarFunction { + /// Adds a parameter to the scalar function. + /// + /// # Arguments + /// * `logical_type`: The type of the parameter to add. + pub fn add_parameter(&self, logical_type: &LogicalType) -> &Self { + unsafe { + duckdb_scalar_function_add_parameter(self.ptr, logical_type.ptr); + } + self + } + + /// Sets the return type of the scalar function. + /// + /// # Arguments + /// * `logical_type`: The return type of the scalar function. + pub fn set_return_type(&self, logical_type: &LogicalType) -> &Self { + unsafe { + duckdb_scalar_function_set_return_type(self.ptr, logical_type.ptr); + } + self + } + + /// Sets the main function of the scalar function + /// + /// # Arguments + /// * `function`: The function + pub fn set_function( + &self, + func: Option, + ) -> &Self { + unsafe { + duckdb_scalar_function_set_function(self.ptr, func); + } + self + } + + /// Creates a new empty scalar function. + pub fn new() -> Self { + Self { + ptr: unsafe { duckdb_create_scalar_function() }, + } + } + + /// Sets the name of the given scalar function. + /// + /// # Arguments + /// * `name`: The name of the scalar function + pub fn set_name(&self, name: &str) -> &ScalarFunction { + unsafe { + let string = CString::from_vec_unchecked(name.as_bytes().into()); + duckdb_scalar_function_set_name(self.ptr, string.as_ptr()); + } + self + } + + /// Assigns extra information to the scalar function that can be fetched during binding, etc. + /// + /// # Arguments + /// * `extra_info`: The extra information + /// * `destroy`: The callback that will be called to destroy the bind data (if any) + /// + /// # Safety + pub unsafe fn set_extra_info(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) { + duckdb_scalar_function_set_extra_info(self.ptr, extra_info, destroy); + } +} + +impl Default for ScalarFunction { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/duckdb/src/vtab/mod.rs b/crates/duckdb/src/vtab/mod.rs index 40717819..8141e538 100644 --- a/crates/duckdb/src/vtab/mod.rs +++ b/crates/duckdb/src/vtab/mod.rs @@ -21,12 +21,12 @@ pub use self::arrow::{ mod excel; pub use data_chunk::DataChunk; -pub use function::{BindInfo, FunctionInfo, InitInfo, TableFunction}; +pub use function::{BindInfo, FunctionInfo, InitInfo, ScalarFunction, TableFunction}; pub use logical_type::{LogicalType, LogicalTypeId}; pub use value::Value; pub use vector::{FlatVector, Inserter, ListVector, StructVector, Vector}; -use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info}; +use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info, duckdb_vector}; use ffi::duckdb_malloc; use std::mem::size_of; @@ -161,6 +161,48 @@ where } } +/// Duckdb scalar function trait +/// +pub trait VScalar: Sized { + /// The actual function + /// + /// # Safety + /// + /// This function is unsafe because it: + /// + /// - Dereferences multiple raw pointers (`func``). + /// + unsafe fn func( + func: &FunctionInfo, + input: &mut DataChunk, + output: &mut FlatVector, + ) -> Result<(), Box>; + /// The parameters of the table function + /// default is None + fn parameters() -> Option> { + None + } + + /// The return type of the scalar function + /// default is None + fn return_type() -> LogicalType { + panic!("return_type not implemented") + } +} + +unsafe extern "C" fn scalar_func(info: duckdb_function_info, input: duckdb_data_chunk, output: duckdb_vector) +where + T: VScalar, +{ + let info = FunctionInfo::from(info); + let mut input = DataChunk::from(input); + let mut output_vector = FlatVector::from(output); + let result = T::func(&info, &mut input, &mut output_vector); + if result.is_err() { + info.set_error(&result.err().unwrap().to_string()); + } +} + impl Connection { /// Register the given TableFunction with the current db #[inline] @@ -180,6 +222,21 @@ impl Connection { } self.db.borrow_mut().register_table_function(table_function) } + + /// Register the given ScalarFunction with the current db + #[inline] + pub fn register_scalar_function(&self, name: &str) -> Result<()> { + let scalar_function = ScalarFunction::default(); + scalar_function + .set_name(name) + .set_return_type(&S::return_type()) + //.set_extra_info() + .set_function(Some(scalar_func::)); + for ty in S::parameters().unwrap_or_default() { + scalar_function.add_parameter(&ty); + } + self.db.borrow_mut().register_scalar_function(scalar_function) + } } impl InnerConnection { @@ -193,6 +250,17 @@ impl InnerConnection { } Ok(()) } + + /// Register the given ScalarFunction with the current db + pub fn register_scalar_function(&mut self, scalar_function: ScalarFunction) -> Result<()> { + unsafe { + let rc = ffi::duckdb_register_scalar_function(self.con, scalar_function.ptr); + if rc != ffi::DuckDBSuccess { + return Err(Error::DuckDBFailure(ffi::Error::new(rc), None)); + } + } + Ok(()) + } } #[cfg(test)] diff --git a/crates/libduckdb-sys/src/lib.rs b/crates/libduckdb-sys/src/lib.rs index ae57cadf..8bc50e12 100644 --- a/crates/libduckdb-sys/src/lib.rs +++ b/crates/libduckdb-sys/src/lib.rs @@ -14,6 +14,24 @@ pub use bindings::*; pub const DuckDBError: duckdb_state = duckdb_state_DuckDBError; pub const DuckDBSuccess: duckdb_state = duckdb_state_DuckDBSuccess; +use std::slice; +impl From<&duckdb_string_t> for String { + fn from(source: &duckdb_string_t) -> String { + unsafe { + let s = if source.value.inlined.length <= 12 { + source.value.inlined.inlined.as_ptr() + } else { + source.value.pointer.ptr + }; + return std::str::from_utf8_unchecked(slice::from_raw_parts( + s as *const u8, + source.value.inlined.length as usize, + )) + .to_string(); + } + } +} + pub use self::error::*; mod error;