Skip to content

Commit

Permalink
core implementation of X-Dust-Group-Ids and retrieval of view_filter …
Browse files Browse the repository at this point in the history
…from registry
  • Loading branch information
spolu committed Jul 25, 2024
1 parent 69749b2 commit 3b0a7f5
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 30 deletions.
18 changes: 18 additions & 0 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,15 @@ async fn runs_create(
},
None => (),
};
match headers.get("X-Dust-Group-Ids") {
Some(v) => match v.to_str() {
Ok(v) => {
credentials.insert("DUST_GROUP_IDS".to_string(), v.to_string());
}
_ => (),
},
None => (),
};

match run_helper(project_id, payload.clone(), state.clone()).await {
Ok(app) => {
Expand Down Expand Up @@ -834,6 +843,15 @@ async fn runs_create_stream(
},
None => (),
};
match headers.get("X-Dust-Group-Ids") {
Some(v) => match v.to_str() {
Ok(v) => {
credentials.insert("DUST_GROUP_IDS".to_string(), v.to_string());
}
_ => (),
},
None => (),
};

// create unbounded channel to pass as stream to Sse::new
let (tx, mut rx) = unbounded_channel::<Value>();
Expand Down
19 changes: 7 additions & 12 deletions core/src/blocks/data_source.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::blocks::block::{
parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env,
};
use crate::blocks::helpers::get_data_source_project;
use crate::blocks::helpers::get_data_source_project_and_view_filter;
use crate::data_sources::data_source::{Document, SearchFilter};
use crate::deno::js_executor::JSExecutor;
use crate::Rule;
Expand Down Expand Up @@ -75,18 +75,14 @@ impl DataSource {
async fn search_data_source(
&self,
env: &Env,
workspace_id: Option<String>,
workspace_id: String,
data_source_id: String,
top_k: usize,
filter: Option<SearchFilter>,
target_document_tokens: Option<usize>,
) -> Result<Vec<Document>> {
let data_source_project = match workspace_id {
Some(workspace_id) => {
get_data_source_project(&workspace_id, &data_source_id, env).await?
}
None => env.project.clone(),
};
let (data_source_project, view_filter) =
get_data_source_project_and_view_filter(&workspace_id, &data_source_id, env).await?;

let ds = match env
.store
Expand All @@ -110,8 +106,7 @@ impl DataSource {
Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)),
None => None,
},
// TODO(spolu): add in subsequent PR (data_source block view_filter support).
None,
view_filter,
self.full_text,
target_document_tokens,
)
Expand Down Expand Up @@ -200,8 +195,8 @@ impl Block for DataSource {
.iter()
.map(|v| {
let workspace_id = match v.get("workspace_id") {
Some(Value::String(p)) => Some(p.clone()),
_ => None,
Some(Value::String(p)) => p.clone(),
_ => Err(anyhow!(err_msg.clone()))?,
};
let data_source_id = match v.get("data_source_id") {
Some(Value::String(i)) => i.clone(),
Expand Down
12 changes: 7 additions & 5 deletions core/src/blocks/database_schema.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::helpers::get_data_source_project;
use super::helpers::get_data_source_project_and_view_filter;
use crate::blocks::block::{Block, BlockResult, BlockType, Env};
use crate::databases::database::{get_unique_table_names_for_database, Table};
use crate::Rule;
Expand Down Expand Up @@ -126,18 +126,20 @@ pub async fn load_tables_from_identifiers(
.collect::<Vec<_>>();

// Get a vec of the corresponding project ids for each (workspace_id, data_source_id) pair.
let project_ids = try_join_all(
let project_ids_view_filters = try_join_all(
data_source_identifiers
.iter()
.map(|(w, d)| get_data_source_project(w, d, env)),
.map(|(w, d)| get_data_source_project_and_view_filter(w, d, env)),
)
.await?;

// TODO(GROUPS_INFRA): enforce view_filter as returned above.

// Create a hashmap of (workspace_id, data_source_id) -> project_id.
let project_by_data_source = data_source_identifiers
.iter()
.zip(project_ids.iter())
.map(|((w, d), p)| ((*w, *d), p.clone()))
.zip(project_ids_view_filters.iter())
.map(|((w, d), p)| ((*w, *d), p.0.clone()))
.collect::<std::collections::HashMap<_, _>>();

let store = env.store.clone();
Expand Down
38 changes: 26 additions & 12 deletions core/src/blocks/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,40 @@
use super::block::Env;
use crate::project::Project;
use crate::{data_sources::data_source::SearchFilter, project::Project};
use anyhow::{anyhow, Result};
use hyper::body::Buf;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io::prelude::*;
use url::Url;
use urlencoding::encode;

pub async fn get_data_source_project(
#[derive(Debug, Serialize, Deserialize, Clone)]
struct FrontRegistryPayload {
data_source_id: String,
project_id: i64,
view_filter: Option<SearchFilter>,
}

pub async fn get_data_source_project_and_view_filter(
workspace_id: &String,
data_source_id: &String,
env: &Env,
) -> Result<Project> {
) -> Result<(Project, Option<SearchFilter>)> {
let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") {
None => Err(anyhow!(
"DUST_WORKSPACE_ID credentials missing, but `workspace_id` \
is set in `data_source` block config"
))?,
Some(v) => v.clone(),
};
let dust_group_ids = match env.credentials.get("DUST_GROUP_IDS") {
Some(v) => v.clone(),
// We default to the empty string if not set which will default to the workspace global
// group in front registry.
None => "".to_string(),
};

let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") {
Ok(key) => key,
Err(_) => Err(anyhow!(
Expand All @@ -46,6 +61,7 @@ pub async fn get_data_source_project(
format!("Bearer {}", registry_secret.as_str()),
)
.header("X-Dust-Workspace-Id", dust_workspace_id)
.header("X-Dust-Group-Ids", dust_group_ids)
.send()
.await?;

Expand All @@ -65,16 +81,14 @@ pub async fn get_data_source_project(

let response_body = String::from_utf8_lossy(&b).into_owned();

let body = match serde_json::from_str::<serde_json::Value>(&response_body) {
Ok(body) => body,
// parse body into FrontRegistryPayload
let payload: FrontRegistryPayload = match serde_json::from_str(&response_body) {
Ok(payload) => payload,
Err(_) => Err(anyhow!("Failed to parse registry response"))?,
};

match body.get("project_id") {
Some(Value::Number(p)) => match p.as_i64() {
Some(p) => Ok(Project::new_from_id(p)),
None => Err(anyhow!("Failed to parse registry response")),
},
_ => Err(anyhow!("Failed to parse registry response")),
}
Ok((
Project::new_from_id(payload.project_id),
payload.view_filter,
))
}
4 changes: 3 additions & 1 deletion front/pages/api/registry/[type]/lookup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ async function handler(
return;
}

// TODO(GROUPS_INFRA): Add x-dust-group-ids header retrieval + checks
// TODO(GROUPS_INFRA): Add x-dust-group-ids header retrieval
// - If not set default to the global workspace group
// - Enforce checks for access to data sources and data sources view below

const dustWorkspaceId = req.headers["x-dust-workspace-id"] as string;

Expand Down

0 comments on commit 3b0a7f5

Please sign in to comment.