Skip to content

Commit

Permalink
Merge branch 'main' into flav/viz-draft
Browse files Browse the repository at this point in the history
  • Loading branch information
flvndvd committed Jul 24, 2024
2 parents efe11e4 + 42d85e4 commit 0b36eec
Show file tree
Hide file tree
Showing 15 changed files with 520 additions and 139 deletions.
65 changes: 62 additions & 3 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,7 @@ struct DatasourceSearchPayload {
query: Option<String>,
top_k: usize,
filter: Option<SearchFilter>,
view_filter: Option<SearchFilter>,
full_text: bool,
credentials: run::Credentials,
target_document_tokens: Option<usize>,
Expand Down Expand Up @@ -1299,6 +1300,11 @@ async fn data_sources_search(
Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)),
None => None,
},
// TODO(spolu): follow_up PR.
// match payload.view_filter {
// Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)),
// None => None,
// },
payload.full_text,
payload.target_document_tokens,
)
Expand Down Expand Up @@ -1480,15 +1486,32 @@ async fn data_sources_documents_update_parents(
struct DataSourcesDocumentsVersionsListQuery {
offset: usize,
limit: usize,
// hash of the latest version to retrieve
latest_hash: Option<String>,
latest_hash: Option<String>, // Hash of the latest version to retrieve.
view_filter: Option<String>, // Parsed as JSON.
}

async fn data_sources_documents_versions_list(
Path((project_id, data_source_id, document_id)): Path<(i64, String, String)>,
State(state): State<Arc<APIState>>,
Query(query): Query<DataSourcesDocumentsVersionsListQuery>,
) -> (StatusCode, Json<APIResponse>) {
let view_filter: Option<SearchFilter> = match query
.view_filter
.as_ref()
.and_then(|f| Some(serde_json::from_str(f)))
{
Some(Ok(f)) => Some(f),
None => None,
Some(Err(e)) => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_view_filter",
"Failed to parse view_filter query parameter",
Some(e.into()),
)
}
};

let project = project::Project::new_from_id(project_id);
match state
.store
Expand All @@ -1497,6 +1520,10 @@ async fn data_sources_documents_versions_list(
&data_source_id,
&document_id,
Some((query.limit, query.offset)),
&match view_filter {
Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)),
None => None,
},
&query.latest_hash,
)
.await
Expand Down Expand Up @@ -1626,20 +1653,42 @@ async fn data_sources_documents_upsert(
struct DataSourcesListQuery {
offset: usize,
limit: usize,
view_filter: Option<String>, // Parsed as JSON.
}

async fn data_sources_documents_list(
Path((project_id, data_source_id)): Path<(i64, String)>,
State(state): State<Arc<APIState>>,
Query(query): Query<DataSourcesListQuery>,
) -> (StatusCode, Json<APIResponse>) {
let view_filter: Option<SearchFilter> = match query
.view_filter
.as_ref()
.and_then(|f| Some(serde_json::from_str(f)))
{
Some(Ok(f)) => Some(f),
None => None,
Some(Err(e)) => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_view_filter",
"Failed to parse view_filter query parameter",
Some(e.into()),
)
}
};

