You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
4.3 KiB
124 lines
4.3 KiB
import pytest
|
|
import requests
|
|
from typing import Dict, Any
|
|
from .conftest import BASE_URL
|
|
|
|
|
|
class TestGeneralEndpoint:
|
|
"""Tests the /general endpoint for Named Entity Recognition"""
|
|
|
|
def test_single_text(self, api_client):
|
|
"""Test with a single text"""
|
|
payload = {
|
|
"text": "Apple Inc. was founded by Steve Jobs in Cupertino, California.",
|
|
"entities": ["organization", "person", "location"],
|
|
"threshold": 0.5
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "outputs" in data
|
|
assert isinstance(data["outputs"], list)
|
|
|
|
def test_batch_texts(self, api_client):
|
|
"""Test with a batch of texts"""
|
|
payload = {
|
|
"texts": [
|
|
"Apple Inc. was founded by Steve Jobs in California.",
|
|
"Microsoft is located in Redmond, Washington.",
|
|
"Elon Musk leads Tesla and SpaceX."
|
|
],
|
|
"entities": ["organization", "person", "location"],
|
|
"threshold": 0.5
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "outputs" in data
|
|
assert isinstance(data["outputs"], list)
|
|
assert len(data["outputs"]) == 3
|
|
|
|
# Check the structure of each result
|
|
for result in data["outputs"]:
|
|
assert isinstance(result, list)
|
|
|
|
def test_with_optional_prompt(self, api_client):
|
|
"""Test with an optional prompt"""
|
|
payload = {
|
|
"text": "Google was founded in 1998.",
|
|
"entities": ["organization", "date"],
|
|
"prompt": "Extract entities from the text",
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 200
|
|
|
|
def test_error_both_text_and_texts(self, api_client):
|
|
"""Test error when both text and texts are present"""
|
|
payload = {
|
|
"text": "Some text",
|
|
"texts": ["Another text"],
|
|
"entities": ["organization"]
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 400
|
|
|
|
error = response.json()
|
|
assert "error" in error
|
|
assert error["error"]["code"] == "INVALID_INPUT"
|
|
assert "both" in error["error"]["message"].lower()
|
|
|
|
def test_error_no_text_provided(self, api_client):
|
|
"""Test error when neither text nor texts is provided"""
|
|
payload = {
|
|
"entities": ["organization"]
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 400
|
|
|
|
error = response.json()
|
|
assert "error" in error
|
|
assert error["error"]["code"] == "INVALID_INPUT"
|
|
|
|
def test_error_empty_entities_list(self, api_client):
|
|
"""Test error with an empty entities list"""
|
|
payload = {
|
|
"text": "Some text",
|
|
"entities": []
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 400
|
|
|
|
def test_error_empty_texts_list(self, api_client):
|
|
"""Test error with an empty texts list"""
|
|
payload = {
|
|
"texts": [],
|
|
"entities": ["organization"]
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 400
|
|
|
|
@pytest.mark.parametrize("threshold", [-0.1, 1.5, 2.0])
|
|
def test_error_invalid_threshold(self, api_client, threshold):
|
|
"""Test error with an invalid threshold"""
|
|
payload = {
|
|
"text": "Some text",
|
|
"entities": ["organization"],
|
|
"threshold": threshold
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 400
|
|
|
|
@pytest.mark.parametrize("threshold", [0.0, 0.5, 1.0])
|
|
def test_valid_threshold_range(self, api_client, threshold):
|
|
"""Test valid threshold values"""
|
|
payload = {
|
|
"text": "Apple Inc.",
|
|
"entities": ["organization"],
|
|
"threshold": threshold
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
assert response.status_code == 200
|