diff --git a/gliner_inference_server/main.py b/gliner_inference_server/main.py index 0de10e4..2790c2b 100644 --- a/gliner_inference_server/main.py +++ b/gliner_inference_server/main.py @@ -1,4 +1,5 @@ from fastapi import FastAPI, HTTPException, status +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from starlette.requests import Request from typing import List, Optional, Union, Dict, Any, Tuple @@ -71,6 +72,36 @@ async def http_exception_handler(request: Request, exc: HTTPException): content=response_body ) + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """ + Handles Pydantic validation errors (normally 422) and returns them as 400 Bad Request + with a consistent error structure. + + Return first validation error. + """ + # Extract validation error details + errors = exc.errors() + + # Get first error: + location = " -> ".join(str(l) for l in errors[0]["loc"]) + error_message = f"{location}: {errors[0]['msg']}" + + response_body = { + "error": { + "code": "INVALID_INPUT", + "message": error_message, + "details": {} + } + } + + return JSONResponse( + # Changed from 422 to 400 + status_code=400, + content=response_body + ) + # ==================== Endpoints ==================== @app.get("/health") async def health_check(): diff --git a/tests/test_general_endpoint.py b/tests/test_general_endpoint.py index a72acfa..cf31ddc 100644 --- a/tests/test_general_endpoint.py +++ b/tests/test_general_endpoint.py @@ -89,7 +89,7 @@ class TestGeneralEndpoint: "entities": [] } response = api_client.post(f"{BASE_URL}/general", json=payload) - assert response.status_code == 422 # Pydantic validation error + assert response.status_code == 400 def test_error_empty_texts_list(self, api_client): """Test error with an empty texts list""" @@ -109,7 +109,7 @@ class TestGeneralEndpoint: "threshold": threshold } response = api_client.post(f"{BASE_URL}/general", json=payload) - assert response.status_code == 422 + assert response.status_code == 400 @pytest.mark.parametrize("threshold", [0.0, 0.5, 1.0]) def test_valid_threshold_range(self, api_client, threshold): diff --git a/tests/test_relation_extraction_endpoint.py b/tests/test_relation_extraction_endpoint.py index 2e7eb88..52932bd 100644 --- a/tests/test_relation_extraction_endpoint.py +++ b/tests/test_relation_extraction_endpoint.py @@ -73,7 +73,7 @@ class TestRelationExtractionEndpoint: "relations": [] } response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload) - assert response.status_code == 422 + assert response.status_code == 400 def test_error_both_text_and_texts(self, api_client): """Test error when both text and texts are present"""