You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
434 lines
16 KiB
434 lines
16 KiB
from fastapi import FastAPI, HTTPException, status
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.requests import Request
|
|
from typing import List, Optional, Union, Dict, Any, Tuple
|
|
from contextlib import asynccontextmanager
|
|
import os
|
|
import torch
|
|
import logging
|
|
from gliner import GLiNER
|
|
from .models import *
|
|
from .helpers import *
|
|
|
|
# Global model instance
|
|
model = None
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Load model on startup, cleanup on shutdown"""
|
|
global model
|
|
print("Loading GLiNER model...")
|
|
model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = GLiNER.from_pretrained(model_name).to(device)
|
|
print(f"Model loaded on {device}")
|
|
yield
|
|
print("Shutting down...")
|
|
|
|
|
|
app = FastAPI(
|
|
title="GLiNER Inference Server",
|
|
description="Named Entity Recognition, Relation Extraction, and Summarization API",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
"""
|
|
Handles FastAPI's HTTPException to enforce a consistent, flat error
|
|
response structure (removing the default 'detail' wrapping).
|
|
|
|
The handler expects 'exc.detail' to be a dictionary containing the
|
|
desired error structure with the key 'error' (e.g., {"error": {...}}).
|
|
|
|
:param request: The incoming request object.
|
|
:param exc: The HTTPException instance raised (e.g., in validate_single_or_batch).
|
|
:return: A JSONResponse object with the custom error body and the correct status code.
|
|
"""
|
|
|
|
# Determine the content to be returned in the response body
|
|
if isinstance(exc.detail, dict) and "error" in exc.detail:
|
|
# If the detail matches the expected custom format (e.g., {"error": {...}}),
|
|
# use the entire dictionary as the response content.
|
|
response_body = exc.detail
|
|
else:
|
|
# Fallback for unexpected or standard FastAPI error formats (e.g., detail is a string).
|
|
# Ensures the response always adheres to the expected structure.
|
|
response_body = {
|
|
"error": {
|
|
"code": "GENERIC_ERROR",
|
|
"message": str(exc.detail),
|
|
"details": {}
|
|
}
|
|
}
|
|
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content=response_body
|
|
)
|
|
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
"""
|
|
Handles Pydantic validation errors (normally 422) and returns them as 400 Bad Request
|
|
with a consistent error structure.
|
|
|
|
Return first validation error.
|
|
"""
|
|
# Extract validation error details
|
|
errors = exc.errors()
|
|
|
|
# Get first error:
|
|
location = " -> ".join(str(l) for l in errors[0]["loc"])
|
|
error_message = f"{location}: {errors[0]['msg']}"
|
|
|
|
response_body = {
|
|
"error": {
|
|
"code": "INVALID_INPUT",
|
|
"message": error_message,
|
|
"details": {}
|
|
}
|
|
}
|
|
|
|
return JSONResponse(
|
|
# Changed from 422 to 400
|
|
status_code=400,
|
|
content=response_body
|
|
)
|
|
|
|
# ==================== Endpoints ====================
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "healthy", "model_loaded": model is not None}
|
|
|
|
|
|
@app.post("/general")
|
|
async def general_extraction(request: GeneralRequest):
|
|
"""General Named Entity Recognition endpoint with prompt support
|
|
|
|
Example usage:
|
|
{
|
|
"text": "Apple Inc. is located in Cupertino, California.",
|
|
"entities": ["organization", "location"],
|
|
"prompt": "Extract business entities:\n",
|
|
"threshold": 0.5
|
|
}
|
|
"""
|
|
try:
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
# Process batch or single
|
|
texts_to_process = input_data if is_batch else [input_data]
|
|
|
|
results = []
|
|
for text in texts_to_process:
|
|
# Calculate prompt offset for position adjustment
|
|
prompt_offset = len(request.prompt) if request.prompt else 0
|
|
|
|
# Apply prompt to text
|
|
text_with_prompt = add_prompt_to_text(text, request.prompt)
|
|
|
|
# Extract entities with prompt
|
|
entities = model.predict_entities(
|
|
text_with_prompt,
|
|
request.entities,
|
|
threshold=request.threshold,
|
|
flat_ner=True
|
|
)
|
|
|
|
output = []
|
|
for ent in entities:
|
|
# Skip entities found in the prompt itself
|
|
if ent["start"] < prompt_offset:
|
|
continue
|
|
|
|
# Adjust positions to be relative to original text (without prompt)
|
|
output.append({
|
|
"entity": ent["label"],
|
|
"span": ent["text"],
|
|
"start": ent["start"] - prompt_offset,
|
|
"end": ent["end"] - prompt_offset,
|
|
"score": float(ent["score"])
|
|
})
|
|
|
|
results.append(output)
|
|
|
|
# Return based on input mode
|
|
if is_batch:
|
|
return {"outputs": results}
|
|
else:
|
|
return {"outputs": results[0]}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in general: {str(e)}", exc_info=True)
|
|
handle_error(e)
|
|
|
|
|
|
@app.post("/relation-extraction")
|
|
async def relation_extraction(request: RelationExtractionRequest):
|
|
"""Relation Extraction endpoint
|
|
|
|
Example usage:
|
|
{
|
|
"text": "Microsoft was founded by Bill Gates",
|
|
"relations": ["founder", "inception date"],
|
|
"entities": ["organization", "person", "date"],
|
|
"prompt": "Extract business relationships:\n",
|
|
"threshold": 0.5
|
|
}
|
|
"""
|
|
try:
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
texts_to_process = input_data if is_batch else [input_data]
|
|
|
|
results = []
|
|
for text in texts_to_process:
|
|
output = []
|
|
|
|
# Calculate prompt offset for position adjustment
|
|
prompt_offset = len(request.prompt) if request.prompt else 0
|
|
|
|
# Apply prompt to text
|
|
text_with_prompt = add_prompt_to_text(text, request.prompt)
|
|
|
|
# Extract entities if specified
|
|
entities_dict = {}
|
|
if request.entities:
|
|
entities = model.predict_entities(
|
|
text_with_prompt,
|
|
request.entities,
|
|
threshold=request.threshold,
|
|
flat_ner=True
|
|
)
|
|
# Index entities by position for quick lookup
|
|
# Adjust positions to be relative to original text (without prompt)
|
|
for entity in entities:
|
|
# Only include entities that are in the actual text, not in the prompt
|
|
if entity["start"] >= prompt_offset:
|
|
adjusted_start = entity["start"] - prompt_offset
|
|
adjusted_end = entity["end"] - prompt_offset
|
|
key = (adjusted_start, adjusted_end)
|
|
entities_dict[key] = {
|
|
"label": entity["label"],
|
|
"text": entity["text"],
|
|
"start": adjusted_start,
|
|
"end": adjusted_end,
|
|
"score": entity.get("score", 0.0)
|
|
}
|
|
|
|
# Form relation labels according to GLiNER documentation
|
|
# Format: "entity_type <> relation_type"
|
|
relation_labels = []
|
|
|
|
if request.entities and request.rules:
|
|
# If there are rules with pairs_filter, create specific labels
|
|
for rule in request.rules:
|
|
if rule.pairs_filter:
|
|
for pair in rule.pairs_filter:
|
|
# Format: "source_entity <> relation"
|
|
label = f"{pair[0]} <> {rule.relation}"
|
|
relation_labels.append(label)
|
|
else:
|
|
# Without pairs filter - general relation
|
|
relation_labels.append(f"<> {rule.relation}")
|
|
elif request.entities:
|
|
# If there are entities but no rules, create combinations
|
|
for relation in request.relations:
|
|
for entity_type in request.entities:
|
|
label = f"{entity_type} <> {relation}"
|
|
relation_labels.append(label)
|
|
else:
|
|
# If there are no entities, just search for relations
|
|
for relation in request.relations:
|
|
relation_labels.append(f"<> {relation}")
|
|
|
|
# Remove duplicates
|
|
relation_labels = list(set(relation_labels))
|
|
|
|
# Extract relations using GLiNER with prompt
|
|
relations_output = model.predict_entities(
|
|
text_with_prompt,
|
|
relation_labels,
|
|
threshold=request.threshold,
|
|
flat_ner=True
|
|
)
|
|
|
|
# Format the results
|
|
for rel in relations_output:
|
|
# GLiNER returns a text span that represents the relation
|
|
# E.g.: "Bill Gates" for label "Microsoft <> founder"
|
|
|
|
# Skip relations found in the prompt itself
|
|
if rel["start"] < prompt_offset:
|
|
continue
|
|
|
|
# Adjust positions to be relative to original text
|
|
target_start = rel["start"] - prompt_offset
|
|
target_end = rel["end"] - prompt_offset
|
|
target_span = rel["text"]
|
|
|
|
# Parse the label to get entity type and relation
|
|
label_parts = rel["label"].split("<>")
|
|
|
|
if len(label_parts) == 2:
|
|
source_entity_type = label_parts[0].strip()
|
|
relation_type = label_parts[1].strip()
|
|
else:
|
|
source_entity_type = ""
|
|
relation_type = rel["label"]
|
|
|
|
# Attempt to find the source entity in context
|
|
# Find the closest entity of the specified type before the target
|
|
source_entity = None
|
|
if source_entity_type and entities_dict:
|
|
# Find entities of the specified type that precede the target
|
|
candidates = [
|
|
e for e in entities_dict.values()
|
|
if e["label"].lower() == source_entity_type.lower()
|
|
and e["end"] <= target_start
|
|
]
|
|
# Take the nearest one
|
|
if candidates:
|
|
source_entity = max(candidates, key=lambda x: x["end"])
|
|
|
|
# Determine target entity type
|
|
target_entity_type = ""
|
|
target_key = (target_start, target_end)
|
|
if target_key in entities_dict:
|
|
target_entity_type = entities_dict[target_key]["label"]
|
|
|
|
# Apply rules if present
|
|
if request.rules:
|
|
skip = False
|
|
for rule in request.rules:
|
|
if rule.relation == relation_type:
|
|
# Check distance
|
|
if rule.distance and source_entity:
|
|
distance = target_start - source_entity["end"]
|
|
if distance > rule.distance:
|
|
skip = True
|
|
break
|
|
|
|
# Check pairs_filter
|
|
if rule.pairs_filter and source_entity:
|
|
valid_pair = any(
|
|
source_entity["label"].lower() == pair[0].lower() and
|
|
target_entity_type.lower() == pair[1].lower()
|
|
for pair in rule.pairs_filter
|
|
)
|
|
if not valid_pair:
|
|
skip = True
|
|
break
|
|
|
|
if skip:
|
|
continue
|
|
|
|
# Form the output object
|
|
relation_obj = {
|
|
"source": {
|
|
"entity": source_entity["label"] if source_entity else source_entity_type,
|
|
"span": source_entity["text"] if source_entity else "",
|
|
"start": source_entity["start"] if source_entity else 0,
|
|
"end": source_entity["end"] if source_entity else 0,
|
|
"score": float(source_entity.get("score", 0.0)) if source_entity else 0.0
|
|
},
|
|
"relation": relation_type,
|
|
"target": {
|
|
"entity": target_entity_type,
|
|
"span": target_span,
|
|
"start": target_start,
|
|
"end": target_end,
|
|
"score": float(rel.get("score", 0.0))
|
|
},
|
|
"score": float(rel.get("score", 0.0))
|
|
}
|
|
|
|
output.append(relation_obj)
|
|
|
|
results.append(output)
|
|
|
|
if is_batch:
|
|
return {"outputs": results}
|
|
else:
|
|
return {"outputs": results[0]}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in relation extraction: {str(e)}", exc_info=True)
|
|
handle_error(e)
|
|
|
|
|
|
@app.post("/summarization")
|
|
async def summarization(request: SummarizationRequest):
|
|
"""Summarization endpoint with prompt support
|
|
|
|
Example usage:
|
|
{
|
|
"text": "Long article text here...",
|
|
"prompt": "Extract the most important points:\n",
|
|
"threshold": 0.5
|
|
}
|
|
"""
|
|
try:
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
texts_to_process = input_data if is_batch else [input_data]
|
|
|
|
results = []
|
|
for text in texts_to_process:
|
|
# Calculate prompt offset for position adjustment
|
|
prompt_offset = len(request.prompt) if request.prompt else 0
|
|
|
|
# Apply prompt to text
|
|
text_with_prompt = add_prompt_to_text(text, request.prompt)
|
|
|
|
# For summarization, extract key phrases/sentences
|
|
summaries = model.predict_entities(
|
|
text_with_prompt,
|
|
["summary", "key point", "important information"],
|
|
threshold=request.threshold,
|
|
flat_ner=True
|
|
)
|
|
|
|
output = []
|
|
for summ in summaries:
|
|
# Skip summaries found in the prompt itself
|
|
if summ["start"] < prompt_offset:
|
|
continue
|
|
|
|
# Adjust positions to be relative to original text (without prompt)
|
|
output.append({
|
|
"text": summ["text"],
|
|
"start": summ["start"] - prompt_offset,
|
|
"end": summ["end"] - prompt_offset,
|
|
"score": float(summ["score"])
|
|
})
|
|
|
|
results.append(output)
|
|
|
|
if is_batch:
|
|
return {"outputs": results}
|
|
else:
|
|
return {"outputs": results[0]}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in summarization: {str(e)}", exc_info=True)
|
|
handle_error(e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
port = int(os.getenv("PORT", "8000"))
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|