let project = project::Project::new_from_id(project_id);
match state
.store
.list_data_source_documents(
&project,
&data_source_id,
Some((query.limit, query.offset)),
&match view_filter {
Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)),
None => None,
},
true, // remove system tags
)
.await
Expand Down Expand Up @@ -1669,6 +1718,7 @@ async fn data_sources_documents_list(
#[derive(serde::Deserialize)]
struct DataSourcesDocumentsRetrieveQuery {
version_hash: Option<String>,
view_filter: Option<SearchFilter>,
}

async fn data_sources_documents_retrieve(
Expand Down Expand Up @@ -1696,7 +1746,16 @@ async fn data_sources_documents_retrieve(
None,
),
Some(ds) => match ds
.retrieve(state.store.clone(), &document_id, true, &query.version_hash)
.retrieve(
state.store.clone(),
&document_id,
&match query.view_filter {
Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)),
None => None,
},
true,
&query.version_hash,
)
.await
{
Err(e) => error_response(
Expand Down
4 changes: 2 additions & 2 deletions core/bin/qdrant/migrate_embedder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ async fn refresh_chunk_count_for_updated_documents(

let filter = SearchFilter {
timestamp: Some(TimestampFilter {
gt: Some(from_timestamp),
lt: Some(now),
gt: Some(from_timestamp as i64),
lt: Some(now as i64),
}),
tags: None,
parents: None,
Expand Down
73 changes: 69 additions & 4 deletions core/src/data_sources/data_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ pub struct ParentsFilter {
/// timestamp greater than `gt` and less than `lt`.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TimestampFilter {
pub gt: Option<u64>,
pub lt: Option<u64>,
pub gt: Option<i64>,
pub lt: Option<i64>,
}

// Custom deserializer for `TimestampFilter`
Expand All @@ -73,8 +73,8 @@ where

let f = Option::<InnerTimestampFilter>::deserialize(deserializer)?.map(|inner_filter| {
TimestampFilter {
gt: inner_filter.gt.map(|value| value as u64), // Convert f64 to u64
lt: inner_filter.lt.map(|value| value as u64), // Convert f64 to u64
gt: inner_filter.gt.map(|value| value as i64), // Convert f64 to u64
lt: inner_filter.lt.map(|value| value as i64), // Convert f64 to u64
}
});

Expand Down Expand Up @@ -403,6 +403,65 @@ impl Document {
token_count: None,
})
}

pub fn match_filter(&self, filter: &Option<SearchFilter>) -> bool {
match &filter {
Some(filter) => {
let mut m = true;
match &filter.tags {
Some(tags) => {
m = m
&& match &tags.is_in {
Some(is_in) => is_in.iter().any(|tag| self.tags.contains(tag)),
None => true,
};
m = m
&& match &tags.is_not {
Some(is_not) => is_not.iter().all(|tag| !self.tags.contains(tag)),
None => true,
};
}
None => (),
}
match &filter.parents {
Some(parents) => {
m = m
&& match &parents.is_in {
Some(is_in) => {
is_in.iter().any(|parent| self.parents.contains(parent))
}
None => true,
};
m = m
&& match &parents.is_not {
Some(is_not) => {
is_not.iter().all(|parent| !self.parents.contains(parent))
}
None => true,
};
}
None => (),
}
match &filter.timestamp {
Some(timestamp) => {
m = m
&& match timestamp.gt {
Some(gt) => self.timestamp as i64 >= gt,
None => true,
};
m = m
&& match timestamp.lt {
Some(lt) => self.timestamp as i64 <= lt,
None => true,
};
}
None => (),
}
m
}
None => true,
}
}
}

pub fn make_document_id_hash(document_id: &str) -> String {
Expand Down Expand Up @@ -1756,6 +1815,7 @@ impl DataSource {
&self,
store: Box<dyn Store + Sync + Send>,
document_id: &str,
view_filter: &Option<SearchFilter>,
remove_system_tags: bool,
version_hash: &Option<String>,
) -> Result<Option<Document>> {
Expand All @@ -1776,6 +1836,11 @@ impl DataSource {
}
};

// If the view_filter does not match the document we return as if it didn't exist.
if !d.match_filter(view_filter) {
return Ok(None);
}

d.tags = if remove_system_tags {
// remove tags that are prefixed with the system tag prefix
d.tags
Expand Down
1 change: 0 additions & 1 deletion core/src/data_sources/file_storage_document.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use anyhow::{anyhow, Result};
use cloud_storage::Object;
use serde::{Deserialize, Serialize};
use tokio::try_join;
use tracing::info;

use crate::utils;
Expand Down
Loading

0 comments on commit 0b36eec

Please sign in to comment.