-
Notifications
You must be signed in to change notification settings - Fork 571
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move agent tool implementations to submodules
- Loading branch information
1 parent
d501b55
commit 4abb676
Showing
9 changed files
with
975 additions
and
897 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
use anyhow::Result; | ||
use futures::TryStreamExt; | ||
use tracing::info; | ||
|
||
use crate::{ | ||
agent::{ | ||
exchange::{CodeChunk, SearchStep, Update}, | ||
prompts, Agent, | ||
}, | ||
analytics::EventData, | ||
llm_gateway, | ||
}; | ||
|
||
impl Agent { | ||
pub async fn code_search(&mut self, query: &String) -> Result<String> { | ||
const CODE_SEARCH_LIMIT: u64 = 10; | ||
self.update(Update::StartStep(SearchStep::Code { | ||
query: query.clone(), | ||
response: String::new(), | ||
})) | ||
.await?; | ||
|
||
let mut results = self | ||
.semantic_search(query.into(), CODE_SEARCH_LIMIT, 0, 0.0, true) | ||
.await?; | ||
|
||
let hyde_docs = self.hyde(query).await?; | ||
if !hyde_docs.is_empty() { | ||
let hyde_doc = hyde_docs.first().unwrap().into(); | ||
let hyde_results = self | ||
.semantic_search(hyde_doc, CODE_SEARCH_LIMIT, 0, 0.3, true) | ||
.await?; | ||
results.extend(hyde_results); | ||
} | ||
|
||
let chunks = results | ||
.into_iter() | ||
.map(|chunk| { | ||
let relative_path = chunk.relative_path; | ||
|
||
CodeChunk { | ||
path: relative_path.clone(), | ||
alias: self.get_path_alias(&relative_path), | ||
snippet: chunk.text, | ||
start_line: (chunk.start_line as usize).saturating_add(1), | ||
end_line: (chunk.end_line as usize).saturating_add(1), | ||
} | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
for chunk in chunks.iter().filter(|c| !c.is_empty()) { | ||
self.exchanges | ||
.last_mut() | ||
.unwrap() | ||
.code_chunks | ||
.push(chunk.clone()) | ||
} | ||
|
||
let response = serde_json::to_string(&chunks).unwrap(); | ||
|
||
self.update(Update::ReplaceStep(SearchStep::Code { | ||
query: query.clone(), | ||
response: response.clone(), | ||
})) | ||
.await?; | ||
|
||
self.track_query( | ||
EventData::input_stage("semantic code search") | ||
.with_payload("query", query) | ||
.with_payload("hyde_queries", &hyde_docs) | ||
.with_payload("chunks", &chunks) | ||
.with_payload("raw_prompt", &response), | ||
); | ||
|
||
Ok(response) | ||
} | ||
|
||
/// Hypothetical Document Embedding (HyDE): https://arxiv.org/abs/2212.10496 | ||
/// | ||
/// This method generates synthetic documents based on the query. These are then | ||
/// parsed and code is extracted. This has been shown to improve semantic search recall. | ||
async fn hyde(&self, query: &str) -> Result<Vec<String>> { | ||
let prompt = vec![llm_gateway::api::Message::system( | ||
&prompts::hypothetical_document_prompt(query), | ||
)]; | ||
|
||
tracing::trace!(?query, "generating hyde docs"); | ||
|
||
let response = self | ||
.llm_gateway | ||
.clone() | ||
.model("gpt-3.5-turbo-0613") | ||
.chat(&prompt, None) | ||
.await? | ||
.try_collect::<String>() | ||
.await?; | ||
|
||
tracing::trace!("parsing hyde response"); | ||
|
||
let documents = prompts::try_parse_hypothetical_documents(&response); | ||
|
||
for doc in documents.iter() { | ||
info!(?doc, "got hyde doc"); | ||
} | ||
|
||
Ok(documents) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use std::collections::HashSet; | ||
|
||
use anyhow::Result; | ||
|
||
use crate::{ | ||
agent::{ | ||
exchange::{SearchStep, Update}, | ||
Agent, | ||
}, | ||
analytics::EventData, | ||
}; | ||
|
||
impl Agent { | ||
pub async fn path_search(&mut self, query: &String) -> Result<String> { | ||
self.update(Update::StartStep(SearchStep::Path { | ||
query: query.clone(), | ||
response: String::new(), | ||
})) | ||
.await?; | ||
|
||
// First, perform a lexical search for the path | ||
let mut paths = self | ||
.fuzzy_path_search(query) | ||
.await | ||
.map(|c| c.relative_path) | ||
.collect::<HashSet<_>>() // TODO: This shouldn't be necessary. Path search should return unique results. | ||
.into_iter() | ||
.collect::<Vec<_>>(); | ||
|
||
let is_semantic = paths.is_empty(); | ||
|
||
// If there are no lexical results, perform a semantic search. | ||
if paths.is_empty() { | ||
let semantic_paths = self | ||
.semantic_search(query.into(), 30, 0, 0.0, true) | ||
.await? | ||
.into_iter() | ||
.map(|chunk| chunk.relative_path) | ||
.collect::<HashSet<_>>() | ||
.into_iter() | ||
.collect(); | ||
|
||
paths = semantic_paths; | ||
} | ||
|
||
let formatted_paths = paths | ||
.iter() | ||
.map(|p| (p.to_string(), self.get_path_alias(p))) | ||
.collect::<Vec<_>>(); | ||
|
||
let response = serde_json::to_string(&formatted_paths).unwrap(); | ||
|
||
self.update(Update::ReplaceStep(SearchStep::Path { | ||
query: query.clone(), | ||
response: response.clone(), | ||
})) | ||
.await?; | ||
|
||
self.track_query( | ||
EventData::input_stage("path search") | ||
.with_payload("query", query) | ||
.with_payload("is_semantic", is_semantic) | ||
.with_payload("results", &paths) | ||
.with_payload("raw_prompt", &response), | ||
); | ||
|
||
Ok(response) | ||
} | ||
} |
Oops, something went wrong.