Skip to content

Commit

Permalink
feat: add bindings for scalar functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rustyconover committed Jun 23, 2024
1 parent 693f1b7 commit 004c7c5
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 3 deletions.
96 changes: 95 additions & 1 deletion crates/duckdb/src/vtab/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{
duckdb_table_function_add_named_parameter, duckdb_table_function_add_parameter, duckdb_table_function_init_t,
duckdb_table_function_set_bind, duckdb_table_function_set_extra_info, duckdb_table_function_set_function,
duckdb_table_function_set_init, duckdb_table_function_set_local_init, duckdb_table_function_set_name,
duckdb_table_function_supports_projection_pushdown, idx_t,
duckdb_table_function_supports_projection_pushdown, duckdb_vector, idx_t,
},
LogicalType, Value,
};
Expand Down Expand Up @@ -385,3 +385,97 @@ impl From<duckdb_function_info> for FunctionInfo {
Self(ptr)
}
}

use super::ffi::{
duckdb_create_scalar_function, duckdb_destroy_scalar_function, duckdb_scalar_function,

Check failure on line 390 in crates/duckdb/src/vtab/function.rs

View workflow job for this annotation

GitHub Actions / Test x86_64-pc-windows-msvc

unresolved imports `super::ffi::duckdb_create_scalar_function`, `super::ffi::duckdb_destroy_scalar_function`, `super::ffi::duckdb_scalar_function`, `super::ffi::duckdb_scalar_function_add_parameter`, `super::ffi::duckdb_scalar_function_set_extra_info`, `super::ffi::duckdb_scalar_function_set_function`, `super::ffi::duckdb_scalar_function_set_name`, `super::ffi::duckdb_scalar_function_set_return_type`
duckdb_scalar_function_add_parameter, duckdb_scalar_function_set_extra_info, duckdb_scalar_function_set_function,
duckdb_scalar_function_set_name, duckdb_scalar_function_set_return_type,
};

/// A function that returns a queryable scalar function
#[derive(Debug)]
pub struct ScalarFunction {
pub(crate) ptr: duckdb_scalar_function,
}

impl Drop for ScalarFunction {
fn drop(&mut self) {
unsafe {
duckdb_destroy_scalar_function(&mut self.ptr);
}
}
}

impl ScalarFunction {
/// Adds a parameter to the scalar function.
///
/// # Arguments
/// * `logical_type`: The type of the parameter to add.
pub fn add_parameter(&self, logical_type: &LogicalType) -> &Self {
unsafe {
duckdb_scalar_function_add_parameter(self.ptr, logical_type.ptr);
}
self
}

/// Sets the return type of the scalar function.
///
/// # Arguments
/// * `logical_type`: The return type of the scalar function.
pub fn set_return_type(&self, logical_type: &LogicalType) -> &Self {
unsafe {
duckdb_scalar_function_set_return_type(self.ptr, logical_type.ptr);
}
self
}

/// Sets the main function of the scalar function
///
/// # Arguments
/// * `function`: The function
pub fn set_function(
&self,
func: Option<unsafe extern "C" fn(info: duckdb_function_info, input: duckdb_data_chunk, output: duckdb_vector)>,
) -> &Self {
unsafe {
duckdb_scalar_function_set_function(self.ptr, func);
}
self
}

/// Creates a new empty scalar function.
pub fn new() -> Self {
Self {
ptr: unsafe { duckdb_create_scalar_function() },
}
}

/// Sets the name of the given scalar function.
///
/// # Arguments
/// * `name`: The name of the scalar function
pub fn set_name(&self, name: &str) -> &ScalarFunction {
unsafe {
let string = CString::from_vec_unchecked(name.as_bytes().into());
duckdb_scalar_function_set_name(self.ptr, string.as_ptr());
}
self
}

/// Assigns extra information to the scalar function that can be fetched during binding, etc.
///
/// # Arguments
/// * `extra_info`: The extra information
/// * `destroy`: The callback that will be called to destroy the bind data (if any)
///
/// # Safety
pub unsafe fn set_extra_info(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) {
duckdb_scalar_function_set_extra_info(self.ptr, extra_info, destroy);
}
}

