Skip to content

Commit

Permalink
chore: formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CommanderStorm committed Jul 8, 2024
1 parent 980e714 commit cbac495
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
29 changes: 20 additions & 9 deletions src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ pub struct SearchQuery<'a, Http: HttpClient> {
///
/// **Default: `false`**
#[serde(skip_serializing_if = "Option::is_none")]
retrieve_vectors: Option<bool>
retrieve_vectors: Option<bool>,
}

#[allow(missing_docs)]
Expand Down Expand Up @@ -734,18 +734,24 @@ mod tests {
_vectors: Option<Vectors>,
}


#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Vector {
embeddings: Vec<Vec<f32>>,
regenerate: bool,
}

#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct Vectors(HashMap<String, Vector>);

impl From<&[f32; 1]> for Vectors {
fn from(value: &[f32;1]) -> Self {
Vectors(HashMap::from([(S("default"), Vector { embeddings: Vec::from([value.to_vec()]), regenerate:false })]))
fn from(value: &[f32; 1]) -> Self {
Vectors(HashMap::from([(
S("default"),
Vector {
embeddings: Vec::from([value.to_vec()]),
regenerate: false,
},
)]))
}
}

Expand Down Expand Up @@ -1376,7 +1382,8 @@ mod tests {
.await
.expect("could not enable the vector store");
assert_eq!(features.vector_store, true);
let embedder_setting = Embedder::UserProvided(UserProvidedEmbedderSettings { dimensions: 1 });
let embedder_setting =
Embedder::UserProvided(UserProvidedEmbedderSettings { dimensions: 1 });
let t3 = index
.set_settings(&crate::settings::Settings {
embedders: Some(HashMap::from([("default".to_string(), embedder_setting)])),
Expand All @@ -1392,18 +1399,22 @@ mod tests {
setup_hybrid_searching(&client, &index).await?;
setup_test_index(&client, &index).await?;

let results: SearchResults<Document> = index.search()
let results: SearchResults<Document> = index
.search()
.with_query("lorem ipsum")
.with_retrieve_vectors(true)
.execute().await?;
.execute()
.await?;
assert_eq!(results.hits.len(), 1);
let expected = Vectors::from(&[1000.0]);
assert_eq!(results.hits[0].result._vectors, Some(expected));

let results: SearchResults<Document> = index.search()
let results: SearchResults<Document> = index
.search()
.with_query("lorem ipsum")
.with_retrieve_vectors(false)
.execute().await?;
.execute()
.await?;
assert_eq!(results.hits.len(), 1);
assert_eq!(results.hits[0].result._vectors, None);
Ok(())
Expand Down
14 changes: 7 additions & 7 deletions src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub enum Embedder {
/// ..Default::default()
/// };
/// # let expected = r#"{"model":"BAAI/bge-base-en-v1.5","documentTemplate":"A document titled {{doc.title}} whose description starts with {{doc.overview|truncatewords: 20}}"}"#;
/// # let expected: HuggingFaceEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # let expected: HuggingFaceEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # assert_eq!(embedder_setting, expected);
/// ```
#[derive(Serialize, Deserialize, Default, Debug, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -109,7 +109,7 @@ pub struct HuggingFaceEmbedderSettings {
/// ..Default::default()
/// };
/// # let expected = r#"{"apiKey":"anOpenAiApiKey","model":"text-embedding-3-small","documentTemplate":"A document titled {{doc.title}} whose description starts with {{doc.overview|truncatewords: 20}}","dimensions": 1536"}"#;
/// # let expected: OpenapiEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # let expected: OpenapiEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # assert_eq!(embedder_setting, expected);
/// ```
#[derive(Serialize, Deserialize, Default, Debug, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -150,7 +150,7 @@ pub struct OpenapiEmbedderSettings {
/// document_template: Some("A document titled {{doc.title}} whose description starts with {{doc.overview|truncatewords: 20}}".to_string()),
/// };
/// # let expected = r#"{"url":"http://localhost:11434/api/embeddings","apiKey":"foobarbaz","model":"nomic-embed-text","documentTemplate":"A document titled {{doc.title}} whose description starts with {{doc.overview|truncatewords: 20}}"}"#;
/// # let expected: OllamaEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # let expected: OllamaEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # assert_eq!(embedder_setting, expected);
/// ```
#[derive(Serialize, Deserialize, Default, Debug, Clone, Eq, PartialEq)]
Expand All @@ -170,10 +170,10 @@ pub struct OllamaEmbedderSettings {
///
/// # Example embedding models
///
/// | Model | Parameter | Size |
/// | Model | Parameter | Size |
/// |--------------------------|--------------|-----------------------------------------------------------------|
/// | `mxbai-embed-large` | `334M` | [View model](https://ollama.com/library/mxbai-embed-large) |
/// | `nomic-embed-text` | `137M` | [View model](https://ollama.com/library/nomic-embed-text) |
/// | `mxbai-embed-large` | `334M` | [View model](https://ollama.com/library/mxbai-embed-large) |
/// | `nomic-embed-text` | `137M` | [View model](https://ollama.com/library/nomic-embed-text) |
/// | `all-minilm` | `23M`,`33M` | [View model](https://ollama.com/library/all-minilm) |
/// | `snowflake-arctic-embed` | varies | [View model](https://ollama.com/library/snowflake-arctic-embed) |
pub model: String,
Expand Down Expand Up @@ -208,7 +208,7 @@ pub struct OllamaEmbedderSettings {
/// embedding_object: vec!["embedding".to_string()],
/// };
/// # let expected = r#"{"url":"http://localhost:12345/api/v1/embed","apiKey":"SOURCE_API_KEY","dimensions":512,"documentTemplate":"A document titled {{doc.title}} whose description starts with {{doc.overview|truncatewords: 20}}","inputField":["data","text"],"inputType":"text","query":{"dimensions":512,"model":"MODEL_NAME"},"pathToEmbeddings":["data"],"embeddingObject":["embedding"]}"#;
/// # let expected: GenericRestEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # let expected: GenericRestEmbedderSettings = serde_json::from_str(expected).unwrap();
/// # assert_eq!(embedder_setting, expected);
/// ```
#[derive(Serialize, Deserialize, Default, Debug, Clone, Eq, PartialEq)]
Expand Down

0 comments on commit cbac495

Please sign in to comment.