init. project
This commit is contained in:
330
rag-web-ui/backend/nano_graphrag/entity_extraction/module.py
Normal file
330
rag-web-ui/backend/nano_graphrag/entity_extraction/module.py
Normal file
@@ -0,0 +1,330 @@
|
||||
import dspy
|
||||
from pydantic import BaseModel, Field
|
||||
from nano_graphrag._utils import clean_str
|
||||
from nano_graphrag._utils import logger
|
||||
|
||||
|
||||
"""
|
||||
Obtained from:
|
||||
https://github.com/SciPhi-AI/R2R/blob/6e958d1e451c1cb10b6fc868572659785d1091cb/r2r/providers/prompts/defaults.jsonl
|
||||
"""
|
||||
ENTITY_TYPES = [
|
||||
"PERSON",
|
||||
"ORGANIZATION",
|
||||
"LOCATION",
|
||||
"DATE",
|
||||
"TIME",
|
||||
"MONEY",
|
||||
"PERCENTAGE",
|
||||
"PRODUCT",
|
||||
"EVENT",
|
||||
"LANGUAGE",
|
||||
"NATIONALITY",
|
||||
"RELIGION",
|
||||
"TITLE",
|
||||
"PROFESSION",
|
||||
"ANIMAL",
|
||||
"PLANT",
|
||||
"DISEASE",
|
||||
"MEDICATION",
|
||||
"CHEMICAL",
|
||||
"MATERIAL",
|
||||
"COLOR",
|
||||
"SHAPE",
|
||||
"MEASUREMENT",
|
||||
"WEATHER",
|
||||
"NATURAL_DISASTER",
|
||||
"AWARD",
|
||||
"LAW",
|
||||
"CRIME",
|
||||
"TECHNOLOGY",
|
||||
"SOFTWARE",
|
||||
"HARDWARE",
|
||||
"VEHICLE",
|
||||
"FOOD",
|
||||
"DRINK",
|
||||
"SPORT",
|
||||
"MUSIC_GENRE",
|
||||
"INSTRUMENT",
|
||||
"ARTWORK",
|
||||
"BOOK",
|
||||
"MOVIE",
|
||||
"TV_SHOW",
|
||||
"ACADEMIC_SUBJECT",
|
||||
"SCIENTIFIC_THEORY",
|
||||
"POLITICAL_PARTY",
|
||||
"CURRENCY",
|
||||
"STOCK_SYMBOL",
|
||||
"FILE_TYPE",
|
||||
"PROGRAMMING_LANGUAGE",
|
||||
"MEDICAL_PROCEDURE",
|
||||
"CELESTIAL_BODY",
|
||||
]
|
||||
|
||||
|
||||
class Entity(BaseModel):
|
||||
entity_name: str = Field(..., description="The name of the entity.")
|
||||
entity_type: str = Field(..., description="The type of the entity.")
|
||||
description: str = Field(
|
||||
..., description="The description of the entity, in details and comprehensive."
|
||||
)
|
||||
importance_score: float = Field(
|
||||
...,
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Importance score of the entity. Should be between 0 and 1 with 1 being the most important.",
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"entity_name": clean_str(self.entity_name.upper()),
|
||||
"entity_type": clean_str(self.entity_type.upper()),
|
||||
"description": clean_str(self.description),
|
||||
"importance_score": float(self.importance_score),
|
||||
}
|
||||
|
||||
|
||||
class Relationship(BaseModel):
|
||||
src_id: str = Field(..., description="The name of the source entity.")
|
||||
tgt_id: str = Field(..., description="The name of the target entity.")
|
||||
description: str = Field(
|
||||
...,
|
||||
description="The description of the relationship between the source and target entity, in details and comprehensive.",
|
||||
)
|
||||
weight: float = Field(
|
||||
...,
|
||||
ge=0,
|
||||
le=1,
|
||||
description="The weight of the relationship. Should be between 0 and 1 with 1 being the strongest relationship.",
|
||||
)
|
||||
order: int = Field(
|
||||
...,
|
||||
ge=1,
|
||||
le=3,
|
||||
description="The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order.",
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"src_id": clean_str(self.src_id.upper()),
|
||||
"tgt_id": clean_str(self.tgt_id.upper()),
|
||||
"description": clean_str(self.description),
|
||||
"weight": float(self.weight),
|
||||
"order": int(self.order),
|
||||
}
|
||||
|
||||
|
||||
class CombinedExtraction(dspy.Signature):
|
||||
"""
|
||||
Given a text document that is potentially relevant to this activity and a list of entity types,
|
||||
identify all entities of those types from the text and all relationships among the identified entities.
|
||||
|
||||
Entity Guidelines:
|
||||
1. Each entity name should be an actual atomic word from the input text.
|
||||
2. Avoid duplicates and generic terms.
|
||||
3. Make sure descriptions are detailed and comprehensive. Use multiple complete sentences for each point below:
|
||||
a). The entity's role or significance in the context
|
||||
b). Key attributes or characteristics
|
||||
c). Relationships to other entities (if applicable)
|
||||
d). Historical or cultural relevance (if applicable)
|
||||
e). Any notable actions or events associated with the entity
|
||||
4. All entity types from the text must be included.
|
||||
5. IMPORTANT: Only use entity types from the provided 'entity_types' list. Do not introduce new entity types.
|
||||
|
||||
Relationship Guidelines:
|
||||
1. Make sure relationship descriptions are detailed and comprehensive. Use multiple complete sentences for each point below:
|
||||
a). The nature of the relationship (e.g., familial, professional, causal)
|
||||
b). The impact or significance of the relationship on both entities
|
||||
c). Any historical or contextual information relevant to the relationship
|
||||
d). How the relationship evolved over time (if applicable)
|
||||
e). Any notable events or actions that resulted from this relationship
|
||||
2. Include direct relationships (order 1) as well as higher-order relationships (order 2 and 3):
|
||||
a). Direct relationships: Immediate connections between entities.
|
||||
b). Second-order relationships: Indirect effects or connections that result from direct relationships.
|
||||
c). Third-order relationships: Further indirect effects that result from second-order relationships.
|
||||
3. The "src_id" and "tgt_id" fields must exactly match entity names from the extracted entities list.
|
||||
"""
|
||||
|
||||
input_text: str = dspy.InputField(
|
||||
desc="The text to extract entities and relationships from."
|
||||
)
|
||||
entity_types: list[str] = dspy.InputField(
|
||||
desc="List of entity types used for extraction."
|
||||
)
|
||||
entities: list[Entity] = dspy.OutputField(
|
||||
desc="List of entities extracted from the text and the entity types."
|
||||
)
|
||||
relationships: list[Relationship] = dspy.OutputField(
|
||||
desc="List of relationships extracted from the text and the entity types."
|
||||
)
|
||||
|
||||
|
||||
class CritiqueCombinedExtraction(dspy.Signature):
|
||||
"""
|
||||
Critique the current extraction of entities and relationships from a given text.
|
||||
Focus on completeness, accuracy, and adherence to the provided entity types and extraction guidelines.
|
||||
|
||||
Critique Guidelines:
|
||||
1. Evaluate if all relevant entities from the input text are captured and correctly typed.
|
||||
2. Check if entity descriptions are comprehensive and follow the provided guidelines.
|
||||
3. Assess the completeness of relationship extractions, including higher-order relationships.
|
||||
4. Verify that relationship descriptions are detailed and follow the provided guidelines.
|
||||
5. Identify any inconsistencies, errors, or missed opportunities in the current extraction.
|
||||
6. Suggest specific improvements or additions to enhance the quality of the extraction.
|
||||
"""
|
||||
|
||||
input_text: str = dspy.InputField(
|
||||
desc="The original text from which entities and relationships were extracted."
|
||||
)
|
||||
entity_types: list[str] = dspy.InputField(
|
||||
desc="List of valid entity types for this extraction task."
|
||||
)
|
||||
current_entities: list[Entity] = dspy.InputField(
|
||||
desc="List of currently extracted entities to be critiqued."
|
||||
)
|
||||
current_relationships: list[Relationship] = dspy.InputField(
|
||||
desc="List of currently extracted relationships to be critiqued."
|
||||
)
|
||||
entity_critique: str = dspy.OutputField(
|
||||
desc="Detailed critique of the current entities, highlighting areas for improvement for completeness and accuracy.."
|
||||
)
|
||||
relationship_critique: str = dspy.OutputField(
|
||||
desc="Detailed critique of the current relationships, highlighting areas for improvement for completeness and accuracy.."
|
||||
)
|
||||
|
||||
|
||||
class RefineCombinedExtraction(dspy.Signature):
|
||||
"""
|
||||
Refine the current extraction of entities and relationships based on the provided critique.
|
||||
Improve completeness, accuracy, and adherence to the extraction guidelines.
|
||||
|
||||
Refinement Guidelines:
|
||||
1. Address all points raised in the entity and relationship critiques.
|
||||
2. Add missing entities and relationships identified in the critique.
|
||||
3. Improve entity and relationship descriptions as suggested.
|
||||
4. Ensure all refinements still adhere to the original extraction guidelines.
|
||||
5. Maintain consistency between entities and relationships during refinement.
|
||||
6. Focus on enhancing the overall quality and comprehensiveness of the extraction.
|
||||
"""
|
||||
|
||||
input_text: str = dspy.InputField(
|
||||
desc="The original text from which entities and relationships were extracted."
|
||||
)
|
||||
entity_types: list[str] = dspy.InputField(
|
||||
desc="List of valid entity types for this extraction task."
|
||||
)
|
||||
current_entities: list[Entity] = dspy.InputField(
|
||||
desc="List of currently extracted entities to be refined."
|
||||
)
|
||||
current_relationships: list[Relationship] = dspy.InputField(
|
||||
desc="List of currently extracted relationships to be refined."
|
||||
)
|
||||
entity_critique: str = dspy.InputField(
|
||||
desc="Detailed critique of the current entities to guide refinement."
|
||||
)
|
||||
relationship_critique: str = dspy.InputField(
|
||||
desc="Detailed critique of the current relationships to guide refinement."
|
||||
)
|
||||
refined_entities: list[Entity] = dspy.OutputField(
|
||||
desc="List of refined entities, addressing the entity critique and improving upon the current entities."
|
||||
)
|
||||
refined_relationships: list[Relationship] = dspy.OutputField(
|
||||
desc="List of refined relationships, addressing the relationship critique and improving upon the current relationships."
|
||||
)
|
||||
|
||||
|
||||
class TypedEntityRelationshipExtractorException(dspy.Module):
|
||||
def __init__(
|
||||
self,
|
||||
predictor: dspy.Module,
|
||||
exception_types: tuple[type[Exception]] = (Exception,),
|
||||
):
|
||||
super().__init__()
|
||||
self.predictor = predictor
|
||||
self.exception_types = exception_types
|
||||
|
||||
def copy(self):
|
||||
return TypedEntityRelationshipExtractorException(self.predictor)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
try:
|
||||
prediction = self.predictor(**kwargs)
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, self.exception_types):
|
||||
return dspy.Prediction(entities=[], relationships=[])
|
||||
|
||||
raise e
|
||||
|
||||
|
||||
class TypedEntityRelationshipExtractor(dspy.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lm: dspy.LM = None,
|
||||
max_retries: int = 3,
|
||||
entity_types: list[str] = ENTITY_TYPES,
|
||||
self_refine: bool = False,
|
||||
num_refine_turns: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.lm = lm
|
||||
self.entity_types = entity_types
|
||||
self.self_refine = self_refine
|
||||
self.num_refine_turns = num_refine_turns
|
||||
|
||||
self.extractor = dspy.ChainOfThought(
|
||||
signature=CombinedExtraction, max_retries=max_retries
|
||||
)
|
||||
self.extractor = TypedEntityRelationshipExtractorException(
|
||||
self.extractor, exception_types=(ValueError,)
|
||||
)
|
||||
|
||||
if self.self_refine:
|
||||
self.critique = dspy.ChainOfThought(
|
||||
signature=CritiqueCombinedExtraction, max_retries=max_retries
|
||||
)
|
||||
self.refine = dspy.ChainOfThought(
|
||||
signature=RefineCombinedExtraction, max_retries=max_retries
|
||||
)
|
||||
|
||||
def forward(self, input_text: str) -> dspy.Prediction:
|
||||
with dspy.context(lm=self.lm if self.lm is not None else dspy.settings.lm):
|
||||
extraction_result = self.extractor(
|
||||
input_text=input_text, entity_types=self.entity_types
|
||||
)
|
||||
|
||||
current_entities: list[Entity] = extraction_result.entities
|
||||
current_relationships: list[Relationship] = extraction_result.relationships
|
||||
|
||||
if self.self_refine:
|
||||
for _ in range(self.num_refine_turns):
|
||||
critique_result = self.critique(
|
||||
input_text=input_text,
|
||||
entity_types=self.entity_types,
|
||||
current_entities=current_entities,
|
||||
current_relationships=current_relationships,
|
||||
)
|
||||
refined_result = self.refine(
|
||||
input_text=input_text,
|
||||
entity_types=self.entity_types,
|
||||
current_entities=current_entities,
|
||||
current_relationships=current_relationships,
|
||||
entity_critique=critique_result.entity_critique,
|
||||
relationship_critique=critique_result.relationship_critique,
|
||||
)
|
||||
logger.debug(
|
||||
f"entities: {len(current_entities)} | refined_entities: {len(refined_result.refined_entities)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"relationships: {len(current_relationships)} | refined_relationships: {len(refined_result.refined_relationships)}"
|
||||
)
|
||||
current_entities = refined_result.refined_entities
|
||||
current_relationships = refined_result.refined_relationships
|
||||
|
||||
entities = [entity.to_dict() for entity in current_entities]
|
||||
relationships = [
|
||||
relationship.to_dict() for relationship in current_relationships
|
||||
]
|
||||
|
||||
return dspy.Prediction(entities=entities, relationships=relationships)
|
||||
Reference in New Issue
Block a user