Add validations to test if combined prompt and text length exceed max.

main
KKlochko 1 month ago
parent dfab9d3df0
commit 09d88ef949
Signed by: KKlochko
GPG Key ID: 572ECCD219BBA91B

@ -1,5 +1,9 @@
from fastapi import HTTPException, status from fastapi import HTTPException, status
from typing import List, Optional, Union, Dict, Any, Tuple from typing import List, Optional, Union, Dict, Any, Tuple
import os
MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "1000"))
def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]): def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]):
@ -69,6 +73,48 @@ def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]):
return text if text is not None else texts return text if text is not None else texts
def validate_prompt_and_text_length(prompt: Optional[str], text: Optional[str], texts: Optional[List[str]]):
"""Validate that combined prompt and text length doesn't exceed maximum"""
prompt_length = len(prompt) if prompt else 0
if text:
combined_length = prompt_length + len(text)
if combined_length > MAX_TEXT_LENGTH:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": f"Combined prompt and text length ({combined_length}) exceeds maximum length of {MAX_TEXT_LENGTH}",
"details": {
"prompt_length": prompt_length,
"text_length": len(text),
"max_length": MAX_TEXT_LENGTH
}
}
}
)
if texts:
for i, txt in enumerate(texts):
combined_length = prompt_length + len(txt)
if combined_length > MAX_TEXT_LENGTH:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": {
"code": "INVALID_INPUT",
"message": f"Combined prompt and texts[{i}] length ({combined_length}) exceeds maximum length of {MAX_TEXT_LENGTH}",
"details": {
"prompt_length": prompt_length,
"text_index": i,
"text_length": len(txt),
"max_length": MAX_TEXT_LENGTH
}
}
}
)
def handle_error(e: Exception): def handle_error(e: Exception):
"""Convert exceptions to consistent error format""" """Convert exceptions to consistent error format"""
if isinstance(e, HTTPException): if isinstance(e, HTTPException):

@ -126,6 +126,7 @@ async def general_extraction(request: GeneralRequest):
""" """
try: try:
input_data = validate_single_or_batch(request.text, request.texts) input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list) is_batch = isinstance(input_data, list)
# Process batch or single # Process batch or single
@ -190,6 +191,7 @@ async def relation_extraction(request: RelationExtractionRequest):
""" """
try: try:
input_data = validate_single_or_batch(request.text, request.texts) input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list) is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data] texts_to_process = input_data if is_batch else [input_data]
@ -383,6 +385,7 @@ async def summarization(request: SummarizationRequest):
""" """
try: try:
input_data = validate_single_or_batch(request.text, request.texts) input_data = validate_single_or_batch(request.text, request.texts)
validate_prompt_and_text_length(request.prompt, request.text, request.texts)
is_batch = isinstance(input_data, list) is_batch = isinstance(input_data, list)
texts_to_process = input_data if is_batch else [input_data] texts_to_process = input_data if is_batch else [input_data]

Loading…
Cancel
Save