Files
rag_agent/rag-web-ui/backend/app/services/document_processor.py

628 lines
23 KiB
Python
Raw Normal View History

2026-04-13 11:34:23 +08:00
import logging
import os
import hashlib
import tempfile
import traceback
import json
from app.db.session import SessionLocal
from io import BytesIO
from types import SimpleNamespace
2026-04-13 11:34:23 +08:00
from typing import Optional, List, Dict, Any
from fastapi import UploadFile
from langchain_community.document_loaders import (
PyPDFLoader,
Docx2txtLoader,
UnstructuredMarkdownLoader,
TextLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document as LangchainDocument
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.minio import get_minio_client
from app.models.knowledge import ProcessingTask, Document, DocumentChunk
from app.services.chunk_record import ChunkRecord
from minio.error import MinioException
from minio.commonconfig import CopySource
from app.services.vector_store import VectorStoreFactory
from app.services.embedding.embedding_factory import EmbeddingsFactory
from app.services.model_config import ModelConfigService
2026-04-13 11:34:23 +08:00
class UploadResult(BaseModel):
file_path: str
file_name: str
file_size: int
content_type: str
file_hash: str
class TextChunk(BaseModel):
content: str
metadata: Optional[Dict] = None
class PreviewResult(BaseModel):
chunks: List[TextChunk]
total_chunks: int
def _estimate_token_count(text: str) -> int:
# Lightweight estimation without adding tokenizer dependencies.
return len(text)
def _build_enriched_chunk_metadata(
*,
source_metadata: Optional[Dict[str, Any]],
chunk_id: str,
file_name: str,
file_path: str,
kb_id: int,
document_id: int,
chunk_index: int,
chunk_text: str,
) -> Dict[str, Any]:
source_metadata = source_metadata or {}
token_count = _estimate_token_count(chunk_text)
return {
**source_metadata,
"source": file_name,
"chunk_id": chunk_id,
"file_name": file_name,
"file_path": file_path,
"kb_id": kb_id,
"document_id": document_id,
"chunk_index": chunk_index,
"chunk_text": chunk_text,
"token_count": token_count,
"language": source_metadata.get("language", "zh"),
"source_type": "document",
"mission_phase": source_metadata.get("mission_phase"),
"section_title": source_metadata.get("section_title"),
"publish_time": source_metadata.get("publish_time"),
# Keep graph-linked fields for future graph/vector federation.
"extracted_entities": source_metadata.get("extracted_entities", []),
"extracted_entity_types": source_metadata.get("extracted_entity_types", []),
"extracted_relations": source_metadata.get("extracted_relations", []),
"graph_node_ids": source_metadata.get("graph_node_ids", []),
"graph_edge_ids": source_metadata.get("graph_edge_ids", []),
"community_ids": source_metadata.get("community_ids", []),
}
def _sanitize_metadata_for_vector_store(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Normalize metadata to satisfy Chroma's strict metadata constraints."""
if not metadata:
return {}
sanitized: Dict[str, Any] = {}
scalar_types = (str, int, float, bool)
for key, value in metadata.items():
if value is None:
continue
if isinstance(value, scalar_types):
sanitized[key] = value
continue
if isinstance(value, list):
primitive_items = [item for item in value if isinstance(item, scalar_types)]
if primitive_items:
sanitized[key] = primitive_items
elif value:
sanitized[key] = json.dumps(value, ensure_ascii=False)
continue
if isinstance(value, dict):
sanitized[key] = json.dumps(value, ensure_ascii=False)
continue
sanitized[key] = str(value)
return sanitized
def _resolve_model_profile(db: Session, user_id: Optional[int]) -> Any:
if user_id is None:
return None
return ModelConfigService.require_active_config(db, user_id)
def _model_profile_snapshot(model_profile: Any) -> Any:
if model_profile is None:
return None
return SimpleNamespace(
provider=model_profile.provider,
api_key=model_profile.api_key,
api_base=model_profile.api_base,
chat_model=model_profile.chat_model,
embedding_model=model_profile.embedding_model,
)
def _load_model_profile_for_user(user_id: Optional[int]) -> Any:
if user_id is None:
return None
db = SessionLocal()
try:
model_profile = ModelConfigService.require_active_config(db, user_id)
ModelConfigService.touch_last_used(db, model_profile)
return _model_profile_snapshot(model_profile)
finally:
db.close()
async def process_document(
file_path: str,
file_name: str,
kb_id: int,
document_id: int,
chunk_size: int = 1000,
chunk_overlap: int = 200,
user_id: Optional[int] = None,
) -> None:
2026-04-13 11:34:23 +08:00
"""Process document and store in vector database with incremental updates"""
logger = logging.getLogger(__name__)
try:
preview_result = await preview_document(file_path, chunk_size, chunk_overlap)
# Initialize embeddings
logger.info("Initializing OpenAI embeddings...")
model_profile = _load_model_profile_for_user(user_id)
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
2026-04-13 11:34:23 +08:00
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
# Initialize chunk record manager
chunk_manager = ChunkRecord(kb_id)
# Get existing chunk hashes for this file
existing_hashes = chunk_manager.list_chunks(file_name)
# Prepare new chunks
new_chunks = []
current_hashes = set()
documents_to_update = []
for i, chunk in enumerate(preview_result.chunks):
# Calculate chunk hash
chunk_hash = hashlib.sha256(
(chunk.content + str(chunk.metadata)).encode()
).hexdigest()
current_hashes.add(chunk_hash)
# Skip if chunk hasn't changed
if chunk_hash in existing_hashes:
continue
# Create unique ID for the chunk
chunk_id = hashlib.sha256(
f"{kb_id}:{file_name}:{chunk_hash}".encode()
).hexdigest()
metadata = _build_enriched_chunk_metadata(
source_metadata=chunk.metadata,
chunk_id=chunk_id,
file_name=file_name,
file_path=file_path,
kb_id=kb_id,
document_id=document_id,
chunk_index=i,
chunk_text=chunk.content,
)
vector_metadata = _sanitize_metadata_for_vector_store(metadata)
new_chunks.append({
"id": chunk_id,
"kb_id": kb_id,
"document_id": document_id,
"file_name": file_name,
"metadata": metadata,
"hash": chunk_hash
})
# Prepare document for vector store
doc = LangchainDocument(
page_content=chunk.content,
metadata=vector_metadata
)
documents_to_update.append(doc)
# Add new chunks to database and vector store
if new_chunks:
logger.info(f"Adding {len(new_chunks)} new/updated chunks")
chunk_manager.add_chunks(new_chunks)
vector_store.add_documents(documents_to_update)
if settings.GRAPHRAG_ENABLED:
try:
from app.services.graph.graphrag_adapter import GraphRAGAdapter
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
2026-04-13 11:34:23 +08:00
source_texts = [doc.page_content for doc in documents_to_update if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info("GraphRAG ingestion completed in incremental processing")
except Exception as graph_exc:
logger.error(f"GraphRAG ingestion failed in incremental processing: {graph_exc}")
# Delete removed chunks
chunks_to_delete = chunk_manager.get_deleted_chunks(current_hashes, file_name)
if chunks_to_delete:
logger.info(f"Removing {len(chunks_to_delete)} deleted chunks")
chunk_manager.delete_chunks(chunks_to_delete)
vector_store.delete(chunks_to_delete)
logger.info("Document processing completed successfully")
except Exception as e:
logger.error(f"Error processing document: {str(e)}")
raise
async def upload_document(file: UploadFile, kb_id: int) -> UploadResult:
"""Step 1: Upload document to MinIO"""
content = await file.read()
file_size = len(content)
file_hash = hashlib.sha256(content).hexdigest()
# Clean and normalize filename
file_name = "".join(c for c in file.filename if c.isalnum() or c in ('-', '_', '.')).strip()
object_path = f"kb_{kb_id}/{file_name}"
content_types = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".md": "text/markdown",
".txt": "text/plain"
}
_, ext = os.path.splitext(file_name)
content_type = content_types.get(ext.lower(), "application/octet-stream")
# Upload to MinIO
minio_client = get_minio_client()
try:
minio_client.put_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=object_path,
data=BytesIO(content),
length=file_size,
content_type=content_type
)
except Exception as e:
logging.error(f"Failed to upload file to MinIO: {str(e)}")
raise
return UploadResult(
file_path=object_path,
file_name=file_name,
file_size=file_size,
content_type=content_type,
file_hash=file_hash
)
async def preview_document(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> PreviewResult:
"""Step 2: Generate preview chunks"""
# Get file from MinIO
minio_client = get_minio_client()
_, ext = os.path.splitext(file_path)
ext = ext.lower()
# Download to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
minio_client.fget_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=file_path,
file_path=temp_file.name
)
temp_path = temp_file.name
try:
# Select appropriate loader
if ext == ".pdf":
loader = PyPDFLoader(temp_path)
elif ext == ".docx":
loader = Docx2txtLoader(temp_path)
elif ext == ".md":
loader = UnstructuredMarkdownLoader(temp_path)
else: # Default to text loader
loader = TextLoader(temp_path)
# Load and split the document
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_documents(documents)
# Convert to preview format
preview_chunks = [
TextChunk(
content=chunk.page_content,
metadata=chunk.metadata
)
for chunk in chunks
]
return PreviewResult(
chunks=preview_chunks,
total_chunks=len(chunks)
)
finally:
os.unlink(temp_path)
async def process_document_background(
temp_path: str,
file_name: str,
kb_id: int,
task_id: int,
db: Session = None,
chunk_size: int = 1000,
chunk_overlap: int = 200,
user_id: Optional[int] = None,
2026-04-13 11:34:23 +08:00
) -> None:
"""Process document in background"""
logger = logging.getLogger(__name__)
logger.info(f"Starting background processing for task {task_id}, file: {file_name}")
# if we don't pass in db, create a new database session
if db is None:
db = SessionLocal()
should_close_db = True
else:
should_close_db = False
task = db.query(ProcessingTask).get(task_id)
if not task:
logger.error(f"Task {task_id} not found")
return
minio_client = None
local_temp_path = None
try:
logger.info(f"Task {task_id}: Setting status to processing")
task.status = "processing"
db.commit()
model_profile = _resolve_model_profile(db, user_id)
if model_profile is not None:
ModelConfigService.touch_last_used(db, model_profile)
2026-04-13 11:34:23 +08:00
# 1. 从临时目录下载文件
minio_client = get_minio_client()
try:
local_temp_path = f"/tmp/temp_{task_id}_{file_name}" # 使用系统临时目录
logger.info(f"Task {task_id}: Downloading file from MinIO: {temp_path} to {local_temp_path}")
minio_client.fget_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path,
file_path=local_temp_path
)
logger.info(f"Task {task_id}: File downloaded successfully")
except MinioException as e:
# Idempotent fallback: temp object may already be consumed by another task.
# If the final document is already created, treat current task as completed.
if "NoSuchKey" in str(e) and task.document_upload:
existing_document = db.query(Document).filter(
Document.knowledge_base_id == kb_id,
Document.file_name == file_name,
Document.file_hash == task.document_upload.file_hash,
).first()
if existing_document:
logger.warning(
f"Task {task_id}: Temp object missing but document already exists, "
f"marking task as completed (document_id={existing_document.id})"
)
task.status = "completed"
task.document_id = existing_document.id
task.error_message = None
task.document_upload.status = "completed"
task.document_upload.error_message = None
db.commit()
return
error_msg = f"Failed to download temp file: {str(e)}"
logger.error(f"Task {task_id}: {error_msg}")
raise Exception(error_msg)
try:
# 2. 加载和分块文档
_, ext = os.path.splitext(file_name)
ext = ext.lower()
logger.info(f"Task {task_id}: Loading document with extension {ext}")
# 选择合适的加载器
if ext == ".pdf":
loader = PyPDFLoader(local_temp_path)
elif ext == ".docx":
loader = Docx2txtLoader(local_temp_path)
elif ext == ".md":
loader = UnstructuredMarkdownLoader(local_temp_path)
else: # 默认使用文本加载器
loader = TextLoader(local_temp_path)
logger.info(f"Task {task_id}: Loading document content")
documents = loader.load()
logger.info(f"Task {task_id}: Document loaded successfully")
logger.info(f"Task {task_id}: Splitting document into chunks")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
chunks = text_splitter.split_documents(documents)
logger.info(f"Task {task_id}: Document split into {len(chunks)} chunks")
# 3. 创建向量存储
logger.info(f"Task {task_id}: Initializing vector store")
embeddings = EmbeddingsFactory.create(model_profile=model_profile)
2026-04-13 11:34:23 +08:00
vector_store = VectorStoreFactory.create(
store_type=settings.VECTOR_STORE_TYPE,
collection_name=f"kb_{kb_id}",
embedding_function=embeddings,
)
# 4. 将临时文件移动到永久目录
permanent_path = f"kb_{kb_id}/{file_name}"
try:
logger.info(f"Task {task_id}: Moving file to permanent storage")
# 复制到永久目录
source = CopySource(settings.MINIO_BUCKET_NAME, temp_path)
minio_client.copy_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=permanent_path,
source=source
)
logger.info(f"Task {task_id}: File moved to permanent storage")
# 删除临时文件
logger.info(f"Task {task_id}: Removing temporary file from MinIO")
minio_client.remove_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path
)
logger.info(f"Task {task_id}: Temporary file removed")
except MinioException as e:
error_msg = f"Failed to move file to permanent storage: {str(e)}"
logger.error(f"Task {task_id}: {error_msg}")
raise Exception(error_msg)
# 5. 创建文档记录
logger.info(f"Task {task_id}: Creating document record")
document = Document(
file_name=file_name,
file_path=permanent_path,
file_hash=task.document_upload.file_hash,
file_size=task.document_upload.file_size,
content_type=task.document_upload.content_type,
knowledge_base_id=kb_id
)
db.add(document)
db.flush()
db.refresh(document)
logger.info(f"Task {task_id}: Document record created with ID {document.id}")
# 6. 存储文档块
logger.info(f"Task {task_id}: Storing document chunks")
for i, chunk in enumerate(chunks):
# 为每个 chunk 生成唯一的 ID
chunk_id = hashlib.sha256(
f"{kb_id}:{file_name}:{chunk.page_content}".encode()
).hexdigest()
metadata = _build_enriched_chunk_metadata(
source_metadata=chunk.metadata,
chunk_id=chunk_id,
file_name=file_name,
file_path=permanent_path,
kb_id=kb_id,
document_id=document.id,
chunk_index=i,
chunk_text=chunk.page_content,
)
chunk.metadata = metadata
doc_chunk = DocumentChunk(
id=chunk_id, # 添加 ID 字段
document_id=document.id,
kb_id=kb_id,
file_name=file_name,
chunk_metadata={
"page_content": chunk.page_content,
**metadata
},
hash=hashlib.sha256(
(chunk.page_content + str(metadata)).encode()
).hexdigest()
)
db.add(doc_chunk)
if i > 0 and i % 100 == 0:
logger.info(f"Task {task_id}: Stored {i} chunks")
db.flush()
# 7. 添加到向量存储
logger.info(f"Task {task_id}: Adding chunks to vector store")
vector_chunks = [
LangchainDocument(
page_content=chunk.page_content,
metadata=_sanitize_metadata_for_vector_store(chunk.metadata),
)
for chunk in chunks
]
vector_store.add_documents(vector_chunks)
# 移除 persist() 调用,因为新版本不需要
logger.info(f"Task {task_id}: Chunks added to vector store")
if settings.GRAPHRAG_ENABLED:
try:
from app.services.graph.graphrag_adapter import GraphRAGAdapter
logger.info(f"Task {task_id}: Starting GraphRAG ingestion")
graph_adapter = GraphRAGAdapter(model_profile=model_profile)
2026-04-13 11:34:23 +08:00
source_texts = [doc.page_content for doc in documents if doc.page_content.strip()]
await graph_adapter.ingest_texts(kb_id, source_texts)
logger.info(f"Task {task_id}: GraphRAG ingestion completed")
except Exception as graph_exc:
logger.error(f"Task {task_id}: GraphRAG ingestion failed: {graph_exc}")
# 8. 更新任务状态
logger.info(f"Task {task_id}: Updating task status to completed")
task.status = "completed"
task.document_id = document.id # 更新为新创建的文档ID
# 9. 更新上传记录状态
upload = task.document_upload # 直接通过关系获取
if upload:
logger.info(f"Task {task_id}: Updating upload record status to completed")
upload.status = "completed"
db.commit()
logger.info(f"Task {task_id}: Processing completed successfully")
finally:
# 清理本地临时文件
try:
if os.path.exists(local_temp_path):
logger.info(f"Task {task_id}: Cleaning up local temp file")
os.remove(local_temp_path)
logger.info(f"Task {task_id}: Local temp file cleaned up")
except Exception as e:
logger.warning(f"Task {task_id}: Failed to clean up local temp file: {str(e)}")
except Exception as e:
logger.error(f"Task {task_id}: Error processing document: {str(e)}")
logger.error(f"Task {task_id}: Stack trace: {traceback.format_exc()}")
db.rollback()
failed_task = db.query(ProcessingTask).get(task_id)
if failed_task:
failed_task.status = "failed"
failed_task.error_message = str(e)
if failed_task.document_upload:
failed_task.document_upload.status = "failed"
failed_task.document_upload.error_message = str(e)
db.commit()
# 清理临时文件
try:
logger.info(f"Task {task_id}: Cleaning up temporary file after error")
if minio_client is not None:
minio_client.remove_object(
bucket_name=settings.MINIO_BUCKET_NAME,
object_name=temp_path
)
logger.info(f"Task {task_id}: Temporary file cleaned up after error")
except:
logger.warning(f"Task {task_id}: Failed to clean up temporary file after error")
finally:
# if we create the db session, we need to close it
if should_close_db and db:
db.close()