增加代码知识库;修复文档处理内容;增加API设置
This commit is contained in:
300
RAG-TEST-TOOLS/code_parser.py
Normal file
300
RAG-TEST-TOOLS/code_parser.py
Normal file
@@ -0,0 +1,300 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
try:
|
||||
import tree_sitter
|
||||
import tree_sitter_c
|
||||
import tree_sitter_cpp
|
||||
from tree_sitter import Language, Parser
|
||||
except ImportError: # pragma: no cover - depends on deployment environment
|
||||
tree_sitter = None
|
||||
tree_sitter_c = None
|
||||
tree_sitter_cpp = None
|
||||
Language = None
|
||||
Parser = None
|
||||
|
||||
# 常量定义
|
||||
CPP_EXTENSIONS = {'.c', '.cpp', '.h', '.hpp', '.tcc'}
|
||||
IGNORE_DIRS = {'build', 'cmake-build', '.git', 'vendor', 'lib', 'external', 'Debug'}
|
||||
CPP_LANGUAGE = Language(tree_sitter_cpp.language()) if Language and tree_sitter_cpp else None
|
||||
|
||||
class RegexNode:
|
||||
def __init__(self, node_type, text="", start_line=0, end_line=0, fields=None, children=None):
|
||||
self.type = node_type
|
||||
self.text = text.encode("utf-8", errors="ignore") if isinstance(text, str) else text
|
||||
self.start_point = (start_line, 0)
|
||||
self.end_point = (end_line, 0)
|
||||
self._fields = fields or {}
|
||||
self.children = children or []
|
||||
self.named_children = self.children
|
||||
|
||||
def child_by_field_name(self, name):
|
||||
return self._fields.get(name)
|
||||
|
||||
|
||||
class RegexTree:
|
||||
def __init__(self, root_node):
|
||||
self.root_node = root_node
|
||||
|
||||
|
||||
class RegexParser:
|
||||
_signature_re = re.compile(
|
||||
r"(?m)^[ \t]*(?:[A-Za-z_][\w\s\*\&:<>]*?[ \t]+)?([A-Za-z_]\w*(?:::[A-Za-z_]\w*)?)"
|
||||
r"\s*\([^;{}]*\)\s*\{"
|
||||
)
|
||||
_call_re = re.compile(r"\b([A-Za-z_]\w*)\s*\(")
|
||||
_control_words = {"if", "for", "while", "switch", "return", "sizeof"}
|
||||
|
||||
def parse(self, source_bytes):
|
||||
source_text = safe_decode(source_bytes)
|
||||
children = self._parse_includes(source_text)
|
||||
children.extend(self._parse_functions(source_text))
|
||||
return RegexTree(RegexNode("translation_unit", source_text, 0, source_text.count("\n"), children=children))
|
||||
|
||||
def _parse_includes(self, source_text):
|
||||
nodes = []
|
||||
for match in re.finditer(r"(?m)^\s*#\s*include\s*([<\"][^>\"]+[>\"])", source_text):
|
||||
line = source_text.count("\n", 0, match.start())
|
||||
path_node = RegexNode("string_literal", match.group(1), line, line)
|
||||
nodes.append(
|
||||
RegexNode(
|
||||
"preproc_include",
|
||||
match.group(0),
|
||||
line,
|
||||
line,
|
||||
fields={"path": path_node},
|
||||
children=[path_node],
|
||||
)
|
||||
)
|
||||
return nodes
|
||||
|
||||
def _parse_functions(self, source_text):
|
||||
nodes = []
|
||||
for match in self._signature_re.finditer(source_text):
|
||||
name = match.group(1)
|
||||
if name in self._control_words:
|
||||
continue
|
||||
body_start = match.end() - 1
|
||||
body_end = self._find_matching_brace(source_text, body_start)
|
||||
if body_end <= body_start:
|
||||
continue
|
||||
start_line = source_text.count("\n", 0, match.start())
|
||||
end_line = source_text.count("\n", 0, body_end)
|
||||
function_text = source_text[match.start() : body_end + 1]
|
||||
body_text = source_text[body_start : body_end + 1]
|
||||
identifier = RegexNode("identifier", name, start_line, start_line)
|
||||
declarator = RegexNode(
|
||||
"function_declarator",
|
||||
source_text[match.start() : body_start],
|
||||
start_line,
|
||||
start_line,
|
||||
fields={"declarator": identifier},
|
||||
children=[identifier],
|
||||
)
|
||||
body = RegexNode(
|
||||
"compound_statement",
|
||||
body_text,
|
||||
start_line,
|
||||
end_line,
|
||||
children=self._parse_calls(body_text, start_line),
|
||||
)
|
||||
nodes.append(
|
||||
RegexNode(
|
||||
"function_definition",
|
||||
function_text,
|
||||
start_line,
|
||||
end_line,
|
||||
fields={"declarator": declarator, "body": body},
|
||||
children=[declarator, body],
|
||||
)
|
||||
)
|
||||
return nodes
|
||||
|
||||
def _find_matching_brace(self, source_text, start_index):
|
||||
depth = 0
|
||||
for index in range(start_index, len(source_text)):
|
||||
char = source_text[index]
|
||||
if char == "{":
|
||||
depth += 1
|
||||
elif char == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return index
|
||||
return -1
|
||||
|
||||
def _parse_calls(self, body_text, start_line):
|
||||
calls = []
|
||||
for match in self._call_re.finditer(body_text):
|
||||
name = match.group(1)
|
||||
if name in self._control_words:
|
||||
continue
|
||||
line = start_line + body_text.count("\n", 0, match.start())
|
||||
function_node = RegexNode("identifier", name, line, line)
|
||||
calls.append(
|
||||
RegexNode(
|
||||
"call_expression",
|
||||
match.group(0),
|
||||
line,
|
||||
line,
|
||||
fields={"function": function_node},
|
||||
children=[function_node],
|
||||
)
|
||||
)
|
||||
return calls
|
||||
|
||||
|
||||
# 全局解析器
|
||||
if Parser and CPP_LANGUAGE:
|
||||
parser = Parser()
|
||||
parser.language = CPP_LANGUAGE
|
||||
else:
|
||||
parser = RegexParser()
|
||||
|
||||
# 全局图数据结构
|
||||
FUNCTION_CALL_GRAPH = {}
|
||||
FILE_DEPENDENCIES = {}
|
||||
CALLED_BY_GRAPH = {}
|
||||
|
||||
def safe_decode(b: bytes) -> str:
|
||||
for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']:
|
||||
try:
|
||||
return b.decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return b.decode('utf-8', errors='replace')
|
||||
|
||||
def extract_comment_before(node, source_lines):
|
||||
start_line = node.start_point[0]
|
||||
comment_lines = []
|
||||
for i in range(start_line - 1, max(0, start_line - 10), -1):
|
||||
line = safe_decode(source_lines[i]).strip()
|
||||
if line.startswith('//') or line.startswith('/*'):
|
||||
comment_lines.append(line)
|
||||
return " ".join(comment_lines[::-1]) or "无注释"
|
||||
|
||||
def extract_code_snippet(tree, node, source_lines):
|
||||
start_line = node.start_point[0]
|
||||
end_line = node.end_point[0]
|
||||
code_lines = []
|
||||
for i in range(start_line, end_line + 1):
|
||||
if i < len(source_lines):
|
||||
code_lines.append(safe_decode(source_lines[i]))
|
||||
return "\n".join(code_lines)
|
||||
|
||||
def extract_full_function_name(func_def_node):
|
||||
declarator = func_def_node.child_by_field_name("declarator")
|
||||
if not declarator:
|
||||
return "<unknown>"
|
||||
if declarator.type == "function_declarator":
|
||||
inner_declarator = declarator.child_by_field_name("declarator")
|
||||
if not inner_declarator:
|
||||
return "<unknown>"
|
||||
if inner_declarator.type == "identifier":
|
||||
return inner_declarator.text.decode('utf-8')
|
||||
elif inner_declarator.type == "field_expression":
|
||||
return extract_qualified_name_from_field(inner_declarator)
|
||||
else:
|
||||
raw = inner_declarator.text.decode('utf-8')
|
||||
return raw.split('(')[0].strip()
|
||||
elif declarator.type == "identifier":
|
||||
return declarator.text.decode('utf-8')
|
||||
else:
|
||||
raw = declarator.text.decode('utf-8')
|
||||
return raw.split('(')[0].strip()
|
||||
|
||||
def extract_qualified_name_from_field(field_node):
|
||||
def _extract(n):
|
||||
if n.type == "identifier":
|
||||
return n.text.decode('utf-8')
|
||||
elif n.type == "field_expression":
|
||||
left = n.child_by_field_name("argument") or (n.children[0] if n.children else None)
|
||||
right = n.child_by_field_name("field") or (n.children[-1] if n.children else None)
|
||||
left_str = _extract(left) if left else ""
|
||||
right_str = _extract(right) if right else ""
|
||||
if left_str and right_str:
|
||||
return f"{left_str}::{right_str}"
|
||||
return left_str or right_str
|
||||
elif n.type == "function_declarator":
|
||||
inner = n.child_by_field_name("declarator")
|
||||
return _extract(inner) if inner else n.text.decode('utf-8')
|
||||
else:
|
||||
text = n.text.decode('utf-8')
|
||||
return text.split('(')[0].strip()
|
||||
return _extract(field_node)
|
||||
|
||||
def collect_call_expressions(func_ast_node):
|
||||
calls = set()
|
||||
def traverse_ast(node):
|
||||
if node.type == "call_expression":
|
||||
func_node = node.child_by_field_name("function")
|
||||
if func_node:
|
||||
callee_name = _simplified_extract_callee_name(func_node)
|
||||
if callee_name != "<unknown>":
|
||||
calls.add(callee_name)
|
||||
for child in node.children:
|
||||
traverse_ast(child)
|
||||
body_node = func_ast_node.child_by_field_name("body")
|
||||
if body_node:
|
||||
traverse_ast(body_node)
|
||||
return list(calls)
|
||||
|
||||
def _simplified_extract_callee_name(func_node):
|
||||
if not func_node:
|
||||
return "<unknown>"
|
||||
raw_text = func_node.text.decode('utf-8', errors='ignore')
|
||||
name_part = raw_text.split('(')[0].strip()
|
||||
name_part = name_part.rstrip('*&')
|
||||
for sep in ['::', '->', '.']:
|
||||
if sep in name_part:
|
||||
parts = name_part.rsplit(sep, 1)
|
||||
name_part = parts[-1] if parts else name_part
|
||||
if name_part and (name_part[0].isalpha() or name_part[0] == '_'):
|
||||
return name_part
|
||||
return "<unknown>"
|
||||
|
||||
def extract_class_name(class_node):
|
||||
name_node = class_node.child_by_field_name("name")
|
||||
if name_node and name_node.type == "type_identifier":
|
||||
return name_node.text.decode('utf-8')
|
||||
for child in class_node.children:
|
||||
if child.type == "type_identifier":
|
||||
return child.text.decode('utf-8')
|
||||
return "<unknown>"
|
||||
|
||||
def extract_member_function_name(func_node, parent_class_name):
|
||||
declarator = func_node.child_by_field_name("declarator")
|
||||
if not declarator:
|
||||
return "<unknown>"
|
||||
if declarator.type == "function_declarator":
|
||||
inner_declarator = declarator.child_by_field_name("declarator")
|
||||
if not inner_declarator:
|
||||
return "<unknown>"
|
||||
if inner_declarator.type == "field_expression":
|
||||
return extract_qualified_name_from_field(inner_declarator)
|
||||
elif inner_declarator.type == "identifier":
|
||||
return inner_declarator.text.decode('utf-8')
|
||||
raw_text = declarator.text.decode('utf-8', errors='ignore')
|
||||
func_name = raw_text.split('(')[0].strip()
|
||||
func_name = func_name.split()[-1] if ' ' in func_name else func_name
|
||||
return func_name
|
||||
|
||||
def extract_base_classes(base_clause_node):
|
||||
base_classes = []
|
||||
for child in base_clause_node.named_children:
|
||||
if child.type == "base_class_clause":
|
||||
type_node = child.child_by_field_name("type")
|
||||
if type_node:
|
||||
base_name = type_node.text.decode('utf-8')
|
||||
if '::' in base_name:
|
||||
base_name = base_name.split('::')[-1]
|
||||
access = 'public'
|
||||
for subchild in child.children:
|
||||
if subchild.type in ["public", "protected", "private", "virtual"]:
|
||||
access_keyword = subchild.text.decode('utf-8')
|
||||
if access_keyword in ['public', 'protected', 'private']:
|
||||
access = access_keyword
|
||||
base_classes.append({
|
||||
'name': base_name,
|
||||
'access': access
|
||||
})
|
||||
return base_classes
|
||||
Reference in New Issue
Block a user