from fastapi import FastAPI, HTTPException, status from typing import List, Optional, Union, Dict, Any, Tuple from contextlib import asynccontextmanager import os import torch from gliner import GLiNER from .models import * from .helpers import * # Global model instance model = None @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup, cleanup on shutdown""" global model print("Loading GLiNER model...") model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5") device = "cuda" if torch.cuda.is_available() else "cpu" model = GLiNER.from_pretrained(model_name).to(device) print(f"Model loaded on {device}") yield print("Shutting down...") app = FastAPI( title="GLiNER Inference Server", description="Named Entity Recognition, Relation Extraction, and Summarization API", version="1.0.0", lifespan=lifespan ) # ==================== Endpoints ==================== @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": model is not None} @app.post("/general") async def general_extraction(request: GeneralRequest): """Named Entity Recognition endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) # Process batch or single texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: entities = model.predict_entities( text, request.entities, threshold=request.threshold, flat_ner=True ) output = [ { "entity": ent["label"], "span": ent["text"], "start": ent["start"], "end": ent["end"], "score": float(ent["score"]) } for ent in entities ] results.append(output) # Return based on input mode if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: handle_error(e) @app.post("/relation-extraction") async def relation_extraction(request: RelationExtractionRequest): """Relation Extraction endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: # First extract entities if provided if request.entities: entities = model.predict_entities( text, request.entities, threshold=request.threshold, flat_ner=True ) else: entities = [] # Extract relations relations = model.predict_relations( text, request.relations, threshold=request.threshold ) output = [ { "source": { "entity": rel.get("head_type", ""), "span": rel.get("head_text", ""), "start": rel.get("head_start", 0), "end": rel.get("head_end", 0), "score": float(rel.get("head_score", 0.0)) }, "relation": rel["label"], "target": { "entity": rel.get("tail_type", ""), "span": rel.get("tail_text", ""), "start": rel.get("tail_start", 0), "end": rel.get("tail_end", 0), "score": float(rel.get("tail_score", 0.0)) }, "score": float(rel["score"]) } for rel in relations ] results.append(output) if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: handle_error(e) @app.post("/summarization") async def summarization(request: SummarizationRequest): """Summarization endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: # For summarization, extract key phrases/sentences summaries = model.predict_entities( text, ["summary", "key point", "important information"], threshold=request.threshold, flat_ner=True ) output = [ { "text": summ["text"], "start": summ["start"], "end": summ["end"], "score": float(summ["score"]) } for summ in summaries ] results.append(output) if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: handle_error(e) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "8000")) uvicorn.run(app, host="0.0.0.0", port=port)