Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unittest gen wip #1107

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/celery_app/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import seer.app # noqa: F401
from celery_app.app import celery_app as celery # noqa: F401
from celery_app.config import CeleryQueues
from seer.automation.autofix.tasks import (
check_and_mark_recent_autofix_runs,
delete_old_autofix_runs,
)
from seer.automation.autofix.tasks import check_and_mark_recent_autofix_runs
from seer.automation.tasks import delete_old_automation_runs


@celery.on_after_finalize.connect
Expand All @@ -21,6 +19,6 @@ def setup_periodic_tasks(sender, **kwargs):

sender.add_periodic_task(
crontab(minute="0", hour="0"), # run once a day
delete_old_autofix_runs.signature(kwargs={}, queue=CeleryQueues.DEFAULT),
name="Delete old Autofix runs for 90 day time-to-live",
delete_old_automation_runs.signature(kwargs={}, queue=CeleryQueues.DEFAULT),
name="Delete old Automation runs for 90 day time-to-live",
)
34 changes: 34 additions & 0 deletions src/migrations/versions/2df77dc3a696_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Migration

Revision ID: 2df77dc3a696
Revises: 6caab02760db
Create Date: 2024-08-20 17:18:33.399435

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "2df77dc3a696"
down_revision = "6caab02760db"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("run_state", schema=None) as batch_op:
batch_op.add_column(
sa.Column("type", sa.String(), nullable=False, server_default="autofix")
)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("run_state", schema=None) as batch_op:
batch_op.drop_column("type")

# ### end Alembic commands ###
28 changes: 28 additions & 0 deletions src/seer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
RepoAccessCheckResponse,
)
from seer.automation.codebase.repo_client import RepoClient
from seer.automation.codegen.models import (
CodegenUnitTestsRequest,
CodegenUnitTestsResponse,
CodegenUnitTestsStateRequest,
CodegenUnitTestsStateResponse,
)
from seer.automation.codegen.tasks import codegen_unittest, get_unittest_state
from seer.automation.summarize.issue import run_summarize_issue
from seer.automation.summarize.models import SummarizeIssueRequest, SummarizeIssueResponse
from seer.automation.utils import raise_if_no_genai_consent
Expand Down Expand Up @@ -214,6 +221,27 @@ def autofix_evaluation_start_endpoint(data: AutofixEvaluationRequest) -> Autofix
return AutofixEndpointResponse(started=True, run_id=-1)


@json_api(blueprint, "/v1/automation/codegen/unit-tests")
def codegen_unit_tests_endpoint(data: CodegenUnitTestsRequest) -> CodegenUnitTestsResponse:
return codegen_unittest(data)


@json_api(blueprint, "/v1/automation/codegen/unit-tests/state")
def codegen_unit_tests_state_endpoint(
data: CodegenUnitTestsStateRequest,
) -> CodegenUnitTestsStateResponse:
state = get_unittest_state(data)

return CodegenUnitTestsStateResponse(
run_id=state.run_id,
status=state.status,
changes=state.file_changes,
triggered_at=state.last_triggered_at,
updated_at=state.updated_at,
completed_at=state.completed_at,
)


@json_api(blueprint, "/v1/automation/summarize/issue")
def summarize_issue_endpoint(data: SummarizeIssueRequest) -> SummarizeIssueResponse:
return run_summarize_issue(data)
Expand Down
44 changes: 5 additions & 39 deletions src/seer/automation/autofix/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Literal, cast

import sentry_sdk
import sqlalchemy.sql as sql
from langfuse import Langfuse

