From efab81f6eabad2f9a091e97cbca0d6b89efd74e2 Mon Sep 17 00:00:00 2001 From: Frank Elsinga Date: Mon, 26 Feb 2024 22:28:34 +0100 Subject: [PATCH] added vector indexing --- server/Cargo.toml | 3 +- server/src/main.rs | 4 +- server/src/search/search_executor/mod.rs | 8 +-- server/src/search/search_executor/query.rs | 2 + server/src/setup/meilisearch.rs | 61 ++++++++++++++++++---- server/src/setup/tests.rs | 10 ++-- 6 files changed, 68 insertions(+), 20 deletions(-) diff --git a/server/Cargo.toml b/server/Cargo.toml index 5de5c37f8..bda9b7f60 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -47,7 +47,8 @@ sqlx = { version = "0.8.0", features = ['chrono', 'json', 'macros', 'migrate', ' chrono = { version = "0.4.38", default-features = false, features = ["serde"] } # search -meilisearch-sdk = "0.27.1" +# IFF (open for 4 Months now) meilisearch merges my PR, change back to the main release +meilisearch = { git = "https://github.com/commanderstorm/meilisearch-rust.git", branch = "vector-search-embedder" } logos = "0.14.1" regex = "1.10.6" diff --git a/server/src/main.rs b/server/src/main.rs index 9457b8d4e..70517a920 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -148,8 +148,8 @@ async fn run_maintenance_work( let ms_url = std::env::var("MIELI_URL").unwrap_or_else(|_| "http://localhost:7700".to_string()); let client = Client::new(ms_url, std::env::var("MEILI_MASTER_KEY").ok()).unwrap(); - setup::meilisearch::setup(&client).await.unwrap(); - setup::meilisearch::load_data(&client).await.unwrap(); + setup::meilisearch::setup(&client, true).await.unwrap(); + setup::meilisearch::load_data(&client, true).await.unwrap(); } else { info!("skipping the database setup as SKIP_MS_SETUP=true"); initalisation_started.wait().await; diff --git a/server/src/search/search_executor/mod.rs b/server/src/search/search_executor/mod.rs index c6b5689e6..6fdcb1a9f 100644 --- a/server/src/search/search_executor/mod.rs +++ b/server/src/search/search_executor/mod.rs @@ -130,8 +130,8 @@ mod test { #[tokio::test] #[tracing_test::traced_test] async fn test_good_queries() { - let ms = MeiliSearchTestContainer::new().await; - crate::setup::meilisearch::load_data(&ms.client) + let ms = MeiliSearchTestContainer::new(true).await; + crate::setup::meilisearch::load_data(&ms.client, true) .await .unwrap(); for query in TestQuery::load_good() { @@ -155,8 +155,8 @@ mod test { #[tokio::test] #[tracing_test::traced_test] async fn test_bad_queries() { - let ms = MeiliSearchTestContainer::new().await; - crate::setup::meilisearch::load_data(&ms.client) + let ms = MeiliSearchTestContainer::new(true).await; + crate::setup::meilisearch::load_data(&ms.client, true) .await .unwrap(); for query in TestQuery::load_bad() { diff --git a/server/src/search/search_executor/query.rs b/server/src/search/search_executor/query.rs index e205eb134..552ef27b8 100644 --- a/server/src/search/search_executor/query.rs +++ b/server/src/search/search_executor/query.rs @@ -183,6 +183,7 @@ impl GeoEntryQuery { .with_highlight_pre_tag(&self.highlighting.pre) .with_highlight_post_tag(&self.highlighting.post) .with_attributes_to_highlight(Selectors::Some(&["name"])) + .with_hybrid("default", 0.1) .build() } @@ -211,6 +212,7 @@ impl GeoEntryQuery { .with_query(query) .with_limit(2 * self.limits.buildings_count) // we might do reordering later .with_filter(&self.filters.buildings) + .with_hybrid("default", 0.1) .build() } diff --git a/server/src/setup/meilisearch.rs b/server/src/setup/meilisearch.rs index 0c2151f6e..f4a0f7a09 100644 --- a/server/src/setup/meilisearch.rs +++ b/server/src/setup/meilisearch.rs @@ -1,13 +1,13 @@ -use std::collections::HashMap; -use std::time::Duration; - use meilisearch_sdk::client::Client; -use meilisearch_sdk::settings::Settings; +use meilisearch_sdk::settings::{Embedder, HuggingFaceEmbedderSettings, Settings}; use meilisearch_sdk::tasks::Task; use serde_json::Value; +use std::collections::HashMap; +use std::time::Duration; use tracing::{debug, error, info}; const TIMEOUT: Option = Some(Duration::from_secs(20)); +const TIMEOUT_SETUP: Option = Some(Duration::from_secs(10 * 60)); const POLLING_RATE: Option = Some(Duration::from_millis(250)); #[derive(serde::Deserialize)] @@ -18,6 +18,7 @@ impl Synonyms { serde_yaml::from_str(include_str!("search_synonyms.yaml")) } } + #[tracing::instrument(skip(client))] async fn wait_for_healthy(client: &Client) { let mut counter = 0; @@ -43,11 +44,21 @@ async fn wait_for_healthy(client: &Client) { tokio::time::sleep(Duration::from_secs(1)).await; } } + #[tracing::instrument(skip(client))] -pub async fn setup(client: &Client) -> anyhow::Result<()> { +pub async fn setup(client: &Client, vector_search: bool) -> anyhow::Result<()> { debug!("waiting for Meilisearch to be healthy"); wait_for_healthy(client).await; info!("Meilisearch is healthy"); + meilisearch_sdk::features::ExperimentalFeatures::new(client) + .set_vector_store(true) + .update() + .await?; + + meilisearch_sdk::features::ExperimentalFeatures::new(client) + .set_vector_store(true) + .update() + .await?; client .create_index("entries", Some("ms_id")) @@ -55,8 +66,18 @@ pub async fn setup(client: &Client) -> anyhow::Result<()> { .wait_for_completion(client, POLLING_RATE, TIMEOUT) .await?; let entries = client.index("entries"); + let en_embedder =Embedder::HuggingFace(HuggingFaceEmbedderSettings{ + model: Some("BAAI/bge-base-en-v1.5".to_string()), + document_template: Some("A room titled '{{doc.name}}' with type '{{doc.type_common_name}}' used as '{{doc.usage}}'".to_string()), + ..Default::default() + }); + let _de_embedder=Embedder::HuggingFace(HuggingFaceEmbedderSettings{ + model: Some("google-bert/bert-base-german-cased".to_string()), + document_template: Some("Ein Raum '{{doc.name}}' vom typ '{{doc.type_common_name}}' benutzt als '{{doc.usage}}'".to_string()), + ..Default::default() + }); - let settings = Settings::new() + let mut settings = Settings::new() .with_filterable_attributes([ "facet", "parent_keywords", @@ -89,18 +110,32 @@ pub async fn setup(client: &Client) -> anyhow::Result<()> { ]) .with_synonyms(Synonyms::try_load()?.0); + if vector_search { + settings = settings.with_embedders(HashMap::from([("default", en_embedder)])) + } + let res = entries .set_settings(&settings) .await? - .wait_for_completion(client, POLLING_RATE, TIMEOUT) + .wait_for_completion( + client, + POLLING_RATE, + if vector_search { + TIMEOUT_SETUP + } else { + TIMEOUT + }, + ) .await?; if let Task::Failed { content } = res { panic!("Failed to add settings to Meilisearch: {content:?}"); } + Ok(()) } + #[tracing::instrument(skip(client))] -pub async fn load_data(client: &Client) -> anyhow::Result<()> { +pub async fn load_data(client: &Client, vector_search: bool) -> anyhow::Result<()> { let entries = client.index("entries"); let cdn_url = std::env::var("CDN_URL").unwrap_or_else(|_| "https://nav.tum.de/cdn".to_string()); let documents = reqwest::get(format!("{cdn_url}/search_data.json")) @@ -111,7 +146,15 @@ pub async fn load_data(client: &Client) -> anyhow::Result<()> { let res = entries .add_documents(&documents, Some("ms_id")) .await? - .wait_for_completion(client, POLLING_RATE, TIMEOUT) + .wait_for_completion( + client, + POLLING_RATE, + if vector_search { + TIMEOUT_SETUP + } else { + TIMEOUT + }, + ) .await?; if let Task::Failed { content } = res { panic!("Failed to add documents to Meilisearch: {content:?}"); diff --git a/server/src/setup/tests.rs b/server/src/setup/tests.rs index 1decd930e..fe559b155 100644 --- a/server/src/setup/tests.rs +++ b/server/src/setup/tests.rs @@ -40,7 +40,7 @@ pub struct MeiliSearchTestContainer { impl MeiliSearchTestContainer { /// Create a meilisearch instance for testing against - pub async fn new() -> Self { + pub async fn new(vector_search: bool) -> Self { let container = meilisearch::Meilisearch::default() .with_tag("v1.9.0") .start() @@ -53,7 +53,9 @@ impl MeiliSearchTestContainer { ); let client = Client::new(meili_url.clone(), None::).unwrap(); - super::meilisearch::setup(&client).await.unwrap(); + super::meilisearch::setup(&client, vector_search) + .await + .unwrap(); Self { _container: container, client, @@ -80,8 +82,8 @@ async fn test_db_setup() { #[tokio::test] #[tracing_test::traced_test] async fn test_meilisearch_setup() { - let ms = MeiliSearchTestContainer::new().await; - crate::setup::meilisearch::load_data(&ms.client) + let ms = MeiliSearchTestContainer::new(false).await; + crate::setup::meilisearch::load_data(&ms.client, false) .await .unwrap(); }