Skip to content

Commit

Permalink
feat(query): Support access Mysql data from dictionaries via the `dic…
Browse files Browse the repository at this point in the history
…t_get` function. (#16444)

* feat: access data from mysql via dict_get

* fix

* update: operator & transform

* update: dict_get mysql

* update: cargo & transform

* update : add Date & Timestamp

* update : Date & Timestamp.

* fix: cancel date & timestamp feat: test.

* fix

* fix.

* update:test & mysql_source.

* fix: transform dictionary.

* fix: cargo

* fix: cargo.lock
  • Loading branch information
Winnie-Hong0927 committed Sep 13, 2024
1 parent ab5af93 commit b4e0a2b
Show file tree
Hide file tree
Showing 12 changed files with 978 additions and 71 deletions.
419 changes: 412 additions & 7 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/common/exception/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ prost = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = "0.8"
tantivy = { workspace = true }
thiserror = { workspace = true }
tonic = { workspace = true }
Expand Down
6 changes: 6 additions & 0 deletions src/common/exception/src/exception_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,9 @@ impl From<ErrorCode> for tonic::Status {
}
}
}

impl From<sqlx::Error> for ErrorCode {
fn from(error: sqlx::Error) -> Self {
ErrorCode::DictionarySourceError(format!("Dictionary Sqlx Error, cause: {}", error))
}
}
1 change: 1 addition & 0 deletions src/query/service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ serde_stacker = { workspace = true }
serde_urlencoded = "0.7.1"
sha2 = { workspace = true }
socket2 = "0.5.3"
sqlx = { version = "0.8", features = ["mysql", "runtime-tokio"] }
strength_reduce = "0.2.4"
sysinfo = "0.30"
tempfile = "3.4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,24 @@ use databend_common_meta_app::schema::GetSequenceNextValueReq;
use databend_common_meta_app::schema::SequenceIdent;
use databend_common_pipeline_transforms::processors::AsyncTransform;
use databend_common_storages_fuse::TableContext;
use opendal::Operator;

use crate::pipelines::processors::transforms::transform_dictionary::DictionaryOperator;
use crate::sessions::QueryContext;
use crate::sql::executor::physical_plans::AsyncFunctionDesc;
use crate::sql::plans::AsyncFunctionArgument;

pub struct TransformAsyncFunction {
ctx: Arc<QueryContext>,
// key is the index of async_func_desc
pub(crate) operators: BTreeMap<usize, Arc<Operator>>,
pub(crate) operators: BTreeMap<usize, Arc<DictionaryOperator>>,
async_func_descs: Vec<AsyncFunctionDesc>,
}

