|
|
|
|
@ -118,7 +118,16 @@ async def general_extraction(request: GeneralRequest):
|
|
|
|
|
|
|
|
|
|
@app.post("/relation-extraction")
|
|
|
|
|
async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
"""Relation Extraction endpoint"""
|
|
|
|
|
"""Relation Extraction endpoint
|
|
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
|
{
|
|
|
|
|
"text": "Microsoft was founded by Bill Gates",
|
|
|
|
|
"relations": ["founder", "inception date"],
|
|
|
|
|
"entities": ["organization", "person", "date"],
|
|
|
|
|
"threshold": 0.5
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
|
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
|
@ -127,7 +136,10 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
for text in texts_to_process:
|
|
|
|
|
# First extract entities if provided
|
|
|
|
|
output = []
|
|
|
|
|
|
|
|
|
|
# Extract entities if specified
|
|
|
|
|
entities_dict = {}
|
|
|
|
|
if request.entities:
|
|
|
|
|
entities = model.predict_entities(
|
|
|
|
|
text,
|
|
|
|
|
@ -135,37 +147,136 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
threshold=request.threshold,
|
|
|
|
|
flat_ner=True
|
|
|
|
|
)
|
|
|
|
|
# Index entities by position for quick lookup
|
|
|
|
|
for entity in entities:
|
|
|
|
|
key = (entity["start"], entity["end"])
|
|
|
|
|
entities_dict[key] = entity
|
|
|
|
|
|
|
|
|
|
# Form relation labels according to GLiNER documentation
|
|
|
|
|
# Format: "entity_type <> relation_type"
|
|
|
|
|
relation_labels = []
|
|
|
|
|
|
|
|
|
|
if request.entities and request.rules:
|
|
|
|
|
# If there are rules with pairs_filter, create specific labels
|
|
|
|
|
for rule in request.rules:
|
|
|
|
|
if rule.pairs_filter:
|
|
|
|
|
for pair in rule.pairs_filter:
|
|
|
|
|
# Format: "source_entity <> relation"
|
|
|
|
|
label = f"{pair[0]} <> {rule.relation}"
|
|
|
|
|
relation_labels.append(label)
|
|
|
|
|
else:
|
|
|
|
|
# Without pairs filter - general relation
|
|
|
|
|
relation_labels.append(f"<> {rule.relation}")
|
|
|
|
|
elif request.entities:
|
|
|
|
|
# If there are entities but no rules, create combinations
|
|
|
|
|
for relation in request.relations:
|
|
|
|
|
for entity_type in request.entities:
|
|
|
|
|
label = f"{entity_type} <> {relation}"
|
|
|
|
|
relation_labels.append(label)
|
|
|
|
|
else:
|
|
|
|
|
entities = []
|
|
|
|
|
# If there are no entities, just search for relations
|
|
|
|
|
for relation in request.relations:
|
|
|
|
|
relation_labels.append(f"<> {relation}")
|
|
|
|
|
|
|
|
|
|
# Extract relations
|
|
|
|
|
relations = model.predict_relations(
|
|
|
|
|
# Remove duplicates
|
|
|
|
|
relation_labels = list(set(relation_labels))
|
|
|
|
|
|
|
|
|
|
# Extract relations using GLiNER
|
|
|
|
|
relations_output = model.predict_entities(
|
|
|
|
|
text,
|
|
|
|
|
request.relations,
|
|
|
|
|
threshold=request.threshold
|
|
|
|
|
relation_labels,
|
|
|
|
|
threshold=request.threshold,
|
|
|
|
|
flat_ner=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output = [
|
|
|
|
|
{
|
|
|
|
|
# Format the results
|
|
|
|
|
for rel in relations_output:
|
|
|
|
|
# GLiNER returns a text span that represents the relation
|
|
|
|
|
# E.g.: "Bill Gates" for label "Microsoft <> founder"
|
|
|
|
|
|
|
|
|
|
# Parse the label to get entity type and relation
|
|
|
|
|
label_parts = rel["label"].split("<>")
|
|
|
|
|
|
|
|
|
|
if len(label_parts) == 2:
|
|
|
|
|
source_entity_type = label_parts[0].strip()
|
|
|
|
|
relation_type = label_parts[1].strip()
|
|
|
|
|
else:
|
|
|
|
|
source_entity_type = ""
|
|
|
|
|
relation_type = rel["label"]
|
|
|
|
|
|
|
|
|
|
# The extracted span is typically the target entity
|
|
|
|
|
target_span = rel["text"]
|
|
|
|
|
target_start = rel["start"]
|
|
|
|
|
target_end = rel["end"]
|
|
|
|
|
|
|
|
|
|
# Attempt to find the source entity in context
|
|
|
|
|
# Find the closest entity of the specified type before the target
|
|
|
|
|
source_entity = None
|
|
|
|
|
if source_entity_type and entities_dict:
|
|
|
|
|
# Find entities of the specified type that precede the target
|
|
|
|
|
candidates = [
|
|
|
|
|
e for e in entities_dict.values()
|
|
|
|
|
if e["label"].lower() == source_entity_type.lower()
|
|
|
|
|
and e["end"] <= target_start
|
|
|
|
|
]
|
|
|
|
|
# Take the nearest one
|
|
|
|
|
if candidates:
|
|
|
|
|
source_entity = max(candidates, key=lambda x: x["end"])
|
|
|
|
|
|
|
|
|
|
# Determine target entity type
|
|
|
|
|
target_entity_type = ""
|
|
|
|
|
target_key = (target_start, target_end)
|
|
|
|
|
if target_key in entities_dict:
|
|
|
|
|
target_entity_type = entities_dict[target_key]["label"]
|
|
|
|
|
|
|
|
|
|
# Apply rules if present
|
|
|
|
|
if request.rules:
|
|
|
|
|
skip = False
|
|
|
|
|
for rule in request.rules:
|
|
|
|
|
if rule.relation == relation_type:
|
|
|
|
|
# Check distance
|
|
|
|
|
if rule.distance and source_entity:
|
|
|
|
|
distance = target_start - source_entity["end"]
|
|
|
|
|
if distance > rule.distance:
|
|
|
|
|
skip = True
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# Check pairs_filter
|
|
|
|
|
if rule.pairs_filter and source_entity:
|
|
|
|
|
valid_pair = any(
|
|
|
|
|
source_entity["label"].lower() == pair[0].lower() and
|
|
|
|
|
target_entity_type.lower() == pair[1].lower()
|
|
|
|
|
for pair in rule.pairs_filter
|
|
|
|
|
)
|
|
|
|
|
if not valid_pair:
|
|
|
|
|
skip = True
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if skip:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Form the output object
|
|
|
|
|
relation_obj = {
|
|
|
|
|
"source": {
|
|
|
|
|
"entity": rel.get("head_type", ""),
|
|
|
|
|
"span": rel.get("head_text", ""),
|
|
|
|
|
"start": rel.get("head_start", 0),
|
|
|
|
|
"end": rel.get("head_end", 0),
|
|
|
|
|
"score": float(rel.get("head_score", 0.0))
|
|
|
|
|
"entity": source_entity["label"] if source_entity else source_entity_type,
|
|
|
|
|
"span": source_entity["text"] if source_entity else "",
|
|
|
|
|
"start": source_entity["start"] if source_entity else 0,
|
|
|
|
|
"end": source_entity["end"] if source_entity else 0,
|
|
|
|
|
"score": float(source_entity.get("score", 0.0)) if source_entity else 0.0
|
|
|
|
|
},
|
|
|
|
|
"relation": rel["label"],
|
|
|
|
|
"relation": relation_type,
|
|
|
|
|
"target": {
|
|
|
|
|
"entity": rel.get("tail_type", ""),
|
|
|
|
|
"span": rel.get("tail_text", ""),
|
|
|
|
|
"start": rel.get("tail_start", 0),
|
|
|
|
|
"end": rel.get("tail_end", 0),
|
|
|
|
|
"score": float(rel.get("tail_score", 0.0))
|
|
|
|
|
"entity": target_entity_type,
|
|
|
|
|
"span": target_span,
|
|
|
|
|
"start": target_start,
|
|
|
|
|
"end": target_end,
|
|
|
|
|
"score": float(rel.get("score", 0.0))
|
|
|
|
|
},
|
|
|
|
|
"score": float(rel["score"])
|
|
|
|
|
"score": float(rel.get("score", 0.0))
|
|
|
|
|
}
|
|
|
|
|
for rel in relations
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
output.append(relation_obj)
|
|
|
|
|
|
|
|
|
|
results.append(output)
|
|
|
|
|
|
|
|
|
|
if is_batch:
|
|
|
|
|
|