You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

187 lines
5.5 KiB

from fastapi import FastAPI, HTTPException, status
from typing import List, Optional, Union, Dict, Any, Tuple
from contextlib import asynccontextmanager
import os
import torch
from gliner import GLiNER
from .models import *
from .helpers import *
# Global model instance
model = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup, cleanup on shutdown"""
global model
print("Loading GLiNER model...")
model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GLiNER.from_pretrained(model_name).to(device)
print(f"Model loaded on {device}")
yield
print("Shutting down...")
app = FastAPI(
title="GLiNER Inference Server",
description="Named Entity Recognition, Relation Extraction, and Summarization API",
version="1.0.0",
lifespan=lifespan
)
# ==================== Endpoints ====================
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": model is not None}
@app.post("/general")
async def general_extraction(request: GeneralRequest):
"""Named Entity Recognition endpoint"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
# Process batch or single
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
entities = model.predict_entities(
text,
request.entities,
threshold=request.threshold,
flat_ner=True
)
output = [
{
"entity": ent["label"],
"span": ent["text"],
"start": ent["start"],
"end": ent["end"],
"score": float(ent["score"])
}
for ent in entities
]
results.append(output)
# Return based on input mode
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
@app.post("/relation-extraction")
async def relation_extraction(request: RelationExtractionRequest):
"""Relation Extraction endpoint"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
# First extract entities if provided
if request.entities:
entities = model.predict_entities(
text,
request.entities,
threshold=request.threshold,
flat_ner=True
)
else:
entities = []
# Extract relations
relations = model.predict_relations(
text,
request.relations,
threshold=request.threshold
)
output = [
{
"source": {
"entity": rel.get("head_type", ""),
"span": rel.get("head_text", ""),
"start": rel.get("head_start", 0),
"end": rel.get("head_end", 0),
"score": float(rel.get("head_score", 0.0))
},
"relation": rel["label"],
"target": {
"entity": rel.get("tail_type", ""),
"span": rel.get("tail_text", ""),
"start": rel.get("tail_start", 0),
"end": rel.get("tail_end", 0),
"score": float(rel.get("tail_score", 0.0))
},
"score": float(rel["score"])
}
for rel in relations
]
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
@app.post("/summarization")
async def summarization(request: SummarizationRequest):
"""Summarization endpoint"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
# For summarization, extract key phrases/sentences
summaries = model.predict_entities(
text,
["summary", "key point", "important information"],
threshold=request.threshold,
flat_ner=True
)
output = [
{
"text": summ["text"],
"start": summ["start"],
"end": summ["end"],
"score": float(summ["score"])
}
for summ in summaries
]
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)