Files

301 lines
11 KiB
Python
Raw Permalink Normal View History

import os
import re
from pathlib import Path
try:
import tree_sitter
import tree_sitter_c
import tree_sitter_cpp
from tree_sitter import Language, Parser
except ImportError: # pragma: no cover - depends on deployment environment
tree_sitter = None
tree_sitter_c = None
tree_sitter_cpp = None
Language = None
Parser = None
# 常量定义
CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'}
IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external', 'Debug'}
CPP_LANGUAGE = Language(tree_sitter_cpp.language()) if Language and tree_sitter_cpp else None
class RegexNode:
def __init__(self, node_type, text="", start_line=0, end_line=0, fields=None, children=None):
self.type = node_type
self.text = text.encode("utf-8", errors="ignore") if isinstance(text, str) else text
self.start_point = (start_line, 0)
self.end_point = (end_line, 0)
self._fields = fields or {}
self.children = children or []
self.named_children = self.children
def child_by_field_name(self, name):
return self._fields.get(name)
class RegexTree:
def __init__(self, root_node):
self.root_node = root_node
class RegexParser:
_signature_re = re.compile(
r"(?m)^[ \t]*(?:[A-Za-z_][\w\s\*\&:<>]*?[ \t]+)?([A-Za-z_]\w*(?:::[A-Za-z_]\w*)?)"
r"\s*\([^;{}]*\)\s*\{"
)
_call_re = re.compile(r"\b([A-Za-z_]\w*)\s*\(")
_control_words = {"if", "for", "while", "switch", "return", "sizeof"}
def parse(self, source_bytes):
source_text = safe_decode(source_bytes)
children = self._parse_includes(source_text)
children.extend(self._parse_functions(source_text))
return RegexTree(RegexNode("translation_unit", source_text, 0, source_text.count("\n"), children=children))
def _parse_includes(self, source_text):
nodes = []
for match in re.finditer(r"(?m)^\s*#\s*include\s*([<\"][^>\"]+[>\"])", source_text):
line = source_text.count("\n", 0, match.start())
path_node = RegexNode("string_literal", match.group(1), line, line)
nodes.append(
RegexNode(
"preproc_include",
match.group(0),
line,
line,
fields={"path": path_node},
children=[path_node],
)
)
return nodes
def _parse_functions(self, source_text):
nodes = []
for match in self._signature_re.finditer(source_text):
name = match.group(1)
if name in self._control_words:
continue
body_start = match.end() - 1
body_end = self._find_matching_brace(source_text, body_start)
if body_end <= body_start:
continue
start_line = source_text.count("\n", 0, match.start())
end_line = source_text.count("\n", 0, body_end)
function_text = source_text[match.start() : body_end + 1]
body_text = source_text[body_start : body_end + 1]
identifier = RegexNode("identifier", name, start_line, start_line)
declarator = RegexNode(
"function_declarator",
source_text[match.start() : body_start],
start_line,
start_line,
fields={"declarator": identifier},
children=[identifier],
)
body = RegexNode(
"compound_statement",
body_text,
start_line,
end_line,
children=self._parse_calls(body_text, start_line),
)
nodes.append(
RegexNode(
"function_definition",
function_text,
start_line,
end_line,
fields={"declarator": declarator, "body": body},
children=[declarator, body],
)
)
return nodes
def _find_matching_brace(self, source_text, start_index):
depth = 0
for index in range(start_index, len(source_text)):
char = source_text[index]
if char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
return index
return -1
def _parse_calls(self, body_text, start_line):
calls = []
for match in self._call_re.finditer(body_text):
name = match.group(1)
if name in self._control_words:
continue
line = start_line + body_text.count("\n", 0, match.start())
function_node = RegexNode("identifier", name, line, line)
calls.append(
RegexNode(
"call_expression",
match.group(0),
line,
line,
fields={"function": function_node},
children=[function_node],
)
)
return calls
# 全局解析器
if Parser and CPP_LANGUAGE:
parser = Parser()
parser.language = CPP_LANGUAGE
else:
parser = RegexParser()
# 全局图数据结构
FUNCTION_CALL_GRAPH = {}
FILE_DEPENDENCIES = {}
CALLED_BY_GRAPH = {}
def safe_decode(b: bytes) -> str:
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
try:
return b.decode(encoding)
except UnicodeDecodeError:
continue
return b.decode('utf-8', errors='replace')
def extract_comment_before(node, source_lines):
start_line = node.start_point[0]
comment_lines = []
for i in range(start_line - 1, max(0, start_line - 10), -1):
line = safe_decode(source_lines[i]).strip()
if line.startswith('//') or line.startswith('/*'):
comment_lines.append(line)
return " ".join(comment_lines[::-1]) or "无注释"
def extract_code_snippet(tree, node, source_lines):
start_line = node.start_point[0]
end_line = node.end_point[0]
code_lines = []
for i in range(start_line, end_line + 1):
if i < len(source_lines):
code_lines.append(safe_decode(source_lines[i]))
return "\n".join(code_lines)
def extract_full_function_name(func_def_node):
declarator = func_def_node.child_by_field_name("declarator")
if not declarator:
return "<unknown>"
if declarator.type == "function_declarator":
inner_declarator = declarator.child_by_field_name("declarator")
if not inner_declarator:
return "<unknown>"
if inner_declarator.type == "identifier":
return inner_declarator.text.decode('utf-8')
elif inner_declarator.type == "field_expression":
return extract_qualified_name_from_field(inner_declarator)
else:
raw = inner_declarator.text.decode('utf-8')
return raw.split('(')[0].strip()
elif declarator.type == "identifier":
return declarator.text.decode('utf-8')
else:
raw = declarator.text.decode('utf-8')
return raw.split('(')[0].strip()
def extract_qualified_name_from_field(field_node):
def _extract(n):
if n.type == "identifier":
return n.text.decode('utf-8')
elif n.type == "field_expression":
left = n.child_by_field_name("argument") or (n.children[0] if n.children else None)
right = n.child_by_field_name("field") or (n.children[-1] if n.children else None)
left_str = _extract(left) if left else ""
right_str = _extract(right) if right else ""
if left_str and right_str:
return f"{left_str}::{right_str}"
return left_str or right_str
elif n.type == "function_declarator":
inner = n.child_by_field_name("declarator")
return _extract(inner) if inner else n.text.decode('utf-8')
else:
text = n.text.decode('utf-8')
return text.split('(')[0].strip()
return _extract(field_node)
def collect_call_expressions(func_ast_node):
calls = set()
def traverse_ast(node):
if node.type == "call_expression":
func_node = node.child_by_field_name("function")
if func_node:
callee_name = _simplified_extract_callee_name(func_node)
if callee_name != "<unknown>":
calls.add(callee_name)
for child in node.children:
traverse_ast(child)
body_node = func_ast_node.child_by_field_name("body")
if body_node:
traverse_ast(body_node)
return list(calls)
def _simplified_extract_callee_name(func_node):
if not func_node:
return "<unknown>"
raw_text = func_node.text.decode('utf-8', errors='ignore')
name_part = raw_text.split('(')[0].strip()
name_part = name_part.rstrip('*&')
for sep in ['::', '->', '.']:
if sep in name_part:
parts = name_part.rsplit(sep, 1)
name_part = parts[-1] if parts else name_part
if name_part and (name_part[0].isalpha() or name_part[0] == '_'):
return name_part
return "<unknown>"
def extract_class_name(class_node):
name_node = class_node.child_by_field_name("name")
if name_node and name_node.type == "type_identifier":
return name_node.text.decode('utf-8')
for child in class_node.children:
if child.type == "type_identifier":
return child.text.decode('utf-8')
return "<unknown>"
def extract_member_function_name(func_node, parent_class_name):
declarator = func_node.child_by_field_name("declarator")
if not declarator:
return "<unknown>"
if declarator.type == "function_declarator":
inner_declarator = declarator.child_by_field_name("declarator")
if not inner_declarator:
return "<unknown>"
if inner_declarator.type == "field_expression":
return extract_qualified_name_from_field(inner_declarator)
elif inner_declarator.type == "identifier":
return inner_declarator.text.decode('utf-8')
raw_text = declarator.text.decode('utf-8', errors='ignore')
func_name = raw_text.split('(')[0].strip()
func_name = func_name.split()[-1] if ' ' in func_name else func_name
return func_name
def extract_base_classes(base_clause_node):
base_classes = []
for child in base_clause_node.named_children:
if child.type == "base_class_clause":
type_node = child.child_by_field_name("type")
if type_node:
base_name = type_node.text.decode('utf-8')
if '::' in base_name:
base_name = base_name.split('::')[-1]
access = 'public'
for subchild in child.children:
if subchild.type in ["public", "protected", "private", "virtual"]:
access_keyword = subchild.text.decode('utf-8')
if access_keyword in ['public', 'protected', 'private']:
access = access_keyword
base_classes.append({
'name': base_name,
'access': access
})
return base_classes