from celery_app.app import celery_app
Expand Down Expand Up @@ -94,38 +93,14 @@ def get_autofix_state(

def get_all_autofix_runs_after(after: datetime.datetime):
with Session() as session:
runs = session.query(DbRunState).filter(DbRunState.last_triggered_at > after).all()
runs = (
session.query(DbRunState)
.filter(DbRunState.last_triggered_at > after, DbRunState.type == "autofix")
.all()
)
return [ContinuationState.from_id(run.id, AutofixContinuation) for run in runs]


def delete_all_runs_before(before: datetime.datetime, batch_size=1000):
deleted_count = 0
while True:
with Session() as session:
subquery = (
session.query(DbRunState.id)
.filter(DbRunState.last_triggered_at < before)
.limit(batch_size)
.subquery()
)
count = (
session.query(DbRunState)
.filter(sql.exists().where(DbRunState.id == subquery.c.id))
.delete()
)
session.commit()

deleted_count += count
if count == 0:
break
sentry_sdk.metrics.incr(
key="autofix_state_TTL_deletion",
value=count,
)

return deleted_count


@celery_app.task(time_limit=15)
def check_and_mark_recent_autofix_runs():
logger.info("Checking and marking recent autofix runs")
Expand All @@ -137,15 +112,6 @@ def check_and_mark_recent_autofix_runs():
check_and_mark_if_timed_out(run)


@celery_app.task(time_limit=30)
def delete_old_autofix_runs():
logger.info("Deleting old Autofix runs for 90 day time-to-live")
before = datetime.datetime.now() - datetime.timedelta(days=90) # over 90 days old
deleted_count = delete_all_runs_before(before)
print(deleted_count)
logger.info(f"Deleted {deleted_count} runs")


def check_and_mark_if_timed_out(state: ContinuationState):
with state.update() as cur:
if cur.is_running and cur.has_timed_out:
Expand Down
85 changes: 79 additions & 6 deletions src/seer/automation/autofix/tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import fnmatch
import logging
import os
import textwrap

from langfuse.decorators import observe
Expand All @@ -9,15 +11,16 @@
from seer.automation.codebase.code_search import CodeSearcher
from seer.automation.codebase.models import MatchXml
from seer.automation.codebase.utils import cleanup_dir
from seer.automation.codegen.codegen_context import CodegenContext

logger = logging.getLogger(__name__)


class BaseTools:
context: AutofixContext
context: AutofixContext | CodegenContext
retrieval_top_k: int

def __init__(self, context: AutofixContext, retrieval_top_k: int = 8):
def __init__(self, context: AutofixContext | CodegenContext, retrieval_top_k: int = 8):
self.context = context
self.retrieval_top_k = retrieval_top_k

Expand All @@ -30,7 +33,8 @@ def expand_document(self, input: str, repo_name: str | None = None):
client = self.context.get_repo_client(repo_name)
repo_name = client.repo_name

self.context.event_manager.add_log(f"Looked at `{input}` in `{repo_name}`")
if isinstance(self.context, AutofixContext):
self.context.event_manager.add_log(f"Looked at `{input}` in `{repo_name}`")

if file_contents:
return file_contents
Expand Down Expand Up @@ -170,12 +174,47 @@ def keyword_search(
file_names.append(f"`{result.relative_path}`")
result_str += f"{match_xml.to_prompt_str()}\n\n"

self.context.event_manager.add_log(
f"Searched codebase for `{keyword}`, found {len(file_names)} result(s) in {', '.join(file_names)}"
)
if isinstance(self.context, AutofixContext):
self.context.event_manager.add_log(
f"Searched codebase for `{keyword}`, found {len(file_names)} result(s) in {', '.join(file_names)}"
)

return result_str

@observe(name="File Search")
@ai_track(description="File Search")
def file_search(
self,
filename: str,
repo_name: str | None = None,
):
"""
Given a filename with extension returns the list of locations where a file with the name is found.
"""
repo_client = self.context.get_repo_client(repo_name=repo_name)
all_paths = repo_client.get_index_file_set()
found = [path for path in all_paths if os.path.basename(path) == filename]
if len(found) == 0:
return f"no file with name {filename} found in repository"
return ",".join(found)

@observe(name="File Search Wildcard")
@ai_track(description="File Search Wildcard")
def file_search_wildcard(
self,
pattern: str,
repo_name: str | None = None,
):
"""
Given a filename pattern with wildcards, returns the list of file paths that match the pattern.
"""
repo_client = self.context.get_repo_client(repo_name=repo_name)
all_paths = repo_client.get_index_file_set()
found = [path for path in all_paths if fnmatch.fnmatch(path, pattern)]
if len(found) == 0:
return f"No files matching pattern '{pattern}' found in repository"
return "\n".join(found)

def get_tools(self):
tools = [
FunctionTool(
Expand Down Expand Up @@ -245,6 +284,40 @@ def get_tools(self):
},
],
),
FunctionTool(
name="file_search",
fn=self.file_search,
description="Searches for a file in the codebase.",
parameters=[
{
"name": "filename",
"type": "string",
"description": "The file to search for.",
},
{
"name": "repo_name",
"type": "string",
"description": "Optional name of the repository to search in if you know it.",
},
],
),
FunctionTool(
name="file_search_wildcard",
fn=self.file_search_wildcard,
description="Searches for files in a folder using a wildcard pattern.",
parameters=[
{
"name": "pattern",
"type": "string",
"description": "The wildcard pattern to match files.",
},
{
"name": "repo_name",
"type": "string",
"description": "Optional name of the repository to search in if you know it.",
},
],
),
]

