import pytest import requests from typing import Dict, Any from .conftest import BASE_URL class TestGeneralEndpoint: """Tests the /general endpoint for Named Entity Recognition""" def test_single_text(self, api_client): """Test with a single text""" payload = { "text": "Apple Inc. was founded by Steve Jobs in Cupertino, California.", "entities": ["organization", "person", "location"], "threshold": 0.5 } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 200 data = response.json() assert "outputs" in data assert isinstance(data["outputs"], list) def test_batch_texts(self, api_client): """Test with a batch of texts""" payload = { "texts": [ "Apple Inc. was founded by Steve Jobs in California.", "Microsoft is located in Redmond, Washington.", "Elon Musk leads Tesla and SpaceX." ], "entities": ["organization", "person", "location"], "threshold": 0.5 } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 200 data = response.json() assert "outputs" in data assert isinstance(data["outputs"], list) assert len(data["outputs"]) == 3 # Check the structure of each result for result in data["outputs"]: assert isinstance(result, list) def test_with_optional_prompt(self, api_client): """Test with an optional prompt""" payload = { "text": "Google was founded in 1998.", "entities": ["organization", "date"], "prompt": "Extract entities from the text", "threshold": 0.3 } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 200 def test_error_both_text_and_texts(self, api_client): """Test error when both text and texts are present""" payload = { "text": "Some text", "texts": ["Another text"], "entities": ["organization"] } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 400 error = response.json() assert "error" in error assert error["error"]["code"] == "INVALID_INPUT" assert "both" in error["error"]["message"].lower() def test_error_no_text_provided(self, api_client): """Test error when neither text nor texts is provided""" payload = { "entities": ["organization"] } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 400 error = response.json() assert "error" in error assert error["error"]["code"] == "INVALID_INPUT" def test_error_empty_entities_list(self, api_client): """Test error with an empty entities list""" payload = { "text": "Some text", "entities": [] } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 400 def test_error_empty_texts_list(self, api_client): """Test error with an empty texts list""" payload = { "texts": [], "entities": ["organization"] } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 400 @pytest.mark.parametrize("threshold", [-0.1, 1.5, 2.0]) def test_error_invalid_threshold(self, api_client, threshold): """Test error with an invalid threshold""" payload = { "text": "Some text", "entities": ["organization"], "threshold": threshold } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 400 @pytest.mark.parametrize("threshold", [0.0, 0.5, 1.0]) def test_valid_threshold_range(self, api_client, threshold): """Test valid threshold values""" payload = { "text": "Apple Inc.", "entities": ["organization"], "threshold": threshold } response = api_client.post(f"{BASE_URL}/general", json=payload) assert response.status_code == 200