583 lines
22 KiB
Python
583 lines
22 KiB
Python
import logging
|
|
import os
|
|
import hashlib
|
|
import tempfile
|
|
import traceback
|
|
import json
|
|
from app.db.session import SessionLocal
|
|
from io import BytesIO
|
|
from typing import Optional, List, Dict, Any
|
|
from fastapi import UploadFile
|
|
from langchain_community.document_loaders import (
|
|
PyPDFLoader,
|
|
Docx2txtLoader,
|
|
UnstructuredMarkdownLoader,
|
|
TextLoader
|
|
)
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_core.documents import Document as LangchainDocument
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
from app.core.config import settings
|
|
from app.core.minio import get_minio_client
|
|
from app.models.knowledge import ProcessingTask, Document, DocumentChunk
|
|
from app.services.chunk_record import ChunkRecord
|
|
from minio.error import MinioException
|
|
from minio.commonconfig import CopySource
|
|
from app.services.vector_store import VectorStoreFactory
|
|
from app.services.embedding.embedding_factory import EmbeddingsFactory
|
|
|
|
class UploadResult(BaseModel):
|
|
file_path: str
|
|
file_name: str
|
|
file_size: int
|
|
content_type: str
|
|
file_hash: str
|
|
|
|
class TextChunk(BaseModel):
|
|
content: str
|
|
metadata: Optional[Dict] = None
|
|
|
|
class PreviewResult(BaseModel):
|
|
chunks: List[TextChunk]
|
|
total_chunks: int
|
|
|
|
|
|
def _estimate_token_count(text: str) -> int:
|
|
# Lightweight estimation without adding tokenizer dependencies.
|
|
return len(text)
|
|
|
|
|
|
def _build_enriched_chunk_metadata(
|
|
*,
|
|
source_metadata: Optional[Dict[str, Any]],
|
|
chunk_id: str,
|
|
file_name: str,
|
|
file_path: str,
|
|
kb_id: int,
|
|
document_id: int,
|
|
chunk_index: int,
|
|
chunk_text: str,
|
|
) -> Dict[str, Any]:
|
|
source_metadata = source_metadata or {}
|
|
token_count = _estimate_token_count(chunk_text)
|
|
|
|
return {
|
|
**source_metadata,
|
|
"source": file_name,
|
|
"chunk_id": chunk_id,
|
|
"file_name": file_name,
|
|
"file_path": file_path,
|
|
"kb_id": kb_id,
|
|
"document_id": document_id,
|
|
"chunk_index": chunk_index,
|
|
"chunk_text": chunk_text,
|
|
"token_count": token_count,
|
|
"language": source_metadata.get("language", "zh"),
|
|
"source_type": "document",
|
|
"mission_phase": source_metadata.get("mission_phase"),
|
|
"section_title": source_metadata.get("section_title"),
|
|
"publish_time": source_metadata.get("publish_time"),
|
|
# Keep graph-linked fields for future graph/vector federation.
|
|
"extracted_entities": source_metadata.get("extracted_entities", []),
|
|
"extracted_entity_types": source_metadata.get("extracted_entity_types", []),
|
|
"extracted_relations": source_metadata.get("extracted_relations", []),
|
|
"graph_node_ids": source_metadata.get("graph_node_ids", []),
|
|
"graph_edge_ids": source_metadata.get("graph_edge_ids", []),
|
|
"community_ids": source_metadata.get("community_ids", []),
|
|
}
|
|
|
|
|
|
def _sanitize_metadata_for_vector_store(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Normalize metadata to satisfy Chroma's strict metadata constraints."""
|
|
if not metadata:
|
|
return {}
|
|
|
|
sanitized: Dict[str, Any] = {}
|
|
scalar_types = (str, int, float, bool)
|
|
|
|
for key, value in metadata.items():
|
|
if value is None:
|
|
continue
|
|
|
|
if isinstance(value, scalar_types):
|
|
sanitized[key] = value
|
|
continue
|
|
|
|
if isinstance(value, list):
|
|
primitive_items = [item for item in value if isinstance(item, scalar_types)]
|
|
if primitive_items:
|
|
sanitized[key] = primitive_items
|
|
elif value:
|
|
sanitized[key] = json.dumps(value, ensure_ascii=False)
|
|
continue
|
|
|
|
if isinstance(value, dict):
|
|
sanitized[key] = json.dumps(value, ensure_ascii=False)
|
|
continue
|
|
|
|
sanitized[key] = str(value)
|
|
|
|
return sanitized
|
|
|
|
async def process_document(file_path: str, file_name: str, kb_id: int, document_id: int, chunk_size: int = 1000, chunk_overlap: int = 200) -> None:
|
|
"""Process document and store in vector database with incremental updates"""
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
preview_result = await preview_document(file_path, chunk_size, chunk_overlap)
|
|
|
|
# Initialize embeddings
|
|
logger.info("Initializing OpenAI embeddings...")
|
|
embeddings = EmbeddingsFactory.create()
|
|
|
|
logger.info(f"Initializing vector store with collection: kb_{kb_id}")
|
|
vector_store = VectorStoreFactory.create(
|
|
store_type=settings.VECTOR_STORE_TYPE,
|
|
collection_name=f"kb_{kb_id}",
|
|
embedding_function=embeddings,
|
|
)
|
|
|
|
# Initialize chunk record manager
|
|
chunk_manager = ChunkRecord(kb_id)
|
|
|
|
# Get existing chunk hashes for this file
|
|
existing_hashes = chunk_manager.list_chunks(file_name)
|
|
|
|
# Prepare new chunks
|
|
new_chunks = []
|
|
current_hashes = set()
|
|
documents_to_update = []
|
|
|
|
for i, chunk in enumerate(preview_result.chunks):
|
|
# Calculate chunk hash
|
|
chunk_hash = hashlib.sha256(
|
|
(chunk.content + str(chunk.metadata)).encode()
|
|
).hexdigest()
|
|
current_hashes.add(chunk_hash)
|
|
|
|
# Skip if chunk hasn't changed
|
|
if chunk_hash in existing_hashes:
|
|
continue
|
|
|
|
# Create unique ID for the chunk
|
|
chunk_id = hashlib.sha256(
|
|
f"{kb_id}:{file_name}:{chunk_hash}".encode()
|
|
).hexdigest()
|
|
|
|
metadata = _build_enriched_chunk_metadata(
|
|
source_metadata=chunk.metadata,
|
|
chunk_id=chunk_id,
|
|
file_name=file_name,
|
|
file_path=file_path,
|
|
kb_id=kb_id,
|
|
document_id=document_id,
|
|
chunk_index=i,
|
|
chunk_text=chunk.content,
|
|
)
|
|
vector_metadata = _sanitize_metadata_for_vector_store(metadata)
|
|
|
|
new_chunks.append({
|
|
"id": chunk_id,
|
|
"kb_id": kb_id,
|
|
"document_id": document_id,
|
|
"file_name": file_name,
|
|
"metadata": metadata,
|
|
"hash": chunk_hash
|
|
})
|
|
|
|
# Prepare document for vector store
|
|
doc = LangchainDocument(
|
|
page_content=chunk.content,
|
|
metadata=vector_metadata
|
|
)
|
|
documents_to_update.append(doc)
|
|
|
|
# Add new chunks to database and vector store
|
|
if new_chunks:
|
|
logger.info(f"Adding {len(new_chunks)} new/updated chunks")
|
|
chunk_manager.add_chunks(new_chunks)
|
|
vector_store.add_documents(documents_to_update)
|
|
if settings.GRAPHRAG_ENABLED:
|
|
try:
|
|
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
|
|
|
graph_adapter = GraphRAGAdapter()
|
|
source_texts = [doc.page_content for doc in documents_to_update if doc.page_content.strip()]
|
|
await graph_adapter.ingest_texts(kb_id, source_texts)
|
|
logger.info("GraphRAG ingestion completed in incremental processing")
|
|
except Exception as graph_exc:
|
|
logger.error(f"GraphRAG ingestion failed in incremental processing: {graph_exc}")
|
|
|
|
# Delete removed chunks
|
|
chunks_to_delete = chunk_manager.get_deleted_chunks(current_hashes, file_name)
|
|
if chunks_to_delete:
|
|
logger.info(f"Removing {len(chunks_to_delete)} deleted chunks")
|
|
chunk_manager.delete_chunks(chunks_to_delete)
|
|
vector_store.delete(chunks_to_delete)
|
|
|
|
logger.info("Document processing completed successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing document: {str(e)}")
|
|
raise
|
|
|
|
async def upload_document(file: UploadFile, kb_id: int) -> UploadResult:
|
|
"""Step 1: Upload document to MinIO"""
|
|
content = await file.read()
|
|
file_size = len(content)
|
|
|
|
file_hash = hashlib.sha256(content).hexdigest()
|
|
|
|
# Clean and normalize filename
|
|
file_name = "".join(c for c in file.filename if c.isalnum() or c in ('-', '_', '.')).strip()
|
|
object_path = f"kb_{kb_id}/{file_name}"
|
|
|
|
content_types = {
|
|
".pdf": "application/pdf",
|
|
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
".md": "text/markdown",
|
|
".txt": "text/plain"
|
|
}
|
|
|
|
_, ext = os.path.splitext(file_name)
|
|
content_type = content_types.get(ext.lower(), "application/octet-stream")
|
|
|
|
# Upload to MinIO
|
|
minio_client = get_minio_client()
|
|
try:
|
|
minio_client.put_object(
|
|
bucket_name=settings.MINIO_BUCKET_NAME,
|
|
object_name=object_path,
|
|
data=BytesIO(content),
|
|
length=file_size,
|
|
content_type=content_type
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Failed to upload file to MinIO: {str(e)}")
|
|
raise
|
|
|
|
return UploadResult(
|
|
file_path=object_path,
|
|
file_name=file_name,
|
|
file_size=file_size,
|
|
content_type=content_type,
|
|
file_hash=file_hash
|
|
)
|
|
|
|
async def preview_document(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> PreviewResult:
|
|
"""Step 2: Generate preview chunks"""
|
|
# Get file from MinIO
|
|
minio_client = get_minio_client()
|
|
_, ext = os.path.splitext(file_path)
|
|
ext = ext.lower()
|
|
|
|
# Download to temp file
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
|
|
minio_client.fget_object(
|
|
bucket_name=settings.MINIO_BUCKET_NAME,
|
|
object_name=file_path,
|
|
file_path=temp_file.name
|
|
)
|
|
temp_path = temp_file.name
|
|
|
|
try:
|
|
# Select appropriate loader
|
|
if ext == ".pdf":
|
|
loader = PyPDFLoader(temp_path)
|
|
elif ext == ".docx":
|
|
loader = Docx2txtLoader(temp_path)
|
|
elif ext == ".md":
|
|
loader = UnstructuredMarkdownLoader(temp_path)
|
|
else: # Default to text loader
|
|
loader = TextLoader(temp_path)
|
|
|
|
# Load and split the document
|
|
documents = loader.load()
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap
|
|
)
|
|
chunks = text_splitter.split_documents(documents)
|
|
|
|
# Convert to preview format
|
|
preview_chunks = [
|
|
TextChunk(
|
|
content=chunk.page_content,
|
|
metadata=chunk.metadata
|
|
)
|
|
for chunk in chunks
|
|
]
|
|
|
|
return PreviewResult(
|
|
chunks=preview_chunks,
|
|
total_chunks=len(chunks)
|
|
)
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
async def process_document_background(
|
|
temp_path: str,
|
|
file_name: str,
|
|
kb_id: int,
|
|
task_id: int,
|
|
db: Session = None,
|
|
chunk_size: int = 1000,
|
|
chunk_overlap: int = 200
|
|
) -> None:
|
|
"""Process document in background"""
|
|
logger = logging.getLogger(__name__)
|
|
logger.info(f"Starting background processing for task {task_id}, file: {file_name}")
|
|
|
|
# if we don't pass in db, create a new database session
|
|
if db is None:
|
|
db = SessionLocal()
|
|
should_close_db = True
|
|
else:
|
|
should_close_db = False
|
|
|
|
task = db.query(ProcessingTask).get(task_id)
|
|
if not task:
|
|
logger.error(f"Task {task_id} not found")
|
|
return
|
|
|
|
minio_client = None
|
|
local_temp_path = None
|
|
|
|
try:
|
|
logger.info(f"Task {task_id}: Setting status to processing")
|
|
task.status = "processing"
|
|
db.commit()
|
|
|
|
# 1. 从临时目录下载文件
|
|
minio_client = get_minio_client()
|
|
try:
|
|
local_temp_path = f"/tmp/temp_{task_id}_{file_name}" # 使用系统临时目录
|
|
logger.info(f"Task {task_id}: Downloading file from MinIO: {temp_path} to {local_temp_path}")
|
|
minio_client.fget_object(
|
|
bucket_name=settings.MINIO_BUCKET_NAME,
|
|
object_name=temp_path,
|
|
file_path=local_temp_path
|
|
)
|
|
logger.info(f"Task {task_id}: File downloaded successfully")
|
|
except MinioException as e:
|
|
# Idempotent fallback: temp object may already be consumed by another task.
|
|
# If the final document is already created, treat current task as completed.
|
|
if "NoSuchKey" in str(e) and task.document_upload:
|
|
existing_document = db.query(Document).filter(
|
|
Document.knowledge_base_id == kb_id,
|
|
Document.file_name == file_name,
|
|
Document.file_hash == task.document_upload.file_hash,
|
|
).first()
|
|
if existing_document:
|
|
logger.warning(
|
|
f"Task {task_id}: Temp object missing but document already exists, "
|
|
f"marking task as completed (document_id={existing_document.id})"
|
|
)
|
|
task.status = "completed"
|
|
task.document_id = existing_document.id
|
|
task.error_message = None
|
|
task.document_upload.status = "completed"
|
|
task.document_upload.error_message = None
|
|
db.commit()
|
|
return
|
|
|
|
error_msg = f"Failed to download temp file: {str(e)}"
|
|
logger.error(f"Task {task_id}: {error_msg}")
|
|
raise Exception(error_msg)
|
|
|
|
try:
|
|
# 2. 加载和分块文档
|
|
_, ext = os.path.splitext(file_name)
|
|
ext = ext.lower()
|
|
|
|
logger.info(f"Task {task_id}: Loading document with extension {ext}")
|
|
# 选择合适的加载器
|
|
if ext == ".pdf":
|
|
loader = PyPDFLoader(local_temp_path)
|
|
elif ext == ".docx":
|
|
loader = Docx2txtLoader(local_temp_path)
|
|
elif ext == ".md":
|
|
loader = UnstructuredMarkdownLoader(local_temp_path)
|
|
else: # 默认使用文本加载器
|
|
loader = TextLoader(local_temp_path)
|
|
|
|
logger.info(f"Task {task_id}: Loading document content")
|
|
documents = loader.load()
|
|
logger.info(f"Task {task_id}: Document loaded successfully")
|
|
|
|
logger.info(f"Task {task_id}: Splitting document into chunks")
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap
|
|
)
|
|
chunks = text_splitter.split_documents(documents)
|
|
logger.info(f"Task {task_id}: Document split into {len(chunks)} chunks")
|
|
|
|
# 3. 创建向量存储
|
|
logger.info(f"Task {task_id}: Initializing vector store")
|
|
embeddings = EmbeddingsFactory.create()
|
|
|
|
vector_store = VectorStoreFactory.create(
|
|
store_type=settings.VECTOR_STORE_TYPE,
|
|
collection_name=f"kb_{kb_id}",
|
|
embedding_function=embeddings,
|
|
)
|
|
|
|
# 4. 将临时文件移动到永久目录
|
|
permanent_path = f"kb_{kb_id}/{file_name}"
|
|
try:
|
|
logger.info(f"Task {task_id}: Moving file to permanent storage")
|
|
# 复制到永久目录
|
|
source = CopySource(settings.MINIO_BUCKET_NAME, temp_path)
|
|
minio_client.copy_object(
|
|
bucket_name=settings.MINIO_BUCKET_NAME,
|
|
object_name=permanent_path,
|
|
source=source
|
|
)
|
|
logger.info(f"Task {task_id}: File moved to permanent storage")
|
|
|
|
# 删除临时文件
|
|
logger.info(f"Task {task_id}: Removing temporary file from MinIO")
|
|
minio_client.remove_object(
|
|
bucket_name=settings.MINIO_BUCKET_NAME,
|
|
object_name=temp_path
|
|
)
|
|
logger.info(f"Task {task_id}: Temporary file removed")
|
|
except MinioException as e:
|
|
error_msg = f"Failed to move file to permanent storage: {str(e)}"
|
|
logger.error(f"Task {task_id}: {error_msg}")
|
|
raise Exception(error_msg)
|
|
|
|
# 5. 创建文档记录
|
|
logger.info(f"Task {task_id}: Creating document record")
|
|
document = Document(
|
|
file_name=file_name,
|
|
file_path=permanent_path,
|
|
file_hash=task.document_upload.file_hash,
|
|
file_size=task.document_upload.file_size,
|
|
content_type=task.document_upload.content_type,
|
|
knowledge_base_id=kb_id
|
|
)
|
|
db.add(document)
|
|
db.flush()
|
|
db.refresh(document)
|
|
logger.info(f"Task {task_id}: Document record created with ID {document.id}")
|
|
|
|
# 6. 存储文档块
|
|
logger.info(f"Task {task_id}: Storing document chunks")
|
|
for i, chunk in enumerate(chunks):
|
|
# 为每个 chunk 生成唯一的 ID
|
|
chunk_id = hashlib.sha256(
|
|
f"{kb_id}:{file_name}:{chunk.page_content}".encode()
|
|
).hexdigest()
|
|
|
|
metadata = _build_enriched_chunk_metadata(
|
|
source_metadata=chunk.metadata,
|
|
chunk_id=chunk_id,
|
|
file_name=file_name,
|
|
file_path=permanent_path,
|
|
kb_id=kb_id,
|
|
document_id=document.id,
|
|
chunk_index=i,
|
|
chunk_text=chunk.page_content,
|
|
)
|
|
chunk.metadata = metadata
|
|
|
|
doc_chunk = DocumentChunk(
|
|
id=chunk_id, # 添加 ID 字段
|
|
document_id=document.id,
|
|
kb_id=kb_id,
|
|
file_name=file_name,
|
|
chunk_metadata={
|
|
"page_content": chunk.page_content,
|
|
**metadata
|
|
},
|
|
hash=hashlib.sha256(
|
|
(chunk.page_content + str(metadata)).encode()
|
|
).hexdigest()
|
|
)
|
|
db.add(doc_chunk)
|
|
if i > 0 and i % 100 == 0:
|
|
logger.info(f"Task {task_id}: Stored {i} chunks")
|
|
db.flush()
|
|
|
|
# 7. 添加到向量存储
|
|
logger.info(f"Task {task_id}: Adding chunks to vector store")
|
|
vector_chunks = [
|
|
LangchainDocument(
|
|
page_content=chunk.page_content,
|
|
metadata=_sanitize_metadata_for_vector_store(chunk.metadata),
|
|
)
|
|
for chunk in chunks
|
|
]
|
|
vector_store.add_documents(vector_chunks)
|
|
# 移除 persist() 调用,因为新版本不需要
|
|
logger.info(f"Task {task_id}: Chunks added to vector store")
|
|
|
|
if settings.GRAPHRAG_ENABLED:
|
|
try:
|
|
from app.services.graph.graphrag_adapter import GraphRAGAdapter
|
|
|
|
logger.info(f"Task {task_id}: Starting GraphRAG ingestion")
|
|
graph_adapter = GraphRAGAdapter()
|
|
source_texts = [doc.page_content for doc in documents if doc.page_content.strip()]
|
|
await graph_adapter.ingest_texts(kb_id, source_texts)
|
|
logger.info(f"Task {task_id}: GraphRAG ingestion completed")
|
|
except Exception as graph_exc:
|
|
logger.error(f"Task {task_id}: GraphRAG ingestion failed: {graph_exc}")
|
|
|
|
# 8. 更新任务状态
|
|
logger.info(f"Task {task_id}: Updating task status to completed")
|
|
task.status = "completed"
|
|
task.document_id = document.id # 更新为新创建的文档ID
|
|
|
|
# 9. 更新上传记录状态
|
|
upload = task.document_upload # 直接通过关系获取
|
|
if upload:
|
|
logger.info(f"Task {task_id}: Updating upload record status to completed")
|
|
upload.status = "completed"
|
|
|
|
db.commit()
|
|
logger.info(f"Task {task_id}: Processing completed successfully")
|
|
|
|
finally:
|
|
# 清理本地临时文件
|
|
try:
|
|
if os.path.exists(local_temp_path):
|
|
logger.info(f"Task {task_id}: Cleaning up local temp file")
|
|
os.remove(local_temp_path)
|
|
logger.info(f"Task {task_id}: Local temp file cleaned up")
|
|
except Exception as e:
|
|
logger.warning(f"Task {task_id}: Failed to clean up local temp file: {str(e)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Task {task_id}: Error processing document: {str(e)}")
|
|
logger.error(f"Task {task_id}: Stack trace: {traceback.format_exc()}")
|
|
db.rollback()
|
|
|
|
failed_task = db.query(ProcessingTask).get(task_id)
|
|
if failed_task:
|
|
failed_task.status = "failed"
|
|
failed_task.error_message = str(e)
|
|
if failed_task.document_upload:
|
|
failed_task.document_upload.status = "failed"
|
|
failed_task.document_upload.error_message = str(e)
|
|
db.commit()
|
|
|
|
# 清理临时文件
|
|
try:
|
|
logger.info(f"Task {task_id}: Cleaning up temporary file after error")
|
|
if minio_client is not None:
|
|
minio_client.remove_object(
|
|
bucket_name=settings.MINIO_BUCKET_NAME,
|
|
object_name=temp_path
|
|
)
|
|
logger.info(f"Task {task_id}: Temporary file cleaned up after error")
|
|
except:
|
|
logger.warning(f"Task {task_id}: Failed to clean up temporary file after error")
|
|
finally:
|
|
# if we create the db session, we need to close it
|
|
if should_close_db and db:
|
|
db.close()
|