Add validations to test if combined prompt and text length exceed max.

main
KKlochko 1 month ago
parent dfab9d3df0
commit 09d88ef949
Signed by: KKlochko
GPG Key ID: 572ECCD219BBA91B

@ -1,5 +1,9 @@
from fastapi import HTTPException, status
from typing import List, Optional, Union, Dict, Any, Tuple
import os
MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "1000"))
def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]):
@ -69,6 +73,48 @@ def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]):
return text if text is not None else texts
def validate_prompt_and_text_length(prompt: Optional[str], text: Optional[str], texts: Optional[List[str]]):
"""Validate that combined prompt and text length doesn't exceed maximum"""
prompt_length = len(prompt) if prompt else 0
if text:
combined_length = prompt_length + len(text)
if combined_length > MAX_TEXT_LENGTH:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": f"Combined prompt and text length ({combined_length}) exceeds maximum length of {MAX_TEXT_LENGTH}",
"details": {
"prompt_length": prompt_length,
"text_length": len(text),
"max_length": MAX_TEXT_LENGTH
}
}
}
)
if texts:
for i, txt in enumerate(texts):
combined_length = prompt_length + len(txt)
if combined_length > MAX_TEXT_LENGTH:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": f"Combined prompt and texts[{i}] length ({combined_length}) exceeds maximum length of {MAX_TEXT_LENGTH}",
"details": {
"prompt_length": prompt_length,
"text_index": i,
"text_length": len(txt),
"max_length": MAX_TEXT_LENGTH
}
}
}
)
def handle_error(e: Exception):
"""Convert exceptions to consistent error format"""
if isinstance(e, HTTPException):

@ -126,6 +126,7 @@ async def general_extraction(request: GeneralRequest):
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list)
# Process batch or single
@ -190,6 +191,7 @@ async def relation_extraction(request: RelationExtractionRequest):
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]
@ -383,6 +385,7 @@ async def summarization(request: SummarizationRequest):
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data]

Loading…
Cancel
Save