diff --git a/gliner_inference_server/helpers.py b/gliner_inference_server/helpers.py index 0eaf3a6..40d0f87 100644 --- a/gliner_inference_server/helpers.py +++ b/gliner_inference_server/helpers.py @@ -1,5 +1,9 @@ from fastapi import HTTPException, status from typing import List, Optional, Union, Dict, Any, Tuple +import os + + +MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "1000")) def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]): @@ -69,6 +73,48 @@ def validate_single_or_batch(text: Optional[str], texts: Optional[List[str]]): return text if text is not None else texts +def validate_prompt_and_text_length(prompt: Optional[str], text: Optional[str], texts: Optional[List[str]]): + """Validate that combined prompt and text length doesn't exceed maximum""" + prompt_length = len(prompt) if prompt else 0 + + if text: + combined_length = prompt_length + len(text) + if combined_length > MAX_TEXT_LENGTH: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": { + "code": "INVALID_INPUT", + "message": f"Combined prompt and text length ({combined_length}) exceeds maximum length of {MAX_TEXT_LENGTH}", + "details": { + "prompt_length": prompt_length, + "text_length": len(text), + "max_length": MAX_TEXT_LENGTH + } + } + } + ) + + if texts: + for i, txt in enumerate(texts): + combined_length = prompt_length + len(txt) + if combined_length > MAX_TEXT_LENGTH: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": { + "code": "INVALID_INPUT", + "message": f"Combined prompt and texts[{i}] length ({combined_length}) exceeds maximum length of {MAX_TEXT_LENGTH}", + "details": { + "prompt_length": prompt_length, + "text_index": i, + "text_length": len(txt), + "max_length": MAX_TEXT_LENGTH + } + } + } + ) + def handle_error(e: Exception): """Convert exceptions to consistent error format""" if isinstance(e, HTTPException): diff --git a/gliner_inference_server/main.py b/gliner_inference_server/main.py index fcac631..eb412f1 100644 --- a/gliner_inference_server/main.py +++ b/gliner_inference_server/main.py @@ -126,6 +126,7 @@ async def general_extraction(request: GeneralRequest): """ try: input_data = validate_single_or_batch(request.text, request.texts) + validate_prompt_and_text_length(request.prompt, request.text, request.texts) is_batch = isinstance(input_data, list) # Process batch or single @@ -190,6 +191,7 @@ async def relation_extraction(request: RelationExtractionRequest): """ try: input_data = validate_single_or_batch(request.text, request.texts) + validate_prompt_and_text_length(request.prompt, request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data] @@ -383,6 +385,7 @@ async def summarization(request: SummarizationRequest): """ try: input_data = validate_single_or_batch(request.text, request.texts) + validate_prompt_and_text_length(request.prompt, request.text, request.texts) is_batch = isinstance(input_data, list) texts_to_process = input_data if is_batch else [input_data]