return tools
13 changes: 13 additions & 0 deletions src/seer/automation/codebase/repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,16 @@ def get_index_file_set(
file_set.add(file.path)

return file_set

def get_pr_diff_content(self, pr_url: str) -> str:
requester = self.repo._requester
headers = {
"Authorization": f"{requester.auth.token_type} {requester.auth.token}", # type: ignore
"Accept": "application/vnd.github.diff",
}

data = requests.get(pr_url, headers=headers)

data.raise_for_status() # Raise an exception for HTTP errors

return data.text
76 changes: 76 additions & 0 deletions src/seer/automation/codegen/codegen_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging

from seer.automation.codebase.repo_client import RepoClient
from seer.automation.codegen.codegen_event_manager import CodegenEventManager
from seer.automation.codegen.models import CodegenContinuation
from seer.automation.codegen.state import CodegenContinuationState
from seer.automation.models import RepoDefinition
from seer.automation.pipeline import PipelineContext

logger = logging.getLogger(__name__)

RepoExternalId = str
RepoInternalId = int
RepoKey = RepoExternalId | RepoInternalId
RepoIdentifiers = tuple[RepoExternalId, RepoInternalId]


class CodegenContext(PipelineContext):
state: CodegenContinuationState
event_manager: CodegenEventManager
repo: RepoDefinition

def __init__(
self,
state: CodegenContinuationState,
):
request = state.get().request

self.repo = request.repo
self.state = state
self.event_manager = CodegenEventManager(state)

logger.info(f"CodegenContext initialized with run_id {self.run_id}")

@classmethod
def from_run_id(cls, run_id: int):
state = CodegenContinuationState.from_id(run_id, model=CodegenContinuation)
with state.update() as cur:
cur.mark_triggered()

return cls(state)

@property
def run_id(self) -> int:
return self.state.get().run_id

@property
def signals(self) -> list[str]:
return self.state.get().signals

@signals.setter
def signals(self, value: list[str]):
with self.state.update() as state:
state.signals = value

def get_repo_client(self, repo_name: str | None = None):
"""
Gets a repo client for the current single repo or for a given repo name.
If there are more than 1 repos, a repo name must be provided.
"""
return RepoClient.from_repo_definition(self.repo, "read")

def get_file_contents(
self, path: str, repo_name: str | None = None, ignore_local_changes: bool = False
) -> str | None:
repo_client = self.get_repo_client()

file_contents = repo_client.get_file_content(path)

if not ignore_local_changes:
cur_state = self.state.get()
current_file_changes = list(filter(lambda x: x.path == path, cur_state.file_changes))
for file_change in current_file_changes:
file_contents = file_change.apply(file_contents)

return file_contents
Loading