-
Notifications
You must be signed in to change notification settings - Fork 650
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fc4cc97
commit a9278be
Showing
1 changed file
with
97 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import typing as t | ||
from dataclasses import dataclass, field | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
from ragas.dataset_schema import SingleTurnSample | ||
from ragas.experimental.llms.prompt import PydanticPrompt | ||
from ragas.metrics.base import MetricType, MetricWithLLM, SingleTurnMetric | ||
|
||
if t.TYPE_CHECKING: | ||
from langchain_core.callbacks import Callbacks | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class EquivalenceInput(BaseModel): | ||
reference: str = Field(..., description="Reference SQL") | ||
response: str = Field(..., description="Generated SQL") | ||
database_schema: str = Field(..., description="Reference SQL schema") | ||
|
||
|
||
class EquivalenceOutput(BaseModel): | ||
response_query_explaination: str = Field( | ||
..., description="Explanation of the generated SQL" | ||
) | ||
reference_query_explaination: str = Field( | ||
..., description="Explanation of the reference SQL" | ||
) | ||
equivalence: bool = Field( | ||
..., description="Whether the generated SQL is equivalent to the reference SQL" | ||
) | ||
|
||
|
||
class EquivalencePrompt(PydanticPrompt[EquivalenceInput, EquivalenceOutput]): | ||
instruction = """ | ||
Explain and compare two SQL queries (Q1 and Q2) based on the provided database schema. First, explain each query, then determine if they have significant logical differences. | ||
""" | ||
input_model = EquivalenceInput | ||
output_model = EquivalenceOutput | ||
examples = [ | ||
( | ||
EquivalenceInput( | ||
reference="SELECT id, name FROM users WHERE active = 1;", | ||
response="SELECT id, name FROM users WHERE active = true;", | ||
database_schema=""" | ||
Table users: | ||
- id: INT | ||
- name: VARCHAR | ||
- active: BOOLEAN | ||
""", | ||
), | ||
EquivalenceOutput( | ||
response_query_explaination="The generated SQL query retrieves the id and name of users where the active field is true.", | ||
reference_query_explaination="The reference SQL query retrieves the id and name of users where the active field equals 1.", | ||
equivalence=True, | ||
), | ||
) | ||
] | ||
|
||
|
||
@dataclass | ||
class LLMSqlEquivalenceWithReference(MetricWithLLM, SingleTurnMetric): | ||
name: str = "llm_sql_equivalence_with_reference" # type: ignore | ||
_required_columns: t.Dict[MetricType, t.Set[str]] = field( | ||
default_factory=lambda: { | ||
MetricType.SINGLE_TURN: {"response", "reference", "reference_contexts"} | ||
} | ||
) | ||
equivalence_prompt: PydanticPrompt = EquivalencePrompt() | ||
|
||
async def _single_turn_ascore( | ||
self, sample: SingleTurnSample, callbacks: Callbacks | ||
) -> float: | ||
assert self.llm is not None, "LLM is not initialized" | ||
assert isinstance(sample.reference, str), "Sample reference must be a string" | ||
assert isinstance(sample.response, str), "Sample response must be a string" | ||
assert isinstance( | ||
sample.reference_contexts, list | ||
), "Sample reference_contexts must be a List" | ||
|
||
database_schema = " ".join(sample.reference_contexts) | ||
input_data = EquivalenceInput( | ||
reference=sample.reference, | ||
response=sample.response, | ||
database_schema=database_schema, | ||
) | ||
response = await self.equivalence_prompt.generate( | ||
data=input_data, llm=self.llm, callbacks=callbacks | ||
) | ||
return response.equivalence | ||
|
||
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: | ||
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks) |