63 lines
2.5 KiB
Python
63 lines
2.5 KiB
Python
import dspy
|
|
from nano_graphrag.entity_extraction.module import Relationship
|
|
|
|
|
|
class AssessRelationships(dspy.Signature):
|
|
"""
|
|
Assess the similarity between gold and predicted relationships:
|
|
1. Match relationships based on src_id and tgt_id pairs, allowing for slight variations in entity names.
|
|
2. For matched pairs, compare:
|
|
a) Description similarity (semantic meaning)
|
|
b) Weight similarity
|
|
c) Order similarity
|
|
3. Consider unmatched relationships as penalties.
|
|
4. Aggregate scores, accounting for precision and recall.
|
|
5. Return a final similarity score between 0 (no similarity) and 1 (perfect match).
|
|
|
|
Key considerations:
|
|
- Prioritize matching based on entity pairs over exact string matches.
|
|
- Use semantic similarity for descriptions rather than exact matches.
|
|
- Weight the importance of different aspects (e.g., entity matching, description, weight, order).
|
|
- Balance the impact of matched and unmatched relationships in the final score.
|
|
"""
|
|
|
|
gold_relationships: list[Relationship] = dspy.InputField(
|
|
desc="The gold-standard relationships to compare against."
|
|
)
|
|
predicted_relationships: list[Relationship] = dspy.InputField(
|
|
desc="The predicted relationships to compare against the gold-standard relationships."
|
|
)
|
|
similarity_score: float = dspy.OutputField(
|
|
desc="Similarity score between 0 and 1, with 1 being the highest similarity."
|
|
)
|
|
|
|
|
|
def relationships_similarity_metric(
|
|
gold: dspy.Example, pred: dspy.Prediction, trace=None
|
|
) -> float:
|
|
model = dspy.ChainOfThought(AssessRelationships)
|
|
gold_relationships = [Relationship(**item) for item in gold["relationships"]]
|
|
predicted_relationships = [Relationship(**item) for item in pred["relationships"]]
|
|
similarity_score = float(
|
|
model(
|
|
gold_relationships=gold_relationships,
|
|
predicted_relationships=predicted_relationships,
|
|
).similarity_score
|
|
)
|
|
return similarity_score
|
|
|
|
|
|
def entity_recall_metric(
|
|
gold: dspy.Example, pred: dspy.Prediction, trace=None
|
|
) -> float:
|
|
true_set = set(item["entity_name"] for item in gold["entities"])
|
|
pred_set = set(item["entity_name"] for item in pred["entities"])
|
|
true_positives = len(pred_set.intersection(true_set))
|
|
false_negatives = len(true_set - pred_set)
|
|
recall = (
|
|
true_positives / (true_positives + false_negatives)
|
|
if (true_positives + false_negatives) > 0
|
|
else 0
|
|
)
|
|
return recall
|