From f2d1ec9957acd1216f5825b1c3a86cbe16838c2c Mon Sep 17 00:00:00 2001 From: KKlochko Date: Fri, 14 Nov 2025 15:16:25 +0200 Subject: [PATCH] Add helper methods and endpoint implementations. --- gliner_inference_server/helpers.py | 72 +++++++++++++++ gliner_inference_server/main.py | 139 +++++++++++++++++++++++++++-- 2 files changed, 204 insertions(+), 7 deletions(-) create mode 100644 gliner_inference_server/helpers.py diff --git a/gliner_inference_server/helpers.py b/gliner_inference_server/helpers.py new file mode 100644 index 0000000..3b1dedd --- /dev/null +++ b/gliner_inference_server/helpers.py @@ -0,0 +1,72 @@ +from fastapi import HTTPException, status +from typing import List, Optional, Union, Dict, Any, Tuple + + +def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]): + """Validate that either text or texts is provided, but not both""" + if text is None and texts is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": { + "code": "INVALID_INPUT", + "message": "Either 'text' or 'texts' must be provided", + "details": {} + } + } + ) + if text is not None and texts is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": { + "code": "INVALID_INPUT", + "message": "Provide either 'text' or 'texts', not both", + "details": {} + } + } + ) + + if texts is not None and len(texts) == 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": { + "code": "INVALID_INPUT", + "message": "texts list cannot be empty", + "details": {} + } + } + ) + + return text if text is not None else texts + + +def handle_error(e: Exception): + """Convert exceptions to consistent error format""" + if isinstance(e, HTTPException): + raise e + + if isinstance(e, ValueError): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": { + "code": "INVALID_INPUT", + "message": str(e), + "details": {} + } + } + ) + + # Internal server error + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error": { + "code": "INTERNAL_ERROR", + "message": "An internal error occurred while processing your request", + "details": {} + } + } + ) diff --git a/gliner_inference_server/main.py b/gliner_inference_server/main.py index f22a45e..f4f34de 100644 --- a/gliner_inference_server/main.py +++ b/gliner_inference_server/main.py @@ -5,6 +5,7 @@ import os import torch from gliner import GLiNER from .models import * +from .helpers import * # Global model instance model = None @@ -30,7 +31,7 @@ app = FastAPI( lifespan=lifespan ) - +# ==================== Endpoints ==================== @app.get("/health") async def health_check(): """Health check endpoint""" @@ -38,21 +39,145 @@ async def health_check(): @app.post("/general") -async def general_extraction(request): +async def general_extraction(request: GeneralRequest): """Named Entity Recognition endpoint""" - pass + try: + input_data = validate_single_or_batch(request.text, request.texts) + is_batch = isinstance(input_data, list) + + # Process batch or single + texts_to_process = input_data if is_batch else [input_data] + + results = [] + for text in texts_to_process: + entities = model.predict_entities( + text, + request.entities, + threshold=request.threshold, + flat_ner=True + ) + + output = [ + { + "entity": ent["label"], + "span": ent["text"], + "start": ent["start"], + "end": ent["end"], + "score": float(ent["score"]) + } + for ent in entities + ] + results.append(output) + + # Return based on input mode + if is_batch: + return {"outputs": results} + else: + return {"outputs": results[0]} + + except Exception as e: + handle_error(e) @app.post("/relation-extraction") -async def relation_extraction(request): +async def relation_extraction(request: RelationExtractionRequest): """Relation Extraction endpoint""" - pass + try: + input_data = validate_single_or_batch(request.text, request.texts) + is_batch = isinstance(input_data, list) + + texts_to_process = input_data if is_batch else [input_data] + + results = [] + for text in texts_to_process: + # First extract entities if provided + if request.entities: + entities = model.predict_entities( + text, + request.entities, + threshold=request.threshold, + flat_ner=True + ) + else: + entities = [] + + # Extract relations + relations = model.predict_relations( + text, + request.relations, + threshold=request.threshold + ) + + output = [ + { + "source": { + "entity": rel.get("head_type", ""), + "span": rel.get("head_text", ""), + "start": rel.get("head_start", 0), + "end": rel.get("head_end", 0), + "score": float(rel.get("head_score", 0.0)) + }, + "relation": rel["label"], + "target": { + "entity": rel.get("tail_type", ""), + "span": rel.get("tail_text", ""), + "start": rel.get("tail_start", 0), + "end": rel.get("tail_end", 0), + "score": float(rel.get("tail_score", 0.0)) + }, + "score": float(rel["score"]) + } + for rel in relations + ] + results.append(output) + + if is_batch: + return {"outputs": results} + else: + return {"outputs": results[0]} + + except Exception as e: + handle_error(e) @app.post("/summarization") -async def summarization(request): +async def summarization(request: SummarizationRequest): """Summarization endpoint""" - pass + try: + input_data = validate_single_or_batch(request.text, request.texts) + is_batch = isinstance(input_data, list) + + texts_to_process = input_data if is_batch else [input_data] + + results = [] + for text in texts_to_process: + # For summarization, extract key phrases/sentences + summaries = model.predict_entities( + text, + ["summary", "key point", "important information"], + threshold=request.threshold, + flat_ner=True + ) + + output = [ + { + "text": summ["text"], + "start": summ["start"], + "end": summ["end"], + "score": float(summ["score"]) + } + for summ in summaries + ] + results.append(output) + + if is_batch: + return {"outputs": results} + else: + return {"outputs": results[0]} + + except Exception as e: + handle_error(e) + if __name__ == "__main__": import uvicorn