import os import re from pathlib import Path try: import tree_sitter import tree_sitter_c import tree_sitter_cpp from tree_sitter import Language, Parser except ImportError: # pragma: no cover - depends on deployment environment tree_sitter = None tree_sitter_c = None tree_sitter_cpp = None Language = None Parser = None # 常量定义 CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'} IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external', 'Debug'} CPP_LANGUAGE = Language(tree_sitter_cpp.language()) if Language and tree_sitter_cpp else None class RegexNode: def __init__(self, node_type, text="", start_line=0, end_line=0, fields=None, children=None): self.type = node_type self.text = text.encode("utf-8", errors="ignore") if isinstance(text, str) else text self.start_point = (start_line, 0) self.end_point = (end_line, 0) self._fields = fields or {} self.children = children or [] self.named_children = self.children def child_by_field_name(self, name): return self._fields.get(name) class RegexTree: def __init__(self, root_node): self.root_node = root_node class RegexParser: _signature_re = re.compile( r"(?m)^[ \t]*(?:[A-Za-z_][\w\s\*\&:<>]*?[ \t]+)?([A-Za-z_]\w*(?:::[A-Za-z_]\w*)?)" r"\s*\([^;{}]*\)\s*\{" ) _call_re = re.compile(r"\b([A-Za-z_]\w*)\s*\(") _control_words = {"if", "for", "while", "switch", "return", "sizeof"} def parse(self, source_bytes): source_text = safe_decode(source_bytes) children = self._parse_includes(source_text) children.extend(self._parse_functions(source_text)) return RegexTree(RegexNode("translation_unit", source_text, 0, source_text.count("\n"), children=children)) def _parse_includes(self, source_text): nodes = [] for match in re.finditer(r"(?m)^\s*#\s*include\s*([<\"][^>\"]+[>\"])", source_text): line = source_text.count("\n", 0, match.start()) path_node = RegexNode("string_literal", match.group(1), line, line) nodes.append( RegexNode( "preproc_include", match.group(0), line, line, fields={"path": path_node}, children=[path_node], ) ) return nodes def _parse_functions(self, source_text): nodes = [] for match in self._signature_re.finditer(source_text): name = match.group(1) if name in self._control_words: continue body_start = match.end() - 1 body_end = self._find_matching_brace(source_text, body_start) if body_end <= body_start: continue start_line = source_text.count("\n", 0, match.start()) end_line = source_text.count("\n", 0, body_end) function_text = source_text[match.start() : body_end + 1] body_text = source_text[body_start : body_end + 1] identifier = RegexNode("identifier", name, start_line, start_line) declarator = RegexNode( "function_declarator", source_text[match.start() : body_start], start_line, start_line, fields={"declarator": identifier}, children=[identifier], ) body = RegexNode( "compound_statement", body_text, start_line, end_line, children=self._parse_calls(body_text, start_line), ) nodes.append( RegexNode( "function_definition", function_text, start_line, end_line, fields={"declarator": declarator, "body": body}, children=[declarator, body], ) ) return nodes def _find_matching_brace(self, source_text, start_index): depth = 0 for index in range(start_index, len(source_text)): char = source_text[index] if char == "{": depth += 1 elif char == "}": depth -= 1 if depth == 0: return index return -1 def _parse_calls(self, body_text, start_line): calls = [] for match in self._call_re.finditer(body_text): name = match.group(1) if name in self._control_words: continue line = start_line + body_text.count("\n", 0, match.start()) function_node = RegexNode("identifier", name, line, line) calls.append( RegexNode( "call_expression", match.group(0), line, line, fields={"function": function_node}, children=[function_node], ) ) return calls # 全局解析器 if Parser and CPP_LANGUAGE: parser = Parser() parser.language = CPP_LANGUAGE else: parser = RegexParser() # 全局图数据结构 FUNCTION_CALL_GRAPH = {} FILE_DEPENDENCIES = {} CALLED_BY_GRAPH = {} def safe_decode(b: bytes) -> str: for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: try: return b.decode(encoding) except UnicodeDecodeError: continue return b.decode('utf-8', errors='replace') def extract_comment_before(node, source_lines): start_line = node.start_point[0] comment_lines = [] for i in range(start_line - 1, max(0, start_line - 10), -1): line = safe_decode(source_lines[i]).strip() if line.startswith('//') or line.startswith('/*'): comment_lines.append(line) return " ".join(comment_lines[::-1]) or "无注释" def extract_code_snippet(tree, node, source_lines): start_line = node.start_point[0] end_line = node.end_point[0] code_lines = [] for i in range(start_line, end_line + 1): if i < len(source_lines): code_lines.append(safe_decode(source_lines[i])) return "\n".join(code_lines) def extract_full_function_name(func_def_node): declarator = func_def_node.child_by_field_name("declarator") if not declarator: return "" if declarator.type == "function_declarator": inner_declarator = declarator.child_by_field_name("declarator") if not inner_declarator: return "" if inner_declarator.type == "identifier": return inner_declarator.text.decode('utf-8') elif inner_declarator.type == "field_expression": return extract_qualified_name_from_field(inner_declarator) else: raw = inner_declarator.text.decode('utf-8') return raw.split('(')[0].strip() elif declarator.type == "identifier": return declarator.text.decode('utf-8') else: raw = declarator.text.decode('utf-8') return raw.split('(')[0].strip() def extract_qualified_name_from_field(field_node): def _extract(n): if n.type == "identifier": return n.text.decode('utf-8') elif n.type == "field_expression": left = n.child_by_field_name("argument") or (n.children[0] if n.children else None) right = n.child_by_field_name("field") or (n.children[-1] if n.children else None) left_str = _extract(left) if left else "" right_str = _extract(right) if right else "" if left_str and right_str: return f"{left_str}::{right_str}" return left_str or right_str elif n.type == "function_declarator": inner = n.child_by_field_name("declarator") return _extract(inner) if inner else n.text.decode('utf-8') else: text = n.text.decode('utf-8') return text.split('(')[0].strip() return _extract(field_node) def collect_call_expressions(func_ast_node): calls = set() def traverse_ast(node): if node.type == "call_expression": func_node = node.child_by_field_name("function") if func_node: callee_name = _simplified_extract_callee_name(func_node) if callee_name != "": calls.add(callee_name) for child in node.children: traverse_ast(child) body_node = func_ast_node.child_by_field_name("body") if body_node: traverse_ast(body_node) return list(calls) def _simplified_extract_callee_name(func_node): if not func_node: return "" raw_text = func_node.text.decode('utf-8', errors='ignore') name_part = raw_text.split('(')[0].strip() name_part = name_part.rstrip('*&') for sep in ['::', '->', '.']: if sep in name_part: parts = name_part.rsplit(sep, 1) name_part = parts[-1] if parts else name_part if name_part and (name_part[0].isalpha() or name_part[0] == '_'): return name_part return "" def extract_class_name(class_node): name_node = class_node.child_by_field_name("name") if name_node and name_node.type == "type_identifier": return name_node.text.decode('utf-8') for child in class_node.children: if child.type == "type_identifier": return child.text.decode('utf-8') return "" def extract_member_function_name(func_node, parent_class_name): declarator = func_node.child_by_field_name("declarator") if not declarator: return "" if declarator.type == "function_declarator": inner_declarator = declarator.child_by_field_name("declarator") if not inner_declarator: return "" if inner_declarator.type == "field_expression": return extract_qualified_name_from_field(inner_declarator) elif inner_declarator.type == "identifier": return inner_declarator.text.decode('utf-8') raw_text = declarator.text.decode('utf-8', errors='ignore') func_name = raw_text.split('(')[0].strip() func_name = func_name.split()[-1] if ' ' in func_name else func_name return func_name def extract_base_classes(base_clause_node): base_classes = [] for child in base_clause_node.named_children: if child.type == "base_class_clause": type_node = child.child_by_field_name("type") if type_node: base_name = type_node.text.decode('utf-8') if '::' in base_name: base_name = base_name.split('::')[-1] access = 'public' for subchild in child.children: if subchild.type in ["public", "protected", "private", "virtual"]: access_keyword = subchild.text.decode('utf-8') if access_keyword in ['public', 'protected', 'private']: access = access_keyword base_classes.append({ 'name': base_name, 'access': access }) return base_classes