Skip to content

Commit

Permalink
Merge branch 'master' into introduce-replace_index
Browse files Browse the repository at this point in the history
  • Loading branch information
ketkarameya committed Jun 21, 2023
2 parents dc8564d + 4baecbd commit ec6c79b
Show file tree
Hide file tree
Showing 15 changed files with 990 additions and 140 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ repos:
- /*| | */

- repo: https://github.com/pre-commit/mirrors-autopep8
rev: 'v2.0.1'
rev: "v2.0.1"
hooks:
- id: autopep8
- id: autopep8
files: demo/
exclude: only_lists.py
84 changes: 84 additions & 0 deletions experimental/rule_inference/comment_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
This module provides a CommentFinder class for finding comments in an abstract syntax tree (AST) and their associated
replacement pairs, as well as removing partial nodes from the pairs.
"""

from collections import defaultdict, deque
from typing import List, Dict, Deque, Tuple

import attr
from tree_sitter import Tree, Node

from experimental.rule_inference.node_utils import NodeUtils


@attr.s
class CommentFinder:
"""
CommentFinder traverses an AST to find nodes of type comment. Each comment will be associated with a number, and
it tells us when we should start collecting nodes for the replacement pair. When we find // 1 end, we stop collecting.
"""

source_tree = attr.ib(type=Tree)
target_tree = attr.ib(type=Tree)
replacement_source = attr.ib(
type=Dict[str, List[Tree]], default=attr.Factory(lambda: defaultdict(list))
)
replacement_target = attr.ib(
type=Dict[str, List[Tree]], default=attr.Factory(lambda: defaultdict(list))
)
edges = attr.ib(
type=Dict[str, List[str]], default=attr.Factory(lambda: defaultdict(list))
)

def find_replacement_pairs(self) -> Dict[str, Tuple[List[Node], List[Node]]]:
"""
Invokes find_replacement_pair on each tree. Finds matching pairs using the comments, returns those pairs.
"""

source_dict = self.find_replacement_pair(self.source_tree)
target_dict = self.find_replacement_pair(self.target_tree)

# Collect matching pairs using the comments
matching_pairs = {}

for comment in source_dict:
if comment in target_dict:
matching_pairs[comment] = (source_dict[comment], target_dict[comment])

return matching_pairs

def find_replacement_pair(self, tree: Tree):
"""
Traverses the AST to find nodes of type comment. Each comment will be associated with a number, and
it tells us when we should start collecting nodes for the replacement pair. When we find // 1 end, we stop collecting.
"""

root = tree.root_node
stack: Deque = deque([root])
comment = None
replacement_dict = defaultdict(list)
while stack:
node = stack.pop()

if node.type == "line_comment":
prev_comment = comment
comment = node.text.decode("utf8")
if "->" in comment:
x, y = comment.split("->")
self.edges[x[2:].strip()].append(y.strip())
comment = prev_comment
elif "end" in comment:
comment = None

elif comment:
replacement_dict[comment].append(node)
for child in reversed(node.children):
stack.append(child)

for comment, nodes in replacement_dict.items():
nodes = NodeUtils.remove_partial_nodes(nodes)
nodes = NodeUtils.get_smallest_nonoverlapping_set(nodes)
replacement_dict[comment] = nodes

return replacement_dict
70 changes: 58 additions & 12 deletions experimental/rule_inference/local.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,80 @@
import os

from flask import Flask, request, jsonify
import attr
from flask import Flask, request, jsonify, session
import openai
from flask import Flask, render_template
import logging
from piranha_agent import PiranhaAgent
from flask_socketio import SocketIO, join_room


logging.getLogger("Flask").setLevel(logging.DEBUG)
logger = logging.getLogger("Flask")
logger.setLevel(logging.DEBUG)
app = Flask(__name__)
socketio = SocketIO(app)


# Define data validation classes
@attr.s
class InferData:
source_code = attr.ib(validator=attr.validators.instance_of(str))
target_code = attr.ib(validator=attr.validators.instance_of(str))
language = attr.ib(validator=attr.validators.in_(["python", "java"]))
hints = attr.ib(validator=attr.validators.instance_of(str))


@attr.s
class FolderData:
folder_path = attr.ib(validator=attr.validators.instance_of(str))


@app.route("/")
def home():
return render_template("index.html")


@app.route("/api/infer_piranha", methods=["POST"])
def infer_from_example():
data = request.get_json()
@socketio.on("infer_piranha")
def infer_from_example(data):
# Validate the data
data = InferData(**data)
openai.api_key = os.getenv("OPENAI_API_KEY")
agent = PiranhaAgent(
data["source_code"],
data["target_code"],
language=data["language"],
hints=data["hints"],
data.source_code,
data.target_code,
language=data.language,
hints=data.hints,
)

room = session.get("room")
join_room(room)

rule_name, rule = agent.infer_rules(
lambda intermediate_result: socketio.emit(
"infer_progress",
{"rule": intermediate_result},
room=room,
)
)
socketio.emit("infer_result", {"rule_name": rule_name, "rule": rule}, room=room)
session["last_inference_result"] = {
"rule_name": rule_name,
"rule": rule,
}

return jsonify({"message": f"Received source code: {data.source_code}"}), 200


@app.route("/api/process_folder", methods=["POST"])
def process_folder():
data = request.get_json()
data = FolderData(**data)
folder_path = data.folder_path

# Use the folder_path variable to process the folder.
# Note: This assumes your server has the appropriate permissions to access and read the directory.

rule_name, rule = agent.infer_rules()
return jsonify(rule_name, rule), 200
# Let's just return a message for this example
return jsonify({"message": f"Received folder path: {folder_path}"}), 200


if __name__ == "__main__":
Expand Down
149 changes: 149 additions & 0 deletions experimental/rule_inference/node_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from tree_sitter import TreeCursor, Node
from typing import List, Tuple, Dict
import re


class NodeUtils:
@staticmethod
def generate_sexpr(node, depth=0, prefix=""):
indent = " " * depth
cursor: TreeCursor = node.walk()
s_exp = indent + f"{prefix}({node.type} "
next_child = cursor.goto_first_child()

while next_child:
child_node: Node = cursor.node
if child_node.is_named:
s_exp += "\n"
prefix = ""
if cursor.current_field_name():
prefix = f"{cursor.current_field_name()}: "
s_exp += NodeUtils.generate_sexpr(child_node, depth + 1, prefix)
elif cursor.current_field_name():
s_exp += "\n" + " " * (depth + 1)
s_exp += f'{cursor.current_field_name()}: ("{child_node.type}")'
next_child = cursor.goto_next_sibling()
return s_exp + ")"

@staticmethod
def convert_to_source(node, depth=0, exclude=None):
if exclude is None:
exclude = []
for to_exclude in exclude:
if NodeUtils.contains(to_exclude, node):
return "{placeholder}"

cursor: TreeCursor = node.walk()
s_exp = ""
has_next_child = cursor.goto_first_child()
if not has_next_child:
s_exp += node.text.decode("utf8")
return s_exp

while has_next_child:
nxt = NodeUtils.convert_to_source(cursor.node, depth + 1, exclude)
s_exp += nxt + " "
has_next_child = cursor.goto_next_sibling()
return s_exp.strip()

@staticmethod
def get_smallest_nonoverlapping_set(nodes: List[Node]):
"""
Get the smallest non overlapping set of nodes from the given list.
:param nodes:
:return:
"""
# sort the nodes by their start position
# if the start positions are equal, sort by end position in reverse order
nodes = sorted(
nodes, key=lambda x: (x.start_point, tuple(map(lambda n: -n, x.end_point)))
)
# get the smallest non overlapping set of nodes
smallest_non_overlapping_set = []
for node in nodes:
if not smallest_non_overlapping_set:
smallest_non_overlapping_set.append(node)
else:
if node.start_point > smallest_non_overlapping_set[-1].end_point:
smallest_non_overlapping_set.append(node)
return smallest_non_overlapping_set

@staticmethod
def remove_partial_nodes(nodes: List[Node]) -> List[Node]:
"""
Remove nodes that whose children are not contained in the replacement pair.
Until a fixed point is reached where no more nodes can be removed.
"""
while True:
new_nodes = [node for node in nodes if all(child in nodes for child in node.children)]
if len(new_nodes) == len(nodes):
break
nodes = new_nodes
return new_nodes

@staticmethod
def normalize_code(code: str) -> str:
"""Eliminates unnecessary spaces and newline characters from code.
This function is as preprocessing step before comparing the refactored code with the target code.
:param code: str, Code to normalize.
:return: str, Normalized code.
"""

# replace multiple spaces with a single space
code = re.sub(r"\s+", "", code)
# replace multiple newlines with a single newline
code = re.sub(r"\n+", "", code)
# remove spaces before and after newlines
code = re.sub(r" ?\n ?", "", code)
# remove spaces at the beginning and end of the code
code = code.strip()
return code

@staticmethod
def contains(node: Node, other: Node) -> bool:
"""Checks if the given node contains the other node.
:param node: Node, Node to check if it contains the other node.
:param other: Node, Node to check if it is contained by the other node.
:return: bool, True if the given node contains the other node, False otherwise.
"""
return (
node.start_point <= other.start_point and node.end_point >= other.end_point
)

@staticmethod
def find_lowest_common_ancestor(nodes: List[Node]) -> Node:
"""
Find the smallest common ancestor of the provided nodes.
:param nodes: list of nodes for which to find the smallest common ancestor.
:return: Node which is the smallest common ancestor.
"""
# Ensure the list of nodes isn't empty
assert len(nodes) > 0

# Prepare a dictionary to map node's id to the node object
ids_to_nodes = {node.id: node for node in nodes}

# For each node, follow its parent chain and add each one to the ancestor set and ids_to_nodes map
ancestor_ids = [set() for _ in nodes]
for i, node in enumerate(nodes):
while node is not None:
ancestor_ids[i].add(node.id)
ids_to_nodes[node.id] = node
node = node.parent

# Get the intersection of all ancestor sets
common_ancestors_ids = set.intersection(*ancestor_ids)

# If there are no common ancestors, there's a problem with the input tree
if not common_ancestors_ids:
raise ValueError("Nodes have no common ancestor")

# The LCA is the deepest node, i.e. the one with maximum start_byte
max_start_byte_id = max(
common_ancestors_ids, key=lambda node_id: ids_to_nodes[node_id].start_byte
)

return ids_to_nodes[max_start_byte_id]
Loading

0 comments on commit ec6c79b

Please sign in to comment.