Skip to content

Commit

Permalink
Changes to the prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
danieltrt committed Jun 30, 2023
1 parent 4ec9d7d commit 5686b2e
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 116 deletions.
4 changes: 2 additions & 2 deletions experimental/rule_inference/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def infer_from_example(data):
{
"rule_name": rule_name,
"rule": rule,
"gpt_output": agent.get_explanation(rule),
"gpt_output": agent.get_explanation(),
},
room=room,
)
Expand All @@ -110,7 +110,7 @@ def improve_rules(data):
{
"rule_name": rule_name,
"rule": rule,
"gpt_output": agent.get_explanation(rule),
"gpt_output": agent.get_explanation(),
},
room=room,
)
Expand Down
101 changes: 61 additions & 40 deletions experimental/rule_inference/piranha_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class PiranhaAgent:
chat = attr.ib(default=None)
tree_sitter_language = attr.ib(default=None)
tree_sitter_parser = attr.ib(default=None)
explanation = attr.ib(default=None)
language_mappings = {
"java": "java",
"kt": "kotlin",
Expand Down Expand Up @@ -90,6 +91,10 @@ def infer_rules(self, callback=None) -> Optional[Tuple[str, str]]:
rules[from_name].name: [rules[to_name].name for to_name in to_names]
for from_name, to_names in finder.edges.items()
}
#
edges = [
{"from": k, "to": v, "scope": "File"} for k, v in edges.items() if v != []
]
graph = RawRuleGraph(list(rules.values()), edges)
rules = graph.to_toml()

Expand Down Expand Up @@ -158,20 +163,20 @@ def create_chats(self, rules):
chat_interactions[i].append_system_message(response)
return chat_interactions

def get_explanation(self, rules):
self.chat.append_explanation_request(rules)
response = self.chat.get_model_response()
response = re.sub(r"```md(.*)```", r"\1", response)
return response
def get_explanation(self):
return self.explanation

def iterate_inference(self, chat_interactions):
"""BFS for a rule that transforms the source code into the target code."""
max_rounds = 10
for i in range(max_rounds):
for chat in chat_interactions:
try:
file_name, toml_block = self.validate_rule_wrapper(chat)
file_name, toml_block, explanation = self.validate_rule_wrapper(
chat
)
self.chat = chat
self.explanation = explanation
return file_name, toml_block
except PiranhaAgentError as e:
logger.debug(
Expand All @@ -182,29 +187,17 @@ def iterate_inference(self, chat_interactions):
f"Failed to generate a rule after {max_rounds} rounds of interaction with GPT-4."
)

def append_diff_information(self, diff, source_tree, target_tree):
patches: List[Patch] = Patch.from_diffs(diff)
# Append to the diff the information about the deleted lines for each patch
diff += "\n=== Draft queries to represent deleted lines ===\n\n"
for patch in patches:
node_pairs = patch.get_nodes_from_patch(source_tree, target_tree)
for nodes_before, nodes_after in node_pairs:
for line, node in nodes_before.items():
q = QueryWriter([node])
diff += f"\n\n--------\n\nDelete Line: {line} \n\nCorresponding query:\n{q.write()}"
return diff

def validate_rule_wrapper(self, chat):
def validate_rule_wrapper(self, chat) -> Tuple[str, str, str]:
# with Pool(processes=1) as pool:
completion = chat.get_model_response()
# result = pool.apply_async(self.validate_rule, (completion,))
try:
file_name, toml_block = self.validate_rule(completion)
file_name, toml_block, explanation = self.validate_rule(completion)
# file_name, toml_block = result.get(
# timeout=5
# ) # Add a timeout of 5 seconds
if file_name and toml_block:
return file_name, toml_block
return file_name, toml_block, explanation

except multiprocessing.context.TimeoutError:
raise PiranhaAgentError(
Expand All @@ -213,15 +206,25 @@ def validate_rule_wrapper(self, chat):
"Otherwise you need to use a [[rules.filters]] with contains or not_contains."
)

def validate_rule(self, completion):
def validate_rule(self, completion) -> Tuple[str, str, str]:
# Define regex pattern for ```toml block
pattern = r"```toml(.*?)```"
pattern = r"```toml(?!md)(.*?)```"
logger.debug(f"Completion\n: {completion}")
# Extract all toml block contents
toml_blocks = re.findall(pattern, completion, re.DOTALL)
if not toml_blocks:
raise PiranhaAgentError(
"Could not create Piranha rule. There is no TOML block. "
"Please create a rule to refactor the code."
"No TOML block provided in the expected output format. "
"Please provide a TOML block with the rule. ```toml ... ```"
)

