Skip to content

Commit

Permalink
added vector indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
CommanderStorm committed Sep 1, 2024
1 parent e61d822 commit efab81f
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 20 deletions.
3 changes: 2 additions & 1 deletion server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ sqlx = { version = "0.8.0", features = ['chrono', 'json', 'macros', 'migrate', '
chrono = { version = "0.4.38", default-features = false, features = ["serde"] }

# search
meilisearch-sdk = "0.27.1"
# IFF (open for 4 Months now) meilisearch merges my PR, change back to the main release
meilisearch = { git = "https://github.com/commanderstorm/meilisearch-rust.git", branch = "vector-search-embedder" }
logos = "0.14.1"
regex = "1.10.6"

Expand Down
4 changes: 2 additions & 2 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ async fn run_maintenance_work(
let ms_url =
std::env::var("MIELI_URL").unwrap_or_else(|_| "http://localhost:7700".to_string());
let client = Client::new(ms_url, std::env::var("MEILI_MASTER_KEY").ok()).unwrap();
setup::meilisearch::setup(&client).await.unwrap();
setup::meilisearch::load_data(&client).await.unwrap();
setup::meilisearch::setup(&client, true).await.unwrap();
setup::meilisearch::load_data(&client, true).await.unwrap();
} else {
info!("skipping the database setup as SKIP_MS_SETUP=true");
initalisation_started.wait().await;
Expand Down
8 changes: 4 additions & 4 deletions server/src/search/search_executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ mod test {
#[tokio::test]
#[tracing_test::traced_test]
async fn test_good_queries() {
let ms = MeiliSearchTestContainer::new().await;
crate::setup::meilisearch::load_data(&ms.client)
let ms = MeiliSearchTestContainer::new(true).await;
crate::setup::meilisearch::load_data(&ms.client, true)
.await
.unwrap();
for query in TestQuery::load_good() {
Expand All @@ -155,8 +155,8 @@ mod test {
#[tokio::test]
#[tracing_test::traced_test]
async fn test_bad_queries() {
let ms = MeiliSearchTestContainer::new().await;
crate::setup::meilisearch::load_data(&ms.client)
let ms = MeiliSearchTestContainer::new(true).await;
crate::setup::meilisearch::load_data(&ms.client, true)
.await
.unwrap();
for query in TestQuery::load_bad() {
Expand Down
2 changes: 2 additions & 0 deletions server/src/search/search_executor/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ impl GeoEntryQuery {
.with_highlight_pre_tag(&self.highlighting.pre)
.with_highlight_post_tag(&self.highlighting.post)
.with_attributes_to_highlight(Selectors::Some(&["name"]))
.with_hybrid("default", 0.1)
.build()
}

Expand Down Expand Up @@ -211,6 +212,7 @@ impl GeoEntryQuery {
.with_query(query)
.with_limit(2 * self.limits.buildings_count) // we might do reordering later
.with_filter(&self.filters.buildings)
.with_hybrid("default", 0.1)
.build()
}

Expand Down
61 changes: 52 additions & 9 deletions server/src/setup/meilisearch.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::collections::HashMap;
use std::time::Duration;

use meilisearch_sdk::client::Client;
use meilisearch_sdk::settings::Settings;
use meilisearch_sdk::settings::{Embedder, HuggingFaceEmbedderSettings, Settings};
use meilisearch_sdk::tasks::Task;
use serde_json::Value;
use std::collections::HashMap;
use std::time::Duration;
use tracing::{debug, error, info};

const TIMEOUT: Option<Duration> = Some(Duration::from_secs(20));
const TIMEOUT_SETUP: Option<Duration> = Some(Duration::from_secs(10 * 60));
const POLLING_RATE: Option<Duration> = Some(Duration::from_millis(250));

#[derive(serde::Deserialize)]
Expand All @@ -18,6 +18,7 @@ impl Synonyms {
serde_yaml::from_str(include_str!("search_synonyms.yaml"))
}
}

#[tracing::instrument(skip(client))]
async fn wait_for_healthy(client: &Client) {
let mut counter = 0;
Expand All @@ -43,20 +44,40 @@ async fn wait_for_healthy(client: &Client) {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}

#[tracing::instrument(skip(client))]
pub async fn setup(client: &Client) -> anyhow::Result<()> {
pub async fn setup(client: &Client, vector_search: bool) -> anyhow::Result<()> {
debug!("waiting for Meilisearch to be healthy");
wait_for_healthy(client).await;
info!("Meilisearch is healthy");
meilisearch_sdk::features::ExperimentalFeatures::new(client)
.set_vector_store(true)
.update()
.await?;

meilisearch_sdk::features::ExperimentalFeatures::new(client)
.set_vector_store(true)
.update()
.await?;

client
.create_index("entries", Some("ms_id"))
.await?
.wait_for_completion(client, POLLING_RATE, TIMEOUT)
.await?;
let entries = client.index("entries");
let en_embedder =Embedder::HuggingFace(HuggingFaceEmbedderSettings{
model: Some("BAAI/bge-base-en-v1.5".to_string()),
document_template: Some("A room titled '{{doc.name}}' with type '{{doc.type_common_name}}' used as '{{doc.usage}}'".to_string()),
..Default::default()
});
let _de_embedder=Embedder::HuggingFace(HuggingFaceEmbedderSettings{
model: Some("google-bert/bert-base-german-cased".to_string()),
document_template: Some("Ein Raum '{{doc.name}}' vom typ '{{doc.type_common_name}}' benutzt als '{{doc.usage}}'".to_string()),
..Default::default()
});

let settings = Settings::new()
let mut settings = Settings::new()
.with_filterable_attributes([
"facet",
"parent_keywords",
Expand Down Expand Up @@ -89,18 +110,32 @@ pub async fn setup(client: &Client) -> anyhow::Result<()> {
])
.with_synonyms(Synonyms::try_load()?.0);

if vector_search {
settings = settings.with_embedders(HashMap::from([("default", en_embedder)]))
}

let res = entries
.set_settings(&settings)
.await?
.wait_for_completion(client, POLLING_RATE, TIMEOUT)
.wait_for_completion(
client,
POLLING_RATE,
if vector_search {
TIMEOUT_SETUP
} else {
TIMEOUT
},
)
.await?;
if let Task::Failed { content } = res {
panic!("Failed to add settings to Meilisearch: {content:?}");
}

Ok(())
}

#[tracing::instrument(skip(client))]
pub async fn load_data(client: &Client) -> anyhow::Result<()> {
pub async fn load_data(client: &Client, vector_search: bool) -> anyhow::Result<()> {
let entries = client.index("entries");
let cdn_url = std::env::var("CDN_URL").unwrap_or_else(|_| "https://nav.tum.de/cdn".to_string());
let documents = reqwest::get(format!("{cdn_url}/search_data.json"))
Expand All @@ -111,7 +146,15 @@ pub async fn load_data(client: &Client) -> anyhow::Result<()> {
let res = entries
.add_documents(&documents, Some("ms_id"))
.await?
.wait_for_completion(client, POLLING_RATE, TIMEOUT)
.wait_for_completion(
client,
POLLING_RATE,
if vector_search {
TIMEOUT_SETUP
} else {
TIMEOUT
},
)
.await?;
if let Task::Failed { content } = res {
panic!("Failed to add documents to Meilisearch: {content:?}");
Expand Down
10 changes: 6 additions & 4 deletions server/src/setup/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub struct MeiliSearchTestContainer {

impl MeiliSearchTestContainer {
/// Create a meilisearch instance for testing against
pub async fn new() -> Self {
pub async fn new(vector_search: bool) -> Self {
let container = meilisearch::Meilisearch::default()
.with_tag("v1.9.0")
.start()
Expand All @@ -53,7 +53,9 @@ impl MeiliSearchTestContainer {
);

let client = Client::new(meili_url.clone(), None::<String>).unwrap();
super::meilisearch::setup(&client).await.unwrap();
super::meilisearch::setup(&client, vector_search)
.await
.unwrap();
Self {
_container: container,
client,
Expand All @@ -80,8 +82,8 @@ async fn test_db_setup() {
#[tokio::test]
#[tracing_test::traced_test]
async fn test_meilisearch_setup() {
let ms = MeiliSearchTestContainer::new().await;
crate::setup::meilisearch::load_data(&ms.client)
let ms = MeiliSearchTestContainer::new(false).await;
crate::setup::meilisearch::load_data(&ms.client, false)
.await
.unwrap();
}

0 comments on commit efab81f

Please sign in to comment.