diff --git a/tests/test_general_endpoint.py b/tests/test_general_endpoint.py new file mode 100644 index 0000000..a72acfa --- /dev/null +++ b/tests/test_general_endpoint.py @@ -0,0 +1,123 @@ +import pytest +import requests +from typing import Dict, Any +from .conftest import BASE_URL + + +class TestGeneralEndpoint: + """Tests the /general endpoint for Named Entity Recognition""" + + def test_single_text(self, api_client): + """Test with a single text""" + payload = { + "text": "Apple Inc. was founded by Steve Jobs in Cupertino, California.", + "entities": ["organization", "person", "location"], + "threshold": 0.5 + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 200 + + data = response.json() + assert "outputs" in data + assert isinstance(data["outputs"], list) + + def test_batch_texts(self, api_client): + """Test with a batch of texts""" + payload = { + "texts": [ + "Apple Inc. was founded by Steve Jobs in California.", + "Microsoft is located in Redmond, Washington.", + "Elon Musk leads Tesla and SpaceX." + ], + "entities": ["organization", "person", "location"], + "threshold": 0.5 + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 200 + + data = response.json() + assert "outputs" in data + assert isinstance(data["outputs"], list) + assert len(data["outputs"]) == 3 + + # Check the structure of each result + for result in data["outputs"]: + assert isinstance(result, list) + + def test_with_optional_prompt(self, api_client): + """Test with an optional prompt""" + payload = { + "text": "Google was founded in 1998.", + "entities": ["organization", "date"], + "prompt": "Extract entities from the text", + "threshold": 0.3 + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 200 + + def test_error_both_text_and_texts(self, api_client): + """Test error when both text and texts are present""" + payload = { + "text": "Some text", + "texts": ["Another text"], + "entities": ["organization"] + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 400 + + error = response.json() + assert "error" in error + assert error["error"]["code"] == "INVALID_INPUT" + assert "both" in error["error"]["message"].lower() + + def test_error_no_text_provided(self, api_client): + """Test error when neither text nor texts is provided""" + payload = { + "entities": ["organization"] + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 400 + + error = response.json() + assert "error" in error + assert error["error"]["code"] == "INVALID_INPUT" + + def test_error_empty_entities_list(self, api_client): + """Test error with an empty entities list""" + payload = { + "text": "Some text", + "entities": [] + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 422 # Pydantic validation error + + def test_error_empty_texts_list(self, api_client): + """Test error with an empty texts list""" + payload = { + "texts": [], + "entities": ["organization"] + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 400 + + @pytest.mark.parametrize("threshold", [-0.1, 1.5, 2.0]) + def test_error_invalid_threshold(self, api_client, threshold): + """Test error with an invalid threshold""" + payload = { + "text": "Some text", + "entities": ["organization"], + "threshold": threshold + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 422 + + @pytest.mark.parametrize("threshold", [0.0, 0.5, 1.0]) + def test_valid_threshold_range(self, api_client, threshold): + """Test valid threshold values""" + payload = { + "text": "Apple Inc.", + "entities": ["organization"], + "threshold": threshold + } + response = api_client.post(f"{BASE_URL}/general", json=payload) + assert response.status_code == 200