pattern = r"```md(.*?)```"
explanation = re.findall(pattern, completion, re.DOTALL)

if not explanation:
raise PiranhaAgentError(
"No explanation provided in the expected output format. "
"Please provide an explanation as a markdown block. ```md ... ```"
)

try:
Expand Down Expand Up @@ -251,7 +254,7 @@ def validate_rule(self, completion):
pattern = r"<file_name_start>(.*?)<file_name_end>"
file_names = re.findall(pattern, completion, re.DOTALL)
file_name = file_names[0] if file_names else "rule.toml"
return file_name, toml_block
return file_name, toml_block, explanation[0]

def run_piranha(self, toml_dict):
"""Runs the inferred rule by applying it to the source code using the Piranha.
Expand Down Expand Up @@ -298,56 +301,74 @@ def improve_rule(self, task: str, rules: str):
try:
controller = Controller(chat)
updated_rules = []
explanations = []
for rule in rules.get("rules", []):
rule_str = toml.dumps(rule, encoder=PrettyTOML())
should_improve = controller.should_improve_rule(task, rule_str)
if should_improve:
option = controller.get_option_for_improvement(rule_str)
if option == "add filter":
rule = self.add_filter(task, rule, chat)
rule, explanation = self.add_filter(task, rule, chat)
updated_rules.append(rule)
explanations.append(explanation)
continue
updated_rules.append(rule)
rule_block = "\n".join(
[toml.dumps(rule, encoder=PrettyTOML()) for rule in updated_rules]
)
explanation_block = "\n".join(explanations)
validation = self.validate_rule(
f"<file_name_start>rules.toml<file_name_end> ```toml\n{rule_block}\n```"
f"<file_name_start>rules.toml<file_name_end> ```toml\n{rule_block}\n``` ```md\n{explanation_block}\n```"
)

self.chat = chat
return validation
self.explanation = "\n".join(explanations)
return validation[:-1]
except Exception as e:
logger.debug(
f"GPT-4 failed to generate a rule. Following up the next round with {e}. Trying again...\n"
)
chat.append_user_followup(str(e))

def add_filter(self, desc, rule, chat):
def add_filter(self, desc, rule, chat) -> Tuple[dict, str]:
"""Adds a filter to the rule that encloses the nodes of the rule."""

query = rule.get("query")
source_tree = self.get_tree_from_code(self.source_code)
tree_sitter_q = self.tree_sitter_language.query(query)
captures = tree_sitter_q.captures(source_tree.root_node)
captured_nodes = NodeUtils.get_smallest_nonoverlapping_set(
[c[0] for c in captures]
)
enclosing_nodes = []
# Get the nodes that can be used as enclosing node for the rules
for node in captured_nodes:
captures = NodeUtils.get_smallest_nonoverlapping_set([c[0] for c in captures])

parents = []
for node in captures:
while node:
enclosing_nodes.append(node.sexp())
parents.append(node)
node = node.parent

enclosing_nodes = parents
enclosing_options = ""

for i, node in enumerate(enclosing_nodes):
qw = QueryWriter([node])
query = qw.write(simplify=True)
enclosing_options += f"\n\n=== Option {i} ===\n\n"
enclosing_options += f'enclosing_node = """{query}"""\n'

# Get the nodes that can be used as enclosing node for the rules
chat.append_improve_request(
desc,
toml.dumps(rule, encoder=PrettyTOML()),
enclosing_nodes,
enclosing_options,
)
completion = chat.get_model_response()
pattern = r"```toml(.*?)```"
pattern = r"```toml(?!md)(.*?)```"
# Extract all toml block contents
toml_blocks = re.findall(pattern, completion, re.DOTALL)
return toml.loads(toml_blocks[0])

pattern = r"```md(.*?)```"
explanation = re.findall(pattern, completion, re.DOTALL)

return toml.loads(toml_blocks[0]), explanation[0]

@staticmethod
def normalize_code(code: str) -> str:
Expand Down
Loading

0 comments on commit 5686b2e

Please sign in to comment.