You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
187 lines
5.5 KiB
187 lines
5.5 KiB
from fastapi import FastAPI, HTTPException, status
|
|
from typing import List, Optional, Union, Dict, Any, Tuple
|
|
from contextlib import asynccontextmanager
|
|
import os
|
|
import torch
|
|
from gliner import GLiNER
|
|
from .models import *
|
|
from .helpers import *
|
|
|
|
# Global model instance
|
|
model = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Load model on startup, cleanup on shutdown"""
|
|
global model
|
|
print("Loading GLiNER model...")
|
|
model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = GLiNER.from_pretrained(model_name).to(device)
|
|
print(f"Model loaded on {device}")
|
|
yield
|
|
print("Shutting down...")
|
|
|
|
|
|
app = FastAPI(
|
|
title="GLiNER Inference Server",
|
|
description="Named Entity Recognition, Relation Extraction, and Summarization API",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# ==================== Endpoints ====================
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "healthy", "model_loaded": model is not None}
|
|
|
|
|
|
@app.post("/general")
|
|
async def general_extraction(request: GeneralRequest):
|
|
"""Named Entity Recognition endpoint"""
|
|
try:
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
# Process batch or single
|
|
texts_to_process = input_data if is_batch else [input_data]
|
|
|
|
results = []
|
|
for text in texts_to_process:
|
|
entities = model.predict_entities(
|
|
text,
|
|
request.entities,
|
|
threshold=request.threshold,
|
|
flat_ner=True
|
|
)
|
|
|
|
output = [
|
|
{
|
|
"entity": ent["label"],
|
|
"span": ent["text"],
|
|
"start": ent["start"],
|
|
"end": ent["end"],
|
|
"score": float(ent["score"])
|
|
}
|
|
for ent in entities
|
|
]
|
|
results.append(output)
|
|
|
|
# Return based on input mode
|
|
if is_batch:
|
|
return {"outputs": results}
|
|
else:
|
|
return {"outputs": results[0]}
|
|
|
|
except Exception as e:
|
|
handle_error(e)
|
|
|
|
|
|
@app.post("/relation-extraction")
|
|
async def relation_extraction(request: RelationExtractionRequest):
|
|
"""Relation Extraction endpoint"""
|
|
try:
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
texts_to_process = input_data if is_batch else [input_data]
|
|
|
|
results = []
|
|
for text in texts_to_process:
|
|
# First extract entities if provided
|
|
if request.entities:
|
|
entities = model.predict_entities(
|
|
text,
|
|
request.entities,
|
|
threshold=request.threshold,
|
|
flat_ner=True
|
|
)
|
|
else:
|
|
entities = []
|
|
|
|
# Extract relations
|
|
relations = model.predict_relations(
|
|
text,
|
|
request.relations,
|
|
threshold=request.threshold
|
|
)
|
|
|
|
output = [
|
|
{
|
|
"source": {
|
|
"entity": rel.get("head_type", ""),
|
|
"span": rel.get("head_text", ""),
|
|
"start": rel.get("head_start", 0),
|
|
"end": rel.get("head_end", 0),
|
|
"score": float(rel.get("head_score", 0.0))
|
|
},
|
|
"relation": rel["label"],
|
|
"target": {
|
|
"entity": rel.get("tail_type", ""),
|
|
"span": rel.get("tail_text", ""),
|
|
"start": rel.get("tail_start", 0),
|
|
"end": rel.get("tail_end", 0),
|
|
"score": float(rel.get("tail_score", 0.0))
|
|
},
|
|
"score": float(rel["score"])
|
|
}
|
|
for rel in relations
|
|
]
|
|
results.append(output)
|
|
|
|
if is_batch:
|
|
return {"outputs": results}
|
|
else:
|
|
return {"outputs": results[0]}
|
|
|
|
except Exception as e:
|
|
handle_error(e)
|
|
|
|
|
|
@app.post("/summarization")
|
|
async def summarization(request: SummarizationRequest):
|
|
"""Summarization endpoint"""
|
|
try:
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
texts_to_process = input_data if is_batch else [input_data]
|
|
|
|
results = []
|
|
for text in texts_to_process:
|
|
# For summarization, extract key phrases/sentences
|
|
summaries = model.predict_entities(
|
|
text,
|
|
["summary", "key point", "important information"],
|
|
threshold=request.threshold,
|
|
flat_ner=True
|
|
)
|
|
|
|
output = [
|
|
{
|
|
"text": summ["text"],
|
|
"start": summ["start"],
|
|
"end": summ["end"],
|
|
"score": float(summ["score"])
|
|
}
|
|
for summ in summaries
|
|
]
|
|
results.append(output)
|
|
|
|
if is_batch:
|
|
return {"outputs": results}
|
|
else:
|
|
return {"outputs": results[0]}
|
|
|
|
except Exception as e:
|
|
handle_error(e)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
port = int(os.getenv("PORT", "8000"))
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|