Fix the implentation for relation extraction endpoint.

main
KKlochko 1 month ago
parent 295b9368f9
commit 9b1661c9ef
Signed by: KKlochko
GPG Key ID: 572ECCD219BBA91B

@ -118,7 +118,16 @@ async def general_extraction(request: GeneralRequest):
@app.post("/relation-extraction")
async def relation_extraction(request: RelationExtractionRequest):
"""Relation Extraction endpoint"""
"""Relation Extraction endpoint
Example usage:
{
"text": "Microsoft was founded by Bill Gates",
"relations": ["founder", "inception date"],
"entities": ["organization", "person", "date"],
"threshold": 0.5
}
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
@ -127,7 +136,10 @@ async def relation_extraction(request: RelationExtractionRequest):
results = []
for text in texts_to_process:
# First extract entities if provided
output = []
# Extract entities if specified
entities_dict = {}
if request.entities:
entities = model.predict_entities(
text,
@ -135,37 +147,136 @@ async def relation_extraction(request: RelationExtractionRequest):
threshold=request.threshold,
flat_ner=True
)
# Index entities by position for quick lookup
for entity in entities:
key = (entity["start"], entity["end"])
entities_dict[key] = entity
# Form relation labels according to GLiNER documentation
# Format: "entity_type <> relation_type"
relation_labels = []
if request.entities and request.rules:
# If there are rules with pairs_filter, create specific labels
for rule in request.rules:
if rule.pairs_filter:
for pair in rule.pairs_filter:
# Format: "source_entity <> relation"
label = f"{pair[0]} <> {rule.relation}"
relation_labels.append(label)
else:
# Without pairs filter - general relation
relation_labels.append(f"<> {rule.relation}")
elif request.entities:
# If there are entities but no rules, create combinations
for relation in request.relations:
for entity_type in request.entities:
label = f"{entity_type} <> {relation}"
relation_labels.append(label)
else:
entities = []
# If there are no entities, just search for relations
for relation in request.relations:
relation_labels.append(f"<> {relation}")
# Extract relations
relations = model.predict_relations(
# Remove duplicates
relation_labels = list(set(relation_labels))
# Extract relations using GLiNER
relations_output = model.predict_entities(
text,
request.relations,
threshold=request.threshold
relation_labels,
threshold=request.threshold,
flat_ner=True
)
output = [
{
# Format the results
for rel in relations_output:
# GLiNER returns a text span that represents the relation
# E.g.: "Bill Gates" for label "Microsoft <> founder"
# Parse the label to get entity type and relation
label_parts = rel["label"].split("<>")
if len(label_parts) == 2:
source_entity_type = label_parts[0].strip()
relation_type = label_parts[1].strip()
else:
source_entity_type = ""
relation_type = rel["label"]
# The extracted span is typically the target entity
target_span = rel["text"]
target_start = rel["start"]
target_end = rel["end"]
# Attempt to find the source entity in context
# Find the closest entity of the specified type before the target
source_entity = None
if source_entity_type and entities_dict:
# Find entities of the specified type that precede the target
candidates = [
e for e in entities_dict.values()
if e["label"].lower() == source_entity_type.lower()
and e["end"] <= target_start
]
# Take the nearest one
if candidates:
source_entity = max(candidates, key=lambda x: x["end"])
# Determine target entity type
target_entity_type = ""
target_key = (target_start, target_end)
if target_key in entities_dict:
target_entity_type = entities_dict[target_key]["label"]
# Apply rules if present
if request.rules:
skip = False
for rule in request.rules:
if rule.relation == relation_type:
# Check distance
if rule.distance and source_entity:
distance = target_start - source_entity["end"]
if distance > rule.distance:
skip = True
break
# Check pairs_filter
if rule.pairs_filter and source_entity:
valid_pair = any(
source_entity["label"].lower() == pair[0].lower() and
target_entity_type.lower() == pair[1].lower()
for pair in rule.pairs_filter
)
if not valid_pair:
skip = True
break
if skip:
continue
# Form the output object
relation_obj = {
"source": {
"entity": rel.get("head_type", ""),
"span": rel.get("head_text", ""),
"start": rel.get("head_start", 0),
"end": rel.get("head_end", 0),
"score": float(rel.get("head_score", 0.0))
"entity": source_entity["label"] if source_entity else source_entity_type,
"span": source_entity["text"] if source_entity else "",
"start": source_entity["start"] if source_entity else 0,
"end": source_entity["end"] if source_entity else 0,
"score": float(source_entity.get("score", 0.0)) if source_entity else 0.0
},
"relation": rel["label"],
"relation": relation_type,
"target": {
"entity": rel.get("tail_type", ""),
"span": rel.get("tail_text", ""),
"start": rel.get("tail_start", 0),
"end": rel.get("tail_end", 0),
"score": float(rel.get("tail_score", 0.0))
"entity": target_entity_type,
"span": target_span,
"start": target_start,
"end": target_end,
"score": float(rel.get("score", 0.0))
},
"score": float(rel["score"])
"score": float(rel.get("score", 0.0))
}
for rel in relations
]
output.append(relation_obj)
results.append(output)
if is_batch:

Loading…
Cancel
Save