Skip to content

Commit

Permalink
refactor: combined the various registry objects into a single extensi…
Browse files Browse the repository at this point in the history
…ons registry
  • Loading branch information
westonpace committed Jan 13, 2024
1 parent b77f92e commit 7161c14
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 317 deletions.
14 changes: 5 additions & 9 deletions substrait-expr/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ use crate::error::{Result, SubstraitExprError};
use crate::helpers::expr::ExpressionExt;
use crate::helpers::schema::SchemaInfo;
use crate::helpers::types::TypeExt;
use crate::helpers::UriRegistry;

use self::functions::{FunctionsBuilder, FunctionsRegistry};
use self::functions::FunctionsBuilder;
use self::schema::RefBuilder;

pub mod functions;
pub mod schema;
pub mod types;

pub struct BuilderParams {
pub allow_late_name_lookup: bool,
Expand Down Expand Up @@ -260,7 +260,6 @@ impl NamedExpression {
/// ExtendedExpression, which holds a collection of expressions. If you only need to serialize
/// a single expression then you can create an ExtendedExpression that contains a single expression.
pub struct ExpressionsBuilder {
functions: FunctionsRegistry,
schema: SchemaInfo,
params: BuilderParams,
expressions: RefCell<Vec<NamedExpression>>,
Expand Down Expand Up @@ -291,7 +290,6 @@ impl IntoExprOutputNames for Vec<String> {
impl ExpressionsBuilder {
pub fn new(schema: SchemaInfo, params: BuilderParams) -> Self {
Self {
functions: FunctionsRegistry::new(),
schema,
params,
expressions: RefCell::new(Vec::new()),
Expand All @@ -303,7 +301,7 @@ impl ExpressionsBuilder {
}

pub fn functions(&self) -> FunctionsBuilder {
FunctionsBuilder::new(&self.functions, &self.schema)
FunctionsBuilder::new(&self.schema)
}

pub fn add_expression(
Expand All @@ -321,9 +319,7 @@ impl ExpressionsBuilder {
}

pub fn build(self) -> ExtendedExpression {
let mut uris = UriRegistry::new();
let mut extensions = Vec::new();
self.functions.add_to_extensions(&mut uris, &mut extensions);
let (extension_uris, extensions) = self.schema.extensions_registry().to_substrait();
let referred_expr = self
.expressions
.into_inner()
Expand All @@ -335,7 +331,7 @@ impl ExpressionsBuilder {
.collect::<Vec<_>>();
ExtendedExpression {
version: Some(substrait::version::version_with_producer("substrait-expr")),
extension_uris: uris.to_substrait(),
extension_uris,
extensions,
advanced_extensions: None,
expected_type_urls: Vec::new(),
Expand Down
123 changes: 25 additions & 98 deletions substrait-expr/src/builder/functions.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
use std::{collections::BTreeMap, sync::RwLock};
use std::collections::BTreeMap;

use substrait::proto::{
expression::{RexType, ScalarFunction},
extensions::{
simple_extension_declaration::{ExtensionFunction, MappingType},
SimpleExtensionDeclaration,
},
function_argument::ArgType,
Expression, FunctionArgument, FunctionOption, Type,
};

use crate::{
error::{Result, SubstraitExprError},
helpers::{
registry::ExtensionsRegistry,
schema::SchemaInfo,
types::{self, unknown, TypeExt},
UriRegistry,
types::{self, TypeExt},
},
};

Expand Down Expand Up @@ -67,8 +63,8 @@ impl ImplementationArg {
/// Returns true if an expression of the given type could be used as this argument
///
/// There is no "enum" type so enum arguments will only recognize the string type
pub fn matches(&self, arg_type: &Type) -> Result<bool> {
if arg_type.is_unknown() {
pub fn matches(&self, arg_type: &Type, registry: &ExtensionsRegistry) -> Result<bool> {
if arg_type.is_unknown(registry) {
Ok(true)
} else {
match &self.arg_type {
Expand All @@ -90,18 +86,22 @@ pub struct FunctionImplementation {

impl FunctionImplementation {
/// Returns true if expressions with types specified by `arg_types` would match this implementation
pub fn matches(&self, arg_types: &[Type]) -> bool {
pub fn matches(&self, arg_types: &[Type], registry: &ExtensionsRegistry) -> bool {
if arg_types.len() != self.args.len() {
false
} else {
self.args
.iter()
.zip(arg_types)
.all(|(imp_arg, arg_type)| imp_arg.matches(arg_type).unwrap_or(false))
.all(|(imp_arg, arg_type)| imp_arg.matches(arg_type, registry).unwrap_or(false))
}
}

fn relax(&self, types: Vec<Type>) -> Result<FunctionImplementation> {
fn relax(
&self,
types: Vec<Type>,
registry: &ExtensionsRegistry,
) -> Result<FunctionImplementation> {
if self.args.len() != types.len() {
Err(SubstraitExprError::InvalidInput(format!(
"Attempt to relax implementation with {} args using {} types",
Expand All @@ -114,7 +114,7 @@ impl FunctionImplementation {
.iter()
.zip(types.iter())
.map(|(arg, typ)| {
if typ.is_unknown() {
if typ.is_unknown(registry) {
ImplementationArg {
name: arg.name.clone(),
arg_type: ImplementationArgType::Value(typ.clone()),
Expand All @@ -124,9 +124,9 @@ impl FunctionImplementation {
}
})
.collect::<Vec<_>>();
let has_unknown = types.iter().any(|typ| typ.is_unknown());
let has_unknown = types.iter().any(|typ| typ.is_unknown(registry));
let output_type = if has_unknown {
types::unknown()
super::types::unknown(registry)
} else {
self.output_type.clone()
};
Expand All @@ -153,14 +153,15 @@ impl FunctionDefinition {
args: &[Expression],
schema: &SchemaInfo,
) -> Result<Option<FunctionImplementation>> {
let registry = schema.extensions_registry();
let types = args
.iter()
.map(|arg| arg.output_type(schema))
.collect::<Result<Vec<_>>>()?;
self.implementations
.iter()
.find(|imp| imp.matches(&types))
.map(|imp| imp.relax(types))
.find(|imp| imp.matches(&types, registry))
.map(|imp| imp.relax(types, registry))
.transpose()
}
}
Expand All @@ -176,13 +177,12 @@ pub const LOOKUP_BY_NAME_FUNC_NAME: &'static str = "lookup_by_name";

/// A builder that can create scalar function expressions
pub struct FunctionsBuilder<'a> {
registry: &'a FunctionsRegistry,
schema: &'a SchemaInfo,
}

impl<'a> FunctionsBuilder<'a> {
pub(crate) fn new(registry: &'a FunctionsRegistry, schema: &'a SchemaInfo) -> Self {
Self { registry, schema }
pub(crate) fn new(schema: &'a SchemaInfo) -> Self {
Self { schema }
}

/// Creates a new [FunctionBuilder] based on a given function definition.
Expand All @@ -197,7 +197,7 @@ impl<'a> FunctionsBuilder<'a> {
func: &'static FunctionDefinition,
args: Vec<Expression>,
) -> FunctionBuilder {
let func_reference = self.registry.register(func);
let func_reference = self.schema.extensions_registry().register_function(func);
FunctionBuilder {
func: func,
func_reference,
Expand All @@ -217,15 +217,15 @@ impl<'a> FunctionsBuilder<'a> {
let arg = FunctionArgument {
arg_type: Some(ArgType::Enum(name.into())),
};
let function_reference = self
.registry
.register_by_name(LOOKUP_BY_NAME_FUNC_URI, LOOKUP_BY_NAME_FUNC_NAME);
let registry = self.schema.extensions_registry();
let function_reference =
registry.register_function_by_name(LOOKUP_BY_NAME_FUNC_URI, LOOKUP_BY_NAME_FUNC_NAME);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
arguments: vec![arg],
function_reference,
// TODO: Use the proper unknown type
output_type: Some(unknown()),
output_type: Some(super::types::unknown(registry)),
options: vec![],
..Default::default()
})),
Expand Down Expand Up @@ -299,76 +299,3 @@ impl<'a> FunctionBuilder<'a> {
})
}
}

pub struct FunctionsRegistryRecord {
uri: String,
name: String,
anchor: u32,
}

struct FunctionRegistryInternal {
function_map: BTreeMap<String, FunctionsRegistryRecord>,
counter: u32,
}

impl FunctionRegistryInternal {
fn register(&mut self, uri: &str, name: &str) -> u32 {
let key = uri.to_string() + name;
let entry = self.function_map.entry(key);
entry
.or_insert_with(|| {
let counter = self.counter;
self.counter += 1;
FunctionsRegistryRecord {
uri: uri.to_string(),
name: name.to_string(),
anchor: counter,
}
})
.anchor
}
}

pub struct FunctionsRegistry {
internal: RwLock<FunctionRegistryInternal>,
}

impl FunctionsRegistry {
pub fn new() -> Self {
Self {
internal: RwLock::new(FunctionRegistryInternal {
function_map: BTreeMap::new(),
counter: 1,
}),
}
}

pub fn register(&self, function: &FunctionDefinition) -> u32 {
let mut internal = self.internal.write().unwrap();
internal.register(&function.uri, &function.name)
}

pub(crate) fn register_by_name(&self, uri: &str, name: &str) -> u32 {
let mut internal = self.internal.write().unwrap();
internal.register(uri, name)
}

pub(crate) fn add_to_extensions(
&self,
uris: &mut UriRegistry,
extensions: &mut Vec<SimpleExtensionDeclaration>,
) {
let internal = self.internal.read().unwrap();
for record in internal.function_map.values() {
let uri_ref = uris.register(record.uri.clone());
let declaration = SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
extension_uri_reference: uri_ref,
function_anchor: record.anchor,
name: record.name.clone(),
})),
};
extensions.push(declaration);
}
}
}
Loading

0 comments on commit 7161c14

Please sign in to comment.