Add tests for general endpoint.

main
KKlochko 1 month ago
parent 487e5d00a1
commit ac34ef0740
Signed by: KKlochko
GPG Key ID: 572ECCD219BBA91B

@ -0,0 +1,123 @@
import pytest
import requests
from typing import Dict, Any
from .conftest import BASE_URL
class TestGeneralEndpoint:
"""Tests the /general endpoint for Named Entity Recognition"""
def test_single_text(self, api_client):
"""Test with a single text"""
payload = {
"text": "Apple Inc. was founded by Steve Jobs in Cupertino, California.",
"entities": ["organization", "person", "location"],
"threshold": 0.5
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 200
data = response.json()
assert "outputs" in data
assert isinstance(data["outputs"], list)
def test_batch_texts(self, api_client):
"""Test with a batch of texts"""
payload = {
"texts": [
"Apple Inc. was founded by Steve Jobs in California.",
"Microsoft is located in Redmond, Washington.",
"Elon Musk leads Tesla and SpaceX."
],
"entities": ["organization", "person", "location"],
"threshold": 0.5
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 200
data = response.json()
assert "outputs" in data
assert isinstance(data["outputs"], list)
assert len(data["outputs"]) == 3
# Check the structure of each result
for result in data["outputs"]:
assert isinstance(result, list)
def test_with_optional_prompt(self, api_client):
"""Test with an optional prompt"""
payload = {
"text": "Google was founded in 1998.",
"entities": ["organization", "date"],
"prompt": "Extract entities from the text",
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 200
def test_error_both_text_and_texts(self, api_client):
"""Test error when both text and texts are present"""
payload = {
"text": "Some text",
"texts": ["Another text"],
"entities": ["organization"]
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 400
error = response.json()
assert "error" in error
assert error["error"]["code"] == "INVALID_INPUT"
assert "both" in error["error"]["message"].lower()
def test_error_no_text_provided(self, api_client):
"""Test error when neither text nor texts is provided"""
payload = {
"entities": ["organization"]
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 400
error = response.json()
assert "error" in error
assert error["error"]["code"] == "INVALID_INPUT"
def test_error_empty_entities_list(self, api_client):
"""Test error with an empty entities list"""
payload = {
"text": "Some text",
"entities": []
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 422 # Pydantic validation error
def test_error_empty_texts_list(self, api_client):
"""Test error with an empty texts list"""
payload = {
"texts": [],
"entities": ["organization"]
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 400
@pytest.mark.parametrize("threshold", [-0.1, 1.5, 2.0])
def test_error_invalid_threshold(self, api_client, threshold):
"""Test error with an invalid threshold"""
payload = {
"text": "Some text",
"entities": ["organization"],
"threshold": threshold
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 422
@pytest.mark.parametrize("threshold", [0.0, 0.5, 1.0])
def test_valid_threshold_range(self, api_client, threshold):
"""Test valid threshold values"""
payload = {
"text": "Apple Inc.",
"entities": ["organization"],
"threshold": threshold
}
response = api_client.post(f"{BASE_URL}/general", json=payload)
assert response.status_code == 200
Loading…
Cancel
Save