from fastapi import FastAPI, HTTPException, status from fastapi.responses import JSONResponse from starlette.requests import Request from typing import List, Optional, Union, Dict, Any, Tuple from contextlib import asynccontextmanager import os import torch from gliner import GLiNER from .models import * from .helpers import * # Global model instance model = None @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup, cleanup on shutdown""" global model print("Loading GLiNER model...") model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5") device = "cuda" if torch.cuda.is_available() else "cpu" model = GLiNER.from_pretrained(model_name).to(device) print(f"Model loaded on {device}") yield print("Shutting down...") app = FastAPI( title="GLiNER Inference Server", description="Named Entity Recognition, Relation Extraction, and Summarization API", version="1.0.0", lifespan=lifespan ) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """ Handles FastAPI's HTTPException to enforce a consistent, flat error response structure (removing the default 'detail' wrapping). The handler expects 'exc.detail' to be a dictionary containing the desired error structure with the key 'error' (e.g., {"error": {...}}). :param request: The incoming request object. :param exc: The HTTPException instance raised (e.g., in validate_single_or_batch). :return: A JSONResponse object with the custom error body and the correct status code. """ # Determine the content to be returned in the response body if isinstance(exc.detail, dict) and "error" in exc.detail: # If the detail matches the expected custom format (e.g., {"error": {...}}), # use the entire dictionary as the response content. response_body = exc.detail else: # Fallback for unexpected or standard FastAPI error formats (e.g., detail is a string). # Ensures the response always adheres to the expected structure. response_body = { "error": { "code": "GENERIC_ERROR", "message": str(exc.detail), "details": {} } } return JSONResponse( status_code=exc.status_code, content=response_body ) # ==================== Endpoints ==================== @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": model is not None} @app.post("/general") async def general_extraction(request: GeneralRequest): """Named Entity Recognition endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) # Process batch or single texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: entities = model.predict_entities( text, request.entities, threshold=request.threshold, flat_ner=True ) output = [ { "entity": ent["label"], "span": ent["text"], "start": ent["start"], "end": ent["end"], "score": float(ent["score"]) } for ent in entities ] results.append(output) # Return based on input mode if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: handle_error(e) @app.post("/relation-extraction") async def relation_extraction(request: RelationExtractionRequest): """Relation Extraction endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: # First extract entities if provided if request.entities: entities = model.predict_entities( text, request.entities, threshold=request.threshold, flat_ner=True ) else: entities = [] # Extract relations relations = model.predict_relations( text, request.relations, threshold=request.threshold ) output = [ { "source": { "entity": rel.get("head_type", ""), "span": rel.get("head_text", ""), "start": rel.get("head_start", 0), "end": rel.get("head_end", 0), "score": float(rel.get("head_score", 0.0)) }, "relation": rel["label"], "target": { "entity": rel.get("tail_type", ""), "span": rel.get("tail_text", ""), "start": rel.get("tail_start", 0), "end": rel.get("tail_end", 0), "score": float(rel.get("tail_score", 0.0)) }, "score": float(rel["score"]) } for rel in relations ] results.append(output) if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: handle_error(e) @app.post("/summarization") async def summarization(request: SummarizationRequest): """Summarization endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: # For summarization, extract key phrases/sentences summaries = model.predict_entities( text, ["summary", "key point", "important information"], threshold=request.threshold, flat_ner=True ) output = [ { "text": summ["text"], "start": summ["start"], "end": summ["end"], "score": float(summ["score"]) } for summ in summaries ] results.append(output) if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: handle_error(e) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "8000")) uvicorn.run(app, host="0.0.0.0", port=port)