-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into introduce-replace_index
- Loading branch information
Showing
21 changed files
with
1,029 additions
and
140 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
This module provides a CommentFinder class for finding comments in an abstract syntax tree (AST) and their associated | ||
replacement pairs, as well as removing partial nodes from the pairs. | ||
""" | ||
|
||
from collections import defaultdict, deque | ||
from typing import List, Dict, Deque, Tuple | ||
|
||
import attr | ||
from tree_sitter import Tree, Node | ||
|
||
from experimental.rule_inference.node_utils import NodeUtils | ||
|
||
|
||
@attr.s | ||
class CommentFinder: | ||
""" | ||
CommentFinder traverses an AST to find nodes of type comment. Each comment will be associated with a number, and | ||
it tells us when we should start collecting nodes for the replacement pair. When we find // 1 end, we stop collecting. | ||
""" | ||
|
||
source_tree = attr.ib(type=Tree) | ||
target_tree = attr.ib(type=Tree) | ||
replacement_source = attr.ib( | ||
type=Dict[str, List[Tree]], default=attr.Factory(lambda: defaultdict(list)) | ||
) | ||
replacement_target = attr.ib( | ||
type=Dict[str, List[Tree]], default=attr.Factory(lambda: defaultdict(list)) | ||
) | ||
edges = attr.ib( | ||
type=Dict[str, List[str]], default=attr.Factory(lambda: defaultdict(list)) | ||
) | ||
|
||
def find_replacement_pairs(self) -> Dict[str, Tuple[List[Node], List[Node]]]: | ||
""" | ||
Invokes find_replacement_pair on each tree. Finds matching pairs using the comments, returns those pairs. | ||
""" | ||
|
||
source_dict = self.find_replacement_pair(self.source_tree) | ||
target_dict = self.find_replacement_pair(self.target_tree) | ||
|
||
# Collect matching pairs using the comments | ||
matching_pairs = {} | ||
|
||
for comment in source_dict: | ||
if comment in target_dict: | ||
matching_pairs[comment] = (source_dict[comment], target_dict[comment]) | ||
|
||
return matching_pairs | ||
|
||
def find_replacement_pair(self, tree: Tree): | ||
""" | ||
Traverses the AST to find nodes of type comment. Each comment will be associated with a number, and | ||
it tells us when we should start collecting nodes for the replacement pair. When we find // 1 end, we stop collecting. | ||
""" | ||
|
||
root = tree.root_node | ||
stack: Deque = deque([root]) | ||
comment = None | ||
replacement_dict = defaultdict(list) | ||
while stack: | ||
node = stack.pop() | ||
|
||
if node.type == "line_comment": | ||
prev_comment = comment | ||
comment = node.text.decode("utf8") | ||
if "->" in comment: | ||
x, y = comment.split("->") | ||
self.edges[x[2:].strip()].append(y.strip()) | ||
comment = prev_comment | ||
elif "end" in comment: | ||
comment = None | ||
|
||
elif comment: | ||
replacement_dict[comment].append(node) | ||
for child in reversed(node.children): | ||
stack.append(child) | ||
|
||
for comment, nodes in replacement_dict.items(): | ||
nodes = NodeUtils.remove_partial_nodes(nodes) | ||
nodes = NodeUtils.get_smallest_nonoverlapping_set(nodes) | ||
replacement_dict[comment] = nodes | ||
|
||
return replacement_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from tree_sitter import TreeCursor, Node | ||
from typing import List, Tuple, Dict | ||
import re | ||
|
||
|
||
class NodeUtils: | ||
@staticmethod | ||
def generate_sexpr(node, depth=0, prefix=""): | ||
indent = " " * depth | ||
cursor: TreeCursor = node.walk() | ||
s_exp = indent + f"{prefix}({node.type} " | ||
next_child = cursor.goto_first_child() | ||
|
||
while next_child: | ||
child_node: Node = cursor.node | ||
if child_node.is_named: | ||
s_exp += "\n" | ||
prefix = "" | ||
if cursor.current_field_name(): | ||
prefix = f"{cursor.current_field_name()}: " | ||
s_exp += NodeUtils.generate_sexpr(child_node, depth + 1, prefix) | ||
elif cursor.current_field_name(): | ||
s_exp += "\n" + " " * (depth + 1) | ||
s_exp += f'{cursor.current_field_name()}: ("{child_node.type}")' | ||
next_child = cursor.goto_next_sibling() | ||
return s_exp + ")" | ||
|
||
@staticmethod | ||
def convert_to_source(node, depth=0, exclude=None): | ||
if exclude is None: | ||
exclude = [] | ||
for to_exclude in exclude: | ||
if NodeUtils.contains(to_exclude, node): | ||
return "{placeholder}" | ||
|
||
cursor: TreeCursor = node.walk() | ||
s_exp = "" | ||
has_next_child = cursor.goto_first_child() | ||
if not has_next_child: | ||
s_exp += node.text.decode("utf8") | ||
return s_exp | ||
|
||
while has_next_child: | ||
nxt = NodeUtils.convert_to_source(cursor.node, depth + 1, exclude) | ||
s_exp += nxt + " " | ||
has_next_child = cursor.goto_next_sibling() | ||
return s_exp.strip() | ||
|
||
@staticmethod | ||
def get_smallest_nonoverlapping_set(nodes: List[Node]): | ||
""" | ||
Get the smallest non overlapping set of nodes from the given list. | ||
:param nodes: | ||
:return: | ||
""" | ||
# sort the nodes by their start position | ||
# if the start positions are equal, sort by end position in reverse order | ||
nodes = sorted( | ||
nodes, key=lambda x: (x.start_point, tuple(map(lambda n: -n, x.end_point))) | ||
) | ||
# get the smallest non overlapping set of nodes | ||
smallest_non_overlapping_set = [] | ||
for node in nodes: | ||
if not smallest_non_overlapping_set: | ||
smallest_non_overlapping_set.append(node) | ||
else: | ||
if node.start_point > smallest_non_overlapping_set[-1].end_point: | ||
smallest_non_overlapping_set.append(node) | ||
return smallest_non_overlapping_set | ||
|
||
@staticmethod | ||
def remove_partial_nodes(nodes: List[Node]) -> List[Node]: | ||
""" | ||
Remove nodes that whose children are not contained in the replacement pair. | ||
Until a fixed point is reached where no more nodes can be removed. | ||
""" | ||
while True: | ||
new_nodes = [node for node in nodes if all(child in nodes for child in node.children)] | ||
if len(new_nodes) == len(nodes): | ||
break | ||
nodes = new_nodes | ||
return new_nodes | ||
|
||
@staticmethod | ||
def normalize_code(code: str) -> str: | ||
"""Eliminates unnecessary spaces and newline characters from code. | ||
This function is as preprocessing step before comparing the refactored code with the target code. | ||
:param code: str, Code to normalize. | ||
:return: str, Normalized code. | ||
""" | ||
|
||
# replace multiple spaces with a single space | ||
code = re.sub(r"\s+", "", code) | ||
# replace multiple newlines with a single newline | ||
code = re.sub(r"\n+", "", code) | ||
# remove spaces before and after newlines | ||
code = re.sub(r" ?\n ?", "", code) | ||
# remove spaces at the beginning and end of the code | ||
code = code.strip() | ||
return code | ||
|
||
@staticmethod | ||
def contains(node: Node, other: Node) -> bool: | ||
"""Checks if the given node contains the other node. | ||
:param node: Node, Node to check if it contains the other node. | ||
:param other: Node, Node to check if it is contained by the other node. | ||
:return: bool, True if the given node contains the other node, False otherwise. | ||
""" | ||
return ( | ||
node.start_point <= other.start_point and node.end_point >= other.end_point | ||
) | ||
|
||
@staticmethod | ||
def find_lowest_common_ancestor(nodes: List[Node]) -> Node: | ||
""" | ||
Find the smallest common ancestor of the provided nodes. | ||
:param nodes: list of nodes for which to find the smallest common ancestor. | ||
:return: Node which is the smallest common ancestor. | ||
""" | ||
# Ensure the list of nodes isn't empty | ||
assert len(nodes) > 0 | ||
|
||
# Prepare a dictionary to map node's id to the node object | ||
ids_to_nodes = {node.id: node for node in nodes} | ||
|
||
# For each node, follow its parent chain and add each one to the ancestor set and ids_to_nodes map | ||
ancestor_ids = [set() for _ in nodes] | ||
for i, node in enumerate(nodes): | ||
while node is not None: | ||
ancestor_ids[i].add(node.id) | ||
ids_to_nodes[node.id] = node | ||
node = node.parent | ||
|
||
# Get the intersection of all ancestor sets | ||
common_ancestors_ids = set.intersection(*ancestor_ids) | ||
|
||
# If there are no common ancestors, there's a problem with the input tree | ||
if not common_ancestors_ids: | ||
raise ValueError("Nodes have no common ancestor") | ||
|
||
# The LCA is the deepest node, i.e. the one with maximum start_byte | ||
max_start_byte_id = max( | ||
common_ancestors_ids, key=lambda node_id: ids_to_nodes[node_id].start_byte | ||
) | ||
|
||
return ids_to_nodes[max_start_byte_id] |
Oops, something went wrong.