impl TransformAsyncFunction {
pub fn new(
ctx: Arc<QueryContext>,
async_func_descs: Vec<AsyncFunctionDesc>,
operators: BTreeMap<usize, Arc<Operator>>,
operators: BTreeMap<usize, Arc<DictionaryOperator>>,
) -> Self {
Self {
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
use std::collections::BTreeMap;
use std::sync::Arc;

use chrono_tz::Tz;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::types::date::date_to_string;
use databend_common_expression::types::timestamp::timestamp_to_string;
use databend_common_expression::types::DataType;
use databend_common_expression::types::Number;
use databend_common_expression::types::NumberDataType;
use databend_common_expression::types::NumberScalar;
use databend_common_expression::with_integer_mapped_type;
use databend_common_expression::BlockEntry;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::DataBlock;
Expand All @@ -27,6 +34,7 @@ use databend_common_expression::Value;
use databend_common_storage::build_operator;
use opendal::services::Redis;
use opendal::Operator;
use sqlx::MySqlPool;

use crate::pipelines::processors::transforms::TransformAsyncFunction;
use crate::sql::executor::physical_plans::AsyncFunctionDesc;
Expand All @@ -35,10 +43,101 @@ use crate::sql::plans::DictGetFunctionArgument;
use crate::sql::plans::DictionarySource;
use crate::sql::IndexType;

pub(crate) enum DictionaryOperator {
Operator(Operator),
Mysql((MySqlPool, String)),
}

impl DictionaryOperator {
fn format_key(&self, key: ScalarRef<'_>) -> String {
match key {
ScalarRef::String(s) => s.to_string(),
ScalarRef::Date(d) => format!("{}", date_to_string(d as i64, Tz::UTC)),
ScalarRef::Timestamp(t) => format!("{}", timestamp_to_string(t, Tz::UTC)),
_ => format!("{}", key),
}
}

async fn dict_get(&self, key: ScalarRef<'_>, data_type: &DataType) -> Result<Option<Scalar>> {
if key == ScalarRef::Null {
return Ok(None);
}
match self {
DictionaryOperator::Operator(op) => {
if let ScalarRef::String(key) = key {
let buffer = op.read(key).await;
match buffer {
Ok(res) => {
let value =
unsafe { String::from_utf8_unchecked(res.current().to_vec()) };
Ok(Some(Scalar::String(value)))
}
Err(e) => {
if e.kind() == opendal::ErrorKind::NotFound {
Ok(None)
} else {
Err(ErrorCode::DictionarySourceError(format!(
"dictionary source error: {e}"
)))
}
}
}
} else {
Ok(None)
}
}
DictionaryOperator::Mysql((pool, sql)) => match data_type.remove_nullable() {
DataType::Boolean => {
let value: Option<bool> = sqlx::query_scalar(sql)
.bind(self.format_key(key))
.fetch_optional(pool)
.await?;
Ok(value.map(Scalar::Boolean))
}
DataType::String => {
let value: Option<String> = sqlx::query_scalar(sql)
.bind(self.format_key(key))
.fetch_optional(pool)
.await?;
Ok(value.map(Scalar::String))
}
DataType::Number(num_ty) => {
with_integer_mapped_type!(|NUM_TYPE| match num_ty {
NumberDataType::NUM_TYPE => {
let value: Option<NUM_TYPE> = sqlx::query_scalar(&sql)
.bind(self.format_key(key))
.fetch_optional(pool)
.await?;
Ok(value.map(|v| Scalar::Number(NUM_TYPE::upcast_scalar(v))))
}
NumberDataType::Float32 => {
let value: Option<f32> = sqlx::query_scalar(sql)
.bind(self.format_key(key))
.fetch_optional(pool)
.await?;
Ok(value.map(|v| Scalar::Number(NumberScalar::Float32(v.into()))))
}
NumberDataType::Float64 => {
let value: Option<f64> = sqlx::query_scalar(sql)
.bind(self.format_key(key))
.fetch_optional(pool)
.await?;
Ok(value.map(|v| Scalar::Number(NumberScalar::Float64(v.into()))))
}
})
}
_ => Err(ErrorCode::DictionarySourceError(format!(
"unsupported value type {data_type}"
))),
},
}
}
}

impl TransformAsyncFunction {
pub fn init_operators(
pub(crate) fn init_operators(
async_func_descs: &[AsyncFunctionDesc],
) -> Result<BTreeMap<usize, Arc<Operator>>> {
) -> Result<BTreeMap<usize, Arc<DictionaryOperator>>> {
let mut operators = BTreeMap::new();
for (i, async_func_desc) in async_func_descs.iter().enumerate() {
if let AsyncFunctionArgument::DictGetFunction(dict_arg) = &async_func_desc.func_arg {
Expand All @@ -55,10 +154,17 @@ impl TransformAsyncFunction {
builder = builder.db(db_index);
}
let op = build_operator(builder)?;
operators.insert(i, Arc::new(op));
operators.insert(i, Arc::new(DictionaryOperator::Operator(op)));
}
DictionarySource::Mysql(_) => {
return Err(ErrorCode::Unimplemented("Mysql source is unsupported"));
DictionarySource::Mysql(sql_source) => {
let mysql_pool = databend_common_base::runtime::block_on(
sqlx::MySqlPool::connect(&sql_source.connection_url),
)?;
let sql = format!(
"SELECT {} FROM {} WHERE {} = ? LIMIT 1",
&sql_source.value_field, &sql_source.table, &sql_source.key_field
);
operators.insert(i, Arc::new(DictionaryOperator::Mysql((mysql_pool, sql))));
}
}
}
Expand All @@ -75,59 +181,26 @@ impl TransformAsyncFunction {
arg_indices: &[IndexType],
data_type: &DataType,
) -> Result<()> {
let op = self.operators.get(&i).unwrap().clone();

let op: &Arc<DictionaryOperator> = self.operators.get(&i).unwrap();
// only support one key field.
let arg_index = arg_indices[0];
let entry = data_block.get_by_offset(arg_index);
let value = match &entry.value {
Value::Scalar(scalar) => {
if let Scalar::String(key) = scalar {
let buffer = op.read(key).await;
match buffer {
Ok(res) => {
let value =
unsafe { String::from_utf8_unchecked(res.current().to_vec()) };
Value::Scalar(Scalar::String(value))
}
Err(e) => {
if e.kind() == opendal::ErrorKind::NotFound {
Value::Scalar(dict_arg.default_value.clone())
} else {
return Err(ErrorCode::DictionarySourceError(format!(
"dictionary source error: {e}"
)));
}
}
}
} else {
Value::Scalar(dict_arg.default_value.clone())
}
let value = op
.dict_get(scalar.as_ref(), data_type)
.await?
.unwrap_or(dict_arg.default_value.clone());
Value::Scalar(value)
}
Value::Column(column) => {
let mut builder = ColumnBuilder::with_capacity(data_type, column.len());
for scalar in column.iter() {
if let ScalarRef::String(key) = scalar {
let buffer = op.read(key).await;
match buffer {
Ok(res) => {
let value =
unsafe { String::from_utf8_unchecked(res.current().to_vec()) };
builder.push(ScalarRef::String(value.as_str()));
}
Err(e) => {
if e.kind() == opendal::ErrorKind::NotFound {
builder.push(dict_arg.default_value.as_ref());
} else {
return Err(ErrorCode::DictionarySourceError(format!(
"dictionary source error: {e}"
)));
}
}
};
} else {
builder.push(dict_arg.default_value.as_ref());
}
for scalar_ref in column.iter() {
let value = op
.dict_get(scalar_ref, data_type)
.await?
.unwrap_or(dict_arg.default_value.clone());
builder.push(value.as_ref());
}
Value::Column(builder.build())
}
Expand Down
8 changes: 2 additions & 6 deletions src/query/sql/src/planner/binder/ddl/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,10 @@ fn validate_mysql_fields(schema: &TableSchema) -> Result<()> {
for field in schema.fields() {
if !matches!(
field.data_type().remove_nullable(),
TableDataType::Boolean
| TableDataType::String
| TableDataType::Number(_)
| TableDataType::Date
| TableDataType::Timestamp
TableDataType::Boolean | TableDataType::String | TableDataType::Number(_)
) {
return Err(ErrorCode::BadArguments(
"The type of Mysql field must be in [`boolean`, `string`, `number`, `timestamp`, `date`]",
"The type of Mysql field must be in [`boolean`, `string`, `number`]",
));
}
}
Expand Down
3 changes: 3 additions & 0 deletions tests/sqllogictests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ databend-common-base = { workspace = true }
databend-common-exception = { workspace = true }
env_logger = "0.10.0"
futures-util = { workspace = true }
msql-srv = "0.11.0"
mysql_async = { workspace = true }
mysql_common = "0.32.4"
rand = { workspace = true }
regex = { workspace = true }
reqwest = { workspace = true }
serde = "1.0.150"
serde_json = { workspace = true }
sqllogictest = "0.21.0"
sqlparser = "0.50.0"
thiserror = { workspace = true }
tokio = { workspace = true }
walkdir = { workspace = true }
Expand Down
18 changes: 14 additions & 4 deletions tests/sqllogictests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::path::Path;
use std::time::Instant;

use clap::Parser;
use databend_sqllogictests::mock_source::run_mysql_source;
use databend_sqllogictests::mock_source::run_redis_source;
use futures_util::stream;
use futures_util::StreamExt;
Expand Down Expand Up @@ -76,10 +77,8 @@ impl sqllogictest::AsyncDB for Databend {
pub async fn main() -> Result<()> {
env_logger::init();

// Run a mock Redis server for dictionary tests.
databend_common_base::runtime::spawn(async move {
run_redis_source().await;
});
// Run mock sources for dictionary test.
run_mock_sources();

let args = SqlLogicTestArgs::parse();
let handlers = match &args.handlers {
Expand All @@ -103,6 +102,17 @@ pub async fn main() -> Result<()> {
Ok(())
}

fn run_mock_sources() {
// Run a mock Redis server.
databend_common_base::runtime::spawn(async move {
run_redis_source().await;
});
// Run a mock MySQL server.
databend_common_base::runtime::Thread::spawn(move || {
run_mysql_source();
});
}

async fn run_mysql_client() -> Result<()> {
println!(
"MySQL client starts to run with: {:?}",
Expand Down
2 changes: 2 additions & 0 deletions tests/sqllogictests/src/mock_source/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod mysql_source;
mod redis_source;
pub use mysql_source::run_mysql_source;
pub use redis_source::run_redis_source;
Loading

0 comments on commit b4e0a2b

Please sign in to comment.