You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

440 lines
16 KiB

from fastapi import FastAPI, HTTPException, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.requests import Request
from typing import List, Optional, Union, Dict, Any, Tuple
from contextlib import asynccontextmanager
import os
import torch
import logging
from gliner import GLiNER
from .models import *
from .helpers import *
# Global model instance
model = None
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup, cleanup on shutdown"""
global model
print("Loading GLiNER model...")
model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GLiNER.from_pretrained(
model_name,
local_files_only = True
).to(device)
print(f"Model loaded on {device}")
yield
print("Shutting down...")
app = FastAPI(
title="GLiNER Inference Server",
description="Named Entity Recognition, Relation Extraction, and Summarization API",
version="1.0.0",
lifespan=lifespan
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""
Handles FastAPI's HTTPException to enforce a consistent, flat error
response structure (removing the default 'detail' wrapping).
The handler expects 'exc.detail' to be a dictionary containing the
desired error structure with the key 'error' (e.g., {"error": {...}}).
:param request: The incoming request object.
:param exc: The HTTPException instance raised (e.g., in validate_single_or_batch).
:return: A JSONResponse object with the custom error body and the correct status code.
"""
# Determine the content to be returned in the response body
if isinstance(exc.detail, dict) and "error" in exc.detail:
# If the detail matches the expected custom format (e.g., {"error": {...}}),
# use the entire dictionary as the response content.
response_body = exc.detail
else:
# Fallback for unexpected or standard FastAPI error formats (e.g., detail is a string).
# Ensures the response always adheres to the expected structure.
response_body = {
"error": {
"code": "GENERIC_ERROR",
"message": str(exc.detail),
"details": {}
}
}
return JSONResponse(
status_code=exc.status_code,
content=response_body
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""
Handles Pydantic validation errors (normally 422) and returns them as 400 Bad Request
with a consistent error structure.
Return first validation error.
"""
# Extract validation error details
errors = exc.errors()
# Get first error:
location = " -> ".join(str(l) for l in errors[0]["loc"])
error_message = f"{location}: {errors[0]['msg']}"
response_body = {
"error": {
"code": "INVALID_INPUT",
"message": error_message,
"details": {}
}
}
return JSONResponse(
# Changed from 422 to 400
status_code=400,
content=response_body
)
# ==================== Endpoints ====================
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": model is not None}
@app.post("/general")
async def general_extraction(request: GeneralRequest):
"""General Named Entity Recognition endpoint with prompt support
Example usage:
{
"text": "Apple Inc. is located in Cupertino, California.",
"entities": ["organization", "location"],
"prompt": "Extract business entities:\n",
"threshold": 0.5
}
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list)
# Process batch or single
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# Extract entities with prompt
entities = model.predict_entities(
text_with_prompt,
request.entities,
threshold=request.threshold,
flat_ner=True
)
output = []
for ent in entities:
# Skip entities found in the prompt itself
if ent["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text (without prompt)
output.append({
"entity": ent["label"],
"span": ent["text"],
"start": ent["start"] - prompt_offset,
"end": ent["end"] - prompt_offset,
"score": float(ent["score"])
})
results.append(output)
# Return based on input mode
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
logger.error(f"Error in general: {str(e)}", exc_info=True)
handle_error(e)
@app.post("/relation-extraction")
async def relation_extraction(request: RelationExtractionRequest):
"""Relation Extraction endpoint
Example usage:
{
"text": "Microsoft was founded by Bill Gates",
"relations": ["founder", "inception date"],
"entities": ["organization", "person", "date"],
"prompt": "Extract business relationships:\n",
"threshold": 0.5
}
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
output = []
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# Extract entities if specified
entities_dict = {}
if request.entities:
entities = model.predict_entities(
text_with_prompt,
request.entities,
threshold=request.threshold,
flat_ner=True
)
# Index entities by position for quick lookup
# Adjust positions to be relative to original text (without prompt)
for entity in entities:
# Only include entities that are in the actual text, not in the prompt
if entity["start"] >= prompt_offset:
adjusted_start = entity["start"] - prompt_offset
adjusted_end = entity["end"] - prompt_offset
key = (adjusted_start, adjusted_end)
entities_dict[key] = {
"label": entity["label"],
"text": entity["text"],
"start": adjusted_start,
"end": adjusted_end,
"score": entity.get("score", 0.0)
}
# Form relation labels according to GLiNER documentation
# Format: "entity_type <> relation_type"
relation_labels = []
if request.entities and request.rules:
# If there are rules with pairs_filter, create specific labels
for rule in request.rules:
if rule.pairs_filter:
for pair in rule.pairs_filter:
# Format: "source_entity <> relation"
label = f"{pair[0]} <> {rule.relation}"
relation_labels.append(label)
else:
# Without pairs filter - general relation
relation_labels.append(f"<> {rule.relation}")
elif request.entities:
# If there are entities but no rules, create combinations
for relation in request.relations:
for entity_type in request.entities:
label = f"{entity_type} <> {relation}"
relation_labels.append(label)
else:
# If there are no entities, just search for relations
for relation in request.relations:
relation_labels.append(f"<> {relation}")
# Remove duplicates
relation_labels = list(set(relation_labels))
# Extract relations using GLiNER with prompt
relations_output = model.predict_entities(
text_with_prompt,
relation_labels,
threshold=request.threshold,
flat_ner=True
)
# Format the results
for rel in relations_output:
# GLiNER returns a text span that represents the relation
# E.g.: "Bill Gates" for label "Microsoft <> founder"
# Skip relations found in the prompt itself
if rel["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text
target_start = rel["start"] - prompt_offset
target_end = rel["end"] - prompt_offset
target_span = rel["text"]
# Parse the label to get entity type and relation
label_parts = rel["label"].split("<>")
if len(label_parts) == 2:
source_entity_type = label_parts[0].strip()
relation_type = label_parts[1].strip()
else:
source_entity_type = ""
relation_type = rel["label"]
# Attempt to find the source entity in context
# Find the closest entity of the specified type before the target
source_entity = None
if source_entity_type and entities_dict:
# Find entities of the specified type that precede the target
candidates = [
e for e in entities_dict.values()
if e["label"].lower() == source_entity_type.lower()
and e["end"] <= target_start
]
# Take the nearest one
if candidates:
source_entity = max(candidates, key=lambda x: x["end"])
# Determine target entity type
target_entity_type = ""
target_key = (target_start, target_end)
if target_key in entities_dict:
target_entity_type = entities_dict[target_key]["label"]
# Apply rules if present
if request.rules:
skip = False
for rule in request.rules:
if rule.relation == relation_type:
# Check distance
if rule.distance and source_entity:
distance = target_start - source_entity["end"]
if distance > rule.distance:
skip = True
break
# Check pairs_filter
if rule.pairs_filter and source_entity:
valid_pair = any(
source_entity["label"].lower() == pair[0].lower() and
target_entity_type.lower() == pair[1].lower()
for pair in rule.pairs_filter
)
if not valid_pair:
skip = True
break
if skip:
continue
# Form the output object
relation_obj = {
"source": {
"entity": source_entity["label"] if source_entity else source_entity_type,
"span": source_entity["text"] if source_entity else "",
"start": source_entity["start"] if source_entity else 0,
"end": source_entity["end"] if source_entity else 0,
"score": float(source_entity.get("score", 0.0)) if source_entity else 0.0
},
"relation": relation_type,
"target": {
"entity": target_entity_type,
"span": target_span,
"start": target_start,
"end": target_end,
"score": float(rel.get("score", 0.0))
},
"score": float(rel.get("score", 0.0))
}
output.append(relation_obj)
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
logger.error(f"Error in relation extraction: {str(e)}", exc_info=True)
handle_error(e)
@app.post("/summarization")
async def summarization(request: SummarizationRequest):
"""Summarization endpoint with prompt support
Example usage:
{
"text": "Long article text here...",
"prompt": "Extract the most important points:\n",
"threshold": 0.5
}
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# For summarization, extract key phrases/sentences
summaries = model.predict_entities(
text_with_prompt,
["summary", "key point", "important information"],
threshold=request.threshold,
flat_ner=True
)
output = []
for summ in summaries:
# Skip summaries found in the prompt itself
if summ["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text (without prompt)
output.append({
"text": summ["text"],
"start": summ["start"] - prompt_offset,
"end": summ["end"] - prompt_offset,
"score": float(summ["score"])
})
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
logger.error(f"Error in summarization: {str(e)}", exc_info=True)
handle_error(e)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)