from fastapi import FastAPI, HTTPException, status from fastapi.responses import JSONResponse from starlette.requests import Request from typing import List, Optional, Union, Dict, Any, Tuple from contextlib import asynccontextmanager import os import torch import logging from gliner import GLiNER from .models import * from .helpers import * # Global model instance model = None logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup, cleanup on shutdown""" global model print("Loading GLiNER model...") model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5") device = "cuda" if torch.cuda.is_available() else "cpu" model = GLiNER.from_pretrained(model_name).to(device) print(f"Model loaded on {device}") yield print("Shutting down...") app = FastAPI( title="GLiNER Inference Server", description="Named Entity Recognition, Relation Extraction, and Summarization API", version="1.0.0", lifespan=lifespan ) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """ Handles FastAPI's HTTPException to enforce a consistent, flat error response structure (removing the default 'detail' wrapping). The handler expects 'exc.detail' to be a dictionary containing the desired error structure with the key 'error' (e.g., {"error": {...}}). :param request: The incoming request object. :param exc: The HTTPException instance raised (e.g., in validate_single_or_batch). :return: A JSONResponse object with the custom error body and the correct status code. """ # Determine the content to be returned in the response body if isinstance(exc.detail, dict) and "error" in exc.detail: # If the detail matches the expected custom format (e.g., {"error": {...}}), # use the entire dictionary as the response content. response_body = exc.detail else: # Fallback for unexpected or standard FastAPI error formats (e.g., detail is a string). # Ensures the response always adheres to the expected structure. response_body = { "error": { "code": "GENERIC_ERROR", "message": str(exc.detail), "details": {} } } return JSONResponse( status_code=exc.status_code, content=response_body ) # ==================== Endpoints ==================== @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": model is not None} @app.post("/general") async def general_extraction(request: GeneralRequest): """Named Entity Recognition endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) # Process batch or single texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: entities = model.predict_entities( text, request.entities, threshold=request.threshold, flat_ner=True ) output = [ { "entity": ent["label"], "span": ent["text"], "start": ent["start"], "end": ent["end"], "score": float(ent["score"]) } for ent in entities ] results.append(output) # Return based on input mode if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: logger.error(f"Error in general: {str(e)}", exc_info=True) handle_error(e) @app.post("/relation-extraction") async def relation_extraction(request: RelationExtractionRequest): """Relation Extraction endpoint Example usage: { "text": "Microsoft was founded by Bill Gates", "relations": ["founder", "inception date"], "entities": ["organization", "person", "date"], "threshold": 0.5 } """ try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: output = [] # Extract entities if specified entities_dict = {} if request.entities: entities = model.predict_entities( text, request.entities, threshold=request.threshold, flat_ner=True ) # Index entities by position for quick lookup for entity in entities: key = (entity["start"], entity["end"]) entities_dict[key] = entity # Form relation labels according to GLiNER documentation # Format: "entity_type <> relation_type" relation_labels = [] if request.entities and request.rules: # If there are rules with pairs_filter, create specific labels for rule in request.rules: if rule.pairs_filter: for pair in rule.pairs_filter: # Format: "source_entity <> relation" label = f"{pair[0]} <> {rule.relation}" relation_labels.append(label) else: # Without pairs filter - general relation relation_labels.append(f"<> {rule.relation}") elif request.entities: # If there are entities but no rules, create combinations for relation in request.relations: for entity_type in request.entities: label = f"{entity_type} <> {relation}" relation_labels.append(label) else: # If there are no entities, just search for relations for relation in request.relations: relation_labels.append(f"<> {relation}") # Remove duplicates relation_labels = list(set(relation_labels)) # Extract relations using GLiNER relations_output = model.predict_entities( text, relation_labels, threshold=request.threshold, flat_ner=True ) # Format the results for rel in relations_output: # GLiNER returns a text span that represents the relation # E.g.: "Bill Gates" for label "Microsoft <> founder" # Parse the label to get entity type and relation label_parts = rel["label"].split("<>") if len(label_parts) == 2: source_entity_type = label_parts[0].strip() relation_type = label_parts[1].strip() else: source_entity_type = "" relation_type = rel["label"] # The extracted span is typically the target entity target_span = rel["text"] target_start = rel["start"] target_end = rel["end"] # Attempt to find the source entity in context # Find the closest entity of the specified type before the target source_entity = None if source_entity_type and entities_dict: # Find entities of the specified type that precede the target candidates = [ e for e in entities_dict.values() if e["label"].lower() == source_entity_type.lower() and e["end"] <= target_start ] # Take the nearest one if candidates: source_entity = max(candidates, key=lambda x: x["end"]) # Determine target entity type target_entity_type = "" target_key = (target_start, target_end) if target_key in entities_dict: target_entity_type = entities_dict[target_key]["label"] # Apply rules if present if request.rules: skip = False for rule in request.rules: if rule.relation == relation_type: # Check distance if rule.distance and source_entity: distance = target_start - source_entity["end"] if distance > rule.distance: skip = True break # Check pairs_filter if rule.pairs_filter and source_entity: valid_pair = any( source_entity["label"].lower() == pair[0].lower() and target_entity_type.lower() == pair[1].lower() for pair in rule.pairs_filter ) if not valid_pair: skip = True break if skip: continue # Form the output object relation_obj = { "source": { "entity": source_entity["label"] if source_entity else source_entity_type, "span": source_entity["text"] if source_entity else "", "start": source_entity["start"] if source_entity else 0, "end": source_entity["end"] if source_entity else 0, "score": float(source_entity.get("score", 0.0)) if source_entity else 0.0 }, "relation": relation_type, "target": { "entity": target_entity_type, "span": target_span, "start": target_start, "end": target_end, "score": float(rel.get("score", 0.0)) }, "score": float(rel.get("score", 0.0)) } output.append(relation_obj) results.append(output) if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: logger.error(f"Error in relation extraction: {str(e)}", exc_info=True) handle_error(e) @app.post("/summarization") async def summarization(request: SummarizationRequest): """Summarization endpoint""" try: input_data = validate_single_or_batch(request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data] results = [] for text in texts_to_process: # For summarization, extract key phrases/sentences summaries = model.predict_entities( text, ["summary", "key point", "important information"], threshold=request.threshold, flat_ner=True ) output = [ { "text": summ["text"], "start": summ["start"], "end": summ["end"], "score": float(summ["score"]) } for summ in summaries ] results.append(output) if is_batch: return {"outputs": results} else: return {"outputs": results[0]} except Exception as e: logger.error(f"Error in summarization: {str(e)}", exc_info=True) handle_error(e) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "8000")) uvicorn.run(app, host="0.0.0.0", port=port)