You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

335 lines
12 KiB

from fastapi import FastAPI, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.requests import Request
from typing import List, Optional, Union, Dict, Any, Tuple
from contextlib import asynccontextmanager
import os
import torch
from gliner import GLiNER
from .models import *
from .helpers import *
# Global model instance
model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup, cleanup on shutdown"""
global model
print("Loading GLiNER model...")
model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GLiNER.from_pretrained(model_name).to(device)
print(f"Model loaded on {device}")
yield
print("Shutting down...")
app = FastAPI(
title="GLiNER Inference Server",
description="Named Entity Recognition, Relation Extraction, and Summarization API",
version="1.0.0",
lifespan=lifespan
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""
Handles FastAPI's HTTPException to enforce a consistent, flat error
response structure (removing the default 'detail' wrapping).
The handler expects 'exc.detail' to be a dictionary containing the
desired error structure with the key 'error' (e.g., {"error": {...}}).
:param request: The incoming request object.
:param exc: The HTTPException instance raised (e.g., in validate_single_or_batch).
:return: A JSONResponse object with the custom error body and the correct status code.
"""
# Determine the content to be returned in the response body
if isinstance(exc.detail, dict) and "error" in exc.detail:
# If the detail matches the expected custom format (e.g., {"error": {...}}),
# use the entire dictionary as the response content.
response_body = exc.detail
else:
# Fallback for unexpected or standard FastAPI error formats (e.g., detail is a string).
# Ensures the response always adheres to the expected structure.
response_body = {
"error": {
"code": "GENERIC_ERROR",
"message": str(exc.detail),
"details": {}
}
}
return JSONResponse(
status_code=exc.status_code,
content=response_body
)
# ==================== Endpoints ====================
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": model is not None}
@app.post("/general")
async def general_extraction(request: GeneralRequest):
"""Named Entity Recognition endpoint"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
# Process batch or single
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
entities = model.predict_entities(
text,
request.entities,
threshold=request.threshold,
flat_ner=True
)
output = [
{
"entity": ent["label"],
"span": ent["text"],
"start": ent["start"],
"end": ent["end"],
"score": float(ent["score"])
}
for ent in entities
]
results.append(output)
# Return based on input mode
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
@app.post("/relation-extraction")
async def relation_extraction(request: RelationExtractionRequest):
"""Relation Extraction endpoint
Example usage:
{
"text": "Microsoft was founded by Bill Gates",
"relations": ["founder", "inception date"],
"entities": ["organization", "person", "date"],
"threshold": 0.5
}
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
output = []
# Extract entities if specified
entities_dict = {}
if request.entities:
entities = model.predict_entities(
text,
request.entities,
threshold=request.threshold,
flat_ner=True
)
# Index entities by position for quick lookup
for entity in entities:
key = (entity["start"], entity["end"])
entities_dict[key] = entity
# Form relation labels according to GLiNER documentation
# Format: "entity_type <> relation_type"
relation_labels = []
if request.entities and request.rules:
# If there are rules with pairs_filter, create specific labels
for rule in request.rules:
if rule.pairs_filter:
for pair in rule.pairs_filter:
# Format: "source_entity <> relation"
label = f"{pair[0]} <> {rule.relation}"
relation_labels.append(label)
else:
# Without pairs filter - general relation
relation_labels.append(f"<> {rule.relation}")
elif request.entities:
# If there are entities but no rules, create combinations
for relation in request.relations:
for entity_type in request.entities:
label = f"{entity_type} <> {relation}"
relation_labels.append(label)
else:
# If there are no entities, just search for relations
for relation in request.relations:
relation_labels.append(f"<> {relation}")
# Remove duplicates
relation_labels = list(set(relation_labels))
# Extract relations using GLiNER
relations_output = model.predict_entities(
text,
relation_labels,
threshold=request.threshold,
flat_ner=True
)
# Format the results
for rel in relations_output:
# GLiNER returns a text span that represents the relation
# E.g.: "Bill Gates" for label "Microsoft <> founder"
# Parse the label to get entity type and relation
label_parts = rel["label"].split("<>")
if len(label_parts) == 2:
source_entity_type = label_parts[0].strip()
relation_type = label_parts[1].strip()
else:
source_entity_type = ""
relation_type = rel["label"]
# The extracted span is typically the target entity
target_span = rel["text"]
target_start = rel["start"]
target_end = rel["end"]
# Attempt to find the source entity in context
# Find the closest entity of the specified type before the target
source_entity = None
if source_entity_type and entities_dict:
# Find entities of the specified type that precede the target
candidates = [
e for e in entities_dict.values()
if e["label"].lower() == source_entity_type.lower()
and e["end"] <= target_start
]
# Take the nearest one
if candidates:
source_entity = max(candidates, key=lambda x: x["end"])
# Determine target entity type
target_entity_type = ""
target_key = (target_start, target_end)
if target_key in entities_dict:
target_entity_type = entities_dict[target_key]["label"]
# Apply rules if present
if request.rules:
skip = False
for rule in request.rules:
if rule.relation == relation_type:
# Check distance
if rule.distance and source_entity:
distance = target_start - source_entity["end"]
if distance > rule.distance:
skip = True
break
# Check pairs_filter
if rule.pairs_filter and source_entity:
valid_pair = any(
source_entity["label"].lower() == pair[0].lower() and
target_entity_type.lower() == pair[1].lower()
for pair in rule.pairs_filter
)
if not valid_pair:
skip = True
break
if skip:
continue
# Form the output object
relation_obj = {
"source": {
"entity": source_entity["label"] if source_entity else source_entity_type,
"span": source_entity["text"] if source_entity else "",
"start": source_entity["start"] if source_entity else 0,
"end": source_entity["end"] if source_entity else 0,
"score": float(source_entity.get("score", 0.0)) if source_entity else 0.0
},
"relation": relation_type,
"target": {
"entity": target_entity_type,
"span": target_span,
"start": target_start,
"end": target_end,
"score": float(rel.get("score", 0.0))
},
"score": float(rel.get("score", 0.0))
}
output.append(relation_obj)
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
@app.post("/summarization")
async def summarization(request: SummarizationRequest):
"""Summarization endpoint"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
# For summarization, extract key phrases/sentences
summaries = model.predict_entities(
text,
["summary", "key point", "important information"],
threshold=request.threshold,
flat_ner=True
)
output = [
{
"text": summ["text"],
"start": summ["start"],
"end": summ["end"],
"score": float(summ["score"])
}
for summ in summaries
]
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)