From fc286b34edeff5d9ad92abdd438f749e0aca3fc6 Mon Sep 17 00:00:00 2001 From: KKlochko Date: Sun, 16 Nov 2025 15:39:24 +0200 Subject: [PATCH] Add the prompt support for endpoints. --- gliner_inference_server/helpers.py | 5 ++ gliner_inference_server/main.py | 118 ++++++++++++++++++++++------- 2 files changed, 95 insertions(+), 28 deletions(-) diff --git a/gliner_inference_server/helpers.py b/gliner_inference_server/helpers.py index 8a659d0..0eaf3a6 100644 --- a/gliner_inference_server/helpers.py +++ b/gliner_inference_server/helpers.py @@ -97,3 +97,8 @@ def handle_error(e: Exception): } } ) + +def add_prompt_to_text(text: str, prompt: Optional[str]): + if prompt in [None, ""]: + return text + return prompt + text \ No newline at end of file diff --git a/gliner_inference_server/main.py b/gliner_inference_server/main.py index 2790c2b..8c31301 100644 --- a/gliner_inference_server/main.py +++ b/gliner_inference_server/main.py @@ -111,7 +111,16 @@ async def health_check(): @app.post("/general") async def general_extraction(request: GeneralRequest): - """Named Entity Recognition endpoint""" + """General Named Entity Recognition endpoint with prompt support + + Example usage: + { + "text": "Apple Inc. is located in Cupertino, California.", + "entities": ["organization", "location"], + "prompt": "Extract business entities:\n", + "threshold": 0.5 + } + """ try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) @@ -121,23 +130,35 @@ async def general_extraction(request: GeneralRequest): results = [] for text in texts_to_process: + # Calculate prompt offset for position adjustment + prompt_offset = len(request.prompt) if request.prompt else 0 + + # Apply prompt to text + text_with_prompt = add_prompt_to_text(text, request.prompt) + + # Extract entities with prompt entities = model.predict_entities( - text, + text_with_prompt, request.entities, threshold=request.threshold, flat_ner=True ) - output = [ - { + output = [] + for ent in entities: + # Skip entities found in the prompt itself + if ent["start"] < prompt_offset: + continue + + # Adjust positions to be relative to original text (without prompt) + output.append({ "entity": ent["label"], "span": ent["text"], - "start": ent["start"], - "end": ent["end"], + "start": ent["start"] - prompt_offset, + "end": ent["end"] - prompt_offset, "score": float(ent["score"]) - } - for ent in entities - ] + }) + results.append(output) # Return based on input mode @@ -160,6 +181,7 @@ async def relation_extraction(request: RelationExtractionRequest): "text": "Microsoft was founded by Bill Gates", "relations": ["founder", "inception date"], "entities": ["organization", "person", "date"], + "prompt": "Extract business relationships:\n", "threshold": 0.5 } """ @@ -173,19 +195,36 @@ async def relation_extraction(request: RelationExtractionRequest): for text in texts_to_process: output = [] + # Calculate prompt offset for position adjustment + prompt_offset = len(request.prompt) if request.prompt else 0 + + # Apply prompt to text + text_with_prompt = add_prompt_to_text(text, request.prompt) + # Extract entities if specified entities_dict = {} if request.entities: entities = model.predict_entities( - text, + text_with_prompt, request.entities, threshold=request.threshold, flat_ner=True ) # Index entities by position for quick lookup + # Adjust positions to be relative to original text (without prompt) for entity in entities: - key = (entity["start"], entity["end"]) - entities_dict[key] = entity + # Only include entities that are in the actual text, not in the prompt + if entity["start"] >= prompt_offset: + adjusted_start = entity["start"] - prompt_offset + adjusted_end = entity["end"] - prompt_offset + key = (adjusted_start, adjusted_end) + entities_dict[key] = { + "label": entity["label"], + "text": entity["text"], + "start": adjusted_start, + "end": adjusted_end, + "score": entity.get("score", 0.0) + } # Form relation labels according to GLiNER documentation # Format: "entity_type <> relation_type" @@ -216,9 +255,9 @@ async def relation_extraction(request: RelationExtractionRequest): # Remove duplicates relation_labels = list(set(relation_labels)) - # Extract relations using GLiNER + # Extract relations using GLiNER with prompt relations_output = model.predict_entities( - text, + text_with_prompt, relation_labels, threshold=request.threshold, flat_ner=True @@ -229,6 +268,15 @@ async def relation_extraction(request: RelationExtractionRequest): # GLiNER returns a text span that represents the relation # E.g.: "Bill Gates" for label "Microsoft <> founder" + # Skip relations found in the prompt itself + if rel["start"] < prompt_offset: + continue + + # Adjust positions to be relative to original text + target_start = rel["start"] - prompt_offset + target_end = rel["end"] - prompt_offset + target_span = rel["text"] + # Parse the label to get entity type and relation label_parts = rel["label"].split("<>") @@ -239,11 +287,6 @@ async def relation_extraction(request: RelationExtractionRequest): source_entity_type = "" relation_type = rel["label"] - # The extracted span is typically the target entity - target_span = rel["text"] - target_start = rel["start"] - target_end = rel["end"] - # Attempt to find the source entity in context # Find the closest entity of the specified type before the target source_entity = None @@ -326,7 +369,15 @@ async def relation_extraction(request: RelationExtractionRequest): @app.post("/summarization") async def summarization(request: SummarizationRequest): - """Summarization endpoint""" + """Summarization endpoint with prompt support + + Example usage: + { + "text": "Long article text here...", + "prompt": "Extract the most important points:\n", + "threshold": 0.5 + } + """ try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) @@ -335,23 +386,34 @@ async def summarization(request: SummarizationRequest): results = [] for text in texts_to_process: + # Calculate prompt offset for position adjustment + prompt_offset = len(request.prompt) if request.prompt else 0 + + # Apply prompt to text + text_with_prompt = add_prompt_to_text(text, request.prompt) + # For summarization, extract key phrases/sentences summaries = model.predict_entities( - text, + text_with_prompt, ["summary", "key point", "important information"], threshold=request.threshold, flat_ner=True ) - output = [ - { + output = [] + for summ in summaries: + # Skip summaries found in the prompt itself + if summ["start"] < prompt_offset: + continue + + # Adjust positions to be relative to original text (without prompt) + output.append({ "text": summ["text"], - "start": summ["start"], - "end": summ["end"], + "start": summ["start"] - prompt_offset, + "end": summ["end"] - prompt_offset, "score": float(summ["score"]) - } - for summ in summaries - ] + }) + results.append(output) if is_batch: