You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

148 lines
5.1 KiB

import pytest
import requests
from typing import Dict, Any
from .conftest import BASE_URL
class TestRelationExtractionEndpoint:
"""Tests the /relation-extraction endpoint"""
def test_single_text_with_entities(self, api_client):
"""Test with a single text and entities"""
payload = {
"text": "Steve Jobs founded Apple Inc. which is located in Cupertino.",
"relations": ["founded", "located_in"],
"entities": ["person", "organization", "location"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
data = response.json()
assert "outputs" in data
assert isinstance(data["outputs"], list)
def test_single_text_without_entities(self, api_client):
"""Test without explicit entities"""
payload = {
"text": "Bill Gates founded Microsoft.",
"relations": ["founded"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
def test_batch_texts(self, api_client):
"""Test with a batch of texts"""
payload = {
"texts": [
"Bill Gates founded Microsoft.",
"Tesla was founded by Elon Musk.",
"Apple is based in California."
],
"relations": ["founded", "founded_by", "based_in"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
data = response.json()
assert len(data["outputs"]) == 3
def test_with_optional_rules(self, api_client):
"""Test with optional rules"""
payload = {
"text": "Steve Jobs founded Apple Inc.",
"relations": ["founded"],
"rules": [
{
"relation": "founded",
"pairs_filter": [["person", "organization"]],
"distance": 10
}
],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
def test_error_empty_relations(self, api_client):
"""Test error with an empty relations list"""
payload = {
"text": "Some text",
"relations": []
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 400
def test_error_both_text_and_texts(self, api_client):
"""Test error when both text and texts are present"""
payload = {
"text": "Some text",
"texts": ["Another text"],
"relations": ["founded"]
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 400
def test_error_no_text_provided(self, api_client):
"""Test error when neither text nor texts is provided"""
payload = {
"text": "",
"relations": ["founded"]
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 400
error = response.json()
assert "error" in error
assert error["error"]["code"] == "INVALID_INPUT"
def test_error_empty_text_provided(self, api_client):
"""Test error when the text is empty"""
payload = {
"relations": ["founded"]
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 400
error = response.json()
assert "error" in error
assert error["error"]["code"] == "INVALID_INPUT"
def test_error_texts_has_empty_value(self, api_client):
"""Test error when texts has an empty value"""
payload = {
"texts": [
"Bill Gates founded Microsoft.",
"",
"Apple is based in California."
],
"relations": ["founded", "founded_by", "based_in"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 400
error = response.json()
assert "error" in error
assert error["error"]["code"] == "INVALID_INPUT"
def test_error_texts_has_empty_values(self, api_client):
"""Test error when texts has empty values"""
payload = {
"texts": [
"",
"",
""
],
"relations": ["founded", "founded_by", "based_in"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 400
error = response.json()
assert "error" in error
assert error["error"]["code"] == "INVALID_INPUT"