impl Default for ScalarFunction {
fn default() -> Self {
Self::new()
}
}
72 changes: 70 additions & 2 deletions crates/duckdb/src/vtab/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ pub use self::arrow::{
mod excel;

pub use data_chunk::DataChunk;
pub use function::{BindInfo, FunctionInfo, InitInfo, TableFunction};
pub use function::{BindInfo, FunctionInfo, InitInfo, ScalarFunction, TableFunction};
pub use logical_type::{LogicalType, LogicalTypeId};
pub use value::Value;
pub use vector::{FlatVector, Inserter, ListVector, StructVector, Vector};

use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info};
use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info, duckdb_vector};

use ffi::duckdb_malloc;
use std::mem::size_of;
Expand Down Expand Up @@ -161,6 +161,48 @@ where
}
}

/// Duckdb scalar function trait
///
pub trait VScalar: Sized {
/// The actual function
///
/// # Safety
///
/// This function is unsafe because it:
///
/// - Dereferences multiple raw pointers (`func``).
///
unsafe fn func(
func: &FunctionInfo,
input: &mut DataChunk,
output: &mut FlatVector,
) -> Result<(), Box<dyn std::error::Error>>;
/// The parameters of the table function
/// default is None
fn parameters() -> Option<Vec<LogicalType>> {
None
}

/// The return type of the scalar function
/// default is None
fn return_type() -> LogicalType {
panic!("return_type not implemented")
}
}

unsafe extern "C" fn scalar_func<T>(info: duckdb_function_info, input: duckdb_data_chunk, output: duckdb_vector)
where
T: VScalar,
{
let info = FunctionInfo::from(info);
let mut input = DataChunk::from(input);
let mut output_vector = FlatVector::from(output);
let result = T::func(&info, &mut input, &mut output_vector);
if result.is_err() {
info.set_error(&result.err().unwrap().to_string());
}
}

impl Connection {
/// Register the given TableFunction with the current db
#[inline]
Expand All @@ -180,6 +222,21 @@ impl Connection {
}
self.db.borrow_mut().register_table_function(table_function)
}

/// Register the given ScalarFunction with the current db
#[inline]
pub fn register_scalar_function<S: VScalar>(&self, name: &str) -> Result<()> {
let scalar_function = ScalarFunction::default();
scalar_function
.set_name(name)
.set_return_type(&S::return_type())
//.set_extra_info()
.set_function(Some(scalar_func::<S>));
for ty in S::parameters().unwrap_or_default() {
scalar_function.add_parameter(&ty);
}
self.db.borrow_mut().register_scalar_function(scalar_function)
}
}

impl InnerConnection {
Expand All @@ -193,6 +250,17 @@ impl InnerConnection {
}
Ok(())
}

/// Register the given ScalarFunction with the current db
pub fn register_scalar_function(&mut self, scalar_function: ScalarFunction) -> Result<()> {
unsafe {
let rc = ffi::duckdb_register_scalar_function(self.con, scalar_function.ptr);

Check failure on line 257 in crates/duckdb/src/vtab/mod.rs

View workflow job for this annotation

GitHub Actions / Test x86_64-pc-windows-msvc

cannot find function `duckdb_register_scalar_function` in crate `ffi`
if rc != ffi::DuckDBSuccess {
return Err(Error::DuckDBFailure(ffi::Error::new(rc), None));
}
}
Ok(())
}
}

#[cfg(test)]
Expand Down
18 changes: 18 additions & 0 deletions crates/libduckdb-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@ pub use bindings::*;
pub const DuckDBError: duckdb_state = duckdb_state_DuckDBError;
pub const DuckDBSuccess: duckdb_state = duckdb_state_DuckDBSuccess;

use std::slice;
impl From<&duckdb_string_t> for String {
fn from(source: &duckdb_string_t) -> String {
unsafe {
let s = if source.value.inlined.length <= 12 {
source.value.inlined.inlined.as_ptr()
} else {
source.value.pointer.ptr
};
return std::str::from_utf8_unchecked(slice::from_raw_parts(
s as *const u8,
source.value.inlined.length as usize,
))
.to_string();
}
}
}

pub use self::error::*;
mod error;

Expand Down

0 comments on commit 004c7c5

Please sign in to comment.