From ded25757dc3627a6b7ecb886857906e51411970c Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Thu, 25 Jul 2024 00:05:17 +0200 Subject: [PATCH 01/14] data sources: use dust credentials for search (#6495) * data sources: use dust credentials for search * lint --- .../[name]/documents/[documentId]/index.ts | 2 +- .../v1/w/[wId]/data_sources/[name]/search.ts | 21 +++---------------- .../[name]/documents/[documentId]/index.ts | 2 +- front/temporal/upsert_queue/activities.ts | 2 +- 4 files changed, 6 insertions(+), 21 deletions(-) diff --git a/front/pages/api/v1/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts b/front/pages/api/v1/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts index e58d1d32c8d2..d41f0270d4bf 100644 --- a/front/pages/api/v1/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts +++ b/front/pages/api/v1/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts @@ -443,7 +443,7 @@ async function handler( }, }); } else { - // Dust managed credentials: all data sources. + // Data source operations are performed with our credentials. const credentials = dustManagedCredentials(); // Create document with the Dust internal API. diff --git a/front/pages/api/v1/w/[wId]/data_sources/[name]/search.ts b/front/pages/api/v1/w/[wId]/data_sources/[name]/search.ts index 3b355a3795fa..8271ee7f4161 100644 --- a/front/pages/api/v1/w/[wId]/data_sources/[name]/search.ts +++ b/front/pages/api/v1/w/[wId]/data_sources/[name]/search.ts @@ -1,9 +1,5 @@ import type { DocumentType, WithAPIErrorResponse } from "@dust-tt/types"; -import type { CredentialsType } from "@dust-tt/types"; -import { - credentialsFromProviders, - dustManagedCredentials, -} from "@dust-tt/types"; +import { dustManagedCredentials } from "@dust-tt/types"; import { CoreAPI } from "@dust-tt/types"; import type { JSONSchemaType } from "ajv"; import type { NextApiRequest, NextApiResponse } from "next"; @@ -12,7 +8,6 @@ import config from "@app/lib/api/config"; import { getDataSource } from "@app/lib/api/data_sources"; import { Authenticator, getAPIKey } from "@app/lib/auth"; import { parse_payload } from "@app/lib/http_utils"; -import { Provider } from "@app/lib/models/apps"; import logger from "@app/logger/logger"; import { apiError, withLogging } from "@app/logger/withlogging"; @@ -224,18 +219,8 @@ async function handler( req.query.tags_not = [req.query.tags_not]; } - let credentials: CredentialsType | null = null; - if (keyRes.value.isSystem) { - // Dust managed credentials: system API key (managed data source). - credentials = dustManagedCredentials(); - } else { - const providers = await Provider.findAll({ - where: { - workspaceId: keyRes.value.workspaceId, - }, - }); - credentials = credentialsFromProviders(providers); - } + // Data source operations are performed with our credentials. + const credentials = dustManagedCredentials(); const queryRes = parse_payload(searchQuerySchema, req.query); if (queryRes.isErr()) { diff --git a/front/pages/api/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts b/front/pages/api/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts index 1dc6d4bda3d8..c5b0464a62a2 100644 --- a/front/pages/api/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts +++ b/front/pages/api/w/[wId]/data_sources/[name]/documents/[documentId]/index.ts @@ -211,7 +211,7 @@ async function handler( }); } - // Dust managed credentials: all data sources. + // Data source operations are performed with our credentials. const credentials = dustManagedCredentials(); // Create document with the Dust internal API. diff --git a/front/temporal/upsert_queue/activities.ts b/front/temporal/upsert_queue/activities.ts index 03880585df1c..d0f82189e820 100644 --- a/front/temporal/upsert_queue/activities.ts +++ b/front/temporal/upsert_queue/activities.ts @@ -76,7 +76,7 @@ export async function upsertDocumentActivity( `workspace_id:${upsertQueueItem.workspaceId}`, ]; - // Dust managed credentials: all data sources. + // Data source operations are performed with our credentials. const credentials = dustManagedCredentials(); const coreAPI = new CoreAPI(config.getCoreAPIConfig(), logger); From a0fa5b9ed57842104d48650c6be0b4e1f87e86b7 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Thu, 25 Jul 2024 00:24:43 +0200 Subject: [PATCH 02/14] core embedder/llm DATA_SOURCES api keys (#6496) --- core/src/providers/mistral.rs | 48 +++++++++++++++++------------------ core/src/providers/openai.rs | 48 +++++++++++++++++------------------ 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/core/src/providers/mistral.rs b/core/src/providers/mistral.rs index 951165da41ac..5179f957265a 100644 --- a/core/src/providers/mistral.rs +++ b/core/src/providers/mistral.rs @@ -882,24 +882,20 @@ impl LLM for MistralAILLM { self.id.clone() } async fn initialize(&mut self, credentials: Credentials) -> Result<()> { - match std::env::var("CORE_DATA_SOURCES_MISTRAL_API_KEY") { - Ok(key) => { - self.api_key = Some(key); + match credentials.get("MISTRAL_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); } - Err(_) => match credentials.get("MISTRAL_API_KEY") { - Some(api_key) => { - self.api_key = Some(api_key.clone()); + None => match tokio::task::spawn_blocking(|| std::env::var("MISTRAL_API_KEY")).await? { + Ok(key) => { + self.api_key = Some(key); } - None => match std::env::var("MISTRAL_API_KEY") { - Ok(key) => { - self.api_key = Some(key); - } - Err(_) => Err(anyhow!( - "Credentials or environment variable `MISTRAL_API_KEY` is not set." - ))?, - }, + Err(_) => Err(anyhow!( + "Credentials or environment variable `MISTRAL_API_KEY` is not set." + ))?, }, } + Ok(()) } @@ -1109,20 +1105,24 @@ impl Embedder for MistralEmbedder { )); } - match credentials.get("MISTRAL_API_KEY") { - Some(api_key) => { - self.api_key = Some(api_key.clone()); + match std::env::var("CORE_DATA_SOURCES_MISTRAL_API_KEY") { + Ok(key) => { + self.api_key = Some(key); } - None => match tokio::task::spawn_blocking(|| std::env::var("MISTRAL_API_KEY")).await? { - Ok(key) => { - self.api_key = Some(key); + Err(_) => match credentials.get("MISTRAL_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); } - Err(_) => Err(anyhow!( - "Credentials or environment variable `MISTRAL_API_KEY` is not set." - ))?, + None => match std::env::var("MISTRAL_API_KEY") { + Ok(key) => { + self.api_key = Some(key); + } + Err(_) => Err(anyhow!( + "Credentials or environment variable `MISTRAL_API_KEY` is not set." + ))?, + }, }, } - Ok(()) } diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 10375e6016ff..5c2ff696f988 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -1721,23 +1721,17 @@ impl LLM for OpenAILLM { } async fn initialize(&mut self, credentials: Credentials) -> Result<()> { - // Give priority to `CORE_DATA_SOURCES_OPENAI_API_KEY` env variable - match std::env::var("CORE_DATA_SOURCES_OPENAI_API_KEY") { - Ok(key) => { - self.api_key = Some(key); + match credentials.get("OPENAI_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); } - Err(_) => match credentials.get("OPENAI_API_KEY") { - Some(api_key) => { - self.api_key = Some(api_key.clone()); + None => match tokio::task::spawn_blocking(|| std::env::var("OPENAI_API_KEY")).await? { + Ok(key) => { + self.api_key = Some(key); } - None => match std::env::var("OPENAI_API_KEY") { - Ok(key) => { - self.api_key = Some(key); - } - Err(_) => Err(anyhow!( - "Credentials or environment variable `OPENAI_API_KEY` is not set." - ))?, - }, + Err(_) => Err(anyhow!( + "Credentials or environment variable `OPENAI_API_KEY` is not set." + ))?, }, } Ok(()) @@ -2154,17 +2148,23 @@ impl Embedder for OpenAIEmbedder { )); } - match credentials.get("OPENAI_API_KEY") { - Some(api_key) => { - self.api_key = Some(api_key.clone()); + // Give priority to `CORE_DATA_SOURCES_OPENAI_API_KEY` env variable + match std::env::var("CORE_DATA_SOURCES_OPENAI_API_KEY") { + Ok(key) => { + self.api_key = Some(key); } - None => match tokio::task::spawn_blocking(|| std::env::var("OPENAI_API_KEY")).await? { - Ok(key) => { - self.api_key = Some(key); + Err(_) => match credentials.get("OPENAI_API_KEY") { + Some(api_key) => { + self.api_key = Some(api_key.clone()); } - Err(_) => Err(anyhow!( - "Credentials or environment variable `OPENAI_API_KEY` is not set." - ))?, + None => match std::env::var("OPENAI_API_KEY") { + Ok(key) => { + self.api_key = Some(key); + } + Err(_) => Err(anyhow!( + "Credentials or environment variable `OPENAI_API_KEY` is not set." + ))?, + }, }, } Ok(()) From 0810b4efe58dac8b393f6ecde7ffe732b909cfa1 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Thu, 25 Jul 2024 00:47:46 +0200 Subject: [PATCH 03/14] Groups: `core` view_filter support in `search` (#6490) * down to retrieve_chunks_without_query * all the way to qdrant_filter constructor: --- core/bin/dust_api.rs | 9 +- core/bin/qdrant/migrate_embedder.rs | 1 + core/src/blocks/data_source.rs | 2 + core/src/data_sources/data_source.rs | 173 ++++++++++++++++----------- core/src/stores/postgres.rs | 7 ++ core/src/stores/store.rs | 1 + 6 files changed, 120 insertions(+), 73 deletions(-) diff --git a/core/bin/dust_api.rs b/core/bin/dust_api.rs index 1fba517708db..156268e853e3 100644 --- a/core/bin/dust_api.rs +++ b/core/bin/dust_api.rs @@ -1300,11 +1300,10 @@ async fn data_sources_search( Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)), None => None, }, - // TODO(spolu): follow_up PR. - // match payload.view_filter { - // Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)), - // None => None, - // }, + match payload.view_filter { + Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)), + None => None, + }, payload.full_text, payload.target_document_tokens, ) diff --git a/core/bin/qdrant/migrate_embedder.rs b/core/bin/qdrant/migrate_embedder.rs index fa70f8fd6d8f..f095de86a734 100644 --- a/core/bin/qdrant/migrate_embedder.rs +++ b/core/bin/qdrant/migrate_embedder.rs @@ -643,6 +643,7 @@ async fn refresh_chunk_count_for_updated_documents( ds.project(), ds.data_source_id(), &Some(filter.clone()), + &None, Some((batch_size, offset)), ) .await?; diff --git a/core/src/blocks/data_source.rs b/core/src/blocks/data_source.rs index fbd3dadb5775..e86208d75d1a 100644 --- a/core/src/blocks/data_source.rs +++ b/core/src/blocks/data_source.rs @@ -110,6 +110,8 @@ impl DataSource { Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)), None => None, }, + // TODO(spolu): add in subsequent PR (data_source block view_filter support). + None, self.full_text, target_document_tokens, ) diff --git a/core/src/data_sources/data_source.rs b/core/src/data_sources/data_source.rs index 64023d3772d9..55343330cfef 100644 --- a/core/src/data_sources/data_source.rs +++ b/core/src/data_sources/data_source.rs @@ -1312,6 +1312,7 @@ impl DataSource { query: &Option, top_k: usize, filter: Option, + view_filter: Option, full_text: bool, target_document_tokens: Option, ) -> Result> { @@ -1319,8 +1320,14 @@ impl DataSource { let qdrant_client = self.main_qdrant_client(&qdrant_clients); - // We ensure that we have not left a `parents.is_in_map`` in the filter. - match filter.as_ref() { + // We ensure that we have not left a `parents.is_in_map`` in the filters. + match &filter { + Some(filter) => { + filter.ensure_postprocessed()?; + } + None => (), + } + match &view_filter { Some(filter) => { filter.ensure_postprocessed()?; } @@ -1352,7 +1359,13 @@ impl DataSource { ))?; } let chunks = self - .retrieve_chunks_without_query(store, qdrant_client.clone(), top_k, &filter) + .retrieve_chunks_without_query( + store, + qdrant_client.clone(), + top_k, + &filter, + &view_filter, + ) .await?; qdrant_search_duration = utils::now() - time_qdrant_start; chunks @@ -1370,7 +1383,7 @@ impl DataSource { embedding_duration = utils::now() - time_qdrant_start; // Construct the filters for the search query if specified. - let f = build_qdrant_filter(&filter); + let f = build_qdrant_filter(&filter, &view_filter); let time_qdrant_search_start = utils::now(); let results = qdrant_client @@ -1687,6 +1700,7 @@ impl DataSource { qdrant_client: DustQdrantClient, top_k: usize, filter: &Option, + view_filter: &Option, ) -> Result> { let store = store.clone(); @@ -1695,6 +1709,7 @@ impl DataSource { &self.project, self.data_source_id(), filter, + view_filter, // With top_k documents, we should be guaranteed to have at least top_k chunks, if // we make the assumption that each document has at least one chunk. Some((top_k, 0)), @@ -2012,7 +2027,10 @@ impl DataSource { } } -fn build_qdrant_filter(filter: &Option) -> Option { +fn build_qdrant_filter( + filter: &Option, + view_filter: &Option, +) -> Option { fn qdrant_match_field_condition(key: &str, v: Vec) -> qdrant::Condition { qdrant::FieldCondition { key: key.to_string(), @@ -2026,81 +2044,100 @@ fn build_qdrant_filter(filter: &Option) -> Option .into() } - // Construct the filters for the search query if specified. - match filter { - Some(f) => { - let mut must_filter: Vec = vec![]; - let mut must_not_filter: Vec = vec![]; + let mut must_filter: Vec = vec![]; + let mut must_not_filter: Vec = vec![]; - match &f.tags { - Some(tags) => { - match tags.is_in.clone() { - Some(v) => must_filter.push(qdrant_match_field_condition("tags", v)), - None => (), - }; - match tags.is_not.clone() { - Some(v) => must_not_filter.push(qdrant_match_field_condition("tags", v)), - None => (), - }; - } - None => (), - }; + let mut process_filter = |f: &SearchFilter| { + match &f.tags { + Some(tags) => { + match tags.is_in.clone() { + Some(v) => must_filter.push(qdrant_match_field_condition("tags", v)), + None => (), + }; + match tags.is_not.clone() { + Some(v) => must_not_filter.push(qdrant_match_field_condition("tags", v)), + None => (), + }; + } + None => (), + }; - match &f.parents { - Some(parents) => { - match parents.is_in.clone() { - Some(v) => must_filter.push(qdrant_match_field_condition("parents", v)), - None => (), - }; - match parents.is_not.clone() { - Some(v) => must_not_filter.push(qdrant_match_field_condition("parents", v)), - None => (), - }; - } - None => (), - }; + match &f.parents { + Some(parents) => { + match parents.is_in.clone() { + Some(v) => must_filter.push(qdrant_match_field_condition("parents", v)), + None => (), + }; + match parents.is_not.clone() { + Some(v) => must_not_filter.push(qdrant_match_field_condition("parents", v)), + None => (), + }; + } + None => (), + }; - match &f.timestamp { - Some(timestamp) => { - match timestamp.gt.clone() { - Some(v) => must_filter.push( - qdrant::FieldCondition { - key: "timestamp".to_string(), - range: Some(qdrant::Range { - gte: Some(v as f64), - ..Default::default() - }), + match &f.timestamp { + Some(timestamp) => { + match timestamp.gt.clone() { + Some(v) => must_filter.push( + qdrant::FieldCondition { + key: "timestamp".to_string(), + range: Some(qdrant::Range { + gte: Some(v as f64), ..Default::default() - } - .into(), - ), - None => (), - }; - match timestamp.lt.clone() { - Some(v) => must_filter.push( - qdrant::FieldCondition { - key: "timestamp".to_string(), - range: Some(qdrant::Range { - lte: Some(v as f64), - ..Default::default() - }), + }), + ..Default::default() + } + .into(), + ), + None => (), + }; + match timestamp.lt.clone() { + Some(v) => must_filter.push( + qdrant::FieldCondition { + key: "timestamp".to_string(), + range: Some(qdrant::Range { + lte: Some(v as f64), ..Default::default() - } - .into(), - ), - None => (), - }; - } - None => (), - }; + }), + ..Default::default() + } + .into(), + ), + None => (), + }; + } + None => (), + }; + }; + match (filter, view_filter) { + (Some(f), Some(vf)) => { + process_filter(f); + process_filter(vf); + Some(qdrant::Filter { + must: must_filter, + must_not: must_not_filter, + ..Default::default() + }) + } + (Some(f), None) => { + process_filter(f); + Some(qdrant::Filter { + must: must_filter, + must_not: must_not_filter, + ..Default::default() + }) + } + (None, Some(vf)) => { + process_filter(vf); Some(qdrant::Filter { must: must_filter, must_not: must_not_filter, ..Default::default() }) } - None => None, + (None, None) => None, } } diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index 4dc04a0ec478..530c67949813 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -1595,6 +1595,7 @@ impl Store for PostgresStore { project: &Project, data_source_id: &str, filter: &Option, + view_filter: &Option, limit_offset: Option<(usize, usize)>, ) -> Result<(Vec, usize)> { let pool = self.pool.clone(); @@ -1624,6 +1625,12 @@ impl Store for PostgresStore { where_clauses.extend(filter_clauses); params.extend(filter_params); + let (view_filter_clauses, view_filter_params, p_idx) = + Self::where_clauses_and_params_for_filter(view_filter, p_idx); + + where_clauses.extend(view_filter_clauses); + params.extend(view_filter_params); + // compute the total count let count_query = format!( "SELECT COUNT(*) FROM data_sources_documents WHERE {}", diff --git a/core/src/stores/store.rs b/core/src/stores/store.rs index 2aa34026eee4..aed74460e9c2 100644 --- a/core/src/stores/store.rs +++ b/core/src/stores/store.rs @@ -118,6 +118,7 @@ pub trait Store { project: &Project, data_source_id: &str, filter: &Option, + view_filter: &Option, limit_offset: Option<(usize, usize)>, ) -> Result<(Vec, usize)>; async fn upsert_data_source_document( From 92e5a16b2c515341baaafb4124b863a9c44d8e83 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Thu, 25 Jul 2024 01:03:50 +0200 Subject: [PATCH 04/14] core: fix version endpoints SQL query (#6497) --- core/src/stores/postgres.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index 530c67949813..8c3d31f82585 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -1525,7 +1525,7 @@ impl Store for PostgresStore { params.push(&data_source_row_id); where_clauses.push("document_id = $2".to_string()); params.push(&document_id); - where_clauses.push("created <= $3'".to_string()); + where_clauses.push("created <= $3".to_string()); params.push(&latest_hash_created); let (filter_clauses, filter_params, p_idx) = From bf0575d7df77ad482a5df8250a0d39b453233043 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Thu, 25 Jul 2024 01:31:11 +0200 Subject: [PATCH 05/14] oauth: remove entirely nango from front (#6498) * oauth: remove entirely nango from front * comment deprecated migration * lint --- .../20230606_deprecate_nango_connection_id.ts | 60 ++-- front/lib/connector_connection_id.ts | 2 +- front/lib/labs/config.ts | 25 -- front/lib/labs/transcripts/utils/helpers.ts | 112 ++------ .../resources/labs_transcripts_resource.ts | 20 -- .../20240722_migrate_labs_nango_connection.ts | 267 ------------------ front/package-lock.json | 16 -- front/package.json | 2 - front/pages/nango-oauth-callback.ts | 10 - types/src/front/lib/labs.ts | 2 - types/src/index.ts | 1 - types/src/shared/nango_errors.ts | 22 -- 12 files changed, 51 insertions(+), 488 deletions(-) delete mode 100644 front/lib/labs/config.ts delete mode 100644 front/migrations/20240722_migrate_labs_nango_connection.ts delete mode 100644 front/pages/nango-oauth-callback.ts delete mode 100644 types/src/shared/nango_errors.ts diff --git a/connectors/migrations/20230606_deprecate_nango_connection_id.ts b/connectors/migrations/20230606_deprecate_nango_connection_id.ts index 4331df3c8212..5691063cdb76 100644 --- a/connectors/migrations/20230606_deprecate_nango_connection_id.ts +++ b/connectors/migrations/20230606_deprecate_nango_connection_id.ts @@ -1,30 +1,30 @@ -import { Op } from "sequelize"; - -import { sequelizeConnection } from "@connectors/resources/storage"; -import { ConnectorModel } from "@connectors/resources/storage/models/connector_model"; - -async function main() { - await ConnectorModel.update( - { - connectionId: sequelizeConnection.col("nangoConnectionId"), - }, - { - // @ts-expect-error `connectionId` has been made non-nullable - where: { - connectionId: { - [Op.eq]: null, - }, - }, - } - ); -} - -main() - .then(() => { - console.log("Done"); - process.exit(0); - }) - .catch((e) => { - console.error(e); - process.exit(1); - }); +// import { Op } from "sequelize"; +// +// import { sequelizeConnection } from "@connectors/resources/storage"; +// import { ConnectorModel } from "@connectors/resources/storage/models/connector_model"; +// +// async function main() { +// await ConnectorModel.update( +// { +// connectionId: sequelizeConnection.col("nangoConnectionId"), +// }, +// { +// // @ts-expect-error `connectionId` has been made non-nullable +// where: { +// connectionId: { +// [Op.eq]: null, +// }, +// }, +// } +// ); +// } +// +// main() +// .then(() => { +// console.log("Done"); +// process.exit(0); +// }) +// .catch((e) => { +// console.error(e); +// process.exit(1); +// }); diff --git a/front/lib/connector_connection_id.ts b/front/lib/connector_connection_id.ts index 3368f81ce983..11c452e39988 100644 --- a/front/lib/connector_connection_id.ts +++ b/front/lib/connector_connection_id.ts @@ -12,7 +12,7 @@ export function buildConnectionId( return connectionName; } -// Labs adds nango connections that are not necessarily made available to the rest of the product. +// Labs adds connections that are not necessarily made available to the rest of the product. export function buildLabsConnectionId( wId: string, provider: LabsConnectorProvider diff --git a/front/lib/labs/config.ts b/front/lib/labs/config.ts deleted file mode 100644 index 5151b3dc7a5e..000000000000 --- a/front/lib/labs/config.ts +++ /dev/null @@ -1,25 +0,0 @@ -import type { LabsConnectorProvider } from "@dust-tt/types"; -import { assertNever, EnvironmentConfig } from "@dust-tt/types"; - -const config = { - getNangoPublicKey: (): string => { - return EnvironmentConfig.getEnvVariable("NANGO_PUBLIC_KEY"); - }, - getNangoSecretKey: (): string => { - return EnvironmentConfig.getEnvVariable("NANGO_SECRET_KEY"); - }, - getNangoConnectorIdForProvider: (provider: LabsConnectorProvider): string => { - switch (provider) { - case "google_drive": - return EnvironmentConfig.getEnvVariable( - "NANGO_GOOGLE_DRIVE_CONNECTOR_ID" - ); - case "gong": - return EnvironmentConfig.getEnvVariable("NANGO_GONG_CONNECTOR_ID"); - default: - assertNever(provider); - } - }, -}; - -export default config; diff --git a/front/lib/labs/transcripts/utils/helpers.ts b/front/lib/labs/transcripts/utils/helpers.ts index d7d87a0a3b10..6a8724b60233 100644 --- a/front/lib/labs/transcripts/utils/helpers.ts +++ b/front/lib/labs/transcripts/utils/helpers.ts @@ -1,26 +1,12 @@ -import type { - ModelId, - NangoIntegrationId, - OAuthProvider, - Result, -} from "@dust-tt/types"; -import { Err, getOAuthConnectionAccessToken, Ok } from "@dust-tt/types"; -import { Nango } from "@nangohq/node"; +import type { ModelId, OAuthProvider } from "@dust-tt/types"; +import { getOAuthConnectionAccessToken } from "@dust-tt/types"; import { google } from "googleapis"; import apiConfig from "@app/lib/api/config"; import type { Authenticator } from "@app/lib/auth"; -import config from "@app/lib/labs/config"; import { LabsTranscriptsConfigurationResource } from "@app/lib/resources/labs_transcripts_resource"; import logger from "@app/logger/logger"; -const nango = new Nango({ secretKey: config.getNangoSecretKey() }); - -export function isDualUseOAuthConnectionId(connectionId: string): boolean { - // TODO(spolu): make sure this function is removed once fully migrated. - return connectionId.startsWith("con_"); -} - // Google Auth export async function getTranscriptsGoogleAuth( auth: Authenticator, @@ -45,86 +31,28 @@ export async function getTranscriptsGoogleAuth( const oauth2Client = new google.auth.OAuth2(); - if (isDualUseOAuthConnectionId(connectionId)) { - const tokRes = await getOAuthConnectionAccessToken({ - config: apiConfig.getOAuthAPIConfig(), - logger, - provider, - connectionId, - }); - - if (tokRes.isErr()) { - logger.error( - { connectionId, error: tokRes.error, provider }, - "Error retrieving access token" - ); - throw new Error(`Error retrieving access token from ${provider}`); - } + const tokRes = await getOAuthConnectionAccessToken({ + config: apiConfig.getOAuthAPIConfig(), + logger, + provider, + connectionId, + }); - oauth2Client.setCredentials({ - access_token: tokRes.value.access_token, - scope: (tokRes.value.scrubbed_raw_json as { scope: string }).scope, - token_type: (tokRes.value.scrubbed_raw_json as { token_type: string }) - .token_type, - expiry_date: tokRes.value.access_token_expiry, - }); - } else { - const res = await nango.getConnection( - config.getNangoConnectorIdForProvider("google_drive"), - connectionId + if (tokRes.isErr()) { + logger.error( + { connectionId, error: tokRes.error, provider }, + "Error retrieving access token" ); - - oauth2Client.setCredentials({ - access_token: res.credentials.raw.access_token, - scope: res.credentials.raw.scope, - token_type: res.credentials.raw.token_type, - expiry_date: new Date(res.credentials.raw.expires_at).getTime(), - }); + throw new Error(`Error retrieving access token from ${provider}`); } - return oauth2Client; -} - -export async function getAccessTokenFromNango( - nangoIntegrationId: NangoIntegrationId, - nangoConnectionId: string -): Promise { - const res = await nango.getConnection(nangoIntegrationId, nangoConnectionId); - - return res.credentials.raw.access_token; -} - -export async function nangoDeleteConnection( - connectionId: string, - providerConfigKey: string -): Promise> { - const url = `${nango.serverUrl}/connection/${connectionId}?provider_config_key=${providerConfigKey}`; - const headers = { - "Content-Type": "application/json", - "Accept-Encoding": "application/json", - Authorization: `Bearer ${nango.secretKey}`, - }; - const res = await fetch(url, { - method: "DELETE", - headers, + oauth2Client.setCredentials({ + access_token: tokRes.value.access_token, + scope: (tokRes.value.scrubbed_raw_json as { scope: string }).scope, + token_type: (tokRes.value.scrubbed_raw_json as { token_type: string }) + .token_type, + expiry_date: tokRes.value.access_token_expiry, }); - if (res.ok) { - return new Ok(undefined); - } else { - logger.error({ connectionId }, "Could not delete Nango connection."); - if (res) { - if (res.status === 404) { - logger.error({ connectionId }, "Connection not found on Nango."); - return new Ok(undefined); - } - - return new Err( - new Error( - `Could not delete connection. ${res.statusText}, ${await res.text()}` - ) - ); - } - return new Err(new Error(`Could not delete connection.`)); - } + return oauth2Client; } diff --git a/front/lib/resources/labs_transcripts_resource.ts b/front/lib/resources/labs_transcripts_resource.ts index c09f8eee507a..bf3d252a7559 100644 --- a/front/lib/resources/labs_transcripts_resource.ts +++ b/front/lib/resources/labs_transcripts_resource.ts @@ -13,11 +13,6 @@ import type { import type { CreationAttributes } from "sequelize"; import type { Authenticator } from "@app/lib/auth"; -import config from "@app/lib/labs/config"; -import { - isDualUseOAuthConnectionId, - nangoDeleteConnection, -} from "@app/lib/labs/transcripts/utils/helpers"; import { BaseResource } from "@app/lib/resources/base_resource"; import { LabsTranscriptsConfigurationModel } from "@app/lib/resources/storage/models/labs_transcripts"; import { LabsTranscriptsHistoryModel } from "@app/lib/resources/storage/models/labs_transcripts"; @@ -187,21 +182,6 @@ export class LabsTranscriptsConfigurationResource extends BaseResource = { - google_drive: NANGO_GOOGLE_DRIVE_CONNECTOR_ID, - gong: NANGO_GONG_CONNECTOR_ID, -}; - -const CONNECTORS_WITH_REFRESH_TOKENS = ["google_drive"]; - -async function appendRollbackCommand( - provider: LabsTranscriptsProviderType, - labsTranscriptConfigurationId: ModelId, - oldConnectionId: string -) { - const sql = `UPDATE labs_transcripts_configurations SET "connectionId" = '${oldConnectionId}' WHERE id = ${labsTranscriptConfigurationId};\n`; - await fs.appendFile(`${provider}_rollback_commands.sql`, sql); -} - -function getRedirectUri(provider: LabsTranscriptsProviderType): string { - return `${config.getDustAPIConfig().url}/oauth/${provider}/finalize`; -} - -async function migrateConfigurationId( - api: OAuthAPI, - provider: LabsTranscriptsProviderType, - configuration: LabsTranscriptsConfigurationResource, - logger: Logger, - execute: boolean -): Promise> { - logger.info( - `Migrating configuration id ${configuration.id}, current connectionId ${configuration.connectionId}.` - ); - - const user = await configuration.getUser(); - const workspace = await Workspace.findOne({ - where: { id: configuration.workspaceId }, - }); - - if (!user || !workspace) { - return new Err(new Error("User or workspace not found")); - } - - const integrationId = NANGO_CONNECTOR_IDS[provider]; - if (!integrationId) { - return new Err(new Error("Nango integration ID not found for provider")); - } - - // Retrieve connection from nango. - let connection: any | null = null; - try { - connection = await nango.getConnection( - integrationId, - configuration.connectionId, - true, // forceRefresh - true // returnRefreshToksn - ); - } catch (e) { - return new Err(new Error(`Nango error: ${e}`)); - } - - console.log( - ">>>>>>>>>>>>>>>>>>>>>>>>>>> BEG CONNECTION <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" - ); - console.log(connection); - console.log( - ">>>>>>>>>>>>>>>>>>>>>>>>>>> END CONNECTION <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" - ); - - if (!connection.credentials.access_token) { - return new Err(new Error("Could not retrieve `access_token` from Nango")); - } - - // We don't have authorization codes from Nango - const migratedCredentials: MigratedCredentialsType = { - redirect_uri: getRedirectUri(provider), - access_token: connection.credentials.access_token, - raw_json: connection.credentials.raw, - }; - - // If provider supports refresh tokens, migrate them. - if (CONNECTORS_WITH_REFRESH_TOKENS.includes(provider)) { - const thirtyMinutesFromNow = new Date(new Date().getTime() + 30 * 60000); - - if ( - !connection.credentials.expires_at || - new Date(connection.credentials.expires_at).getTime() < - thirtyMinutesFromNow.getTime() - ) { - return new Err( - new Error( - "Expires at is not set or is less than 30 minutes from now. Skipping migration." - ) - ); - } - - if (connection.credentials.expires_at) { - migratedCredentials.access_token_expiry = Date.parse( - connection.credentials.expires_at - ); - } - if (connection.credentials.refresh_token) { - migratedCredentials.refresh_token = connection.credentials.refresh_token; - } - } - - console.log( - ">>>>>>>>>>>>>>>>>>>>>>>>>>> BEG MIGRATED_CREDENTIALS <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" - ); - console.log(migratedCredentials); - console.log( - ">>>>>>>>>>>>>>>>>>>>>>>>>>> END MIGRATED_CREDENTIALS <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<" - ); - - if (!execute) { - return new Ok(undefined); - } - - // Save the old connectionId for rollback. - const oldConnectionId = configuration.connectionId; - - // Create the connection with migratedCredentials. - const cRes = await api.createConnection({ - // TOOD(alban): remove the as once gong is an OAuthProvider. - provider: provider as OAuthProvider, - metadata: { - use_case: USE_CASE, - workspace_id: workspace.sId, - user_id: user.sId, - origin: "migrated", - }, - migratedCredentials, - }); - - if (cRes.isErr()) { - return cRes; - } - - const newConnectionId = cRes.value.connection.connection_id; - - // Append rollback command after successful update. - await appendRollbackCommand(provider, configuration.id, oldConnectionId); - - await configuration.updateConnectionId(newConnectionId); - - logger.info( - `Successfully migrated connection id for connector ${configuration.id}, new connectionId ${newConnectionId}.` - ); - - return new Ok(undefined); -} - -async function migrateAllConfigurations( - provider: LabsTranscriptsProviderType, - configurationId: ModelId | undefined, - logger: Logger, - execute: boolean -) { - const api = new OAuthAPI(config.getOAuthAPIConfig(), logger); - - const configurations = configurationId - ? removeNulls([ - await LabsTranscriptsConfigurationResource.fetchByModelId( - configurationId - ), - ]) - : await LabsTranscriptsConfigurationResource.listByProvider({ - provider, - }); - - logger.info( - `Found ${configurations.length} ${provider} configurations to migrate.` - ); - - for (const configuration of configurations) { - const localLogger = logger.child({ - configurationId: configuration.id, - workspaceId: configuration.workspaceId, - }); - - if (isDualUseOAuthConnectionId(configuration.connectionId)) { - localLogger.info("Skipping alreaydy migrated configuration"); - continue; - } - - const migrationRes = await migrateConfigurationId( - api, - provider, - configuration, - localLogger, - execute - ); - if (migrationRes.isErr()) { - localLogger.error( - { - error: migrationRes.error, - }, - "Failed to migrate configuration. Exiting." - ); - } - } - - logger.info(`Done migrating configurations.`); -} - -makeScript( - { - connectorId: { - alias: "c", - describe: "Connector ID", - type: "number", - }, - provider: { - alias: "p", - describe: "OAuth provider to migrate", - type: "string", - }, - }, - async ({ provider, connectorId, execute }, logger) => { - if (isOAuthProvider(provider)) { - await migrateAllConfigurations( - provider as LabsTranscriptsProviderType, - connectorId, - logger, - execute - ); - } else { - logger.error( - { - provider, - }, - "Invalid provider provided" - ); - } - } -); diff --git a/front/package-lock.json b/front/package-lock.json index add3d06731f3..e08386f6c055 100644 --- a/front/package-lock.json +++ b/front/package-lock.json @@ -16,8 +16,6 @@ "@headlessui/react": "^1.7.7", "@heroicons/react": "^2.0.11", "@hookform/resolvers": "^3.3.4", - "@nangohq/frontend": "^0.16.1", - "@nangohq/node": "^0.36.37", "@radix-ui/react-checkbox": "^1.0.4", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-label": "^2.0.2", @@ -13653,20 +13651,6 @@ } } }, - "node_modules/@nangohq/frontend": { - "version": "0.16.1", - "license": "SEE LICENSE IN LICENSE FILE IN GIT REPOSITORY" - }, - "node_modules/@nangohq/node": { - "version": "0.36.37", - "license": "SEE LICENSE IN LICENSE FILE IN GIT REPOSITORY", - "dependencies": { - "axios": "^1.2.0" - }, - "engines": { - "node": ">=16.7" - } - }, "node_modules/@next/env": { "version": "14.2.3", "resolved": "https://registry.npmjs.org/@next/env/-/env-14.2.3.tgz", diff --git a/front/package.json b/front/package.json index 1c89667cdf0f..02a569fc3e68 100644 --- a/front/package.json +++ b/front/package.json @@ -28,8 +28,6 @@ "@headlessui/react": "^1.7.7", "@heroicons/react": "^2.0.11", "@hookform/resolvers": "^3.3.4", - "@nangohq/frontend": "^0.16.1", - "@nangohq/node": "^0.36.37", "@radix-ui/react-checkbox": "^1.0.4", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-label": "^2.0.2", diff --git a/front/pages/nango-oauth-callback.ts b/front/pages/nango-oauth-callback.ts deleted file mode 100644 index d3df1c324f5b..000000000000 --- a/front/pages/nango-oauth-callback.ts +++ /dev/null @@ -1,10 +0,0 @@ -import { useEffect } from "react"; - -export default function NangoRedirect() { - useEffect(() => { - const nangoURL = `https://api.nango.dev/oauth/callback${window.location.search}`; - window.location.replace(nangoURL); - }, []); - - return null; // Render nothing. -} diff --git a/types/src/front/lib/labs.ts b/types/src/front/lib/labs.ts index d9a04c2e1ad9..42443ce4f905 100644 --- a/types/src/front/lib/labs.ts +++ b/types/src/front/lib/labs.ts @@ -2,6 +2,4 @@ export const labsTranscriptsProviders = ["google_drive", "gong"] as const; export type LabsTranscriptsProviderType = (typeof labsTranscriptsProviders)[number]; -export type NangoConnectionId = string; -export type NangoIntegrationId = string; export const minTranscriptsSize = 200; diff --git a/types/src/index.ts b/types/src/index.ts index cfd9549301eb..417883295fe3 100644 --- a/types/src/index.ts +++ b/types/src/index.ts @@ -79,7 +79,6 @@ export * from "./shared/env"; export * from "./shared/feature_flags"; export * from "./shared/message_classification"; export * from "./shared/model_id"; -export * from "./shared/nango_errors"; export * from "./shared/rate_limiter"; export * from "./shared/result"; export * from "./shared/text_extraction"; diff --git a/types/src/shared/nango_errors.ts b/types/src/shared/nango_errors.ts deleted file mode 100644 index 6e6fb56ee65b..000000000000 --- a/types/src/shared/nango_errors.ts +++ /dev/null @@ -1,22 +0,0 @@ -export interface NangoError extends Error { - code: string; - status: number; - config?: { - url?: string; - }; -} - -export function isNangoError(err: unknown): err is NangoError { - const isError = err instanceof Error; - const hasStatus = isError && "status" in err; - const hasNangoCode = - isError && "code" in err && err.code === "ERR_BAD_RESPONSE"; - const hasConfig = isError && "config" in err; - const hasConfigUrl = - hasConfig && - err.config !== null && - typeof err.config === "object" && - "url" in err.config; - - return hasStatus && hasNangoCode && hasConfig && hasConfigUrl; -} From bca0c064d3e944421c3a0d4d6c74aecd7b1d9c97 Mon Sep 17 00:00:00 2001 From: Alban Dumouilla Date: Thu, 25 Jul 2024 10:01:23 +0200 Subject: [PATCH 06/14] Add workspace name in datasource poke view (#6500) --- front/pages/poke/[wId]/data_sources/[name]/index.tsx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/front/pages/poke/[wId]/data_sources/[name]/index.tsx b/front/pages/poke/[wId]/data_sources/[name]/index.tsx index 4930f6c8b20d..9df51b002e75 100644 --- a/front/pages/poke/[wId]/data_sources/[name]/index.tsx +++ b/front/pages/poke/[wId]/data_sources/[name]/index.tsx @@ -277,11 +277,11 @@ const DataSourcePage = ({
- +
- « workspace + « workspace
{ @@ -296,7 +296,7 @@ const DataSourcePage = ({ } }} > - ๐Ÿ”’search + ๐Ÿ”’ search data
From ab67055a82465b6520a0785af4171e28d672af68 Mon Sep 17 00:00:00 2001 From: Jules Belveze <32683010+JulesBelveze@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:05:32 +0200 Subject: [PATCH 07/14] [front] - fix(conversation): action chip state (#6479) * [front/components/assistant/conversation/actions] - refactor: optimize AgentMessageActions to assess action completeness - Implement a new function `isActionComplete` to determine if an agent action is complete based on its type and output - Use `isActionComplete` within `useMemo` to more accurately decide if the agent is thinking or acting by checking each action's status * fix: lint/format * [front/components/assistant/conversation/actions] - refactor: streamline action completion check in AgentMessageActions - Remove redundant isActionBootComplete() function and simplify thinking or acting detection logic - Update ActionDetails onClick to conditionally set cursor-pointer class based on action presence - Improve code clarity by utilizing classNames helper function for conditional styling in ActionDetails component --------- Co-authored-by: Jules --- .../assistant/conversation/actions/AgentMessageActions.tsx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/front/components/assistant/conversation/actions/AgentMessageActions.tsx b/front/components/assistant/conversation/actions/AgentMessageActions.tsx index 57a09b3eb923..4d36093a7a40 100644 --- a/front/components/assistant/conversation/actions/AgentMessageActions.tsx +++ b/front/components/assistant/conversation/actions/AgentMessageActions.tsx @@ -5,6 +5,7 @@ import { useEffect, useMemo, useState } from "react"; import { getActionSpecification } from "@app/components/actions/types"; import { AgentMessageActionsDrawer } from "@app/components/assistant/conversation/actions/AgentMessageActionsDrawer"; import type { MessageSizeType } from "@app/components/assistant/conversation/ConversationMessage"; +import { classNames } from "@app/lib/utils"; interface AgentMessageActionsProps { agentMessage: AgentMessageType; @@ -80,7 +81,10 @@ function ActionDetails({
From 64a995bee4648b29337ebe2119397b0ab474c68d Mon Sep 17 00:00:00 2001 From: Alban Dumouilla Date: Thu, 25 Jul 2024 10:05:41 +0200 Subject: [PATCH 08/14] Reactivate gong labs provider (#6499) * Reactivate gong * Lint --- .../assistant/labs/transcripts/index.tsx | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx b/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx index 6e853436f906..58c29abe3e38 100644 --- a/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx +++ b/front/pages/w/[wId]/assistant/labs/transcripts/index.tsx @@ -6,7 +6,6 @@ import { Page, SliderToggle, Spinner, - Tooltip, XMarkIcon, } from "@dust-tt/sparkle"; import type { SubscriptionType } from "@dust-tt/types"; @@ -443,21 +442,19 @@ export default function LabsTranscriptsIndex({ style={{ maxHeight: "35px" }} />
- -
handleProviderChange("gong")} - > - -
-
+
handleProviderChange("gong")} + > + +
)} From 1a36d2d800b7b36512a338bb074e2a7c6c3011ff Mon Sep 17 00:00:00 2001 From: Alban Dumouilla Date: Thu, 25 Jul 2024 10:17:37 +0200 Subject: [PATCH 09/14] Add gradle and xml files in github connector (#6439) --- connectors/src/connectors/github/lib/github_api.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/connectors/src/connectors/github/lib/github_api.ts b/connectors/src/connectors/github/lib/github_api.ts index c7d68441de09..706afd32f21e 100644 --- a/connectors/src/connectors/github/lib/github_api.ts +++ b/connectors/src/connectors/github/lib/github_api.ts @@ -584,6 +584,8 @@ const EXTENSION_WHITELIST = [ ".sql", ".kt", ".kts", + ".gradle", + ".xml", ]; const SUFFIX_BLACKLIST = [".min.js", ".min.css"]; From afddc25005809c71520beff13ed304f5111f86ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daphn=C3=A9=20Popin?= Date: Thu, 25 Jul 2024 10:20:28 +0200 Subject: [PATCH 10/14] Create group when we create a workspace + backfill the past (#6483) * Create group when we create a workspace + backfill the past * Also create Workspace group * Feedback --- front/admin/cli.ts | 13 +++++ front/lib/iam/workspaces.ts | 14 +++++ front/lib/resources/storage/models/groups.ts | 4 +- .../20240724_workspaces_groups_backfill.ts | 57 +++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 front/migrations/20240724_workspaces_groups_backfill.ts diff --git a/front/admin/cli.ts b/front/admin/cli.ts index 3e9e57e67d63..8a35bc1bc3b9 100644 --- a/front/admin/cli.ts +++ b/front/admin/cli.ts @@ -21,6 +21,7 @@ import { internalSubscribeWorkspaceToFreeNoPlan, internalSubscribeWorkspaceToFreePlan, } from "@app/lib/plans/subscription"; +import { GroupResource } from "@app/lib/resources/group_resource"; import { LabsTranscriptsConfigurationResource } from "@app/lib/resources/labs_transcripts_resource"; import { MembershipResource } from "@app/lib/resources/membership_resource"; import { generateLegacyModelSId } from "@app/lib/resources/string_ids"; @@ -43,6 +44,18 @@ const workspace = async (command: string, args: parseArgs.ParsedArgs) => { sId: generateLegacyModelSId(), name: args.name, }); + await Promise.all([ + GroupResource.makeNew({ + name: "System", + type: "system", + workspaceId: w.id, + }), + GroupResource.makeNew({ + name: "Workspace", + type: "workspace", + workspaceId: w.id, + }), + ]); args.wId = w.sId; await workspace("show", args); diff --git a/front/lib/iam/workspaces.ts b/front/lib/iam/workspaces.ts index 9f13c1ceae49..8a0a37bc8816 100644 --- a/front/lib/iam/workspaces.ts +++ b/front/lib/iam/workspaces.ts @@ -2,6 +2,7 @@ import { sendUserOperationMessage } from "@dust-tt/types"; import type { SessionWithUser } from "@app/lib/iam/provider"; import { Workspace, WorkspaceHasDomain } from "@app/lib/models/workspace"; +import { GroupResource } from "@app/lib/resources/group_resource"; import { generateLegacyModelSId } from "@app/lib/resources/string_ids"; import { isDisposableEmailDomain } from "@app/lib/utils/disposable_email_domains"; import logger from "@app/logger/logger"; @@ -22,6 +23,19 @@ export async function createWorkspace(session: SessionWithUser) { name: externalUser.nickname, }); + await Promise.all([ + GroupResource.makeNew({ + name: "System", + type: "system", + workspaceId: workspace.id, + }), + GroupResource.makeNew({ + name: "Workspace", + type: "workspace", + workspaceId: workspace.id, + }), + ]); + sendUserOperationMessage({ message: `<@U055XEGPR4L> +signupRadar User ${externalUser.email} has created a new workspace.`, logger, diff --git a/front/lib/resources/storage/models/groups.ts b/front/lib/resources/storage/models/groups.ts index 6d0e87764f14..585e554fbb64 100644 --- a/front/lib/resources/storage/models/groups.ts +++ b/front/lib/resources/storage/models/groups.ts @@ -72,7 +72,9 @@ GroupModel.addHook( }); if (existingSystemGroupType) { - throw new Error("A system group exists for this workspace."); + throw new Error("A system group exists for this workspace.", { + cause: "enforce_one_system_group_per_workspace", + }); } } } diff --git a/front/migrations/20240724_workspaces_groups_backfill.ts b/front/migrations/20240724_workspaces_groups_backfill.ts new file mode 100644 index 000000000000..d2d7d16b7ece --- /dev/null +++ b/front/migrations/20240724_workspaces_groups_backfill.ts @@ -0,0 +1,57 @@ +import _ from "lodash"; + +import { Workspace } from "@app/lib/models/workspace"; +import { GroupResource } from "@app/lib/resources/group_resource"; +import { makeScript } from "@app/scripts/helpers"; + +async function backfillWorkspacesGroup(execute: boolean) { + const workspaces = await Workspace.findAll(); + + const chunks = _.chunk(workspaces, 16); + for (const [i, c] of chunks.entries()) { + console.log( + `[execute=${execute}] Processing chunk of ${c.length} workspaces... (${ + i + 1 + }/${chunks.length})` + ); + if (execute) { + await Promise.all( + c.map((w) => + (async () => { + try { + await GroupResource.makeNew({ + name: "System", + type: "system", + workspaceId: w.id, + }); + await GroupResource.makeNew({ + name: "Workspace", + type: "workspace", + workspaceId: w.id, + }); + console.log(`System group created for workspace ${w.id}`); + } catch (error) { + if ( + error instanceof Error && + error.cause && + error.cause === "enforce_one_system_group_per_workspace" + ) { + console.log( + `System group already exists for workspace ${w.id}` + ); + } else { + console.error(error); + } + } + })() + ) + ); + } + } + + console.log(`Done.`); +} + +makeScript({}, async ({ execute }) => { + await backfillWorkspacesGroup(execute); +}); From 7265d5ea7bd29f270616e7d68969b3ade3bdaad0 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Thu, 25 Jul 2024 10:21:38 +0200 Subject: [PATCH 11/14] fix: whitelist prodbox in core network policy (#6501) Co-authored-by: Henry Fontanier --- k8s/network-policies/core-network-policy.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/k8s/network-policies/core-network-policy.yaml b/k8s/network-policies/core-network-policy.yaml index 675902ee6248..1c6f4a23f8ab 100644 --- a/k8s/network-policies/core-network-policy.yaml +++ b/k8s/network-policies/core-network-policy.yaml @@ -22,6 +22,9 @@ spec: - podSelector: matchLabels: app: core-sqlite-worker + - podSelector: + matchLabels: + app: prodbox ports: - protocol: TCP port: 3001 From 9caefabc56e744be6596be167f04b6ce98f188fc Mon Sep 17 00:00:00 2001 From: Thomas Draier Date: Thu, 25 Jul 2024 10:57:21 +0200 Subject: [PATCH 12/14] [microsoft] feature: Use delta on drive (#6487) --- connectors/src/connectors/microsoft/index.ts | 10 +- .../connectors/microsoft/lib/content_nodes.ts | 4 +- .../src/connectors/microsoft/lib/graph_api.ts | 51 +---- .../src/connectors/microsoft/lib/utils.ts | 59 ++++++ .../microsoft/temporal/activities.ts | 184 +++++++++--------- .../src/connectors/microsoft/temporal/file.ts | 24 +-- .../microsoft/temporal/workflows.ts | 31 +-- 7 files changed, 201 insertions(+), 162 deletions(-) create mode 100644 connectors/src/connectors/microsoft/lib/utils.ts diff --git a/connectors/src/connectors/microsoft/index.ts b/connectors/src/connectors/microsoft/index.ts index 98a55b065e18..0cbb225e5121 100644 --- a/connectors/src/connectors/microsoft/index.ts +++ b/connectors/src/connectors/microsoft/index.ts @@ -24,12 +24,14 @@ import { getFilesAndFolders, getSites, getTeams, - internalIdFromTypeAndPath, - typeAndPathFromInternalId, } from "@connectors/connectors/microsoft/lib/graph_api"; import type { MicrosoftNodeType } from "@connectors/connectors/microsoft/lib/types"; import { - getSiteNodesToSync, + internalIdFromTypeAndPath, + typeAndPathFromInternalId, +} from "@connectors/connectors/microsoft/lib/utils"; +import { + getRootNodesToSync, populateDeltas, } from "@connectors/connectors/microsoft/temporal/activities"; import { @@ -309,7 +311,7 @@ export class MicrosoftConnectorManager extends BaseConnectorManager { await MicrosoftRootResource.batchMakeNew(newResourcesBlobs); - const nodesToSync = await getSiteNodesToSync(this.connectorId); + const nodesToSync = await getRootNodesToSync(this.connectorId); // poupulates deltas for the nodes so that if incremental sync starts before // fullsync populated, there's no error diff --git a/connectors/src/connectors/microsoft/lib/content_nodes.ts b/connectors/src/connectors/microsoft/lib/content_nodes.ts index 0cd653a076a6..b34bb77d0173 100644 --- a/connectors/src/connectors/microsoft/lib/content_nodes.ts +++ b/connectors/src/connectors/microsoft/lib/content_nodes.ts @@ -4,9 +4,11 @@ import { getDriveInternalId, getDriveItemInternalId, getSiteAPIPath, +} from "@connectors/connectors/microsoft/lib/graph_api"; +import { internalIdFromTypeAndPath, typeAndPathFromInternalId, -} from "@connectors/connectors/microsoft/lib/graph_api"; +} from "@connectors/connectors/microsoft/lib/utils"; import type { MicrosoftNodeResource } from "@connectors/resources/microsoft_resource"; export function getRootNodes(): ContentNode[] { diff --git a/connectors/src/connectors/microsoft/lib/graph_api.ts b/connectors/src/connectors/microsoft/lib/graph_api.ts index bcb588fade04..57fec064ebba 100644 --- a/connectors/src/connectors/microsoft/lib/graph_api.ts +++ b/connectors/src/connectors/microsoft/lib/graph_api.ts @@ -2,9 +2,11 @@ import { assertNever } from "@dust-tt/types"; import type { Client } from "@microsoft/microsoft-graph-client"; import type * as MicrosoftGraph from "@microsoft/microsoft-graph-types"; -import type { MicrosoftNodeType } from "@connectors/connectors/microsoft/lib/types"; import type { MicrosoftNode } from "@connectors/connectors/microsoft/lib/types"; -import { isValidNodeType } from "@connectors/connectors/microsoft/lib/types"; +import { + internalIdFromTypeAndPath, + typeAndPathFromInternalId, +} from "@connectors/connectors/microsoft/lib/utils"; export async function getSites( client: Client, @@ -408,51 +410,6 @@ export function itemToMicrosoftNode( } } -export function internalIdFromTypeAndPath({ - nodeType, - itemAPIPath, -}: { - nodeType: MicrosoftNodeType; - itemAPIPath: string; -}): string { - let stringId = ""; - if (nodeType === "sites-root" || nodeType === "teams-root") { - stringId = nodeType; - } else { - stringId = `${nodeType}/${itemAPIPath}`; - } - // encode to base64url so the internal id is URL-friendly - return "microsoft-" + Buffer.from(stringId).toString("base64url"); -} - -export function typeAndPathFromInternalId(internalId: string): { - nodeType: MicrosoftNodeType; - itemAPIPath: string; -} { - if (!internalId.startsWith("microsoft-")) { - throw new Error(`Invalid internal id: ${internalId}`); - } - - // decode from base64url - const decodedId = Buffer.from( - internalId.slice("microsoft-".length), - "base64url" - ).toString(); - - if (decodedId === "sites-root" || decodedId === "teams-root") { - return { nodeType: decodedId, itemAPIPath: "" }; - } - - const [nodeType, ...resourcePathArr] = decodedId.split("/"); - if (!nodeType || !isValidNodeType(nodeType)) { - throw new Error( - `Invalid internal id: ${decodedId} with nodeType: ${nodeType}` - ); - } - - return { nodeType, itemAPIPath: resourcePathArr.join("/") }; -} - export function getDriveItemInternalId(item: MicrosoftGraph.DriveItem) { const { parentReference } = item; diff --git a/connectors/src/connectors/microsoft/lib/utils.ts b/connectors/src/connectors/microsoft/lib/utils.ts new file mode 100644 index 000000000000..5c891ed8423d --- /dev/null +++ b/connectors/src/connectors/microsoft/lib/utils.ts @@ -0,0 +1,59 @@ +import type { MicrosoftNodeType } from "./types"; +import { isValidNodeType } from "./types"; + +export function internalIdFromTypeAndPath({ + nodeType, + itemAPIPath, +}: { + nodeType: MicrosoftNodeType; + itemAPIPath: string; +}): string { + let stringId = ""; + if (nodeType === "sites-root" || nodeType === "teams-root") { + stringId = nodeType; + } else { + stringId = `${nodeType}/${itemAPIPath}`; + } + // encode to base64url so the internal id is URL-friendly + return "microsoft-" + Buffer.from(stringId).toString("base64url"); +} + +export function typeAndPathFromInternalId(internalId: string): { + nodeType: MicrosoftNodeType; + itemAPIPath: string; +} { + if (!internalId.startsWith("microsoft-")) { + throw new Error(`Invalid internal id: ${internalId}`); + } + + // decode from base64url + const decodedId = Buffer.from( + internalId.slice("microsoft-".length), + "base64url" + ).toString(); + + if (decodedId === "sites-root" || decodedId === "teams-root") { + return { nodeType: decodedId, itemAPIPath: "" }; + } + + const [nodeType, ...resourcePathArr] = decodedId.split("/"); + if (!nodeType || !isValidNodeType(nodeType)) { + throw new Error( + `Invalid internal id: ${decodedId} with nodeType: ${nodeType}` + ); + } + + return { nodeType, itemAPIPath: resourcePathArr.join("/") }; +} + +export function getDriveInternalIdFromItemId(itemId: string) { + const { itemAPIPath } = typeAndPathFromInternalId(itemId); + if (!itemAPIPath.startsWith("/drives/")) { + throw new Error("Unexpected: no drive id for item"); + } + const parts = itemAPIPath.split("/"); + return internalIdFromTypeAndPath({ + nodeType: "drive", + itemAPIPath: `/drives/${parts[2]}`, + }); +} diff --git a/connectors/src/connectors/microsoft/temporal/activities.ts b/connectors/src/connectors/microsoft/temporal/activities.ts index 34a6c54e9d2f..62afc1870f7e 100644 --- a/connectors/src/connectors/microsoft/temporal/activities.ts +++ b/connectors/src/connectors/microsoft/temporal/activities.ts @@ -18,11 +18,14 @@ import { getParentReferenceInternalId, getSiteAPIPath, getSites, - internalIdFromTypeAndPath, itemToMicrosoftNode, - typeAndPathFromInternalId, } from "@connectors/connectors/microsoft/lib/graph_api"; import type { MicrosoftNode } from "@connectors/connectors/microsoft/lib/types"; +import { + getDriveInternalIdFromItemId, + internalIdFromTypeAndPath, + typeAndPathFromInternalId, +} from "@connectors/connectors/microsoft/lib/utils"; import { deleteFile, deleteFolder, @@ -47,7 +50,7 @@ import type { DataSourceConfig } from "@connectors/types/data_source_config"; const FILES_SYNC_CONCURRENCY = 10; const DELETE_CONCURRENCY = 5; -export async function getSiteNodesToSync( +export async function getRootNodesToSync( connectorId: ModelId ): Promise { const connector = await ConnectorResource.fetchById(connectorId); @@ -158,7 +161,22 @@ export async function getSiteNodesToSync( return nodeResources.map((r) => r.internalId); } +export async function groupRootItemsByDriveId(nodeIds: string[]) { + const itemsWithDrive = nodeIds.map((id) => ({ + drive: getDriveInternalIdFromItemId(id), + folder: id, + })); + return itemsWithDrive.reduce( + (acc, current) => ({ + ...acc, + [current.drive]: [...(acc[current.drive] || []), current.folder], + }), + {} as { [key: string]: string[] } + ); +} + export async function populateDeltas(connectorId: ModelId, nodeIds: string[]) { + const groupedItems = groupRootItemsByDriveId(nodeIds); const connector = await ConnectorResource.fetchById(connectorId); if (!connector) { @@ -167,22 +185,25 @@ export async function populateDeltas(connectorId: ModelId, nodeIds: string[]) { const client = await getClient(connector.connectionId); - for (const nodeId of nodeIds) { - const node = await MicrosoftNodeResource.fetchByInternalId( - connectorId, - nodeId - ); - - if (!node) { - throw new Error(`Node ${nodeId} not found`); - } + for (const [driveId, nodeIds] of Object.entries(groupedItems)) { const { deltaLink } = await getDeltaResults({ client, - parentInternalId: nodeId, + parentInternalId: driveId, token: "latest", }); - await node.update({ deltaLink }); + for (const nodeId of nodeIds) { + const node = await MicrosoftNodeResource.fetchByInternalId( + connectorId, + nodeId + ); + + if (!node) { + throw new Error(`Node ${nodeId} not found`); + } + + await node.update({ deltaLink }); + } } } @@ -369,7 +390,6 @@ export async function syncFiles({ }); const alreadySeenResources = Object.values(alreadySeenResourcesById); - const createdOrUpdatedResources = await MicrosoftNodeResource.batchUpdateOrCreate( connectorId, @@ -400,13 +420,15 @@ export async function syncFiles({ }; } -export async function syncDeltaForRootNode({ +export async function syncDeltaForRootNodesInDrive({ connectorId, - rootNodeId, + driveId, + rootNodeIds, startSyncTs, }: { connectorId: ModelId; - rootNodeId: string; + driveId: string; + rootNodeIds: string[]; startSyncTs: number; }) { const connector = await ConnectorResource.fetchById(connectorId); @@ -423,24 +445,30 @@ export async function syncDeltaForRootNode({ const dataSourceConfig = dataSourceConfigFromConnector(connector); - const { nodeType } = typeAndPathFromInternalId(rootNodeId); + const nodeTypes = rootNodeIds.map( + (nodeId) => typeAndPathFromInternalId(nodeId).nodeType + ); - if (nodeType !== "drive" && nodeType !== "folder") { - throw new Error(`Node ${rootNodeId} is not a drive or folder`); + if ( + nodeTypes.some((nodeType) => nodeType !== "drive" && nodeType !== "folder") + ) { + throw new Error(`Some of ${rootNodeIds} are not a drive or folder`); } - const node = await MicrosoftNodeResource.fetchByInternalId( + const nodes = await MicrosoftNodeResource.fetchByInternalIds( connectorId, - rootNodeId + rootNodeIds ); - if (!node) { - throw new Error(`Root or node resource ${rootNodeId} not found`); + const node = nodes[0]; + + if (nodes.length !== rootNodeIds.length || !node) { + throw new Error(`Root or node resource ${nodes} not found`); } const client = await getClient(connector.connectionId); - logger.info({ connectorId, node }, "Syncing delta for node"); + logger.info({ connectorId, rootNodeIds }, "Syncing delta for node"); // Goes through pagination to return all delta results. This is because delta // list can include same item more than once and api recommendation is to @@ -458,11 +486,28 @@ export async function syncDeltaForRootNode({ node, }); const uniqueChangedItems = removeAllButLastOccurences(results); - const sortedChangedItems = sortForIncrementalUpdate(uniqueChangedItems); + + const microsoftNodes = await concurrentExecutor( + rootNodeIds, + async (rootNodeId) => + getItem(client, typeAndPathFromInternalId(rootNodeId).itemAPIPath), + { concurrency: 5 } + ); + + const sortedChangedItems = microsoftNodes.flatMap((rootNode) => + sortForIncrementalUpdate(uniqueChangedItems, rootNode.id) + ); + + // Finally add all removed items, which may not have been included even if they are in the selected roots + sortedChangedItems.push( + ...uniqueChangedItems.filter( + (item) => + !sortedChangedItems.includes(item) && item.deleted?.state === "deleted" + ) + ); for (const driveItem of sortedChangedItems) { heartbeat(); - if (!driveItem.parentReference) { throw new Error(`Unexpected: parent reference missing: ${driveItem}`); } @@ -471,17 +516,7 @@ export async function syncDeltaForRootNode({ if (driveItem.file) { if (driveItem.deleted) { - // if file was just moved from a toplevel folder to another in the same drive, it's marked - // as deleted but we don't want to delete it - // internally means "in the same Drive" here - if ( - !(await isFileMovedInSameDrive({ - toplevelNode: node, - fileInternalId: internalId, - })) - ) { - await deleteFile({ connectorId, internalId, dataSourceConfig }); - } + await deleteFile({ connectorId, internalId, dataSourceConfig }); } else { await syncOneFile({ connectorId, @@ -512,8 +547,10 @@ export async function syncDeltaForRootNode({ // add parent information to new node resource. for the toplevel folder, // parent is null + // todo check filter const parentInternalId = - resource.internalId === rootNodeId + resource.internalId === driveId || + rootNodeIds.indexOf(resource.internalId) !== -1 ? null : getParentReferenceInternalId(driveItem.parentReference); @@ -535,12 +572,13 @@ export async function syncDeltaForRootNode({ } } - await node.update({ deltaLink }); - - logger.info( - { connectorId, nodeId: node.internalId, name: node.name }, - "Delta sync complete" + await concurrentExecutor( + nodes, + (node) => node && node.update({ deltaLink }), + { concurrency: 5 } ); + + logger.info({ connectorId, driveId, rootNodeIds }, "Delta sync complete"); } /** @@ -565,37 +603,6 @@ function removeAllButLastOccurences(deltaList: microsoftgraph.DriveItem[]) { return resultList; } -/** - * This function checks whether a file marked as deleted from a toplevel folder - * is actually just moved to another toplevel folder in the same drive (in which - * case we should not delete it) - * - * Note: this concerns toplevel folders, not drives; it's fine to delete files - * that move from a drive to another because they change id - */ -async function isFileMovedInSameDrive({ - toplevelNode, - fileInternalId, -}: { - toplevelNode: MicrosoftNodeResource; - fileInternalId: string; -}) { - if (toplevelNode.nodeType === "drive") { - // if the toplevel node is a drive, then the deletion must happen - return false; - } - // check that the file's parents array does not contain the toplevel folder, in - // which case it's a file movement; otherwise it's a file deletion - return !( - await getParents({ - connectorId: toplevelNode.connectorId, - internalId: fileInternalId, - parentInternalId: toplevelNode.internalId, - startSyncTs: new Date().getTime(), - }) - ).includes(toplevelNode.internalId); -} - /** * Order items as follows: * - first those whose parentInternalId is not in the changedList, or the root drive @@ -609,29 +616,29 @@ async function isFileMovedInSameDrive({ * The function makes the assumption that there is no circular parent * relationship */ -function sortForIncrementalUpdate(changedList: DriveItem[]) { +function sortForIncrementalUpdate(changedList: DriveItem[], rootId: string) { if (changedList.length === 0) { return []; } const internalIds = changedList.map((item) => getDriveItemInternalId(item)); - const sortedDriveItemList = changedList.filter((item) => { - if (!item.parentReference) { + const sortedItemList = changedList.filter((item) => { + if (item.id === rootId) { return true; } - if (item.root) { - return true; + if (!item.parentReference) { + return false; } const parentInternalId = getParentReferenceInternalId(item.parentReference); return !internalIds.includes(parentInternalId); }); - while (sortedDriveItemList.length < changedList.length) { + for (;;) { const nextLevel = changedList.filter((item) => { - if (sortedDriveItemList.includes(item)) { + if (sortedItemList.includes(item)) { return false; } @@ -643,15 +650,18 @@ function sortForIncrementalUpdate(changedList: DriveItem[]) { const parentInternalId = getParentReferenceInternalId( item.parentReference ); - return sortedDriveItemList.some( + + return sortedItemList.some( (sortedItem) => getDriveItemInternalId(sortedItem) === parentInternalId ); }); - sortedDriveItemList.push(...nextLevel); - } + if (nextLevel.length === 0) { + return sortedItemList; + } - return sortedDriveItemList; + sortedItemList.push(...nextLevel); + } } async function getDeltaData({ diff --git a/connectors/src/connectors/microsoft/temporal/file.ts b/connectors/src/connectors/microsoft/temporal/file.ts index 6b9230ffc57d..e7852f48007f 100644 --- a/connectors/src/connectors/microsoft/temporal/file.ts +++ b/connectors/src/connectors/microsoft/temporal/file.ts @@ -21,8 +21,8 @@ import { getClient } from "@connectors/connectors/microsoft"; import { getDriveItemInternalId, getFileDownloadURL, - typeAndPathFromInternalId, } from "@connectors/connectors/microsoft/lib/graph_api"; +import { typeAndPathFromInternalId } from "@connectors/connectors/microsoft/lib/utils"; import { getMimeTypesToSync } from "@connectors/connectors/microsoft/temporal/mime_types"; import { deleteAllSheets, @@ -363,15 +363,17 @@ export async function getParents({ startSyncTs ); - return [ - internalId, - ...(await getParents({ - connectorId, - internalId: parentInternalId, - parentInternalId: parentParentInternalId, - startSyncTs, - })), - ]; + return parentParentInternalId + ? [ + internalId, + ...(await getParents({ + connectorId, + internalId: parentInternalId, + parentInternalId: parentParentInternalId, + startSyncTs, + })), + ] + : [internalId]; } /* Fetching parent's parent id queries the db for a resource; since those @@ -385,7 +387,7 @@ const getParentParentId = cacheWithRedis( parentInternalId ); if (!parent) { - throw new Error(`Parent node not found: ${parentInternalId}`); + return ""; } return parent.parentInternalId; diff --git a/connectors/src/connectors/microsoft/temporal/workflows.ts b/connectors/src/connectors/microsoft/temporal/workflows.ts index c0bd9931c50a..c37b988d3a56 100644 --- a/connectors/src/connectors/microsoft/temporal/workflows.ts +++ b/connectors/src/connectors/microsoft/temporal/workflows.ts @@ -9,18 +9,21 @@ import { import type * as activities from "@connectors/connectors/microsoft/temporal/activities"; import type * as sync_status from "@connectors/lib/sync_status"; -const { getSiteNodesToSync, syncFiles, markNodeAsSeen, populateDeltas } = - proxyActivities({ - startToCloseTimeout: "30 minutes", - }); +const { + getRootNodesToSync, + syncFiles, + markNodeAsSeen, + populateDeltas, + groupRootItemsByDriveId, +} = proxyActivities({ + startToCloseTimeout: "30 minutes", +}); const { microsoftDeletionActivity } = proxyActivities({ startToCloseTimeout: "15 minutes", }); -const { syncDeltaForRootNode: syncDeltaForNode } = proxyActivities< - typeof activities ->({ +const { syncDeltaForRootNodesInDrive } = proxyActivities({ startToCloseTimeout: "120 minutes", heartbeatTimeout: "5 minutes", }); @@ -43,7 +46,7 @@ export async function fullSyncWorkflow({ totalCount?: number; }) { if (nodeIdsToSync === undefined) { - nodeIdsToSync = await getSiteNodesToSync(connectorId); + nodeIdsToSync = await getRootNodesToSync(connectorId); } if (startSyncTs === undefined) { @@ -98,12 +101,16 @@ export async function incrementalSyncWorkflow({ }: { connectorId: ModelId; }) { - const nodeIdsToSync = await getSiteNodesToSync(connectorId); + const nodeIdsToSync = await getRootNodesToSync(connectorId); + + const groupedItems = await groupRootItemsByDriveId(nodeIdsToSync); + const startSyncTs = new Date().getTime(); - for (const nodeId of nodeIdsToSync) { - await syncDeltaForNode({ + for (const nodeId of Object.keys(groupedItems)) { + await syncDeltaForRootNodesInDrive({ connectorId, - rootNodeId: nodeId, + driveId: nodeId, + rootNodeIds: groupedItems[nodeId] as string[], startSyncTs, }); } From e0992ce09e46ceefe33a8e7c8c2c7901b0c282a2 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Thu, 25 Jul 2024 11:13:13 +0200 Subject: [PATCH 13/14] fix: actually yield the viz success event (#6502) Co-authored-by: Henry Fontanier --- front/lib/api/assistant/actions/visualization.ts | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/front/lib/api/assistant/actions/visualization.ts b/front/lib/api/assistant/actions/visualization.ts index a3c7f30714ac..1ac26cec0aa0 100644 --- a/front/lib/api/assistant/actions/visualization.ts +++ b/front/lib/api/assistant/actions/visualization.ts @@ -370,6 +370,21 @@ export class VisualizationConfigurationServerRunner extends BaseActionConfigurat } } + yield { + type: "visualization_success", + created: Date.now(), + configurationId: agentConfiguration.sId, + messageId: agentMessage.sId, + action: new VisualizationAction({ + id: action.id, + agentMessageId: action.agentMessageId, + generation, + functionCallId: action.functionCallId, + functionCallName: action.functionCallName, + step: action.step, + }), + }; + logger.info( { workspaceId: conversation.owner.sId, From 83edeb40481478a21bfeb0b26ebf04e441ceb203 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Thu, 25 Jul 2024 11:14:52 +0200 Subject: [PATCH 14/14] enh(viz): use claude 3.5 sonnet unless forbidden (#6485) Co-authored-by: Henry Fontanier --- front/lib/api/assistant/actions/visualization.ts | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/front/lib/api/assistant/actions/visualization.ts b/front/lib/api/assistant/actions/visualization.ts index 1ac26cec0aa0..30e313652ea0 100644 --- a/front/lib/api/assistant/actions/visualization.ts +++ b/front/lib/api/assistant/actions/visualization.ts @@ -14,8 +14,10 @@ import type { } from "@dust-tt/types"; import { BaseAction, + CLAUDE_3_5_SONNET_DEFAULT_MODEL_CONFIG, cloneBaseConfig, DustProdActionRegistry, + isProviderWhitelisted, Ok, VisualizationActionOutputSchema, } from "@dust-tt/types"; @@ -245,10 +247,19 @@ export class VisualizationConfigurationServerRunner extends BaseActionConfigurat const config = cloneBaseConfig( DustProdActionRegistry["assistant-v2-visualization"].config ); - const model = agentConfiguration.model; + + // If we can use Sonnet 3.5, we use it. + // Otherwise, we use the model from the agent configuration. + const model = + auth.isUpgraded() && isProviderWhitelisted(owner, "anthropic") + ? CLAUDE_3_5_SONNET_DEFAULT_MODEL_CONFIG + : agentConfiguration.model; + config.MODEL.provider_id = model.providerId; config.MODEL.model_id = model.modelId; - config.MODEL.temperature = model.temperature; + + // Preserve the temperature from the agent configuration. + config.MODEL.temperature = agentConfiguration.model.temperature; // Execute the Vizualization Dust App. const visualizationRes = await runActionStreamed(