Add helper methods and endpoint implementations.

main
KKlochko 1 month ago
parent eab1d4ffd3
commit f2d1ec9957
Signed by: KKlochko
GPG Key ID: 572ECCD219BBA91B

@ -0,0 +1,72 @@
from fastapi import HTTPException, status
from typing import List, Optional, Union, Dict, Any, Tuple
def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]):
"""Validate that either text or texts is provided, but not both"""
if text is None and texts is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": "Either 'text' or 'texts' must be provided",
"details": {}
}
}
)
if text is not None and texts is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": "Provide either 'text' or 'texts', not both",
"details": {}
}
}
)
if texts is not None and len(texts) == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": "texts list cannot be empty",
"details": {}
}
}
)
return text if text is not None else texts
def handle_error(e: Exception):
"""Convert exceptions to consistent error format"""
if isinstance(e, HTTPException):
raise e
if isinstance(e, ValueError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": str(e),
"details": {}
}
}
)
# Internal server error
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"error": {
"code": "INTERNAL_ERROR",
"message": "An internal error occurred while processing your request",
"details": {}
}
}
)

@ -5,6 +5,7 @@ import os
import torch import torch
from gliner import GLiNER from gliner import GLiNER
from .models import * from .models import *
from .helpers import *
# Global model instance # Global model instance
model = None model = None
@ -30,7 +31,7 @@ app = FastAPI(
lifespan=lifespan lifespan=lifespan
) )
# ==================== Endpoints ====================
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""Health check endpoint""" """Health check endpoint"""
@ -38,21 +39,145 @@ async def health_check():
@app.post("/general") @app.post("/general")
async def general_extraction(request): async def general_extraction(request: GeneralRequest):
"""Named Entity Recognition endpoint""" """Named Entity Recognition endpoint"""
pass try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
# Process batch or single
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
entities = model.predict_entities(
text,
request.entities,
threshold=request.threshold,
flat_ner=True
)
output = [
{
"entity": ent["label"],
"span": ent["text"],
"start": ent["start"],
"end": ent["end"],
"score": float(ent["score"])
}
for ent in entities
]
results.append(output)
# Return based on input mode
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
@app.post("/relation-extraction") @app.post("/relation-extraction")
async def relation_extraction(request): async def relation_extraction(request: RelationExtractionRequest):
"""Relation Extraction endpoint""" """Relation Extraction endpoint"""
pass try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
# First extract entities if provided
if request.entities:
entities = model.predict_entities(
text,
request.entities,
threshold=request.threshold,
flat_ner=True
)
else:
entities = []
# Extract relations
relations = model.predict_relations(
text,
request.relations,
threshold=request.threshold
)
output = [
{
"source": {
"entity": rel.get("head_type", ""),
"span": rel.get("head_text", ""),
"start": rel.get("head_start", 0),
"end": rel.get("head_end", 0),
"score": float(rel.get("head_score", 0.0))
},
"relation": rel["label"],
"target": {
"entity": rel.get("tail_type", ""),
"span": rel.get("tail_text", ""),
"start": rel.get("tail_start", 0),
"end": rel.get("tail_end", 0),
"score": float(rel.get("tail_score", 0.0))
},
"score": float(rel["score"])
}
for rel in relations
]
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
@app.post("/summarization") @app.post("/summarization")
async def summarization(request): async def summarization(request: SummarizationRequest):
"""Summarization endpoint""" """Summarization endpoint"""
pass try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
results = []
for text in texts_to_process:
# For summarization, extract key phrases/sentences
summaries = model.predict_entities(
text,
["summary", "key point", "important information"],
threshold=request.threshold,
flat_ner=True
)
output = [
{
"text": summ["text"],
"start": summ["start"],
"end": summ["end"],
"score": float(summ["score"])
}
for summ in summaries
]
results.append(output)
if is_batch:
return {"outputs": results}
else:
return {"outputs": results[0]}
except Exception as e:
handle_error(e)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

Loading…
Cancel
Save