You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
3.0 KiB

import pytest
import requests
from typing import Dict, Any
from .conftest import BASE_URL
class TestRelationExtractionEndpoint:
"""Tests the /relation-extraction endpoint"""
def test_single_text_with_entities(self, api_client):
"""Test with a single text and entities"""
payload = {
"text": "Steve Jobs founded Apple Inc. which is located in Cupertino.",
"relations": ["founded", "located_in"],
"entities": ["person", "organization", "location"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
data = response.json()
assert "outputs" in data
assert isinstance(data["outputs"], list)
def test_single_text_without_entities(self, api_client):
"""Test without explicit entities"""
payload = {
"text": "Bill Gates founded Microsoft.",
"relations": ["founded"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
def test_batch_texts(self, api_client):
"""Test with a batch of texts"""
payload = {
"texts": [
"Bill Gates founded Microsoft.",
"Tesla was founded by Elon Musk.",
"Apple is based in California."
],
"relations": ["founded", "founded_by", "based_in"],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
data = response.json()
assert len(data["outputs"]) == 3
def test_with_optional_rules(self, api_client):
"""Test with optional rules"""
payload = {
"text": "Steve Jobs founded Apple Inc.",
"relations": ["founded"],
"rules": [
{
"relation": "founded",
"pairs_filter": [["person", "organization"]],
"distance": 10
}
],
"threshold": 0.3
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 200
def test_error_empty_relations(self, api_client):
"""Test error with an empty relations list"""
payload = {
"text": "Some text",
"relations": []
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 422
def test_error_both_text_and_texts(self, api_client):
"""Test error when both text and texts are present"""
payload = {
"text": "Some text",
"texts": ["Another text"],
"relations": ["founded"]
}
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
assert response.status_code == 400