diff --git a/gliner_inference_server/main.py b/gliner_inference_server/main.py index bf1abec..0cfd97a 100644 --- a/gliner_inference_server/main.py +++ b/gliner_inference_server/main.py @@ -118,7 +118,16 @@ async def general_extraction(request: GeneralRequest): @app.post("/relation-extraction") async def relation_extraction(request: RelationExtractionRequest): - """Relation Extraction endpoint""" + """Relation Extraction endpoint + + Example usage: + { + "text": "Microsoft was founded by Bill Gates", + "relations": ["founder", "inception date"], + "entities": ["organization", "person", "date"], + "threshold": 0.5 + } + """ try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) @@ -127,7 +136,10 @@ async def relation_extraction(request: RelationExtractionRequest): results = [] for text in texts_to_process: - # First extract entities if provided + output = [] + + # Extract entities if specified + entities_dict = {} if request.entities: entities = model.predict_entities( text, @@ -135,37 +147,136 @@ async def relation_extraction(request: RelationExtractionRequest): threshold=request.threshold, flat_ner=True ) + # Index entities by position for quick lookup + for entity in entities: + key = (entity["start"], entity["end"]) + entities_dict[key] = entity + + # Form relation labels according to GLiNER documentation + # Format: "entity_type <> relation_type" + relation_labels = [] + + if request.entities and request.rules: + # If there are rules with pairs_filter, create specific labels + for rule in request.rules: + if rule.pairs_filter: + for pair in rule.pairs_filter: + # Format: "source_entity <> relation" + label = f"{pair[0]} <> {rule.relation}" + relation_labels.append(label) + else: + # Without pairs filter - general relation + relation_labels.append(f"<> {rule.relation}") + elif request.entities: + # If there are entities but no rules, create combinations + for relation in request.relations: + for entity_type in request.entities: + label = f"{entity_type} <> {relation}" + relation_labels.append(label) else: - entities = [] + # If there are no entities, just search for relations + for relation in request.relations: + relation_labels.append(f"<> {relation}") + + # Remove duplicates + relation_labels = list(set(relation_labels)) - # Extract relations - relations = model.predict_relations( + # Extract relations using GLiNER + relations_output = model.predict_entities( text, - request.relations, - threshold=request.threshold + relation_labels, + threshold=request.threshold, + flat_ner=True ) - output = [ - { + # Format the results + for rel in relations_output: + # GLiNER returns a text span that represents the relation + # E.g.: "Bill Gates" for label "Microsoft <> founder" + + # Parse the label to get entity type and relation + label_parts = rel["label"].split("<>") + + if len(label_parts) == 2: + source_entity_type = label_parts[0].strip() + relation_type = label_parts[1].strip() + else: + source_entity_type = "" + relation_type = rel["label"] + + # The extracted span is typically the target entity + target_span = rel["text"] + target_start = rel["start"] + target_end = rel["end"] + + # Attempt to find the source entity in context + # Find the closest entity of the specified type before the target + source_entity = None + if source_entity_type and entities_dict: + # Find entities of the specified type that precede the target + candidates = [ + e for e in entities_dict.values() + if e["label"].lower() == source_entity_type.lower() + and e["end"] <= target_start + ] + # Take the nearest one + if candidates: + source_entity = max(candidates, key=lambda x: x["end"]) + + # Determine target entity type + target_entity_type = "" + target_key = (target_start, target_end) + if target_key in entities_dict: + target_entity_type = entities_dict[target_key]["label"] + + # Apply rules if present + if request.rules: + skip = False + for rule in request.rules: + if rule.relation == relation_type: + # Check distance + if rule.distance and source_entity: + distance = target_start - source_entity["end"] + if distance > rule.distance: + skip = True + break + + # Check pairs_filter + if rule.pairs_filter and source_entity: + valid_pair = any( + source_entity["label"].lower() == pair[0].lower() and + target_entity_type.lower() == pair[1].lower() + for pair in rule.pairs_filter + ) + if not valid_pair: + skip = True + break + + if skip: + continue + + # Form the output object + relation_obj = { "source": { - "entity": rel.get("head_type", ""), - "span": rel.get("head_text", ""), - "start": rel.get("head_start", 0), - "end": rel.get("head_end", 0), - "score": float(rel.get("head_score", 0.0)) + "entity": source_entity["label"] if source_entity else source_entity_type, + "span": source_entity["text"] if source_entity else "", + "start": source_entity["start"] if source_entity else 0, + "end": source_entity["end"] if source_entity else 0, + "score": float(source_entity.get("score", 0.0)) if source_entity else 0.0 }, - "relation": rel["label"], + "relation": relation_type, "target": { - "entity": rel.get("tail_type", ""), - "span": rel.get("tail_text", ""), - "start": rel.get("tail_start", 0), - "end": rel.get("tail_end", 0), - "score": float(rel.get("tail_score", 0.0)) + "entity": target_entity_type, + "span": target_span, + "start": target_start, + "end": target_end, + "score": float(rel.get("score", 0.0)) }, - "score": float(rel["score"]) + "score": float(rel.get("score", 0.0)) } - for rel in relations - ] + + output.append(relation_obj) + results.append(output) if is_batch: